Abstract:Neural network representations of simple models, such as linear regression, are being studied increasingly to better understand the underlying principles of deep learning algorithms. However, neural representations of distributional regression models, such as the Cox model, have received little attention so far. We close this gap by proposing a framework for distributional regression using inverse flow transformations (DRIFT), which includes neural representations of the aforementioned models. We empirically demonstrate that the neural representations of models in DRIFT can serve as a substitute for their classical statistical counterparts in several applications involving continuous, ordered, time-series, and survival outcomes. We confirm that models in DRIFT empirically match the performance of several statistical methods in terms of estimation of partial effects, prediction, and aleatoric uncertainty quantification. DRIFT covers both interpretable statistical models and flexible neural networks opening up new avenues in both statistical modeling and deep learning.
Abstract:Survival Analysis provides critical insights for partially incomplete time-to-event data in various domains. It is also an important example of probabilistic machine learning. The probabilistic nature of the predictions can be exploited by using (proper) scoring rules in the model fitting process instead of likelihood-based optimization. Our proposal does so in a generic manner and can be used for a variety of model classes. We establish different parametric and non-parametric sub-frameworks that allow different degrees of flexibility. Incorporated into neural networks, it leads to a computationally efficient and scalable optimization routine, yielding state-of-the-art predictive performance. Finally, we show that using our framework, we can recover various parametric models and demonstrate that optimization works equally well when compared to likelihood-based methods.
Abstract:The influx of deep learning (DL) techniques into the field of survival analysis in recent years, coupled with the increasing availability of high-dimensional omics data and unstructured data like images or text, has led to substantial methodological progress; for instance, learning from such high-dimensional or unstructured data. Numerous modern DL-based survival methods have been developed since the mid-2010s; however, they often address only a small subset of scenarios in the time-to-event data setting - e.g., single-risk right-censored survival tasks - and neglect to incorporate more complex (and common) settings. Partially, this is due to a lack of exchange between experts in the respective fields. In this work, we provide a comprehensive systematic review of DL-based methods for time-to-event analysis, characterizing them according to both survival- and DL-related attributes. In doing so, we hope to provide a helpful overview to practitioners who are interested in DL techniques applicable to their specific use case as well as to enable researchers from both fields to identify directions for future investigation. We provide a detailed characterization of the methods included in this review as an open-source, interactive table: https://survival-org.github.io/DL4Survival. As this research area is advancing rapidly, we encourage the research community to contribute to keeping the information up to date.
Abstract:Survival analysis (SA) is an active field of research that is concerned with time-to-event outcomes and is prevalent in many domains, particularly biomedical applications. Despite its importance, SA remains challenging due to small-scale data sets and complex outcome distributions, concealed by truncation and censoring processes. The piecewise exponential additive mixed model (PAMM) is a model class addressing many of these challenges, yet PAMMs are not applicable in high-dimensional feature settings or in the case of unstructured or multimodal data. We unify existing approaches by proposing DeepPAMM, a versatile deep learning framework that is well-founded from a statistical point of view, yet with enough flexibility for modeling complex hazard structures. We illustrate that DeepPAMM is competitive with other machine learning approaches with respect to predictive performance while maintaining interpretability through benchmark experiments and an extended case study.
Abstract:This paper describes the implementation of semi-structured deep distributional regression, a flexible framework to learn distributions based on a combination of additive regression models and deep neural networks. deepregression is implemented in both R and Python, using the deep learning libraries TensorFlow and PyTorch, respectively. The implementation consists of (1) a modular neural network building system for the combination of various statistical and deep learning approaches, (2) an orthogonalization cell to allow for an interpretable combination of different subnetworks as well as (3) pre-processing steps necessary to initialize such models. The software package allows to define models in a user-friendly manner using distribution definitions via a formula environment that is inspired by classical statistical model frameworks such as mgcv. The packages' modular design and functionality provides a unique resource for rapid and reproducible prototyping of complex statistical and deep learning models while simultaneously retaining the indispensable interpretability of classical statistical models.
Abstract:We propose a versatile framework for survival analysis that combines advanced concepts from statistics with deep learning. The presented framework is based on piecewise exponential models and thereby supports various survival tasks, such as competing risks and multi-state modeling, and further allows for estimation of time-varying effects and time-varying features. To also include multiple data sources and higher-order interaction effects into the model, we embed the model class in a neural network and thereby enable the simultaneous estimation of both inherently interpretable structured regression inputs as well as deep neural network components which can potentially process additional unstructured data sources. A proof of concept is provided by using the framework to predict Alzheimer's disease progression based on tabular and 3D point cloud data and applying it to synthetic data.