trustformers_training/continual/
mod.rs1pub mod catastrophic_prevention;
2pub mod ewc;
3pub mod memory_replay;
4pub mod progressive_networks;
5pub mod task_boundary;
6
7pub use catastrophic_prevention::{CatastrophicPreventionStrategy, RegularizationMethod};
8pub use ewc::{EWCConfig, EWCTrainer, FisherInformation};
9pub use memory_replay::{ExperienceBuffer, MemoryReplay, MemoryReplayConfig};
10pub use progressive_networks::{ProgressiveConfig, ProgressiveNetwork, TaskModule};
11pub use task_boundary::{BoundaryDetectionConfig, TaskBoundaryDetector, TaskTransition};
12
13use serde::{Deserialize, Serialize};
14use std::collections::HashMap;
15
16#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct ContinualLearningConfig {
19 pub prevention_method: CatastrophicPreventionStrategy,
21 pub boundary_detection: BoundaryDetectionConfig,
23 pub memory_replay: Option<MemoryReplayConfig>,
25 pub ewc: Option<EWCConfig>,
27 pub progressive: Option<ProgressiveConfig>,
29 pub max_tasks: usize,
31 pub online_learning: bool,
33}
34
35impl Default for ContinualLearningConfig {
36 fn default() -> Self {
37 Self {
38 prevention_method: CatastrophicPreventionStrategy::EWC,
39 boundary_detection: BoundaryDetectionConfig::default(),
40 memory_replay: None,
41 ewc: Some(EWCConfig::default()),
42 progressive: None,
43 max_tasks: 10,
44 online_learning: true,
45 }
46 }
47}
48
49#[derive(Debug, Clone, Serialize, Deserialize)]
51pub struct TaskInfo {
52 pub task_id: String,
53 pub name: String,
54 pub description: Option<String>,
55 pub data_size: usize,
56 pub num_classes: Option<usize>,
57 pub created_at: chrono::DateTime<chrono::Utc>,
58}
59
60pub struct ContinualLearningManager {
62 config: ContinualLearningConfig,
63 tasks: Vec<TaskInfo>,
64 current_task: Option<String>,
65 task_transitions: Vec<TaskTransition>,
66 #[allow(dead_code)]
67 prevention_strategies: HashMap<String, Box<dyn RegularizationMethod>>,
68}
69
70impl ContinualLearningManager {
71 pub fn new(config: ContinualLearningConfig) -> Self {
72 Self {
73 config,
74 tasks: Vec::new(),
75 current_task: None,
76 task_transitions: Vec::new(),
77 prevention_strategies: HashMap::new(),
78 }
79 }
80
81 pub fn add_task(&mut self, task: TaskInfo) -> anyhow::Result<()> {
82 if self.tasks.len() >= self.config.max_tasks {
83 return Err(anyhow::anyhow!("Maximum number of tasks reached"));
84 }
85
86 self.tasks.push(task);
87 Ok(())
88 }
89
90 pub fn set_current_task(&mut self, task_id: String) -> anyhow::Result<()> {
91 if !self.tasks.iter().any(|t| t.task_id == task_id) {
92 return Err(anyhow::anyhow!("Task not found: {}", task_id));
93 }
94
95 if let Some(prev_task) = &self.current_task {
96 let transition = TaskTransition {
97 from_task: prev_task.clone(),
98 to_task: task_id.clone(),
99 timestamp: chrono::Utc::now(),
100 boundary_score: 1.0, };
102 self.task_transitions.push(transition);
103 }
104
105 self.current_task = Some(task_id);
106 Ok(())
107 }
108
109 pub fn get_current_task(&self) -> Option<&TaskInfo> {
110 self.current_task
111 .as_ref()
112 .and_then(|id| self.tasks.iter().find(|t| &t.task_id == id))
113 }
114
115 pub fn get_task_count(&self) -> usize {
116 self.tasks.len()
117 }
118
119 pub fn get_task_transitions(&self) -> &[TaskTransition] {
120 &self.task_transitions
121 }
122}
123
124#[cfg(test)]
125mod tests {
126 use super::*;
127
128 #[test]
129 fn test_continual_learning_manager() {
130 let config = ContinualLearningConfig::default();
131 let mut manager = ContinualLearningManager::new(config);
132
133 let task1 = TaskInfo {
134 task_id: "task1".to_string(),
135 name: "Classification Task 1".to_string(),
136 description: Some("First classification task".to_string()),
137 data_size: 1000,
138 num_classes: Some(10),
139 created_at: chrono::Utc::now(),
140 };
141
142 manager.add_task(task1).expect("add operation failed");
143 assert_eq!(manager.get_task_count(), 1);
144
145 manager.set_current_task("task1".to_string()).expect("operation failed in test");
146 assert!(manager.get_current_task().is_some());
147 }
148
149 #[test]
150 fn test_max_tasks_limit() {
151 let config = ContinualLearningConfig {
152 max_tasks: 2,
153 ..ContinualLearningConfig::default()
154 };
155 let mut manager = ContinualLearningManager::new(config);
156
157 for i in 0..3 {
158 let task = TaskInfo {
159 task_id: format!("task{}", i),
160 name: format!("Task {}", i),
161 description: None,
162 data_size: 100,
163 num_classes: Some(5),
164 created_at: chrono::Utc::now(),
165 };
166
167 if i < 2 {
168 assert!(manager.add_task(task).is_ok());
169 } else {
170 assert!(manager.add_task(task).is_err());
171 }
172 }
173 }
174}