trustformers_training/continual/
progressive_networks.rs1use anyhow::Result;
2use scirs2_core::ndarray::{Array1, Array2}; use serde::{Deserialize, Serialize};
4use std::collections::HashMap;
5
6#[derive(Debug, Clone, Serialize, Deserialize)]
8pub struct ProgressiveConfig {
9 pub layers_per_column: usize,
11 pub hidden_dim: usize,
13 pub use_lateral_connections: bool,
15 pub adapter_dim: usize,
17 pub adapter_lr: f32,
19 pub freeze_previous_columns: bool,
21 pub max_columns: usize,
23}
24
25impl Default for ProgressiveConfig {
26 fn default() -> Self {
27 Self {
28 layers_per_column: 3,
29 hidden_dim: 512,
30 use_lateral_connections: true,
31 adapter_dim: 64,
32 adapter_lr: 0.001,
33 freeze_previous_columns: true,
34 max_columns: 10,
35 }
36 }
37}
38
39#[derive(Debug, Clone)]
41pub struct TaskModule {
42 pub task_id: String,
44 pub column_index: usize,
46 pub layers: Vec<Layer>,
48 pub lateral_connections: HashMap<usize, Vec<LateralAdapter>>,
50 pub frozen: bool,
52}
53
54impl TaskModule {
55 pub fn new(task_id: String, column_index: usize, config: &ProgressiveConfig) -> Self {
56 let mut layers = Vec::new();
57
58 for layer_idx in 0..config.layers_per_column {
60 let layer = Layer::new(
61 format!("{}_{}", task_id, layer_idx),
62 config.hidden_dim,
63 config.hidden_dim,
64 );
65 layers.push(layer);
66 }
67
68 Self {
69 task_id,
70 column_index,
71 layers,
72 lateral_connections: HashMap::new(),
73 frozen: false,
74 }
75 }
76
77 pub fn add_lateral_connection(
79 &mut self,
80 source_column: usize,
81 layer_idx: usize,
82 config: &ProgressiveConfig,
83 ) -> Result<()> {
84 if layer_idx >= self.layers.len() {
85 return Err(anyhow::anyhow!("Layer index out of bounds"));
86 }
87
88 let adapter = LateralAdapter::new(config.hidden_dim, config.adapter_dim, config.hidden_dim);
89
90 self.lateral_connections.entry(source_column).or_default().push(adapter);
91
92 Ok(())
93 }
94
95 pub fn forward(&self, input: &Array1<f32>) -> Result<Array1<f32>> {
97 let mut output = input.clone();
98
99 for (layer_idx, layer) in self.layers.iter().enumerate() {
100 output = layer.forward(&output)?;
102
103 for adapters in self.lateral_connections.values() {
105 if layer_idx < adapters.len() {
106 }
110 }
111
112 output.mapv_inplace(|x| x.max(0.0));
114 }
115
116 Ok(output)
117 }
118
119 pub fn freeze(&mut self) {
121 self.frozen = true;
122 }
123
124 pub fn unfreeze(&mut self) {
126 self.frozen = false;
127 }
128
129 pub fn num_parameters(&self) -> usize {
131 let layer_params: usize = self.layers.iter().map(|layer| layer.num_parameters()).sum();
132
133 let adapter_params: usize = self
134 .lateral_connections
135 .values()
136 .flatten()
137 .map(|adapter| adapter.num_parameters())
138 .sum();
139
140 layer_params + adapter_params
141 }
142}
143
144#[derive(Debug, Clone)]
146pub struct Layer {
147 pub name: String,
148 pub weights: Array2<f32>,
149 pub bias: Array1<f32>,
150 pub input_dim: usize,
151 pub output_dim: usize,
152}
153
154impl Layer {
155 pub fn new(name: String, input_dim: usize, output_dim: usize) -> Self {
156 let weights = Array2::zeros((output_dim, input_dim));
158 let bias = Array1::zeros(output_dim);
159
160 Self {
161 name,
162 weights,
163 bias,
164 input_dim,
165 output_dim,
166 }
167 }
168
169 pub fn forward(&self, input: &Array1<f32>) -> Result<Array1<f32>> {
170 if input.len() != self.input_dim {
171 return Err(anyhow::anyhow!(
172 "Input dimension mismatch: expected {}, got {}",
173 self.input_dim,
174 input.len()
175 ));
176 }
177
178 let output = self.weights.dot(input) + &self.bias;
179 Ok(output)
180 }
181
182 pub fn num_parameters(&self) -> usize {
183 self.weights.len() + self.bias.len()
184 }
185}
186
187#[derive(Debug, Clone)]
189pub struct LateralAdapter {
190 pub down_projection: Array2<f32>,
191 pub up_projection: Array2<f32>,
192 pub input_dim: usize,
193 pub adapter_dim: usize,
194 pub output_dim: usize,
195}
196
197impl LateralAdapter {
198 pub fn new(input_dim: usize, adapter_dim: usize, output_dim: usize) -> Self {
199 let down_projection = Array2::zeros((adapter_dim, input_dim));
200 let up_projection = Array2::zeros((output_dim, adapter_dim));
201
202 Self {
203 down_projection,
204 up_projection,
205 input_dim,
206 adapter_dim,
207 output_dim,
208 }
209 }
210
211 pub fn forward(&self, input: &Array1<f32>) -> Result<Array1<f32>> {
212 if input.len() != self.input_dim {
213 return Err(anyhow::anyhow!("Input dimension mismatch"));
214 }
215
216 let hidden = self.down_projection.dot(input);
218 let activated = hidden.mapv(|x| x.max(0.0)); let output = self.up_projection.dot(&activated);
220
221 Ok(output)
222 }
223
224 pub fn num_parameters(&self) -> usize {
225 self.down_projection.len() + self.up_projection.len()
226 }
227}
228
229#[derive(Debug)]
231pub struct ProgressiveNetwork {
232 config: ProgressiveConfig,
233 task_modules: HashMap<String, TaskModule>,
234 column_order: Vec<String>,
235 current_task: Option<String>,
236}
237
238impl ProgressiveNetwork {
239 pub fn new(config: ProgressiveConfig) -> Self {
240 Self {
241 config,
242 task_modules: HashMap::new(),
243 column_order: Vec::new(),
244 current_task: None,
245 }
246 }
247
248 pub fn add_task(&mut self, task_id: String) -> Result<()> {
250 if self.task_modules.contains_key(&task_id) {
251 return Err(anyhow::anyhow!("Task {} already exists", task_id));
252 }
253
254 if self.column_order.len() >= self.config.max_columns {
255 return Err(anyhow::anyhow!("Maximum number of columns reached"));
256 }
257
258 let column_index = self.column_order.len();
259 let mut task_module = TaskModule::new(task_id.clone(), column_index, &self.config);
260
261 if self.config.use_lateral_connections {
263 for prev_column in 0..column_index {
264 for layer_idx in 0..self.config.layers_per_column {
265 task_module.add_lateral_connection(prev_column, layer_idx, &self.config)?;
266 }
267 }
268 }
269
270 self.task_modules.insert(task_id.clone(), task_module);
271 self.column_order.push(task_id.clone());
272
273 if self.config.freeze_previous_columns {
275 self.freeze_previous_columns(&task_id);
276 }
277
278 self.current_task = Some(task_id);
279 Ok(())
280 }
281
282 pub fn set_current_task(&mut self, task_id: String) -> Result<()> {
284 if !self.task_modules.contains_key(&task_id) {
285 return Err(anyhow::anyhow!("Task {} not found", task_id));
286 }
287
288 self.current_task = Some(task_id);
289 Ok(())
290 }
291
292 pub fn forward(&self, input: &Array1<f32>) -> Result<Array1<f32>> {
294 let task_id = self
295 .current_task
296 .as_ref()
297 .ok_or_else(|| anyhow::anyhow!("No current task set"))?;
298
299 let task_module = self
300 .task_modules
301 .get(task_id)
302 .ok_or_else(|| anyhow::anyhow!("Task module not found"))?;
303
304 task_module.forward(input)
305 }
306
307 pub fn forward_task(&self, task_id: &str, input: &Array1<f32>) -> Result<Array1<f32>> {
309 let task_module = self
310 .task_modules
311 .get(task_id)
312 .ok_or_else(|| anyhow::anyhow!("Task module not found: {}", task_id))?;
313
314 task_module.forward(input)
315 }
316
317 fn freeze_previous_columns(&mut self, current_task: &str) {
319 for (task_id, module) in &mut self.task_modules {
320 if task_id != current_task {
321 module.freeze();
322 }
323 }
324 }
325
326 pub fn get_network_stats(&self) -> NetworkStats {
328 let total_params: usize =
329 self.task_modules.values().map(|module| module.num_parameters()).sum();
330
331 let frozen_modules: usize =
332 self.task_modules.values().filter(|module| module.frozen).count();
333
334 NetworkStats {
335 num_tasks: self.task_modules.len(),
336 total_parameters: total_params,
337 frozen_modules,
338 current_task: self.current_task.clone(),
339 column_order: self.column_order.clone(),
340 }
341 }
342
343 pub fn remove_task(&mut self, task_id: &str) -> Result<()> {
345 if !self.task_modules.contains_key(task_id) {
346 return Err(anyhow::anyhow!("Task {} not found", task_id));
347 }
348
349 self.task_modules.remove(task_id);
350 self.column_order.retain(|id| id != task_id);
351
352 if self.current_task.as_ref() == Some(&task_id.to_string()) {
353 self.current_task = None;
354 }
355
356 Ok(())
357 }
358
359 pub fn get_task_module(&self, task_id: &str) -> Option<&TaskModule> {
361 self.task_modules.get(task_id)
362 }
363
364 pub fn has_capacity(&self) -> bool {
366 self.column_order.len() < self.config.max_columns
367 }
368
369 pub fn num_tasks(&self) -> usize {
371 self.task_modules.len()
372 }
373}
374
375#[derive(Debug, Clone, Serialize, Deserialize)]
377pub struct NetworkStats {
378 pub num_tasks: usize,
379 pub total_parameters: usize,
380 pub frozen_modules: usize,
381 pub current_task: Option<String>,
382 pub column_order: Vec<String>,
383}
384
385pub mod utils {
387 use super::*;
388
389 pub fn compute_lateral_importance(
391 source_activations: &[Array1<f32>],
392 target_gradients: &[Array1<f32>],
393 ) -> f32 {
394 let mut importance = 0.0;
395
396 for (activation, gradient) in source_activations.iter().zip(target_gradients.iter()) {
397 importance += (activation * gradient).sum().abs();
398 }
399
400 importance / source_activations.len() as f32
401 }
402
403 pub fn prune_lateral_connections(
405 _network: &mut ProgressiveNetwork,
406 _importance_threshold: f32,
407 ) -> Result<usize> {
408 let pruned_count = 0;
409
410 Ok(pruned_count)
414 }
415}
416
417#[cfg(test)]
418mod tests {
419 use super::*;
420
421 #[test]
422 fn test_progressive_network_creation() {
423 let config = ProgressiveConfig::default();
424 let mut network = ProgressiveNetwork::new(config);
425
426 assert!(network.add_task("task1".to_string()).is_ok());
427 assert_eq!(network.num_tasks(), 1);
428 assert!(network.has_capacity());
429 }
430
431 #[test]
432 fn test_task_module_forward() {
433 let config = ProgressiveConfig {
434 layers_per_column: 2,
435 hidden_dim: 4,
436 ..Default::default()
437 };
438
439 let task_module = TaskModule::new("test_task".to_string(), 0, &config);
440 let input = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0]);
441
442 let result = task_module.forward(&input);
443 assert!(result.is_ok());
444
445 let output = result.expect("operation failed in test");
446 assert_eq!(output.len(), config.hidden_dim);
447 }
448
449 #[test]
450 fn test_lateral_adapter() {
451 let adapter = LateralAdapter::new(4, 2, 4);
452 let input = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0]);
453
454 let result = adapter.forward(&input);
455 assert!(result.is_ok());
456
457 let output = result.expect("operation failed in test");
458 assert_eq!(output.len(), 4);
459 }
460
461 #[test]
462 fn test_multiple_tasks() {
463 let config = ProgressiveConfig {
464 max_columns: 3,
465 ..Default::default()
466 };
467 let mut network = ProgressiveNetwork::new(config);
468
469 assert!(network.add_task("task1".to_string()).is_ok());
470 assert!(network.add_task("task2".to_string()).is_ok());
471 assert!(network.add_task("task3".to_string()).is_ok());
472
473 assert!(network.add_task("task4".to_string()).is_err());
475
476 let stats = network.get_network_stats();
477 assert_eq!(stats.num_tasks, 3);
478 assert_eq!(stats.column_order.len(), 3);
479 }
480
481 #[test]
482 fn test_network_forward() {
483 let config = ProgressiveConfig {
484 hidden_dim: 4,
485 ..Default::default()
486 };
487 let mut network = ProgressiveNetwork::new(config);
488
489 network.add_task("task1".to_string()).expect("add operation failed");
490 network.set_current_task("task1".to_string()).expect("operation failed in test");
491
492 let input = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0]);
493 let result = network.forward(&input);
494
495 assert!(result.is_ok());
496 }
497}