Abstract:Training a generalizable agent to continually learn a sequence of tasks from offline trajectories is a natural requirement for long-lived agents, yet remains a significant challenge for current offline reinforcement learning (RL) algorithms. Specifically, an agent must be able to rapidly adapt to new tasks using newly collected trajectories (plasticity), while retaining knowledge from previously learned tasks (stability). However, systematic analyses of this setting are scarce, and it remains unclear whether conventional continual learning (CL) methods are effective in continual offline RL (CORL) scenarios. In this study, we develop the Offline Continual World benchmark and demonstrate that traditional CL methods struggle with catastrophic forgetting, primarily due to the unique distribution shifts inherent to CORL scenarios. To address this challenge, we introduce CompoFormer, a structure-based continual transformer model that adaptively composes previous policies via a meta-policy network. Upon encountering a new task, CompoFormer leverages semantic correlations to selectively integrate relevant prior policies alongside newly trained parameters, thereby enhancing knowledge sharing and accelerating the learning process. Our experiments reveal that CompoFormer outperforms conventional CL methods, particularly in longer task sequences, showcasing a promising balance between plasticity and stability.
Abstract:Offline reinforcement learning (RL) methods harness previous experiences to derive an optimal policy, forming the foundation for pre-trained large-scale models (PLMs). When encountering tasks not seen before, PLMs often utilize several expert trajectories as prompts to expedite their adaptation to new requirements. Though a range of prompt-tuning methods have been proposed to enhance the quality of prompts, these methods often face optimization restrictions due to prompt initialization, which can significantly constrain the exploration domain and potentially lead to suboptimal solutions. To eliminate the reliance on the initial prompt, we shift our perspective towards the generative model, framing the prompt-tuning process as a form of conditional generative modeling, where prompts are generated from random noise. Our innovation, the Prompt Diffuser, leverages a conditional diffusion model to produce prompts of exceptional quality. Central to our framework is the approach to trajectory reconstruction and the meticulous integration of downstream task guidance during the training phase. Further experimental results underscore the potency of the Prompt Diffuser as a robust and effective tool for the prompt-tuning process, demonstrating strong performance in the meta-RL tasks.
Abstract:The purpose of offline multi-task reinforcement learning (MTRL) is to develop a unified policy applicable to diverse tasks without the need for online environmental interaction. Recent advancements approach this through sequence modeling, leveraging the Transformer architecture's scalability and the benefits of parameter sharing to exploit task similarities. However, variations in task content and complexity pose significant challenges in policy formulation, necessitating judicious parameter sharing and management of conflicting gradients for optimal policy performance. Furthermore, identifying the optimal parameter subspace for each task often necessitates prior knowledge of the task identifier during inference, limiting applicability in real-world scenarios with variable task content and unknown current tasks. In this work, we introduce the Harmony Multi-Task Decision Transformer (HarmoDT), a novel solution designed to identify an optimal harmony subspace of parameters for each task. We formulate this as a bi-level optimization problem within a meta-learning framework, where the upper level learns masks to define the harmony subspace, while the inner level focuses on updating parameters to improve the overall performance of the unified policy. To eliminate the need for task identifiers, we further design a group-wise variant (G-HarmoDT) that clusters tasks into coherent groups based on gradient information, and utilizes a gating network to determine task identifiers during inference. Empirical evaluations across various benchmarks highlight the superiority of our approach, demonstrating its effectiveness in the multi-task context with specific improvements of 8% gain in task-provided settings, 5% in task-agnostic settings, and 10% in unseen settings.
Abstract:In numerous artificial intelligence applications, the collaborative efforts of multiple intelligent agents are imperative for the successful attainment of target objectives. To enhance coordination among these agents, a distributed communication framework is often employed. However, indiscriminate information sharing among all agents can be resource-intensive, and the adoption of manually pre-defined communication architectures imposes constraints on inter-agent communication, thus limiting the potential for effective collaboration. Moreover, the communication framework often remains static during inference, which may result in sustained high resource consumption, as in most cases, only key decisions necessitate information sharing among agents. In this study, we introduce a novel approach wherein we conceptualize the communication architecture among agents as a learnable graph. We formulate this problem as the task of determining the communication graph while enabling the architecture parameters to update normally, thus necessitating a bi-level optimization process. Utilizing continuous relaxation of the graph representation and incorporating attention units, our proposed approach, CommFormer, efficiently optimizes the communication graph and concurrently refines architectural parameters through gradient descent in an end-to-end manner. Additionally, we introduce a temporal gating mechanism for each agent, enabling dynamic decisions on whether to receive shared information at a given time, based on current observations, thus improving decision-making efficiency. Extensive experiments on a variety of cooperative tasks substantiate the robustness of our model across diverse cooperative scenarios, where agents are able to develop more coordinated and sophisticated strategies regardless of changes in the number of agents.
Abstract:Continual offline reinforcement learning (CORL) has shown impressive ability in diffusion-based lifelong learning systems by modeling the joint distributions of trajectories. However, most research only focuses on limited continual task settings where the tasks have the same observation and action space, which deviates from the realistic demands of training agents in various environments. In view of this, we propose Vector-Quantized Continual Diffuser, named VQ-CD, to break the barrier of different spaces between various tasks. Specifically, our method contains two complementary sections, where the quantization spaces alignment provides a unified basis for the selective weights activation. In the quantized spaces alignment, we leverage vector quantization to align the different state and action spaces of various tasks, facilitating continual training in the same space. Then, we propose to leverage a unified diffusion model attached by the inverse dynamic model to master all tasks by selectively activating different weights according to the task-related sparse masks. Finally, we conduct extensive experiments on 15 continual learning (CL) tasks, including conventional CL task settings (identical state and action spaces) and general CL task settings (various state and action spaces). Compared with 16 baselines, our method reaches the SOTA performance.
Abstract:Structured pruning is a promising hardware-friendly compression technique for large language models (LLMs), which is expected to be retraining-free to avoid the enormous retraining cost. This retraining-free paradigm involves (1) pruning criteria to define the architecture and (2) distortion reconstruction to restore performance. However, existing methods often emphasize pruning criteria while using reconstruction techniques that are specific to certain modules or criteria, resulting in limited generalizability. To address this, we introduce the Linear Interpolation-based Adaptive Reconstruction (LIAR) framework, which is both efficient and effective. LIAR does not require back-propagation or retraining and is compatible with various pruning criteria and modules. By applying linear interpolation to the preserved weights, LIAR minimizes reconstruction error and effectively reconstructs the pruned output. Our evaluations on benchmarks such as GLUE, SQuAD, WikiText, and common sense reasoning show that LIAR enables a BERT model to maintain 98% accuracy even after removing 50% of its parameters and achieves top performance for LLaMA in just a few minutes.
Abstract:In federated learning (FL), the multi-step update and data heterogeneity among clients often lead to a loss landscape with sharper minima, degenerating the performance of the resulted global model. Prevalent federated approaches incorporate sharpness-aware minimization (SAM) into local training to mitigate this problem. However, the local loss landscapes may not accurately reflect the flatness of global loss landscape in heterogeneous environments; as a result, minimizing local sharpness and calculating perturbations on client data might not align the efficacy of SAM in FL with centralized training. To overcome this challenge, we propose FedLESAM, a novel algorithm that locally estimates the direction of global perturbation on client side as the difference between global models received in the previous active and current rounds. Besides the improved quality, FedLESAM also speed up federated SAM-based approaches since it only performs once backpropagation in each iteration. Theoretically, we prove a slightly tighter bound than its original FedSAM by ensuring consistent perturbation. Empirically, we conduct comprehensive experiments on four federated benchmark datasets under three partition strategies to demonstrate the superior performance and efficiency of FedLESAM.
Abstract:The purpose of offline multi-task reinforcement learning (MTRL) is to develop a unified policy applicable to diverse tasks without the need for online environmental interaction. Recent advancements approach this through sequence modeling, leveraging the Transformer architecture's scalability and the benefits of parameter sharing to exploit task similarities. However, variations in task content and complexity pose significant challenges in policy formulation, necessitating judicious parameter sharing and management of conflicting gradients for optimal policy performance. In this work, we introduce the Harmony Multi-Task Decision Transformer (HarmoDT), a novel solution designed to identify an optimal harmony subspace of parameters for each task. We approach this as a bi-level optimization problem, employing a meta-learning framework that leverages gradient-based techniques. The upper level of this framework is dedicated to learning a task-specific mask that delineates the harmony subspace, while the inner level focuses on updating parameters to enhance the overall performance of the unified policy. Empirical evaluations on a series of benchmarks demonstrate the superiority of HarmoDT, verifying the effectiveness of our approach.
Abstract:Recent advancements in offline reinforcement learning (RL) have underscored the capabilities of Conditional Sequence Modeling (CSM), a paradigm that learns the action distribution based on history trajectory and target returns for each state. However, these methods often struggle with stitching together optimal trajectories from sub-optimal ones due to the inconsistency between the sampled returns within individual trajectories and the optimal returns across multiple trajectories. Fortunately, Dynamic Programming (DP) methods offer a solution by leveraging a value function to approximate optimal future returns for each state, while these techniques are prone to unstable learning behaviors, particularly in long-horizon and sparse-reward scenarios. Building upon these insights, we propose the Q-value regularized Transformer (QT), which combines the trajectory modeling ability of the Transformer with the predictability of optimal future returns from DP methods. QT learns an action-value function and integrates a term maximizing action-values into the training loss of CSM, which aims to seek optimal actions that align closely with the behavior policy. Empirical evaluations on D4RL benchmark datasets demonstrate the superiority of QT over traditional DP and CSM methods, highlighting the potential of QT to enhance the state-of-the-art in offline RL.
Abstract:Transformer-based trajectory optimization methods have demonstrated exceptional performance in offline Reinforcement Learning (offline RL), yet it poses challenges due to substantial parameter size and limited scalability, which is particularly critical in sequential decision-making scenarios where resources are constrained such as in robots and drones with limited computational power. Mamba, a promising new linear-time sequence model, offers performance on par with transformers while delivering substantially fewer parameters on long sequences. As it remains unclear whether Mamba is compatible with trajectory optimization, this work aims to conduct comprehensive experiments to explore the potential of Decision Mamba in offline RL (dubbed DeMa) from the aspect of data structures and network architectures with the following insights: (1) Long sequences impose a significant computational burden without contributing to performance improvements due to the fact that DeMa's focus on sequences diminishes approximately exponentially. Consequently, we introduce a Transformer-like DeMa as opposed to an RNN-like DeMa. (2) For the components of DeMa, we identify that the hidden attention mechanism is key to its success, which can also work well with other residual structures and does not require position embedding. Extensive evaluations from eight Atari games demonstrate that our specially designed DeMa is compatible with trajectory optimization and surpasses previous state-of-the-art methods, outdoing Decision Transformer (DT) by 80\% with 30\% fewer parameters, and exceeds DT in MuJoCo with only a quarter of the parameters.