Abstract:The holy grail of machine learning is to enable Continual Federated Learning (CFL) to enhance the efficiency, privacy, and scalability of AI systems while learning from streaming data. The primary challenge of a CFL system is to overcome global catastrophic forgetting, wherein the accuracy of the global model trained on new tasks declines on the old tasks. In this work, we propose Continual Federated Learning with Aggregated Gradients (C-FLAG), a novel replay-memory based federated strategy consisting of edge-based gradient updates on memory and aggregated gradients on the current data. We provide convergence analysis of the C-FLAG approach which addresses forgetting and bias while converging at a rate of $O(1/\sqrt{T})$ over $T$ communication rounds. We formulate an optimization sub-problem that minimizes catastrophic forgetting, translating CFL into an iterative algorithm with adaptive learning rates that ensure seamless learning across tasks. We empirically show that C-FLAG outperforms several state-of-the-art baselines on both task and class-incremental settings with respect to metrics such as accuracy and forgetting.
Abstract:Privacy, security, and bandwidth constraints have led to federated learning (FL) in wireless systems, where training a machine learning (ML) model is accomplished collaboratively without sharing raw data. Often, such collaborative FL strategies necessitate model aggregation at a server. On the other hand, decentralized FL necessitates that participating clients reach a consensus ML model by exchanging parameter updates. In this work, we propose the over-the-air clustered wireless FL (CWFL) strategy, which eliminates the need for a strong central server and yet achieves an accuracy similar to the server-based strategy while using fewer channel uses as compared to decentralized FL. We theoretically show that the convergence rate of CWFL per cluster is O(1/T) while mitigating the impact of noise. Using the MNIST and CIFAR datasets, we demonstrate the accuracy performance of CWFL for the different number of clusters across communication rounds.