Abstract:BlackJAX is a library implementing sampling and variational inference algorithms commonly used in Bayesian computation. It is designed for ease of use, speed, and modularity by taking a functional approach to the algorithms' implementation. BlackJAX is written in Python, using JAX to compile and run NumpPy-like samplers and variational methods on CPUs, GPUs, and TPUs. The library integrates well with probabilistic programming languages by working directly with the (un-normalized) target log density function. BlackJAX is intended as a collection of low-level, composable implementations of basic statistical 'atoms' that can be combined to perform well-defined Bayesian inference, but also provides high-level routines for ease of use. It is designed for users who need cutting-edge methods, researchers who want to create complex sampling methods, and people who want to learn how these work.
Abstract:Parameter inference for ordinary differential equations (ODEs) is of fundamental importance in many scientific applications. While ODE solutions are typically approximated by deterministic algorithms, new research on probabilistic solvers indicates that they produce more reliable parameter estimates by better accounting for numerical errors. However, many ODE systems are highly sensitive to their parameter values. This produces deep local minima in the likelihood function -- a problem which existing probabilistic solvers have yet to resolve. Here, we show that a Bayesian filtering paradigm for probabilistic ODE solution can dramatically reduce sensitivity to parameters by learning from the noisy ODE observations in a data-adaptive manner. Our method is applicable to ODEs with partially unobserved components and with arbitrary non-Gaussian noise. Several examples demonstrate that it is more accurate than existing probabilistic ODE solvers, and even in some cases than the exact ODE likelihood.