Skip to main content

trustformers_training/continual/
mod.rs

1pub mod catastrophic_prevention;
2pub mod ewc;
3pub mod memory_replay;
4pub mod progressive_networks;
5pub mod task_boundary;
6
7pub use catastrophic_prevention::{CatastrophicPreventionStrategy, RegularizationMethod};
8pub use ewc::{EWCConfig, EWCTrainer, FisherInformation};
9pub use memory_replay::{ExperienceBuffer, MemoryReplay, MemoryReplayConfig};
10pub use progressive_networks::{ProgressiveConfig, ProgressiveNetwork, TaskModule};
11pub use task_boundary::{BoundaryDetectionConfig, TaskBoundaryDetector, TaskTransition};
12
13use serde::{Deserialize, Serialize};
14use std::collections::HashMap;
15
16/// Configuration for continual learning
17#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct ContinualLearningConfig {
19    /// Method for preventing catastrophic forgetting
20    pub prevention_method: CatastrophicPreventionStrategy,
21    /// Task boundary detection configuration
22    pub boundary_detection: BoundaryDetectionConfig,
23    /// Memory replay configuration
24    pub memory_replay: Option<MemoryReplayConfig>,
25    /// EWC configuration
26    pub ewc: Option<EWCConfig>,
27    /// Progressive networks configuration
28    pub progressive: Option<ProgressiveConfig>,
29    /// Maximum number of tasks to remember
30    pub max_tasks: usize,
31    /// Whether to use online or offline learning
32    pub online_learning: bool,
33}
34
35impl Default for ContinualLearningConfig {
36    fn default() -> Self {
37        Self {
38            prevention_method: CatastrophicPreventionStrategy::EWC,
39            boundary_detection: BoundaryDetectionConfig::default(),
40            memory_replay: None,
41            ewc: Some(EWCConfig::default()),
42            progressive: None,
43            max_tasks: 10,
44            online_learning: true,
45        }
46    }
47}
48
49/// Task information for continual learning
50#[derive(Debug, Clone, Serialize, Deserialize)]
51pub struct TaskInfo {
52    pub task_id: String,
53    pub name: String,
54    pub description: Option<String>,
55    pub data_size: usize,
56    pub num_classes: Option<usize>,
57    pub created_at: chrono::DateTime<chrono::Utc>,
58}
59
60/// Continual learning manager
61pub struct ContinualLearningManager {
62    config: ContinualLearningConfig,
63    tasks: Vec<TaskInfo>,
64    current_task: Option<String>,
65    task_transitions: Vec<TaskTransition>,
66    #[allow(dead_code)]
67    prevention_strategies: HashMap<String, Box<dyn RegularizationMethod>>,
68}
69
70impl ContinualLearningManager {
71    pub fn new(config: ContinualLearningConfig) -> Self {
72        Self {
73            config,
74            tasks: Vec::new(),
75            current_task: None,
76            task_transitions: Vec::new(),
77            prevention_strategies: HashMap::new(),
78        }
79    }
80
81    pub fn add_task(&mut self, task: TaskInfo) -> anyhow::Result<()> {
82        if self.tasks.len() >= self.config.max_tasks {
83            return Err(anyhow::anyhow!("Maximum number of tasks reached"));
84        }
85
86        self.tasks.push(task);
87        Ok(())
88    }
89
90    pub fn set_current_task(&mut self, task_id: String) -> anyhow::Result<()> {
91        if !self.tasks.iter().any(|t| t.task_id == task_id) {
92            return Err(anyhow::anyhow!("Task not found: {}", task_id));
93        }
94
95        if let Some(prev_task) = &self.current_task {
96            let transition = TaskTransition {
97                from_task: prev_task.clone(),
98                to_task: task_id.clone(),
99                timestamp: chrono::Utc::now(),
100                boundary_score: 1.0, // This would be computed by boundary detector
101            };
102            self.task_transitions.push(transition);
103        }
104
105        self.current_task = Some(task_id);
106        Ok(())
107    }
108
109    pub fn get_current_task(&self) -> Option<&TaskInfo> {
110        self.current_task
111            .as_ref()
112            .and_then(|id| self.tasks.iter().find(|t| &t.task_id == id))
113    }
114
115    pub fn get_task_count(&self) -> usize {
116        self.tasks.len()
117    }
118
119    pub fn get_task_transitions(&self) -> &[TaskTransition] {
120        &self.task_transitions
121    }
122}
123
124#[cfg(test)]
125mod tests {
126    use super::*;
127
128    #[test]
129    fn test_continual_learning_manager() {
130        let config = ContinualLearningConfig::default();
131        let mut manager = ContinualLearningManager::new(config);
132
133        let task1 = TaskInfo {
134            task_id: "task1".to_string(),
135            name: "Classification Task 1".to_string(),
136            description: Some("First classification task".to_string()),
137            data_size: 1000,
138            num_classes: Some(10),
139            created_at: chrono::Utc::now(),
140        };
141
142        manager.add_task(task1).expect("add operation failed");
143        assert_eq!(manager.get_task_count(), 1);
144
145        manager.set_current_task("task1".to_string()).expect("operation failed in test");
146        assert!(manager.get_current_task().is_some());
147    }
148
149    #[test]
150    fn test_max_tasks_limit() {
151        let config = ContinualLearningConfig {
152            max_tasks: 2,
153            ..ContinualLearningConfig::default()
154        };
155        let mut manager = ContinualLearningManager::new(config);
156
157        for i in 0..3 {
158            let task = TaskInfo {
159                task_id: format!("task{}", i),
160                name: format!("Task {}", i),
161                description: None,
162                data_size: 100,
163                num_classes: Some(5),
164                created_at: chrono::Utc::now(),
165            };
166
167            if i < 2 {
168                assert!(manager.add_task(task).is_ok());
169            } else {
170                assert!(manager.add_task(task).is_err());
171            }
172        }
173    }
174}