Abstract:BlackJAX is a library implementing sampling and variational inference algorithms commonly used in Bayesian computation. It is designed for ease of use, speed, and modularity by taking a functional approach to the algorithms' implementation. BlackJAX is written in Python, using JAX to compile and run NumpPy-like samplers and variational methods on CPUs, GPUs, and TPUs. The library integrates well with probabilistic programming languages by working directly with the (un-normalized) target log density function. BlackJAX is intended as a collection of low-level, composable implementations of basic statistical 'atoms' that can be combined to perform well-defined Bayesian inference, but also provides high-level routines for ease of use. It is designed for users who need cutting-edge methods, researchers who want to create complex sampling methods, and people who want to learn how these work.
Abstract:Stochastic gradient Markov chain Monte Carlo (SGMCMC) is a popular class of algorithms for scalable Bayesian inference. However, these algorithms include hyperparameters such as step size or batch size that influence the accuracy of estimators based on the obtained samples. As a result, these hyperparameters must be tuned by the practitioner and currently no principled and automated way to tune them exists. Standard MCMC tuning methods based on acceptance rates cannot be used for SGMCMC, thus requiring alternative tools and diagnostics. We propose a novel bandit-based algorithm that tunes SGMCMC hyperparameters to maximize the accuracy of the posterior approximation by minimizing the kernel Stein discrepancy (KSD). We provide theoretical results supporting this approach and assess alternative metrics to KSD. We support our results with experiments on both simulated and real datasets, and find that this method is practical for a wide range of application areas.