Abstract:This work suggests fundamentally rethinking the current practice of pruning large language models (LLMs). The way it is done is by divide and conquer: split the model into submodels, sequentially prune them, and reconstruct predictions of the dense counterparts on small calibration data one at a time; the final model is obtained simply by putting the resulting sparse submodels together. While this approach enables pruning under memory constraints, it generates high reconstruction errors. In this work, we first present an array of reconstruction techniques that can significantly reduce this error by more than $90\%$. Unwittingly, however, we discover that minimizing reconstruction error is not always ideal and can overfit the given calibration data, resulting in rather increased language perplexity and poor performance at downstream tasks. We find out that a strategy of self-generating calibration data can mitigate this trade-off between reconstruction and generalization, suggesting new directions in the presence of both benefits and pitfalls of reconstruction for pruning LLMs.
Abstract:Training an overparameterized neural network can yield minimizers of the same level of training loss and yet different generalization capabilities. With evidence that indicates a correlation between sharpness of minima and their generalization errors, increasing efforts have been made to develop an optimization method to explicitly find flat minima as more generalizable solutions. This sharpness-aware minimization (SAM) strategy, however, has not been studied much yet as to how overparameterization can actually affect its behavior. In this work, we analyze SAM under varying degrees of overparameterization and present both empirical and theoretical results that suggest a critical influence of overparameterization on SAM. Specifically, we first use standard techniques in optimization to prove that SAM can achieve a linear convergence rate under overparameterization in a stochastic setting. We also show that the linearly stable minima found by SAM are indeed flatter and have more uniformly distributed Hessian moments compared to those of SGD. These results are corroborated with our experiments that reveal a consistent trend that the generalization improvement made by SAM continues to increase as the model becomes more overparameterized. We further present that sparsity can open up an avenue for effective overparameterization in practice.
Abstract:In federated learning (FL), clients with limited resources can disrupt the training efficiency. A potential solution to this problem is to leverage a new learning procedure that does not rely on backpropagation (BP). We present a novel approach to FL called FedFwd that employs a recent BP-free method by Hinton (2022), namely the Forward Forward algorithm, in the local training process. FedFwd can reduce a significant amount of computations for updating parameters by performing layer-wise local updates, and therefore, there is no need to store all intermediate activation values during training. We conduct various experiments to evaluate FedFwd on standard datasets including MNIST and CIFAR-10, and show that it works competitively to other BP-dependent FL methods.
Abstract:This paper introduces JaxPruner, an open-source JAX-based pruning and sparse training library for machine learning research. JaxPruner aims to accelerate research on sparse neural networks by providing concise implementations of popular pruning and sparse training algorithms with minimal memory and latency overhead. Algorithms implemented in JaxPruner use a common API and work seamlessly with the popular optimization library Optax, which, in turn, enables easy integration with existing JAX based libraries. We demonstrate this ease of integration by providing examples in four different codebases: Scenic, t5x, Dopamine and FedJAX and provide baseline experiments on popular benchmarks.
Abstract:Concept bottleneck models (CBMs) are a class of interpretable neural network models that predict the target response of a given input based on its high-level concepts. Unlike the standard end-to-end models, CBMs enable domain experts to intervene on the predicted concepts and rectify any mistakes at test time, so that more accurate task predictions can be made at the end. While such intervenability provides a powerful avenue of control, many aspects of the intervention procedure remain rather unexplored. In this work, we develop various ways of selecting intervening concepts to improve the intervention effectiveness and conduct an array of in-depth analyses as to how they evolve under different circumstances. Specifically, we find that an informed intervention strategy can reduce the task error more than ten times compared to the current baseline under the same amount of intervention counts in realistic settings, and yet, this can vary quite significantly when taking into account different intervention granularity. We verify our findings through comprehensive evaluations, not only on the standard real datasets, but also on synthetic datasets that we generate based on a set of different causal graphs. We further discover some major pitfalls of the current practices which, without a proper addressing, raise concerns on reliability and fairness of the intervention procedure.
Abstract:Knowledge distillation is a popular and effective regularization technique for training lightweight models, but it also adds significant overhead to the training cost. The drawback is most pronounced when we use large-scale models as our teachers, such as vision transformers (ViTs). We present MaskedKD, a simple yet effective method for reducing the training cost of ViT distillation. MaskedKD masks a fraction of image patch tokens fed to the teacher to save the teacher inference cost. The tokens to mask are determined based on the last layer attention score of the student model, to which we provide the full image. Without requiring any architectural change of the teacher or making sacrifices in the student performance, MaskedKD dramatically reduces the computations and time required for distilling ViTs. We demonstrate that MaskedKD can save up to $50\%$ of the cost of running inference on the teacher model without any performance drop on the student, leading to approximately $28\%$ drop in the teacher and student compute combined.
Abstract:Learning dynamical systems is a promising avenue for scientific discoveries. However, capturing the governing dynamics in multiple environments still remains a challenge: model-based approaches rely on the fidelity of assumptions made for a single environment, whereas data-driven approaches based on neural networks are often fragile on extrapolating into the future. In this work, we develop a method of sparse regression dubbed SpReME to discover the major dynamics that underlie multiple environments. Specifically, SpReME shares a sparse structure of ordinary differential equation (ODE) across different environments in common while allowing each environment to keep the coefficients of ODE terms independently. We demonstrate that the proposed model captures the correct dynamics from multiple environments over four different dynamic systems with improved prediction performance.
Abstract:Implicit neural representations are a promising new avenue of representing general signals by learning a continuous function that, parameterized as a neural network, maps the domain of a signal to its codomain; the mapping from spatial coordinates of an image to its pixel values, for example. Being capable of conveying fine details in a high dimensional signal, unboundedly of its domain, implicit neural representations ensure many advantages over conventional discrete representations. However, the current approach is difficult to scale for a large number of signals or a data set, since learning a neural representation -- which is parameter heavy by itself -- for each signal individually requires a lot of memory and computations. To address this issue, we propose to leverage a meta-learning approach in combination with network compression under a sparsity constraint, such that it renders a well-initialized sparse parameterization that evolves quickly to represent a set of unseen signals in the subsequent training. We empirically demonstrate that meta-learned sparse neural representations achieve a much smaller loss than dense meta-learned models with the same number of parameters, when trained to fit each signal using the same number of optimization steps.
Abstract:Network pruning is an effective methodology to compress large neural networks, and sparse neural networks obtained by pruning can benefit from their reduced memory and computational costs at use. Notably, recent advances have found that it is possible to find a trainable sparse neural network even at random initialization prior to training; hence the obtained sparse network only needs to be trained. While this approach of pruning at initialization turned out to be highly effective, little has been studied about the training aspects of these sparse neural networks. In this work, we focus on measuring the effects of data parallelism on training sparse neural networks. As a result, we find that the data parallelism in training sparse neural networks is no worse than that in training densely parameterized neural networks, despite the general difficulty of training sparse neural networks. When training sparse networks using SGD with momentum, the breakdown of the perfect scaling regime occurs even much later than the dense at large batch sizes.
Abstract:Network pruning is a promising avenue for compressing deep neural networks. A typical approach to pruning starts by training a model and removing unnecessary parameters while minimizing the impact on what is learned. Alternatively, a recent approach shows that pruning can be done at initialization prior to training. However, it remains unclear exactly why pruning an untrained, randomly initialized neural network is effective. In this work, we consider the pruning problem from a signal propagation perspective, formally characterizing initialization conditions that ensure faithful signal propagation throughout a network. Based on singular values of a network's input-output Jacobian, we find that orthogonal initialization enables more faithful signal propagation compared to other initialization schemes, thereby enhancing pruning results on a range of modern architectures and datasets. Also, we empirically study the effect of supervision for pruning at initialization, and show that often unsupervised pruning can be as effective as the supervised pruning. Furthermore, we demonstrate that our signal propagation perspective, combined with unsupervised pruning, can indeed be useful in various scenarios where pruning is applied to non-standard arbitrarily-designed architectures.