Abstract:We study matrix sensing, which is the problem of reconstructing a low-rank matrix from a few linear measurements. It can be formulated as an overparameterized regression problem, which can be solved by factorized gradient descent when starting from a small random initialization. Linear neural networks, and in particular matrix sensing by factorized gradient descent, serve as prototypical models of non-convex problems in modern machine learning, where complex phenomena can be disentangled and studied in detail. Much research has been devoted to studying special cases of asymmetric matrix sensing, such as asymmetric matrix factorization and symmetric positive semi-definite matrix sensing. Our key contribution is introducing a continuous differential equation that we call the $\textit{perturbed gradient flow}$. We prove that the perturbed gradient flow converges quickly to the true target matrix whenever the perturbation is sufficiently bounded. The dynamics of gradient descent for matrix sensing can be reduced to this formulation, yielding a novel proof of asymmetric matrix sensing with factorized gradient descent. Compared to directly analyzing the dynamics of gradient descent, the continuous formulation allows bounding key quantities by considering their derivatives, often simplifying the proofs. We believe the general proof technique may prove useful in other settings as well.
Abstract:Understanding the implicit regularization imposed by neural network architectures and gradient based optimization methods is a key challenge in deep learning and AI. In this work we provide sharp results for the implicit regularization imposed by the gradient flow of Diagonal Linear Networks (DLNs) in the over-parameterized regression setting and, potentially surprisingly, link this to the phenomenon of phase transitions in generalized hardness of approximation (GHA). GHA generalizes the phenomenon of hardness of approximation from computer science to, among others, continuous and robust optimization. It is well-known that the $\ell^1$-norm of the gradient flow of DLNs with tiny initialization converges to the objective function of basis pursuit. We improve upon these results by showing that the gradient flow of DLNs with tiny initialization approximates minimizers of the basis pursuit optimization problem (as opposed to just the objective function), and we obtain new and sharp convergence bounds w.r.t.\ the initialization size. Non-sharpness of our results would imply that the GHA phenomenon would not occur for the basis pursuit optimization problem -- which is a contradiction -- thus implying sharpness. Moreover, we characterize $\textit{which}$ $\ell_1$ minimizer of the basis pursuit problem is chosen by the gradient flow whenever the minimizer is not unique. Interestingly, this depends on the depth of the DLN.
Abstract:Transformers have revolutionized almost all natural language processing (NLP) tasks but suffer from memory and computational complexity that scales quadratically with sequence length. In contrast, recurrent neural networks (RNNs) exhibit linear scaling in memory and computational requirements but struggle to match the same performance as Transformers due to limitations in parallelization and scalability. We propose a novel model architecture, Receptance Weighted Key Value (RWKV), that combines the efficient parallelizable training of Transformers with the efficient inference of RNNs. Our approach leverages a linear attention mechanism and allows us to formulate the model as either a Transformer or an RNN, which parallelizes computations during training and maintains constant computational and memory complexity during inference, leading to the first non-transformer architecture to be scaled to tens of billions of parameters. Our experiments reveal that RWKV performs on par with similarly sized Transformers, suggesting that future work can leverage this architecture to create more efficient models. This work presents a significant step towards reconciling the trade-offs between computational efficiency and model performance in sequence processing tasks.