Abstract:JAX is widely used in machine learning and scientific computing, the latter of which often relies on existing high-performance code that we would ideally like to incorporate into JAX. Reimplementing the existing code in JAX is often impractical and the existing interface in JAX for binding custom code requires deep knowledge of JAX and its C++ backend. The goal of JAXbind is to drastically reduce the effort required to bind custom functions implemented in other programming languages to JAX. Specifically, JAXbind provides an easy-to-use Python interface for defining custom so-called JAX primitives that support arbitrary JAX transformations.
Abstract:Imaging is the process of transforming noisy, incomplete data into a space that humans can interpret. NIFTy is a Bayesian framework for imaging and has already successfully been applied to many fields in astrophysics. Previous design decisions held the performance and the development of methods in NIFTy back. We present a rewrite of NIFTy, coined NIFTy.re, which reworks the modeling principle, extends the inference strategies, and outsources much of the heavy lifting to JAX. The rewrite dramatically accelerates models written in NIFTy, lays the foundation for new types of inference machineries, improves maintainability, and enables interoperability between NIFTy and the JAX machine learning ecosystem.
Abstract:Gaussian Processes (GPs) are highly expressive, probabilistic models. A major limitation is their computational complexity. Naively, exact GP inference requires $\mathcal{O}(N^3)$ computations with $N$ denoting the number of modeled points. Current approaches to overcome this limitation either rely on sparse, structured or stochastic representations of data or kernel respectively and usually involve nested optimizations to evaluate a GP. We present a new, generative method named Iterative Charted Refinement (ICR) to model GPs on nearly arbitrarily spaced points in $\mathcal{O}(N)$ time for decaying kernels without nested optimizations. ICR represents long- as well as short-range correlations by combining views of the modeled locations at varying resolutions with a user-provided coordinate chart. In our experiment with points whose spacings vary over two orders of magnitude, ICR's accuracy is comparable to state-of-the-art GP methods. ICR outperforms existing methods in terms of computational speed by one order of magnitude on the CPU and GPU and has already been successfully applied to model a GP with $122$ billion parameters.
Abstract:The viral load of patients infected with SARS-CoV-2 varies on logarithmic scales and possibly with age. Controversial claims have been made in the literature regarding whether the viral load distribution actually depends on the age of the patients. Such a dependence would have implications for the COVID-19 spreading mechanism, the age-dependent immune system reaction, and thus for policymaking. We hereby develop a method to analyze viral-load distribution data as a function of the patients' age within a flexible, non-parametric, hierarchical, Bayesian, and causal model. This method can be applied to other contexts as well, and for this purpose, it is made freely available. The developed reconstruction method also allows testing for bias in the data. This could be due to, e.g., bias in patient-testing and data collection or systematic errors in the measurement of the viral load. We perform these tests by calculating the Bayesian evidence for each implied possible causal direction. When applying these tests to publicly available age and SARS-CoV-2 viral load data, we find a statistically significant increase in the viral load with age, but only for one of the two analyzed datasets. If we consider this dataset, and based on the current understanding of viral load's impact on patients' infectivity, we expect a non-negligible difference in the infectivity of different age groups. This difference is nonetheless too small to justify considering any age group as noninfectious.