Abstract:Conventional nonlinear RNNs are not naturally parallelizable across the sequence length, whereas transformers and linear RNNs are. Lim et al. [2024] therefore tackle parallelized evaluation of nonlinear RNNs by posing it as a fixed point problem, solved with Newton's method. By deriving and applying a parallelized form of Newton's method, they achieve huge speedups over sequential evaluation. However, their approach inherits cubic computational complexity and numerical instability. We tackle these weaknesses. To reduce the computational complexity, we apply quasi-Newton approximations and show they converge comparably to full-Newton, use less memory, and are faster. To stabilize Newton's method, we leverage a connection between Newton's method damped with trust regions and Kalman smoothing. This connection allows us to stabilize Newtons method, per the trust region, while using efficient parallelized Kalman algorithms to retain performance. We compare these methods empirically, and highlight the use cases where each algorithm excels.
Abstract:An important problem in time-series analysis is modeling systems with time-varying dynamics. Probabilistic models with joint continuous and discrete latent states offer interpretable, efficient, and experimentally useful descriptions of such data. Commonly used models include autoregressive hidden Markov models (ARHMMs) and switching linear dynamical systems (SLDSs), each with its own advantages and disadvantages. ARHMMs permit exact inference and easy parameter estimation, but are parameter intensive when modeling long dependencies, and hence are prone to overfitting. In contrast, SLDSs can capture long-range dependencies in a parameter efficient way through Markovian latent dynamics, but present an intractable likelihood and a challenging parameter estimation task. In this paper, we propose switching autoregressive low-rank tensor (SALT) models, which retain the advantages of both approaches while ameliorating the weaknesses. SALT parameterizes the tensor of an ARHMM with a low-rank factorization to control the number of parameters and allow longer range dependencies without overfitting. We prove theoretical and discuss practical connections between SALT, linear dynamical systems, and SLDSs. We empirically demonstrate quantitative advantages of SALT models on a range of simulated and real prediction tasks, including behavioral and neural datasets. Furthermore, the learned low-rank tensor provides novel insights into temporal dependencies within each discrete state.
Abstract:Efficiently modeling long-range dependencies is an important goal in sequence modeling. Recently, models using structured state space sequence (S4) layers achieved state-of-the-art performance on many long-range tasks. The S4 layer combines linear state space models (SSMs) with deep learning techniques and leverages the HiPPO framework for online function approximation to achieve high performance. However, this framework led to architectural constraints and computational difficulties that make the S4 approach complicated to understand and implement. We revisit the idea that closely following the HiPPO framework is necessary for high performance. Specifically, we replace the bank of many independent single-input, single-output (SISO) SSMs the S4 layer uses with one multi-input, multi-output (MIMO) SSM with a reduced latent dimension. The reduced latent dimension of the MIMO system allows for the use of efficient parallel scans which simplify the computations required to apply the S5 layer as a sequence-to-sequence transformation. In addition, we initialize the state matrix of the S5 SSM with an approximation to the HiPPO-LegS matrix used by S4's SSMs and show that this serves as an effective initialization for the MIMO setting. S5 matches S4's performance on long-range tasks, including achieving an average of 82.46% on the suite of Long Range Arena benchmarks compared to S4's 80.48% and the best transformer variant's 61.41%.
Abstract:Sequential Monte Carlo (SMC) is an inference algorithm for state space models that approximates the posterior by sampling from a sequence of target distributions. The target distributions are often chosen to be the filtering distributions, but these ignore information from future observations, leading to practical and theoretical limitations in inference and model learning. We introduce SIXO, a method that instead learns targets that approximate the smoothing distributions, incorporating information from all observations. The key idea is to use density ratio estimation to fit functions that warp the filtering distributions into the smoothing distributions. We then use SMC with these learned targets to define a variational objective for model and proposal learning. SIXO yields provably tighter log marginal lower bounds and offers significantly more accurate posterior inferences and parameter estimates in a variety of domains.
Abstract:Policies for partially observed Markov decision processes can be efficiently learned by imitating policies for the corresponding fully observed Markov decision processes. Unfortunately, existing approaches for this kind of imitation learning have a serious flaw: the expert does not know what the trainee cannot see, and so may encourage actions that are sub-optimal, even unsafe, under partial information. We derive an objective to instead train the expert to maximize the expected reward of the imitating agent policy, and use it to construct an efficient algorithm, adaptive asymmetric DAgger (A2D), that jointly trains the expert and the agent. We show that A2D produces an expert policy that the agent can safely imitate, in turn outperforming policies learned by imitating a fixed expert.
Abstract:In this work we demonstrate how existing software tools can be used to automate parts of infectious disease-control policy-making via performing inference in existing epidemiological dynamics models. The kind of inference tasks undertaken include computing, for planning purposes, the posterior distribution over putatively controllable, via direct policy-making choices, simulation model parameters that give rise to acceptable disease progression outcomes. Neither the full capabilities of such inference automation software tools nor their utility for planning is widely disseminated at the current time. Timely gains in understanding about these tools and how they can be used may lead to more fine-grained and less economically damaging policy prescriptions, particularly during the current COVID-19 pandemic.
Abstract:Deterministic models are approximations of reality that are easy to interpret and often easier to build than stochastic alternatives. Unfortunately, as nature is capricious, observational data can never be fully explained by deterministic models in practice. Observation and process noise need to be added to adapt deterministic models to behave stochastically, such that they are capable of explaining and extrapolating from noisy data. We investigate and address computational inefficiencies that arise from adding process noise to deterministic simulators that fail to return for certain inputs; a property we describe as "brittle." We show how to train a conditional normalizing flow to propose perturbations such that the simulator succeeds with high probability, increasing computational efficiency.
Abstract:We develop a stochastic whole-brain and body simulator of the nematode roundworm Caenorhabditis elegans (C. elegans) and show that it is sufficiently regularizing to allow imputation of latent membrane potentials from partial calcium fluorescence imaging observations. This is the first attempt we know of to "complete the circle," where an anatomically grounded whole-connectome simulator is used to impute a time-varying "brain" state at single-cell fidelity from covariates that are measurable in practice. The sequential Monte Carlo (SMC) method we employ not only enables imputation of said latent states but also presents a strategy for learning simulator parameters via variational optimization of the noisy model evidence approximation provided by SMC. Our imputation and parameter estimation experiments were conducted on distributed systems using novel implementations of the aforementioned techniques applied to synthetic data of dimension and type representative of that which are measured in laboratories currently.
Abstract:Many problems in machine learning and statistics involve nested expectations and thus do not permit conventional Monte Carlo (MC) estimation. For such problems, one must nest estimators, such that terms in an outer estimator themselves involve calculation of a separate, nested, estimation. We investigate the statistical implications of nesting MC estimators, including cases of multiple levels of nesting, and establish the conditions under which they converge. We derive corresponding rates of convergence and provide empirical evidence that these rates are observed in practice. We further establish a number of pitfalls that can arise from naive nesting of MC estimators, provide guidelines about how these can be avoided, and lay out novel methods for reformulating certain classes of nested expectation problems into single expectations, leading to improved convergence rates. We demonstrate the applicability of our work by using our results to develop a new estimator for discrete Bayesian experimental design problems and derive error bounds for a class of variational objectives.
Abstract:We present an instance of the optimal sensor scheduling problem with the additional relaxation that our observer makes active choices whether or not to observe and how to observe. We mask the nodes in a directed acyclic graph of the model that are observable, effectively optimising whether or not an observation should be made at each time step. The reason for this is simple: it is prudent to seek to reduce sensor costs, since resources (e.g. hardware, personnel and time) are finite. Consequently, rather than treating our plant as if it had infinite sensing resources, we seek to jointly maximise the utility of each perception. This reduces resource expenditure by explicitly minimising an observation-associated cost (e.g. battery use) while also facilitating the potential to yield better state estimates by virtue of being able to use more perceptions in noisy or unpredictable regions of state-space (e.g. a busy traffic junction). We present a general formalisation and notation of this problem, capable of encompassing much of the prior art. To illustrate our formulation, we pose and solve two example problems in this domain. Finally we suggest active areas of research to improve and further generalise this approach.