Skip to main content

trustformers_training/continual/
progressive_networks.rs

1use anyhow::Result;
2use scirs2_core::ndarray::{Array1, Array2}; // SciRS2 Integration Policy
3use serde::{Deserialize, Serialize};
4use std::collections::HashMap;
5
6/// Configuration for Progressive Networks
7#[derive(Debug, Clone, Serialize, Deserialize)]
8pub struct ProgressiveConfig {
9    /// Number of layers per task column
10    pub layers_per_column: usize,
11    /// Hidden dimension for each layer
12    pub hidden_dim: usize,
13    /// Whether to use lateral connections
14    pub use_lateral_connections: bool,
15    /// Adapter dimension for lateral connections
16    pub adapter_dim: usize,
17    /// Learning rate for adapters
18    pub adapter_lr: f32,
19    /// Freeze previous columns during training
20    pub freeze_previous_columns: bool,
21    /// Maximum number of task columns
22    pub max_columns: usize,
23}
24
25impl Default for ProgressiveConfig {
26    fn default() -> Self {
27        Self {
28            layers_per_column: 3,
29            hidden_dim: 512,
30            use_lateral_connections: true,
31            adapter_dim: 64,
32            adapter_lr: 0.001,
33            freeze_previous_columns: true,
34            max_columns: 10,
35        }
36    }
37}
38
39/// Task-specific module in a progressive network
40#[derive(Debug, Clone)]
41pub struct TaskModule {
42    /// Task ID this module belongs to
43    pub task_id: String,
44    /// Column index in the progressive network
45    pub column_index: usize,
46    /// Layer weights for this task
47    pub layers: Vec<Layer>,
48    /// Lateral connections from previous tasks
49    pub lateral_connections: HashMap<usize, Vec<LateralAdapter>>,
50    /// Whether this module is frozen
51    pub frozen: bool,
52}
53
54impl TaskModule {
55    pub fn new(task_id: String, column_index: usize, config: &ProgressiveConfig) -> Self {
56        let mut layers = Vec::new();
57
58        // Create layers for this task column
59        for layer_idx in 0..config.layers_per_column {
60            let layer = Layer::new(
61                format!("{}_{}", task_id, layer_idx),
62                config.hidden_dim,
63                config.hidden_dim,
64            );
65            layers.push(layer);
66        }
67
68        Self {
69            task_id,
70            column_index,
71            layers,
72            lateral_connections: HashMap::new(),
73            frozen: false,
74        }
75    }
76
77    /// Add lateral connection from another task
78    pub fn add_lateral_connection(
79        &mut self,
80        source_column: usize,
81        layer_idx: usize,
82        config: &ProgressiveConfig,
83    ) -> Result<()> {
84        if layer_idx >= self.layers.len() {
85            return Err(anyhow::anyhow!("Layer index out of bounds"));
86        }
87
88        let adapter = LateralAdapter::new(config.hidden_dim, config.adapter_dim, config.hidden_dim);
89
90        self.lateral_connections.entry(source_column).or_default().push(adapter);
91
92        Ok(())
93    }
94
95    /// Forward pass through this task module
96    pub fn forward(&self, input: &Array1<f32>) -> Result<Array1<f32>> {
97        let mut output = input.clone();
98
99        for (layer_idx, layer) in self.layers.iter().enumerate() {
100            // Apply main layer transformation
101            output = layer.forward(&output)?;
102
103            // Add lateral connections from previous tasks
104            for adapters in self.lateral_connections.values() {
105                if layer_idx < adapters.len() {
106                    // This would typically receive activations from the corresponding layer
107                    // in the source column, but for simplicity we'll skip that here
108                    // In practice, this would be: output += adapter.forward(&source_activations)
109                }
110            }
111
112            // Apply activation function (ReLU)
113            output.mapv_inplace(|x| x.max(0.0));
114        }
115
116        Ok(output)
117    }
118
119    /// Freeze this module (prevent parameter updates)
120    pub fn freeze(&mut self) {
121        self.frozen = true;
122    }
123
124    /// Unfreeze this module
125    pub fn unfreeze(&mut self) {
126        self.frozen = false;
127    }
128
129    /// Get number of parameters in this module
130    pub fn num_parameters(&self) -> usize {
131        let layer_params: usize = self.layers.iter().map(|layer| layer.num_parameters()).sum();
132
133        let adapter_params: usize = self
134            .lateral_connections
135            .values()
136            .flatten()
137            .map(|adapter| adapter.num_parameters())
138            .sum();
139
140        layer_params + adapter_params
141    }
142}
143
144/// Neural network layer
145#[derive(Debug, Clone)]
146pub struct Layer {
147    pub name: String,
148    pub weights: Array2<f32>,
149    pub bias: Array1<f32>,
150    pub input_dim: usize,
151    pub output_dim: usize,
152}
153
154impl Layer {
155    pub fn new(name: String, input_dim: usize, output_dim: usize) -> Self {
156        // Initialize with small random weights
157        let weights = Array2::zeros((output_dim, input_dim));
158        let bias = Array1::zeros(output_dim);
159
160        Self {
161            name,
162            weights,
163            bias,
164            input_dim,
165            output_dim,
166        }
167    }
168
169    pub fn forward(&self, input: &Array1<f32>) -> Result<Array1<f32>> {
170        if input.len() != self.input_dim {
171            return Err(anyhow::anyhow!(
172                "Input dimension mismatch: expected {}, got {}",
173                self.input_dim,
174                input.len()
175            ));
176        }
177
178        let output = self.weights.dot(input) + &self.bias;
179        Ok(output)
180    }
181
182    pub fn num_parameters(&self) -> usize {
183        self.weights.len() + self.bias.len()
184    }
185}
186
187/// Lateral adapter for connections between task columns
188#[derive(Debug, Clone)]
189pub struct LateralAdapter {
190    pub down_projection: Array2<f32>,
191    pub up_projection: Array2<f32>,
192    pub input_dim: usize,
193    pub adapter_dim: usize,
194    pub output_dim: usize,
195}
196
197impl LateralAdapter {
198    pub fn new(input_dim: usize, adapter_dim: usize, output_dim: usize) -> Self {
199        let down_projection = Array2::zeros((adapter_dim, input_dim));
200        let up_projection = Array2::zeros((output_dim, adapter_dim));
201
202        Self {
203            down_projection,
204            up_projection,
205            input_dim,
206            adapter_dim,
207            output_dim,
208        }
209    }
210
211    pub fn forward(&self, input: &Array1<f32>) -> Result<Array1<f32>> {
212        if input.len() != self.input_dim {
213            return Err(anyhow::anyhow!("Input dimension mismatch"));
214        }
215
216        // Down-project, apply activation, then up-project
217        let hidden = self.down_projection.dot(input);
218        let activated = hidden.mapv(|x| x.max(0.0)); // ReLU
219        let output = self.up_projection.dot(&activated);
220
221        Ok(output)
222    }
223
224    pub fn num_parameters(&self) -> usize {
225        self.down_projection.len() + self.up_projection.len()
226    }
227}
228
229/// Progressive Network architecture
230#[derive(Debug)]
231pub struct ProgressiveNetwork {
232    config: ProgressiveConfig,
233    task_modules: HashMap<String, TaskModule>,
234    column_order: Vec<String>,
235    current_task: Option<String>,
236}
237
238impl ProgressiveNetwork {
239    pub fn new(config: ProgressiveConfig) -> Self {
240        Self {
241            config,
242            task_modules: HashMap::new(),
243            column_order: Vec::new(),
244            current_task: None,
245        }
246    }
247
248    /// Add a new task column to the network
249    pub fn add_task(&mut self, task_id: String) -> Result<()> {
250        if self.task_modules.contains_key(&task_id) {
251            return Err(anyhow::anyhow!("Task {} already exists", task_id));
252        }
253
254        if self.column_order.len() >= self.config.max_columns {
255            return Err(anyhow::anyhow!("Maximum number of columns reached"));
256        }
257
258        let column_index = self.column_order.len();
259        let mut task_module = TaskModule::new(task_id.clone(), column_index, &self.config);
260
261        // Add lateral connections from all previous tasks
262        if self.config.use_lateral_connections {
263            for prev_column in 0..column_index {
264                for layer_idx in 0..self.config.layers_per_column {
265                    task_module.add_lateral_connection(prev_column, layer_idx, &self.config)?;
266                }
267            }
268        }
269
270        self.task_modules.insert(task_id.clone(), task_module);
271        self.column_order.push(task_id.clone());
272
273        // Freeze previous columns if configured
274        if self.config.freeze_previous_columns {
275            self.freeze_previous_columns(&task_id);
276        }
277
278        self.current_task = Some(task_id);
279        Ok(())
280    }
281
282    /// Set current active task
283    pub fn set_current_task(&mut self, task_id: String) -> Result<()> {
284        if !self.task_modules.contains_key(&task_id) {
285            return Err(anyhow::anyhow!("Task {} not found", task_id));
286        }
287
288        self.current_task = Some(task_id);
289        Ok(())
290    }
291
292    /// Forward pass for the current task
293    pub fn forward(&self, input: &Array1<f32>) -> Result<Array1<f32>> {
294        let task_id = self
295            .current_task
296            .as_ref()
297            .ok_or_else(|| anyhow::anyhow!("No current task set"))?;
298
299        let task_module = self
300            .task_modules
301            .get(task_id)
302            .ok_or_else(|| anyhow::anyhow!("Task module not found"))?;
303
304        task_module.forward(input)
305    }
306
307    /// Forward pass for a specific task
308    pub fn forward_task(&self, task_id: &str, input: &Array1<f32>) -> Result<Array1<f32>> {
309        let task_module = self
310            .task_modules
311            .get(task_id)
312            .ok_or_else(|| anyhow::anyhow!("Task module not found: {}", task_id))?;
313
314        task_module.forward(input)
315    }
316
317    /// Freeze all columns except the current one
318    fn freeze_previous_columns(&mut self, current_task: &str) {
319        for (task_id, module) in &mut self.task_modules {
320            if task_id != current_task {
321                module.freeze();
322            }
323        }
324    }
325
326    /// Get network statistics
327    pub fn get_network_stats(&self) -> NetworkStats {
328        let total_params: usize =
329            self.task_modules.values().map(|module| module.num_parameters()).sum();
330
331        let frozen_modules: usize =
332            self.task_modules.values().filter(|module| module.frozen).count();
333
334        NetworkStats {
335            num_tasks: self.task_modules.len(),
336            total_parameters: total_params,
337            frozen_modules,
338            current_task: self.current_task.clone(),
339            column_order: self.column_order.clone(),
340        }
341    }
342
343    /// Remove a task from the network
344    pub fn remove_task(&mut self, task_id: &str) -> Result<()> {
345        if !self.task_modules.contains_key(task_id) {
346            return Err(anyhow::anyhow!("Task {} not found", task_id));
347        }
348
349        self.task_modules.remove(task_id);
350        self.column_order.retain(|id| id != task_id);
351
352        if self.current_task.as_ref() == Some(&task_id.to_string()) {
353            self.current_task = None;
354        }
355
356        Ok(())
357    }
358
359    /// Get task module for a specific task
360    pub fn get_task_module(&self, task_id: &str) -> Option<&TaskModule> {
361        self.task_modules.get(task_id)
362    }
363
364    /// Check if network has capacity for more tasks
365    pub fn has_capacity(&self) -> bool {
366        self.column_order.len() < self.config.max_columns
367    }
368
369    /// Get number of tasks
370    pub fn num_tasks(&self) -> usize {
371        self.task_modules.len()
372    }
373}
374
375/// Network statistics
376#[derive(Debug, Clone, Serialize, Deserialize)]
377pub struct NetworkStats {
378    pub num_tasks: usize,
379    pub total_parameters: usize,
380    pub frozen_modules: usize,
381    pub current_task: Option<String>,
382    pub column_order: Vec<String>,
383}
384
385/// Utility functions for progressive networks
386pub mod utils {
387    use super::*;
388
389    /// Compute lateral connection importance
390    pub fn compute_lateral_importance(
391        source_activations: &[Array1<f32>],
392        target_gradients: &[Array1<f32>],
393    ) -> f32 {
394        let mut importance = 0.0;
395
396        for (activation, gradient) in source_activations.iter().zip(target_gradients.iter()) {
397            importance += (activation * gradient).sum().abs();
398        }
399
400        importance / source_activations.len() as f32
401    }
402
403    /// Prune weak lateral connections
404    pub fn prune_lateral_connections(
405        _network: &mut ProgressiveNetwork,
406        _importance_threshold: f32,
407    ) -> Result<usize> {
408        let pruned_count = 0;
409
410        // This would require access to activation and gradient history
411        // For now, this is a placeholder implementation
412
413        Ok(pruned_count)
414    }
415}
416
417#[cfg(test)]
418mod tests {
419    use super::*;
420
421    #[test]
422    fn test_progressive_network_creation() {
423        let config = ProgressiveConfig::default();
424        let mut network = ProgressiveNetwork::new(config);
425
426        assert!(network.add_task("task1".to_string()).is_ok());
427        assert_eq!(network.num_tasks(), 1);
428        assert!(network.has_capacity());
429    }
430
431    #[test]
432    fn test_task_module_forward() {
433        let config = ProgressiveConfig {
434            layers_per_column: 2,
435            hidden_dim: 4,
436            ..Default::default()
437        };
438
439        let task_module = TaskModule::new("test_task".to_string(), 0, &config);
440        let input = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0]);
441
442        let result = task_module.forward(&input);
443        assert!(result.is_ok());
444
445        let output = result.expect("operation failed in test");
446        assert_eq!(output.len(), config.hidden_dim);
447    }
448
449    #[test]
450    fn test_lateral_adapter() {
451        let adapter = LateralAdapter::new(4, 2, 4);
452        let input = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0]);
453
454        let result = adapter.forward(&input);
455        assert!(result.is_ok());
456
457        let output = result.expect("operation failed in test");
458        assert_eq!(output.len(), 4);
459    }
460
461    #[test]
462    fn test_multiple_tasks() {
463        let config = ProgressiveConfig {
464            max_columns: 3,
465            ..Default::default()
466        };
467        let mut network = ProgressiveNetwork::new(config);
468
469        assert!(network.add_task("task1".to_string()).is_ok());
470        assert!(network.add_task("task2".to_string()).is_ok());
471        assert!(network.add_task("task3".to_string()).is_ok());
472
473        // Should fail when max columns reached
474        assert!(network.add_task("task4".to_string()).is_err());
475
476        let stats = network.get_network_stats();
477        assert_eq!(stats.num_tasks, 3);
478        assert_eq!(stats.column_order.len(), 3);
479    }
480
481    #[test]
482    fn test_network_forward() {
483        let config = ProgressiveConfig {
484            hidden_dim: 4,
485            ..Default::default()
486        };
487        let mut network = ProgressiveNetwork::new(config);
488
489        network.add_task("task1".to_string()).expect("add operation failed");
490        network.set_current_task("task1".to_string()).expect("operation failed in test");
491
492        let input = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0]);
493        let result = network.forward(&input);
494
495        assert!(result.is_ok());
496    }
497}