Celine
Abstract:Recently, many foundation models for medical image analysis such as MedSAM, SwinUNETR have been released and proven to be useful in multiple tasks. However, considering the inherent heterogeneity and inhomogeneity of real-world medical data, directly applying these models to specific medical image segmentation tasks often leads to negative domain shift effects, which can severely weaken the model's segmentation capabilities. To this end, we propose an adaptive amalgamation knowledge framework that aims to train a versatile foundation model to handle the joint goals of multiple expert models, each specialized for a distinct task. Specifically, we first train an nnUNet-based expert model for each task, and reuse the pre-trained SwinUNTER as the target foundation model. Then, the input data for all challenging tasks are encoded in the foundation model and the expert models, respectively, and their backbone features are jointly projected into the adaptive amalgamation layer. Within the hidden layer, the hierarchical attention mechanisms are designed to achieve adaptive merging of the target model to the hidden layer feature knowledge of all experts, which significantly reduces the domain shift arising from the inter-task differences. Finally, the gold amalgamated features and the prompt features are fed into the mask decoder to obtain the segmentation results. Extensive experiments conducted in these challenging tasks demonstrate the effectiveness and adaptability of our foundation model for real-world medical image segmentation.
Abstract:In Large Language Model (LLM) inference, the output length of an LLM request is typically regarded as not known a priori. Consequently, most LLM serving systems employ a simple First-come-first-serve (FCFS) scheduling strategy, leading to Head-Of-Line (HOL) blocking and reduced throughput and service quality. In this paper, we reexamine this assumption -- we show that, although predicting the exact generation length of each request is infeasible, it is possible to predict the relative ranks of output lengths in a batch of requests, using learning to rank. The ranking information offers valuable guidance for scheduling requests. Building on this insight, we develop a novel scheduler for LLM inference and serving that can approximate the shortest-job-first (SJF) schedule better than existing approaches. We integrate this scheduler with the state-of-the-art LLM serving system and show significant performance improvement in several important applications: 2.8x lower latency in chatbot serving and 6.5x higher throughput in synthetic data generation. Our code is available at https://github.com/hao-ai-lab/vllm-ltr.git
Abstract:Autoregressive Large Language Models (LLMs) have achieved impressive performance in language tasks but face two significant bottlenecks: (1) quadratic complexity in the attention module as the number of tokens increases, and (2) limited efficiency due to the sequential processing nature of autoregressive LLMs during generation. While linear attention and speculative decoding offer potential solutions, their applicability and synergistic potential for enhancing autoregressive LLMs remain uncertain. We conduct the first comprehensive study on the efficacy of existing linear attention methods for autoregressive LLMs, integrating them with speculative decoding. We introduce an augmentation technique for linear attention that ensures compatibility with speculative decoding, enabling more efficient training and serving of LLMs. Extensive experiments and ablation studies involving seven existing linear attention models and five encoder/decoder-based LLMs consistently validate the effectiveness of our augmented linearized LLMs. Notably, our approach achieves up to a 6.67 reduction in perplexity on the LLaMA model and up to a 2$\times$ speedup during generation compared to prior linear attention methods. Codes and models are available at https://github.com/GATECH-EIC/Linearized-LLM.
Abstract:Large language models (LLMs) have shown impressive performance on language tasks but face challenges when deployed on resource-constrained devices due to their extensive parameters and reliance on dense multiplications, resulting in high memory demands and latency bottlenecks. Shift-and-add reparameterization offers a promising solution by replacing costly multiplications with hardware-friendly primitives in both the attention and multi-layer perceptron (MLP) layers of an LLM. However, current reparameterization techniques require training from scratch or full parameter fine-tuning to restore accuracy, which is resource-intensive for LLMs. To address this, we propose accelerating pretrained LLMs through post-training shift-and-add reparameterization, creating efficient multiplication-free models, dubbed ShiftAddLLM. Specifically, we quantize each weight matrix into binary matrices paired with group-wise scaling factors. The associated multiplications are reparameterized into (1) shifts between activations and scaling factors and (2) queries and adds according to the binary matrices. To reduce accuracy loss, we present a multi-objective optimization method to minimize both weight and output activation reparameterization errors. Additionally, based on varying sensitivity across layers to reparameterization, we develop an automated bit allocation strategy to further reduce memory usage and latency. Experiments on five LLM families and eight tasks consistently validate the effectiveness of ShiftAddLLM, achieving average perplexity improvements of 5.6 and 22.7 points at comparable or lower latency compared to the most competitive quantized LLMs at 3 and 2 bits, respectively, and more than 80% memory and energy reductions over the original LLMs. Codes and models are available at https://github.com/GATECH-EIC/ShiftAddLLM.
Abstract:Autoregressive decoding of large language models (LLMs) is memory bandwidth bounded, resulting in high latency and significant wastes of the parallel processing power of modern accelerators. Existing methods for accelerating LLM decoding often require a draft model (e.g., speculative decoding), which is nontrivial to obtain and unable to generalize. In this paper, we introduce Lookahead decoding, an exact, parallel decoding algorithm that accelerates LLM decoding without needing auxiliary models or data stores. It allows trading per-step log(FLOPs) to reduce the number of total decoding steps, is more parallelizable on single or multiple modern accelerators, and is compatible with concurrent memory-efficient attention (e.g., FlashAttention). Our implementation of Lookahead decoding can speed up autoregressive decoding by up to 1.8x on MT-bench and 4x with strong scaling on multiple GPUs in code completion tasks. Our code is avialable at https://github.com/hao-ai-lab/LookaheadDecoding
Abstract:Matrix factorization (MF) is a classical collaborative filtering algorithm for recommender systems. It decomposes the user-item interaction matrix into a product of low-dimensional user representation matrix and item representation matrix. In typical recommendation scenarios, the user-item interaction paradigm is usually a two-stage process and requires static clustering analysis of the obtained user and item representations. The above process, however, is time and computationally intensive, making it difficult to apply in real-time to e-commerce or Internet of Things environments with billions of users and trillions of items. To address this, we propose a unified matrix factorization method based on dynamic multi-view clustering (MFDMC) that employs an end-to-end training paradigm. Specifically, in each view, a user/item representation is regarded as a weighted projection of all clusters. The representation of each cluster is learnable, enabling the dynamic discarding of bad clusters. Furthermore, we employ multi-view clustering to represent multiple roles of users/items, effectively utilizing the representation space and improving the interpretability of the user/item representations for downstream tasks. Extensive experiments show that our proposed MFDMC achieves state-of-the-art performance on real-world recommendation datasets. Additionally, comprehensive visualization and ablation studies interpretably confirm that our method provides meaningful representations for downstream tasks of users/items.
Abstract:Knowledge amalgamation (KA) aims to learn a compact student model to handle the joint objective from multiple teacher models that are are specialized for their own tasks respectively. Current methods focus on coarsely aligning teachers and students in the common representation space, making it difficult for the student to learn the proper decision boundaries from a set of heterogeneous teachers. Besides, the KL divergence in previous works only minimizes the probability distribution difference between teachers and the student, ignoring the intrinsic characteristics of teachers. Therefore, we propose a novel Contrastive Knowledge Amalgamation (CKA) framework, which introduces contrastive losses and an alignment loss to achieve intra-class cohesion and inter-class separation.Contrastive losses intra- and inter- models are designed to widen the distance between representations of different classes. The alignment loss is introduced to minimize the sample-level distribution differences of teacher-student models in the common representation space.Furthermore, the student learns heterogeneous unsupervised classification tasks through soft targets efficiently and flexibly in the task-level amalgamation. Extensive experiments on benchmarks demonstrate the generalization capability of CKA in the amalgamation of specific task as well as multiple tasks. Comprehensive ablation studies provide a further insight into our CKA.
Abstract:Deep neural networks have been widely used in many critical applications, such as autonomous vehicles and medical diagnosis. However, their security is threatened by backdoor attack, which is achieved by adding artificial patterns to specific training data. Existing defense strategies primarily focus on using reverse engineering to reproduce the backdoor trigger generated by attackers and subsequently repair the DNN model by adding the trigger into inputs and fine-tuning the model with ground-truth labels. However, once the trigger generated by the attackers is complex and invisible, the defender can not successfully reproduce the trigger. Consequently, the DNN model will not be repaired since the trigger is not effectively removed. In this work, we propose Feature Map Testing~(FMT). Different from existing defense strategies, which focus on reproducing backdoor triggers, FMT tries to detect the backdoor feature maps, which are trained to extract backdoor information from the inputs. After detecting these backdoor feature maps, FMT will erase them and then fine-tune the model with a secure subset of training data. Our experiments demonstrate that, compared to existing defense strategies, FMT can effectively reduce the Attack Success Rate (ASR) even against the most complex and invisible attack triggers. Second, unlike conventional defense methods that tend to exhibit low Robust Accuracy (i.e., the model's accuracy on the poisoned data), FMT achieves higher RA, indicating its superiority in maintaining model performance while mitigating the effects of backdoor attacks~(e.g., FMT obtains 87.40\% RA in CIFAR10). Third, compared to existing feature map pruning techniques, FMT can cover more backdoor feature maps~(e.g., FMT removes 83.33\% of backdoor feature maps from the model in the CIFAR10 \& BadNet scenario).
Abstract:Due to the widespread application of deep neural networks~(DNNs) in safety-critical tasks, deep learning testing has drawn increasing attention. During the testing process, test cases that have been fuzzed or selected using test metrics are fed into the model to find fault-inducing test units (e.g., neurons and feature maps, activating which will almost certainly result in a model error) and report them to the DNN developer, who subsequently repair them~(e.g., retraining the model with test cases). Current test metrics, however, are primarily concerned with the neurons, which means that test cases that are discovered either by guided fuzzing or selection with these metrics focus on detecting fault-inducing neurons while failing to detect fault-inducing feature maps. In this work, we propose DeepFeature, which tests DNNs from the feature map level. When testing is conducted, DeepFeature will scrutinize every internal feature map in the model and identify vulnerabilities that can be enhanced through repairing to increase the model's overall performance. Exhaustive experiments are conducted to demonstrate that (1) DeepFeature is a strong tool for detecting the model's vulnerable feature maps; (2) DeepFeature's test case selection has a high fault detection rate and can detect more types of faults~(comparing DeepFeature to coverage-guided selection techniques, the fault detection rate is increased by 49.32\%). (3) DeepFeature's fuzzer also outperforms current fuzzing techniques and generates valuable test cases more efficiently.
Abstract:Deep Neural Networks~(DNNs) have been widely deployed in software to address various tasks~(e.g., autonomous driving, medical diagnosis). However, they could also produce incorrect behaviors that result in financial losses and even threaten human safety. To reveal the incorrect behaviors in DNN and repair them, DNN developers often collect rich unlabeled datasets from the natural world and label them to test the DNN models. However, properly labeling a large number of unlabeled datasets is a highly expensive and time-consuming task. To address the above-mentioned problem, we propose NSS, Neuron Sensitivity guided test case Selection, which can reduce the labeling time by selecting valuable test cases from unlabeled datasets. NSS leverages the internal neuron's information induced by test cases to select valuable test cases, which have high confidence in causing the model to behave incorrectly. We evaluate NSS with four widely used datasets and four well-designed DNN models compared to SOTA baseline methods. The results show that NSS performs well in assessing the test cases' probability of fault triggering and model improvement capabilities. Specifically, compared with baseline approaches, NSS obtains a higher fault detection rate~(e.g., when selecting 5\% test case from the unlabeled dataset in MNIST \& LeNet1 experiment, NSS can obtain 81.8\% fault detection rate, 20\% higher than baselines).