MGSER-SAM] A Method For Solving The Ruinous Forgetting Problem In Continuous Learning
3 main points
✔️ Proposal of MGSER-SAM to address the "catastrophic forgetting" problem in continuous learning.
✔️ A new algorithm integrating Sharpness Awareness Optimization (SAM) and Experience Replay (ER).
✔️ Demonstration of a method to improve model generalizability and learning performance using soft logit and memory gradient direction consistency.
MGSER-SAM: Memory-Guided Soft Experience Replay with Sharpness-Aware Optimization for Enhanced Continual Learning
written by Xingyu Li, Bo Tang
(Submitted on 15 May 2024)
Comments: 8 pages, 5 figures
Subjects: Machine Learning (cs.LG)
code:
The images used in this article are from the paper, the introductory slides, or were created based on them.
Summary
In continuous learning (CL), the "catastrophic forgetting" problem, in which previously learned information is lost when learning a new task, is a serious problem. To address this problem, we propose a novel memory replay-based algorithm called MGSER-SAM. First, we integrate the SAM optimizer and adapt it to existing experience replay frameworks such as ER and DER++. Second, we strategically incorporate soft logit and memory gradient direction alignment to resolve the weight perturbation direction conflict between the current task and previously stored memory in the continuous learning process. This allows MGSER-SAM to effectively minimize various training loss terms simultaneously. Experimental results show that MGSER-SAM outperforms the existing baseline in all CL scenarios.
Related Research
Three Scenarios of Continuous Learning
Evaluation of continuous learning approaches poses significant challenges due to differences in experimental protocols and the degree of access to task identification during testing. In response, three standardized scenarios for evaluating continuous learning have been introduced. These scenarios are task incremental learning (task-IL), class incremental learning (class-IL), and domain incremental learning (domain-IL).
Three Types of Continuous Learning Approaches
Regularization approaches: aim to integrate information from new tasks while retaining knowledge from previous tasks in order to reduce the problem of catastrophic forgetting. Examples include LwF (Learning without Forgetting) and EWC (Elastic Weight Consolidation).
Architectural approaches: are dynamic approaches that adapt the structure of the model itself to accommodate new tasks. Examples include PNNs (Progressive Neural Networks) and DENs (Dynamically Expanding Networks).
Memory replay approach: inspired by the relationship between the mammalian hippocampus and neocortex, this method stores previously viewed data points as episodic memory and replays them during learning of a new task to prevent catastrophic forgetting.
Proposed Method (MGSER-SAM)
ER-SAM (Experience Replay with Sharpness Aware Minimization)
ER-SAM, the foundation of MGSER-SAM, is a method that integrates a SAM optimizer into empirical replay (ER) to improve the generalizability of the model by flattening the geometry of the loss function SAM minimizes the worst-case loss in a neighborhood of the model parameter space, thereby flattening the model improves the model's generalizability by minimizing the worst-case losses in the neighborhood of the model parameter space.
Integrating SAM into the ER optimizes the loss function for
where $\rho $ is a constant that controls the radius of the neighborhood and $\delta$ is a weight perturbation that maximizes the worst case loss. This updates the model as follows
where $\ g_{ER-SAM} $ is the gradient of the loss function after perturbation.
MGSER-SAM
MGSER-SAM is designed to overcome the limitations of ER-SAM and resolve the conflict between the current task and memory weight perturbation directions. Specifically, it introduces two regularization terms
1. soft logit: uses the model's output logit as a replacement for the current task's loss and memory loss terms. This allows the model to be updated when learning a new task to be consistent with the logit of the previous task. Specifically, the following loss functions are optimized.
where $\ z' $ is the softlogit corresponding to the memory data $\ x' $.
2. memory gradient direction consistency: integrates SAM optimizer and memory rehearsal techniques to guide memory gradient direction for a more balanced learning process.
The final MGSER-SAM model update is as follows
Experiment
Benchmark
In this study, multiple benchmarks were used for three continuous learning scenarios (task incremental learning, class incremental learning, and domain incremental learning). Details of the benchmarks are as follows
S・MNIST: task incremental learning (task-IL) and class incremental learning (class-IL)
S ・CIFAR10: task incremental learning (task-IL) and class incremental learning (class-IL)
S ・CIFAR100: task incremental learning (task-IL) and class incremental learning (
S・TinyImageNet: task incremental learning (task・IL) and class incremental learning (class・IL)
P ・MNIST: domain incremental learning (domain・IL)
R ・MNIST: domain incremental learning (domain・IL)
Baseline
To evaluate the performance of the proposed method MGSER/SAM, we compare it to the following representative baselines
LWF (Learning without Forgetting)
・ PNN (Progressive Neural Networks)
・ SI (Synaptic Intelligence)
・ oEWC (Online Elastic Weight Consolidation)
・ ER ( Experience Replay)
・ DER++ (Dark Experience Replay)
In addition, to evaluate the adaptability of the SAM optimizer, we also compare it with ER-SAM and DER++-SAM.
Valuation Index
To ensure a fair comparison, all models are trained with the same hyperparameters and computational resources. We also use the following two evaluation metrics
1. average accuracy (ACC): average test accuracy after training for all tasks
2 . forget rate (Forget): difference between the highest and last test accuracy for the previous task
Result
Performance Analysis
Table II shows the performance of each method on the S・MNIST, S・CIFAR10, S・CIFAR100, S・TinyImageNet, P・MNIST, and R・MNIST benchmarks. The results show that the proposed method MGSER-SAM achieves the best performance on all benchmarks. In particular, it can be seen that the test accuracy reaches 93.29% for S and MNIST, which is 4.2% and 17.6% better than ER and DER++, respectively.
Surveys During The CL Process
Figure 3 shows the change in first task test accuracy for S・MNIST, S・CIFAR10, and S・CIFAR100 during the CL process; MGSER・SAM shows the highest first task test accuracy after each task is learned. For example, S-CIFAR10 shows a 24.05% decrease in first task accuracy for MGSER/SAM, which is 54.92% and 12.06% lower than the loss for ER and DER++, respectively.
ACC for All Tasks During Class Incremental Learning
Figure 4 shows the change in average accuracy (ACC) for all tasks in P・MNIST and S・TinyImageNet, showing that MGSER・SAM achieves the highest ACC after learning each task. For example, in P・MNIST, the ACC of MGSER・SAM after 20 task training is 89.92%, which is higher than the ACC of ER after the first task training.
Impact of Memory Capacity
Figure 5 shows the average accuracy (ACC) in class incremental learning for each benchmark at different memory buffer sizes (M ∈ [400, 2000]). The results show that MGSER-SAM always performs best. We also see that as the memory buffer size increases, the performance of all comparison methods improves.
Conclusion
In this paper, we proposed a new algorithm, MGSER-SAM, which integrates Sharpness Awareness Minimization (SAM) into the Experience Replay (ER) framework to solve the "catastrophic forgetting" problem in Continuous Learning (CL). MGSER-SAM is a new algorithm that integrates the conflict of weight perturbation directions between tasks with soft logit and memory gradient direction consistency and achieved up to 24.4% accuracy improvement and minimal forgetting rate across multiple benchmarks.
In the future, MGSER-SAM is expected to further improve its utility through optimization of hyperparameters, application to other CL scenarios and data sets, development of computational cost reduction methods, implementation in real-time applications and edge devices, and enhancement of theoretical underpinnings.
Categories related to this article