Abstract:Second order stochastic optimizers allow parameter update step size and direction to adapt to loss curvature, but have traditionally required too much memory and compute for deep learning. Recently, Shampoo [Gupta et al., 2018] introduced a Kronecker factored preconditioner to reduce these requirements: it is used for large deep models [Anil et al., 2020] and in production [Anil et al., 2022]. However, it takes inverse matrix roots of ill-conditioned matrices. This requires 64-bit precision, imposing strong hardware constraints. In this paper, we propose a novel factorization, Kronecker Approximation-Domination (KrAD). Using KrAD, we update a matrix that directly approximates the inverse empirical Fisher matrix (like full matrix AdaGrad), avoiding inversion and hence 64-bit precision. We then propose KrADagrad$^\star$, with similar computational costs to Shampoo and the same regret. Synthetic ill-conditioned experiments show improved performance over Shampoo for 32-bit precision, while for several real datasets we have comparable or better generalization.
Abstract:Toeplitz Neural Networks (TNNs) (Qin et. al. 2023) are a recent sequence model with impressive results. They require O(n log n) computational complexity and O(n) relative positional encoder (RPE) multi-layer perceptron (MLP) and decay bias calls. We aim to reduce both. We first note that the RPE is a non-SPD (symmetric positive definite) kernel and the Toeplitz matrices are pseudo-Gram matrices. Further 1) the learned kernels display spiky behavior near the main diagonals with otherwise smooth behavior; 2) the RPE MLP is slow. For bidirectional models, this motivates a sparse plus low-rank Toeplitz matrix decomposition. For the sparse component's action, we do a small 1D convolution. For the low rank component, we replace the RPE MLP with linear interpolation and use asymmetric Structured Kernel Interpolation (SKI) (Wilson et. al. 2015) for O(n) complexity: we provide rigorous error analysis. For causal models, "fast" causal masking (Katharopoulos et. al. 2020) negates SKI's benefits. Working in the frequency domain, we avoid an explicit decay bias. To enforce causality, we represent the kernel via the real part of its frequency response using the RPE and compute the imaginary part via a Hilbert transform. This maintains O(n log n) complexity but achieves an absolute speedup. Modeling the frequency response directly is also competitive for bidirectional training, using one fewer FFT. We set a speed state of the art on Long Range Arena (Tay et. al. 2020) with minimal score degradation.