Skip to main content

trustformers_optim/
continual_learning.rs

1//! # Continual Learning Optimizers
2//!
3//! This module implements optimization algorithms specifically designed for
4//! continual learning scenarios where models must learn new tasks while
5//! retaining knowledge of previous tasks.
6//!
7//! ## Available Methods
8//!
9//! - **EWC (Elastic Weight Consolidation)**: Protects important weights from changes
10//! - **PackNet**: Progressive networks with parameter allocation
11//! - **L2 Regularization**: Simple regularization towards previous weights
12//! - **Memory Replay**: Gradient-based memory replay optimization
13//! - **Meta-Learning**: Model-Agnostic Meta-Learning for continual adaptation
14
15use anyhow::{anyhow, Result};
16use serde::{Deserialize, Serialize};
17use std::collections::HashMap;
18use trustformers_core::tensor::Tensor;
19
20/// Configuration for Elastic Weight Consolidation (EWC).
21#[derive(Debug, Clone, Serialize, Deserialize)]
22pub struct EWCConfig {
23    /// Base learning rate
24    pub learning_rate: f32,
25    /// Importance weight for Fisher information regularization
26    pub lambda: f32,
27    /// Method for computing Fisher information
28    pub fisher_method: FisherMethod,
29    /// Number of samples for Fisher information estimation
30    pub fisher_samples: usize,
31    /// Online vs offline EWC
32    pub online: bool,
33    /// Decay factor for online EWC
34    pub decay_factor: f32,
35}
36
37impl Default for EWCConfig {
38    fn default() -> Self {
39        Self {
40            learning_rate: 1e-3,
41            lambda: 1000.0,
42            fisher_method: FisherMethod::Empirical,
43            fisher_samples: 1000,
44            online: false,
45            decay_factor: 0.9,
46        }
47    }
48}
49
50/// Methods for computing Fisher information matrix.
51#[derive(Debug, Clone, Serialize, Deserialize)]
52pub enum FisherMethod {
53    /// Empirical Fisher information
54    Empirical,
55    /// True Fisher information (computationally expensive)
56    True,
57    /// Diagonal approximation
58    Diagonal,
59}
60
61/// Configuration for Progressive Networks (PackNet).
62#[derive(Debug, Clone, Serialize, Deserialize)]
63pub struct PackNetConfig {
64    /// Base learning rate
65    pub learning_rate: f32,
66    /// Sparsity level for each task
67    pub sparsity_level: f32,
68    /// Number of tasks
69    pub num_tasks: usize,
70    /// Parameter allocation strategy
71    pub allocation_strategy: AllocationStrategy,
72}
73
74impl Default for PackNetConfig {
75    fn default() -> Self {
76        Self {
77            learning_rate: 1e-3,
78            sparsity_level: 0.5,
79            num_tasks: 10,
80            allocation_strategy: AllocationStrategy::Sequential,
81        }
82    }
83}
84
85/// Parameter allocation strategies for PackNet.
86#[derive(Debug, Clone, Serialize, Deserialize)]
87pub enum AllocationStrategy {
88    /// Sequential allocation
89    Sequential,
90    /// Random allocation
91    Random,
92    /// Importance-based allocation
93    ImportanceBased,
94}
95
96/// Configuration for L2 regularization towards previous parameters.
97#[derive(Debug, Clone, Serialize, Deserialize)]
98pub struct L2RegularizationConfig {
99    /// Base learning rate
100    pub learning_rate: f32,
101    /// Regularization strength
102    pub reg_strength: f32,
103    /// Update strategy for anchor parameters
104    pub update_strategy: UpdateStrategy,
105}
106
107impl Default for L2RegularizationConfig {
108    fn default() -> Self {
109        Self {
110            learning_rate: 1e-3,
111            reg_strength: 0.1,
112            update_strategy: UpdateStrategy::EMA,
113        }
114    }
115}
116
117/// Strategies for updating anchor parameters.
118#[derive(Debug, Clone, Serialize, Deserialize)]
119pub enum UpdateStrategy {
120    /// No update (fixed anchors)
121    Fixed,
122    /// Exponential moving average
123    EMA,
124    /// Update at task boundaries
125    TaskBoundary,
126}
127
128/// Configuration for Memory Replay optimization.
129#[derive(Debug, Clone, Serialize, Deserialize)]
130pub struct MemoryReplayConfig {
131    /// Base learning rate
132    pub learning_rate: f32,
133    /// Memory buffer size
134    pub memory_size: usize,
135    /// Replay frequency (every N steps)
136    pub replay_frequency: usize,
137    /// Replay batch size
138    pub replay_batch_size: usize,
139    /// Memory selection strategy
140    pub selection_strategy: MemorySelectionStrategy,
141}
142
143impl Default for MemoryReplayConfig {
144    fn default() -> Self {
145        Self {
146            learning_rate: 1e-3,
147            memory_size: 1000,
148            replay_frequency: 10,
149            replay_batch_size: 32,
150            selection_strategy: MemorySelectionStrategy::Random,
151        }
152    }
153}
154
155/// Memory selection strategies for replay.
156#[derive(Debug, Clone, Serialize, Deserialize)]
157pub enum MemorySelectionStrategy {
158    /// Random selection
159    Random,
160    /// Gradient-based selection
161    GradientBased,
162    /// Uncertainty-based selection
163    UncertaintyBased,
164}
165
166/// Elastic Weight Consolidation optimizer.
167pub struct EWC {
168    config: EWCConfig,
169    parameters: Vec<Tensor>,
170    importance_weights: Vec<Tensor>,
171    anchor_parameters: Vec<Tensor>,
172    current_task: usize,
173    accumulated_importance: Vec<Tensor>,
174}
175
176impl EWC {
177    /// Create a new EWC optimizer.
178    pub fn new(config: EWCConfig, initial_parameters: Vec<Tensor>) -> Result<Self> {
179        let param_count = initial_parameters.len();
180
181        Ok(Self {
182            config,
183            parameters: initial_parameters.clone(),
184            importance_weights: (0..param_count)
185                .map(|i| Tensor::zeros(&initial_parameters[i].shape()).unwrap())
186                .collect(),
187            anchor_parameters: initial_parameters.clone(),
188            current_task: 0,
189            accumulated_importance: (0..param_count)
190                .map(|i| Tensor::zeros(&initial_parameters[i].shape()).unwrap())
191                .collect(),
192        })
193    }
194
195    /// Compute Fisher information matrix for current task.
196    pub fn compute_fisher_information(&mut self, gradients_samples: &[Vec<Tensor>]) -> Result<()> {
197        let num_samples = gradients_samples.len();
198        if num_samples == 0 {
199            return Err(anyhow!("No gradient samples provided"));
200        }
201
202        // Reset importance weights for new task
203        for importance in self.importance_weights.iter_mut() {
204            *importance = Tensor::zeros(&importance.shape())?;
205        }
206
207        // Compute empirical Fisher information
208        for gradient_sample in gradients_samples {
209            for (i, gradient) in gradient_sample.iter().enumerate() {
210                if i < self.importance_weights.len() {
211                    let squared_grad = gradient.mul(gradient)?;
212                    self.importance_weights[i] = self.importance_weights[i].add(&squared_grad)?;
213                }
214            }
215        }
216
217        // Average over samples
218        for importance in self.importance_weights.iter_mut() {
219            *importance = importance.div_scalar(num_samples as f32)?;
220        }
221
222        // Update accumulated importance for online EWC
223        if self.config.online {
224            for i in 0..self.accumulated_importance.len() {
225                let decayed =
226                    self.accumulated_importance[i].mul_scalar(self.config.decay_factor)?;
227                self.accumulated_importance[i] = decayed.add(&self.importance_weights[i])?;
228            }
229        }
230
231        Ok(())
232    }
233
234    /// Complete current task and prepare for next task.
235    pub fn finish_task(&mut self) -> Result<()> {
236        // Update anchor parameters to current parameters
237        self.anchor_parameters = self.parameters.clone();
238        self.current_task += 1;
239        Ok(())
240    }
241
242    /// Perform optimization step with EWC regularization.
243    pub fn step(&mut self, gradients: &[Tensor]) -> Result<()> {
244        for (i, gradient) in gradients.iter().enumerate() {
245            if i < self.parameters.len() {
246                // Compute EWC penalty gradient
247                let param_diff = self.parameters[i].sub(&self.anchor_parameters[i])?;
248                let importance = if self.config.online {
249                    &self.accumulated_importance[i]
250                } else {
251                    &self.importance_weights[i]
252                };
253                let ewc_grad = param_diff.mul(importance)?.mul_scalar(self.config.lambda)?;
254
255                // Combine original gradient with EWC penalty
256                let total_grad = gradient.add(&ewc_grad)?;
257
258                // Apply update
259                let update = total_grad.mul_scalar(self.config.learning_rate)?;
260                self.parameters[i] = self.parameters[i].sub(&update)?;
261            }
262        }
263        Ok(())
264    }
265
266    /// Get current parameters.
267    pub fn get_parameters(&self) -> &[Tensor] {
268        &self.parameters
269    }
270
271    /// Get importance weights.
272    pub fn get_importance_weights(&self) -> &[Tensor] {
273        &self.importance_weights
274    }
275}
276
277/// Progressive Networks (PackNet) optimizer.
278pub struct PackNet {
279    config: PackNetConfig,
280    parameters: Vec<Tensor>,
281    #[allow(dead_code)]
282    parameter_masks: Vec<Tensor>,
283    task_allocations: HashMap<usize, Vec<Tensor>>,
284    current_task: usize,
285    available_capacity: Vec<f32>,
286}
287
288impl PackNet {
289    /// Create a new PackNet optimizer.
290    pub fn new(config: PackNetConfig, initial_parameters: Vec<Tensor>) -> Result<Self> {
291        let param_count = initial_parameters.len();
292
293        Ok(Self {
294            config,
295            parameters: initial_parameters.clone(),
296            parameter_masks: (0..param_count)
297                .map(|i| Tensor::ones(&initial_parameters[i].shape()).unwrap())
298                .collect(),
299            task_allocations: HashMap::new(),
300            current_task: 0,
301            available_capacity: vec![1.0; param_count],
302        })
303    }
304
305    /// Allocate parameters for a new task.
306    pub fn allocate_task(&mut self, task_id: usize) -> Result<()> {
307        if self.available_capacity.iter().any(|&cap| cap < self.config.sparsity_level) {
308            return Err(anyhow!("Insufficient parameter capacity for new task"));
309        }
310
311        let mut task_masks = Vec::new();
312
313        for (i, param) in self.parameters.iter().enumerate() {
314            let shape = param.shape();
315            let total_params = shape.iter().product::<usize>();
316            let allocated_params = (total_params as f32 * self.config.sparsity_level) as usize;
317
318            // Create allocation mask
319            let mut mask_data = vec![0.0; total_params];
320
321            match self.config.allocation_strategy {
322                AllocationStrategy::Sequential => {
323                    let start_idx =
324                        ((1.0 - self.available_capacity[i]) * total_params as f32) as usize;
325                    let end_idx = (start_idx + allocated_params).min(total_params);
326                    for idx in start_idx..end_idx {
327                        mask_data[idx] = 1.0;
328                    }
329                },
330                AllocationStrategy::Random => {
331                    use scirs2_core::random::*; // SciRS2 Integration Policy
332                    let mut indices: Vec<usize> = (0..total_params).collect();
333                    let mut rng = thread_rng();
334                    indices.shuffle(rng.rng_mut());
335                    for &idx in indices.iter().take(allocated_params) {
336                        mask_data[idx] = 1.0;
337                    }
338                },
339                AllocationStrategy::ImportanceBased => {
340                    // Simplified importance-based allocation
341                    // In practice, this would use gradient magnitudes or other importance metrics
342                    for idx in 0..allocated_params.min(total_params) {
343                        mask_data[idx] = 1.0;
344                    }
345                },
346            }
347
348            let task_mask = Tensor::new(mask_data)?;
349            task_masks.push(task_mask);
350
351            // Update available capacity
352            self.available_capacity[i] -= self.config.sparsity_level;
353        }
354
355        self.task_allocations.insert(task_id, task_masks);
356        self.current_task = task_id;
357        Ok(())
358    }
359
360    /// Perform optimization step with parameter masking.
361    pub fn step(&mut self, gradients: &[Tensor]) -> Result<()> {
362        let task_masks = self
363            .task_allocations
364            .get(&self.current_task)
365            .ok_or_else(|| anyhow!("No allocation for current task"))?;
366
367        for (i, gradient) in gradients.iter().enumerate() {
368            if i < self.parameters.len() && i < task_masks.len() {
369                // Apply task-specific mask to gradient
370                let masked_grad = gradient.mul(&task_masks[i])?;
371
372                // Apply update
373                let update = masked_grad.mul_scalar(self.config.learning_rate)?;
374                self.parameters[i] = self.parameters[i].sub(&update)?;
375            }
376        }
377        Ok(())
378    }
379
380    /// Get current parameters.
381    pub fn get_parameters(&self) -> &[Tensor] {
382        &self.parameters
383    }
384
385    /// Get available capacity for new tasks.
386    pub fn get_available_capacity(&self) -> &[f32] {
387        &self.available_capacity
388    }
389}
390
391/// L2 Regularization optimizer for continual learning.
392pub struct L2Regularization {
393    config: L2RegularizationConfig,
394    parameters: Vec<Tensor>,
395    anchor_parameters: Vec<Tensor>,
396    ema_decay: f32,
397}
398
399impl L2Regularization {
400    /// Create a new L2 regularization optimizer.
401    pub fn new(config: L2RegularizationConfig, initial_parameters: Vec<Tensor>) -> Self {
402        Self {
403            config,
404            parameters: initial_parameters.clone(),
405            anchor_parameters: initial_parameters,
406            ema_decay: 0.999,
407        }
408    }
409
410    /// Perform optimization step with L2 regularization.
411    pub fn step(&mut self, gradients: &[Tensor]) -> Result<()> {
412        for (i, gradient) in gradients.iter().enumerate() {
413            if i < self.parameters.len() {
414                // Compute L2 regularization term
415                let param_diff = self.parameters[i].sub(&self.anchor_parameters[i])?;
416                let reg_grad = param_diff.mul_scalar(self.config.reg_strength)?;
417
418                // Combine gradient with regularization
419                let total_grad = gradient.add(&reg_grad)?;
420
421                // Apply update
422                let update = total_grad.mul_scalar(self.config.learning_rate)?;
423                self.parameters[i] = self.parameters[i].sub(&update)?;
424
425                // Update anchor parameters based on strategy
426                match self.config.update_strategy {
427                    UpdateStrategy::Fixed => {
428                        // Don't update anchors
429                    },
430                    UpdateStrategy::EMA => {
431                        // Exponential moving average update
432                        let anchor_update = self.parameters[i].mul_scalar(1.0 - self.ema_decay)?;
433                        let anchor_keep = self.anchor_parameters[i].mul_scalar(self.ema_decay)?;
434                        self.anchor_parameters[i] = anchor_update.add(&anchor_keep)?;
435                    },
436                    UpdateStrategy::TaskBoundary => {
437                        // Will be updated when finish_task() is called
438                    },
439                }
440            }
441        }
442        Ok(())
443    }
444
445    /// Finish current task (update anchors for TaskBoundary strategy).
446    pub fn finish_task(&mut self) -> Result<()> {
447        if matches!(self.config.update_strategy, UpdateStrategy::TaskBoundary) {
448            self.anchor_parameters = self.parameters.clone();
449        }
450        Ok(())
451    }
452
453    /// Get current parameters.
454    pub fn get_parameters(&self) -> &[Tensor] {
455        &self.parameters
456    }
457}
458
459/// Memory replay optimizer.
460pub struct MemoryReplay {
461    config: MemoryReplayConfig,
462    parameters: Vec<Tensor>,
463    memory_buffer: Vec<Vec<Tensor>>, // Stored gradients
464    step_count: usize,
465}
466
467impl MemoryReplay {
468    /// Create a new memory replay optimizer.
469    pub fn new(config: MemoryReplayConfig, initial_parameters: Vec<Tensor>) -> Self {
470        Self {
471            config,
472            parameters: initial_parameters,
473            memory_buffer: Vec::new(),
474            step_count: 0,
475        }
476    }
477
478    /// Add gradient to memory buffer.
479    pub fn store_gradient(&mut self, gradients: &[Tensor]) -> Result<()> {
480        if self.memory_buffer.len() >= self.config.memory_size {
481            // Remove oldest or least important gradient
482            match self.config.selection_strategy {
483                MemorySelectionStrategy::Random => {
484                    use scirs2_core::random::*; // SciRS2 Integration Policy
485                    let idx = thread_rng().gen_range(0..self.memory_buffer.len());
486                    self.memory_buffer.remove(idx);
487                },
488                _ => {
489                    self.memory_buffer.remove(0); // FIFO for simplicity
490                },
491            }
492        }
493
494        self.memory_buffer.push(gradients.to_vec());
495        Ok(())
496    }
497
498    /// Perform optimization step with memory replay.
499    pub fn step(&mut self, gradients: &[Tensor]) -> Result<()> {
500        // Regular gradient update
501        for (i, gradient) in gradients.iter().enumerate() {
502            if i < self.parameters.len() {
503                let update = gradient.mul_scalar(self.config.learning_rate)?;
504                self.parameters[i] = self.parameters[i].sub(&update)?;
505            }
506        }
507
508        // Store current gradient
509        self.store_gradient(gradients)?;
510
511        // Replay from memory
512        if self.step_count % self.config.replay_frequency == 0 && !self.memory_buffer.is_empty() {
513            self.replay_step()?;
514        }
515
516        self.step_count += 1;
517        Ok(())
518    }
519
520    fn replay_step(&mut self) -> Result<()> {
521        let batch_size = self.config.replay_batch_size.min(self.memory_buffer.len());
522
523        // Select random batch from memory
524        use scirs2_core::random::*; // SciRS2 Integration Policy
525        let mut indices: Vec<usize> = (0..self.memory_buffer.len()).collect();
526        let mut rng = thread_rng();
527        indices.shuffle(rng.rng_mut());
528
529        for &idx in indices.iter().take(batch_size) {
530            let replay_gradients = &self.memory_buffer[idx];
531
532            // Apply replay gradient with reduced learning rate
533            let replay_lr = self.config.learning_rate * 0.5;
534            for (i, gradient) in replay_gradients.iter().enumerate() {
535                if i < self.parameters.len() {
536                    let update = gradient.mul_scalar(replay_lr)?;
537                    self.parameters[i] = self.parameters[i].sub(&update)?;
538                }
539            }
540        }
541
542        Ok(())
543    }
544
545    /// Get current parameters.
546    pub fn get_parameters(&self) -> &[Tensor] {
547        &self.parameters
548    }
549
550    /// Get memory buffer size.
551    pub fn memory_size(&self) -> usize {
552        self.memory_buffer.len()
553    }
554}
555
556#[cfg(test)]
557mod tests {
558    use super::*;
559
560    #[test]
561    fn test_ewc_config() {
562        let config = EWCConfig::default();
563        assert_eq!(config.learning_rate, 1e-3);
564        assert_eq!(config.lambda, 1000.0);
565        assert!(!config.online);
566    }
567
568    #[test]
569    fn test_packnet_config() {
570        let config = PackNetConfig::default();
571        assert_eq!(config.sparsity_level, 0.5);
572        assert_eq!(config.num_tasks, 10);
573    }
574
575    #[test]
576    fn test_l2_regularization_config() {
577        let config = L2RegularizationConfig::default();
578        assert_eq!(config.reg_strength, 0.1);
579        assert!(matches!(config.update_strategy, UpdateStrategy::EMA));
580    }
581
582    #[test]
583    fn test_memory_replay_config() {
584        let config = MemoryReplayConfig::default();
585        assert_eq!(config.memory_size, 1000);
586        assert_eq!(config.replay_frequency, 10);
587        assert!(matches!(
588            config.selection_strategy,
589            MemorySelectionStrategy::Random
590        ));
591    }
592
593    #[test]
594    fn test_fisher_methods() {
595        assert!(matches!(FisherMethod::Empirical, FisherMethod::Empirical));
596        assert!(matches!(FisherMethod::True, FisherMethod::True));
597        assert!(matches!(FisherMethod::Diagonal, FisherMethod::Diagonal));
598    }
599}