Abstract:BlackJAX is a library implementing sampling and variational inference algorithms commonly used in Bayesian computation. It is designed for ease of use, speed, and modularity by taking a functional approach to the algorithms' implementation. BlackJAX is written in Python, using JAX to compile and run NumpPy-like samplers and variational methods on CPUs, GPUs, and TPUs. The library integrates well with probabilistic programming languages by working directly with the (un-normalized) target log density function. BlackJAX is intended as a collection of low-level, composable implementations of basic statistical 'atoms' that can be combined to perform well-defined Bayesian inference, but also provides high-level routines for ease of use. It is designed for users who need cutting-edge methods, researchers who want to create complex sampling methods, and people who want to learn how these work.
Abstract:The success of deep learning models deployed in the real world depends critically on their ability to generalize well across diverse data domains. Here, we address a fundamental challenge with selective classification during automated diagnosis with domain-shifted medical images. In this scenario, models must learn to avoid making predictions when label confidence is low, especially when tested with samples far removed from the training set (covariate shift). Such uncertain cases are typically referred to the clinician for further analysis and evaluation. Yet, we show that even state-of-the-art domain generalization approaches fail severely during referral when tested on medical images acquired from a different demographic or using a different technology. We examine two benchmark diagnostic medical imaging datasets exhibiting strong covariate shifts: i) diabetic retinopathy prediction with retinal fundus images and ii) multilabel disease prediction with chest X-ray images. We show that predictive uncertainty estimates do not generalize well under covariate shifts leading to non-monotonic referral curves, and severe drops in performance (up to 50%) at high referral rates (>70%). We evaluate novel combinations of robust generalization and post hoc referral approaches, that rescue these failures and achieve significant performance improvements, typically >10%, over baseline methods. Our study identifies a critical challenge with referral in domain-shifted medical images and finds key applications in reliable, automated disease diagnosis.