Abstract:Continuous normalizing flows (CNFs) learn the probability path between a reference and a target density by modeling the vector field generating said path using neural networks. Recently, Lipman et al. (2022) introduced a simple and inexpensive method for training CNFs in generative modeling, termed flow matching (FM). In this paper, we re-purpose this method for probabilistic inference by incorporating Markovian sampling methods in evaluating the FM objective and using the learned probability path to improve Monte Carlo sampling. We propose a sequential method, which uses samples from a Markov chain to fix the probability path defining the FM objective. We augment this scheme with an adaptive tempering mechanism that allows the discovery of multiple modes in the target. Under mild assumptions, we establish convergence to a local optimum of the FM objective, discuss improvements in the convergence rate, and illustrate our methods on synthetic and real-world examples.
Abstract:BlackJAX is a library implementing sampling and variational inference algorithms commonly used in Bayesian computation. It is designed for ease of use, speed, and modularity by taking a functional approach to the algorithms' implementation. BlackJAX is written in Python, using JAX to compile and run NumpPy-like samplers and variational methods on CPUs, GPUs, and TPUs. The library integrates well with probabilistic programming languages by working directly with the (un-normalized) target log density function. BlackJAX is intended as a collection of low-level, composable implementations of basic statistical 'atoms' that can be combined to perform well-defined Bayesian inference, but also provides high-level routines for ease of use. It is designed for users who need cutting-edge methods, researchers who want to create complex sampling methods, and people who want to learn how these work.