Abstract:Attention is a key part of the transformer architecture. It is a sequence-to-sequence mapping that transforms each sequence element into a weighted sum of values. The weights are typically obtained as the softmax of dot products between keys and queries. Recent work has explored alternatives to softmax attention in transformers, such as ReLU and sigmoid activations. In this work, we revisit sigmoid attention and conduct an in-depth theoretical and empirical analysis. Theoretically, we prove that transformers with sigmoid attention are universal function approximators and benefit from improved regularity compared to softmax attention. Through detailed empirical analysis, we identify stabilization of large initial attention norms during the early stages of training as a crucial factor for the successful training of models with sigmoid attention, outperforming prior attempts. We also introduce FLASHSIGMOID, a hardware-aware and memory-efficient implementation of sigmoid attention yielding a 17% inference kernel speed-up over FLASHATTENTION2 on H100 GPUs. Experiments across language, vision, and speech show that properly normalized sigmoid attention matches the strong performance of softmax attention on a wide range of domains and scales, which previous attempts at sigmoid attention were unable to fully achieve. Our work unifies prior art and establishes best practices for sigmoid attention as a drop-in softmax replacement in transformers.
Abstract:Contrastive learning typically matches pairs of related views among a number of unrelated negative views. Views can be generated (e.g. by augmentations) or be observed. We investigate matching when there are more than two related views which we call poly-view tasks, and derive new representation learning objectives using information maximization and sufficient statistics. We show that with unlimited computation, one should maximize the number of related views, and with a fixed compute budget, it is beneficial to decrease the number of unique samples whilst increasing the number of views of those samples. In particular, poly-view contrastive models trained for 128 epochs with batch size 256 outperform SimCLR trained for 1024 epochs at batch size 4096 on ImageNet1k, challenging the belief that contrastive models require large batch sizes and many training epochs.
Abstract:Understanding model uncertainty is important for many applications. We propose Bootstrap Your Own Variance (BYOV), combining Bootstrap Your Own Latent (BYOL), a negative-free Self-Supervised Learning (SSL) algorithm, with Bayes by Backprop (BBB), a Bayesian method for estimating model posteriors. We find that the learned predictive std of BYOV vs. a supervised BBB model is well captured by a Gaussian distribution, providing preliminary evidence that the learned parameter posterior is useful for label free uncertainty estimation. BYOV improves upon the deterministic BYOL baseline (+2.83% test ECE, +1.03% test Brier) and presents better calibration and reliability when tested with various augmentations (eg: +2.4% test ECE, +1.2% test Brier for Salt & Pepper noise).
Abstract:Preserving training dynamics across batch sizes is an important tool for practical machine learning as it enables the trade-off between batch size and wall-clock time. This trade-off is typically enabled by a scaling rule, for example, in stochastic gradient descent, one should scale the learning rate linearly with the batch size. Another important tool for practical machine learning is the model Exponential Moving Average (EMA), which is a model copy that does not receive gradient information, but instead follows its target model with some momentum. This model EMA can improve the robustness and generalization properties of supervised learning, stabilize pseudo-labeling, and provide a learning signal for Self-Supervised Learning (SSL). Prior works have treated the model EMA separately from optimization, leading to different training dynamics across batch sizes and lower model performance. In this work, we provide a scaling rule for optimization in the presence of model EMAs and demonstrate its validity across a range of architectures, optimizers, and data modalities. We also show the rule's validity where the model EMA contributes to the optimization of the target model, enabling us to train EMA-based pseudo-labeling and SSL methods at small and large batch sizes. For SSL, we enable training of BYOL up to batch size 24,576 without sacrificing performance, optimally a 6$\times$ wall-clock time reduction.
Abstract:Self-supervised representation learning (SSL) methods provide an effective label-free initial condition for fine-tuning downstream tasks. However, in numerous realistic scenarios, the downstream task might be biased with respect to the target label distribution. This in turn moves the learned fine-tuned model posterior away from the initial (label) bias-free self-supervised model posterior. In this work, we re-interpret SSL fine-tuning under the lens of Bayesian continual learning and consider regularization through the Elastic Weight Consolidation (EWC) framework. We demonstrate that self-regularization against an initial SSL backbone improves worst sub-group performance in Waterbirds by 5% and Celeb-A by 2% when using the ViT-B/16 architecture. Furthermore, to help simplify the use of EWC with SSL, we pre-compute and publicly release the Fisher Information Matrix (FIM), evaluated with 10,000 ImageNet-1K variates evaluated on large modern SSL architectures including ViT-B/16 and ResNet50 trained with DINO.
Abstract:While state-of-the-art contrastive Self-Supervised Learning (SSL) models produce results competitive with their supervised counterparts, they lack the ability to infer latent variables. In contrast, prescribed latent variable (LV) models enable attributing uncertainty, inducing task specific compression, and in general allow for more interpretable representations. In this work, we introduce LV approximations to large scale contrastive SSL models. We demonstrate that this addition improves downstream performance (resulting in 96.42% and 77.49% test top-1 fine-tuned performance on CIFAR10 and ImageNet respectively with a ResNet50) as well as producing highly compressed representations (588x reduction) that are useful for interpretability, classification and regression downstream tasks.
Abstract:In this work we examine how fine-tuning impacts the fairness of contrastive Self-Supervised Learning (SSL) models. Our findings indicate that Batch Normalization (BN) statistics play a crucial role, and that updating only the BN statistics of a pre-trained SSL backbone improves its downstream fairness (36% worst subgroup, 25% mean subgroup gap). This procedure is competitive with supervised learning, while taking 4.4x less time to train and requiring only 0.35% as many parameters to be updated. Finally, inspired by recent work in supervised learning, we find that updating BN statistics and training residual skip connections (12.3% of the parameters) achieves parity with a fully fine-tuned model, while taking 1.33x less time to train.
Abstract:Despite the success of a number of recent techniques for visual self-supervised deep learning, there remains limited investigation into the representations that are ultimately learned. By using recent advances in comparing neural representations, we explore in this direction by comparing a constrastive self-supervised algorithm (SimCLR) to supervision for simple image data in a common architecture. We find that the methods learn similar intermediate representations through dissimilar means, and that the representations diverge rapidly in the final few layers. We investigate this divergence, finding that it is caused by these layers strongly fitting to the distinct learning objectives. We also find that SimCLR's objective implicitly fits the supervised objective in intermediate layers, but that the reverse is not true. Our work particularly highlights the importance of the learned intermediate representations, and raises important questions for auxiliary task design.
Abstract:In this work, we introduce a new method for imitation learning from video demonstrations. Our method, Relational Mimic (RM), improves on previous visual imitation learning methods by combining generative adversarial networks and relational learning. RM is flexible and can be used in conjunction with other recent advances in generative adversarial imitation learning to better address the need for more robust and sample-efficient approaches. In addition, we introduce a new neural network architecture that improves upon the previous state-of-the-art in reinforcement learning and illustrate how increasing the relational reasoning capabilities of the agent enables the latter to achieve increasingly higher performance in a challenging locomotion task with pixel inputs. Finally, we study the effects and contributions of relational learning in policy evaluation, policy improvement and reward learning through ablation studies.
Abstract:Modern neural network training relies on piece-wise (sub-)differentiable functions in order to use backpropation for efficient calculation of gradients. In this work, we introduce a novel method to allow for non-differentiable functions at intermediary layers of deep neural networks. We do so through the introduction of a differentiable approximation bridge (DAB) neural network which provides smooth approximations to the gradient of the non-differentiable function. We present strong empirical results (performing over 600 experiments) in three different domains: unsupervised (image) representation learning, image classification, and sequence sorting to demonstrate that our proposed method improves state of the art performance. We demonstrate that utilizing non-differentiable functions in unsupervised (image) representation learning improves reconstruction quality and posterior linear separability by 10%. We also observe an accuracy improvement of 77% in neural sequence sorting and a 25% improvement against the straight-through estimator [3] in an image classification setting with the sort non-linearity. This work enables the usage of functions that were previously not usable in neural networks.