Abstract:Transformers can efficiently learn in-context from example demonstrations. Most existing theoretical analyses studied the in-context learning (ICL) ability of transformers for linear function classes, where it is typically shown that the minimizer of the pretraining loss implements one gradient descent step on the least squares objective. However, this simplified linear setting arguably does not demonstrate the statistical efficiency of ICL, since the pretrained transformer does not outperform directly solving linear regression on the test prompt. In this paper, we study ICL of a nonlinear function class via transformer with nonlinear MLP layer: given a class of \textit{single-index} target functions $f_*(\boldsymbol{x}) = \sigma_*(\langle\boldsymbol{x},\boldsymbol{\beta}\rangle)$, where the index features $\boldsymbol{\beta}\in\mathbb{R}^d$ are drawn from a $r$-dimensional subspace, we show that a nonlinear transformer optimized by gradient descent (with a pretraining sample complexity that depends on the \textit{information exponent} of the link functions $\sigma_*$) learns $f_*$ in-context with a prompt length that only depends on the dimension of the distribution of target functions $r$; in contrast, any algorithm that directly learns $f_*$ on test prompt yields a statistical complexity that scales with the ambient dimension $d$. Our result highlights the adaptivity of the pretrained transformer to low-dimensional structures of the function class, which enables sample-efficient ICL that outperforms estimators that only have access to the in-context data.
Abstract:We study the problem of learning multi-index models in high-dimensions using a two-layer neural network trained with the mean-field Langevin algorithm. Under mild distributional assumptions on the data, we characterize the effective dimension $d_{\mathrm{eff}}$ that controls both sample and computational complexity by utilizing the adaptivity of neural networks to latent low-dimensional structures. When the data exhibit such a structure, $d_{\mathrm{eff}}$ can be significantly smaller than the ambient dimension. We prove that the sample complexity grows almost linearly with $d_{\mathrm{eff}}$, bypassing the limitations of the information and generative exponents that appeared in recent analyses of gradient-based feature learning. On the other hand, the computational complexity may inevitably grow exponentially with $d_{\mathrm{eff}}$ in the worst-case scenario. Motivated by improving computational complexity, we take the first steps towards polynomial time convergence of the mean-field Langevin algorithm by investigating a setting where the weights are constrained to be on a compact manifold with positive Ricci curvature, such as the hypersphere. There, we study assumptions under which polynomial time convergence is achievable, whereas similar assumptions in the Euclidean setting lead to exponential time complexity.
Abstract:We study the computational and sample complexity of learning a target function $f_*:\mathbb{R}^d\to\mathbb{R}$ with additive structure, that is, $f_*(x) = \frac{1}{\sqrt{M}}\sum_{m=1}^M f_m(\langle x, v_m\rangle)$, where $f_1,f_2,...,f_M:\mathbb{R}\to\mathbb{R}$ are nonlinear link functions of single-index models (ridge functions) with diverse and near-orthogonal index features $\{v_m\}_{m=1}^M$, and the number of additive tasks $M$ grows with the dimensionality $M\asymp d^\gamma$ for $\gamma\ge 0$. This problem setting is motivated by the classical additive model literature, the recent representation learning theory of two-layer neural network, and large-scale pretraining where the model simultaneously acquires a large number of "skills" that are often localized in distinct parts of the trained network. We prove that a large subset of polynomial $f_*$ can be efficiently learned by gradient descent training of a two-layer neural network, with a polynomial statistical and computational complexity that depends on the number of tasks $M$ and the information exponent of $f_m$, despite the unknown link function and $M$ growing with the dimensionality. We complement this learnability guarantee with computational hardness result by establishing statistical query (SQ) lower bounds for both the correlational SQ and full SQ algorithms.
Abstract:We study the problem of gradient descent learning of a single-index target function $f_*(\boldsymbol{x}) = \textstyle\sigma_*\left(\langle\boldsymbol{x},\boldsymbol{\theta}\rangle\right)$ under isotropic Gaussian data in $\mathbb{R}^d$, where the link function $\sigma_*:\mathbb{R}\to\mathbb{R}$ is an unknown degree $q$ polynomial with information exponent $p$ (defined as the lowest degree in the Hermite expansion). Prior works showed that gradient-based training of neural networks can learn this target with $n\gtrsim d^{\Theta(p)}$ samples, and such statistical complexity is predicted to be necessary by the correlational statistical query lower bound. Surprisingly, we prove that a two-layer neural network optimized by an SGD-based algorithm learns $f_*$ of arbitrary polynomial link function with a sample and runtime complexity of $n \asymp T \asymp C(q) \cdot d\mathrm{polylog} d$, where constant $C(q)$ only depends on the degree of $\sigma_*$, regardless of information exponent; this dimension dependence matches the information theoretic limit up to polylogarithmic factors. Core to our analysis is the reuse of minibatch in the gradient computation, which gives rise to higher-order information beyond correlational queries.
Abstract:Many recent works have studied the eigenvalue spectrum of the Conjugate Kernel (CK) defined by the nonlinear feature map of a feedforward neural network. However, existing results only establish weak convergence of the empirical eigenvalue distribution, and fall short of providing precise quantitative characterizations of the ''spike'' eigenvalues and eigenvectors that often capture the low-dimensional signal structure of the learning problem. In this work, we characterize these signal eigenvalues and eigenvectors for a nonlinear version of the spiked covariance model, including the CK as a special case. Using this general result, we give a quantitative description of how spiked eigenstructure in the input data propagates through the hidden layers of a neural network with random weights. As a second application, we study a simple regime of representation learning where the weight matrix develops a rank-one signal component over training and characterize the alignment of the target function with the spike eigenvector of the CK on test data.
Abstract:Recent works have demonstrated that the sample complexity of gradient-based learning of single index models, i.e. functions that depend on a 1-dimensional projection of the input data, is governed by their information exponent. However, these results are only concerned with isotropic data, while in practice the input often contains additional structure which can implicitly guide the algorithm. In this work, we investigate the effect of a spiked covariance structure and reveal several interesting phenomena. First, we show that in the anisotropic setting, the commonly used spherical gradient dynamics may fail to recover the true direction, even when the spike is perfectly aligned with the target direction. Next, we show that appropriate weight normalization that is reminiscent of batch normalization can alleviate this issue. Further, by exploiting the alignment between the (spiked) input covariance and the target, we obtain improved sample complexity compared to the isotropic case. In particular, under the spiked model with a suitably large spike, the sample complexity of gradient-based training can be made independent of the information exponent while also outperforming lower bounds for rotationally invariant kernel methods.
Abstract:The mean-field Langevin dynamics (MFLD) is a nonlinear generalization of the Langevin dynamics that incorporates a distribution-dependent drift, and it naturally arises from the optimization of two-layer neural networks via (noisy) gradient descent. Recent works have shown that MFLD globally minimizes an entropy-regularized convex functional in the space of measures. However, all prior analyses assumed the infinite-particle or continuous-time limit, and cannot handle stochastic gradient updates. We provide an general framework to prove a uniform-in-time propagation of chaos for MFLD that takes into account the errors due to finite-particle approximation, time-discretization, and stochastic gradient approximation. To demonstrate the wide applicability of this framework, we establish quantitative convergence rate guarantees to the regularized global optimal solution under (i) a wide range of learning problems such as neural network in the mean-field regime and MMD minimization, and (ii) different gradient estimators including SGD and SVRG. Despite the generality of our results, we achieve an improved convergence rate in both the SGD and SVRG settings when specialized to the standard Langevin dynamics.
Abstract:The entropic fictitious play (EFP) is a recently proposed algorithm that minimizes the sum of a convex functional and entropy in the space of measures -- such an objective naturally arises in the optimization of a two-layer neural network in the mean-field regime. In this work, we provide a concise primal-dual analysis of EFP in the setting where the learning problem exhibits a finite-sum structure. We establish quantitative global convergence guarantees for both the continuous-time and discrete-time dynamics based on properties of a proximal Gibbs measure introduced in Nitanda et al. (2022). Furthermore, our primal-dual framework entails a memory-efficient particle-based implementation of the EFP update, and also suggests a connection to gradient boosting methods. We illustrate the efficiency of our novel implementation in experiments including neural network optimization and image synthesis.
Abstract:We study the first gradient descent step on the first-layer parameters $\boldsymbol{W}$ in a two-layer neural network: $f(\boldsymbol{x}) = \frac{1}{\sqrt{N}}\boldsymbol{a}^\top\sigma(\boldsymbol{W}^\top\boldsymbol{x})$, where $\boldsymbol{W}\in\mathbb{R}^{d\times N}, \boldsymbol{a}\in\mathbb{R}^{N}$ are randomly initialized, and the training objective is the empirical MSE loss: $\frac{1}{n}\sum_{i=1}^n (f(\boldsymbol{x}_i)-y_i)^2$. In the proportional asymptotic limit where $n,d,N\to\infty$ at the same rate, and an idealized student-teacher setting, we show that the first gradient update contains a rank-1 "spike", which results in an alignment between the first-layer weights and the linear component of the teacher model $f^*$. To characterize the impact of this alignment, we compute the prediction risk of ridge regression on the conjugate kernel after one gradient step on $\boldsymbol{W}$ with learning rate $\eta$, when $f^*$ is a single-index model. We consider two scalings of the first step learning rate $\eta$. For small $\eta$, we establish a Gaussian equivalence property for the trained feature map, and prove that the learned kernel improves upon the initial random features model, but cannot defeat the best linear model on the input. Whereas for sufficiently large $\eta$, we prove that for certain $f^*$, the same ridge estimator on trained features can go beyond this "linear regime" and outperform a wide range of random features and rotationally invariant kernels. Our results demonstrate that even one gradient step can lead to a considerable advantage over random features, and highlight the role of learning rate scaling in the initial phase of training.
Abstract:As an example of the nonlinear Fokker-Planck equation, the mean field Langevin dynamics attracts attention due to its connection to (noisy) gradient descent on infinitely wide neural networks in the mean field regime, and hence the convergence property of the dynamics is of great theoretical interest. In this work, we give a simple and self-contained convergence rate analysis of the mean field Langevin dynamics with respect to the (regularized) objective function in both continuous and discrete time settings. The key ingredient of our proof is a proximal Gibbs distribution $p_q$ associated with the dynamics, which, in combination of techniques in [Vempala and Wibisono (2019)], allows us to develop a convergence theory parallel to classical results in convex optimization. Furthermore, we reveal that $p_q$ connects to the duality gap in the empirical risk minimization setting, which enables efficient empirical evaluation of the algorithm convergence.