A robust framework for training neural networks that can learn sequentially without forgetting
Overview โข Key Features โข Techniques โข Installation โข Usage โข Results โข Contributing
The Continual Learning System is a comprehensive framework for developing neural networks that can learn tasks sequentially without suffering from catastrophic forgetting. This project implements several state-of-the-art techniques to mitigate forgetting in neural networks, allowing them to adapt to new tasks while retaining performance on previously learned ones.
- Task Sequential Learning: Train models on a sequence of tasks without complete retraining
- Forgetting Mitigation: Advanced techniques to prevent catastrophic forgetting
- Performance Tracking: Comprehensive metrics to monitor how well knowledge is retained
- Experiment Framework: Easily run and compare different continual learning approaches
- Visualization Tools: Track and visualize forgetting metrics across sequential tasks
EWC measures the importance of neural network weights for previously learned tasks and penalizes changes to important weights when learning new tasks.
# Loss calculation with EWC
loss = task_loss + lambda_ewc * ewc_lossThis technique maintains a memory buffer of examples from previous tasks and periodically replays them during training on new tasks.
# Replay during training
combined_loss = current_task_loss + alpha * replay_lossLwF uses knowledge distillation to preserve the model's behavior on previous tasks when learning new ones.
# LwF distillation loss
distillation_loss = KL_divergence(current_outputs, previous_outputs)For some approaches, we isolate or add task-specific parameters while sharing a common feature extraction backbone.
# Clone the repository
git clone https://github.com/1Utkarsh1/continual-learning.git
cd continual-learning
# Create a virtual environment
python -m venv venv
source venv/bin/activate # On Windows: venv\Scripts\activate
# Install dependencies
pip install -r requirements.txt# Run baseline experiment (sequential training without any continual learning techniques)
python src/main.py --method baseline --tasks mnist_split
# Run EWC experiment
python src/main.py --method ewc --tasks mnist_split --lambda_ewc 5000
# Run Experience Replay experiment
python src/main.py --method replay --tasks mnist_split --buffer_size 500You can define your own task sequences in a YAML configuration file:
# config/tasks/custom_sequence.yaml
task_sequence:
- name: "mnist_digits_0_4"
dataset: "mnist"
classes: [0, 1, 2, 3, 4]
- name: "mnist_digits_5_9"
dataset: "mnist"
classes: [5, 6, 7, 8, 9]
- name: "fashion_mnist"
dataset: "fashion_mnist"
classes: "all"| Method | Average Accuracy | Average Forgetting | Training Time |
|---|---|---|---|
| Naรฏve Fine-tuning | 45.2% | 35.8% | 1.0x |
| EWC | 78.5% | 10.2% | 1.2x |
| Experience Replay | 82.3% | 7.5% | 1.5x |
| LwF | 75.7% | 12.8% | 1.3x |
The following experiment results were obtained on March 11, 2025 using the MNIST split task sequence:
Baseline (Naรฏve Fine-tuning):
- Command:
python src/main.py --method baseline --tasks mnist_split --epochs 5 - Task sequence: ['mnist_0_4', 'mnist_5_9']
- Average final accuracy: 49.74%
- Average forgetting: 49.90%
Learning without Forgetting (LwF):
- Command:
python src/main.py --method lwf --tasks mnist_split --epochs 5 - Task sequence: ['mnist_0_4', 'mnist_5_9']
- Average final accuracy: 49.67%
- Average forgetting: 49.83%
continual_learning/
โโโ src/ # Source code
โ โโโ models/ # Neural network architectures
โ โโโ data/ # Data loading and preprocessing
โ โโโ methods/ # Continual learning algorithms
โ โโโ utils/ # Utility functions
โ โโโ main.py # Main entry point
โโโ experiments/ # Jupyter notebooks for experiments
โโโ config/ # Configuration files
โ โโโ models/ # Model configurations
โ โโโ tasks/ # Task sequence definitions
โโโ results/ # Saved results and visualizations
โโโ docs/ # Documentation
-
Split MNIST
- Train on digits 0-4, then 5-9
- Compare different methods' ability to remember the first task
-
Task Incremental Learning
- Train on MNIST โ Fashion-MNIST โ KMNIST
- Measure accuracy on all previous datasets after each task
-
Class Incremental Learning
- Add new classes (one at a time) to a classifier
- Test identification of all classes after each addition
- Implement baseline sequential training
- Implement Elastic Weight Consolidation (EWC)
- Implement Experience Replay
- Implement Learning without Forgetting (LwF)
- Add support for generative replay
- Implement parameter isolation methods
- Add support for continual reinforcement learning
- Develop benchmark suite for comparing methods
Contributions are welcome! Please feel free to submit a Pull Request.
- Fork the repository
- Create your feature branch (
git checkout -b feature/amazing-feature) - Commit your changes (
git commit -m 'Add some amazing feature') - Push to the branch (
git push origin feature/amazing-feature) - Open a Pull Request
- Kirkpatrick, J. et al. "Overcoming catastrophic forgetting in neural networks" - Proceedings of the National Academy of Sciences (2017)
- Rebuffi, S. et al. "iCaRL: Incremental Classifier and Representation Learning" - CVPR (2017)
- Li, Z. and Hoiem, D. "Learning without Forgetting" - IEEE Transactions on Pattern Analysis and Machine Intelligence (2018)
- Chaudhry, A. et al. "Efficient Lifelong Learning with A-GEM" - ICLR (2019)
This project is licensed under the MIT License - see the LICENSE file for details.