Skip to main content

trustformers_models/
continual_learning.rs

1//! # Continual Learning Framework
2//!
3//! This module provides a comprehensive framework for continual learning,
4//! enabling models to learn new tasks while retaining knowledge from previous tasks.
5//!
6//! ## Features
7//!
8//! - **Multiple Continual Learning Strategies**: EWC, PackNet, Progressive Networks, etc.
9//! - **Catastrophic Forgetting Prevention**: Various regularization techniques
10//! - **Memory Management**: Experience replay and memory-based approaches
11//! - **Task Detection**: Automatic task boundary detection
12//! - **Evaluation Metrics**: Specialized metrics for continual learning scenarios
13//! - **Multi-task Support**: Learning multiple tasks simultaneously
14//!
15//! ## Usage
16//!
17//! ```rust,no_run
18//! use trustformers_models::continual_learning::{
19//!     ContinualLearningTrainer, ContinualLearningConfig, ContinualStrategy
20//! };
21//!
22//! let config = ContinualLearningConfig {
23//!     strategy: ContinualStrategy::ElasticWeightConsolidation {
24//!         lambda: 0.4,
25//!         fisher_samples: 1000,
26//!     },
27//!     memory_size: 1000,
28//!     ..Default::default()
29//! };
30//!
31//! let mut trainer = ContinualLearningTrainer::new(model, config)?;
32//!
33//! // Learn task 1
34//! trainer.learn_task(task1_data, 0)?;
35//! // Learn task 2 without forgetting task 1
36//! trainer.learn_task(task2_data, 1)?;
37//! ```
38
39use serde::{Deserialize, Serialize};
40use std::collections::HashMap;
41use trustformers_core::{errors::invalid_input, tensor::Tensor, traits::Model, Result};
42
43/// Configuration for continual learning
44#[derive(Debug, Clone, Serialize, Deserialize)]
45pub struct ContinualLearningConfig {
46    /// Continual learning strategy to use
47    pub strategy: ContinualStrategy,
48    /// Size of memory buffer for experience replay
49    pub memory_size: usize,
50    /// Memory selection strategy
51    pub memory_selection: MemorySelectionStrategy,
52    /// Whether to use task-specific heads
53    pub task_specific_heads: bool,
54    /// Number of tasks to prepare for
55    pub max_tasks: usize,
56    /// Learning rate schedule for continual learning
57    pub learning_rate_schedule: LearningRateSchedule,
58    /// Evaluation frequency (in training steps)
59    pub evaluation_frequency: usize,
60    /// Whether to use task detection
61    pub automatic_task_detection: bool,
62    /// Task detection threshold
63    pub task_detection_threshold: f32,
64}
65
66impl Default for ContinualLearningConfig {
67    fn default() -> Self {
68        Self {
69            strategy: ContinualStrategy::ElasticWeightConsolidation {
70                lambda: 0.4,
71                fisher_samples: 1000,
72            },
73            memory_size: 1000,
74            memory_selection: MemorySelectionStrategy::Random,
75            task_specific_heads: true,
76            max_tasks: 10,
77            learning_rate_schedule: LearningRateSchedule::Constant { lr: 1e-4 },
78            evaluation_frequency: 1000,
79            automatic_task_detection: false,
80            task_detection_threshold: 0.8,
81        }
82    }
83}
84
85/// Different continual learning strategies
86#[derive(Debug, Clone, Serialize, Deserialize)]
87pub enum ContinualStrategy {
88    /// Elastic Weight Consolidation (EWC)
89    ElasticWeightConsolidation { lambda: f32, fisher_samples: usize },
90    /// Online EWC
91    OnlineElasticWeightConsolidation {
92        lambda: f32,
93        gamma: f32,
94        fisher_samples: usize,
95    },
96    /// Synaptic Intelligence (SI)
97    SynapticIntelligence { c: f32, xi: f32 },
98    /// Learning without Forgetting (LwF)
99    LearningWithoutForgetting { lambda: f32, temperature: f32 },
100    /// Progressive Neural Networks
101    ProgressiveNeuralNetworks {
102        lateral_connections: bool,
103        adapter_layers: bool,
104    },
105    /// PackNet
106    PackNet {
107        prune_ratio: f32,
108        retrain_epochs: usize,
109    },
110    /// Experience Replay
111    ExperienceReplay {
112        memory_strength: f32,
113        replay_batch_size: usize,
114    },
115    /// Gradient Episodic Memory (GEM)
116    GradientEpisodicMemory {
117        memory_strength: f32,
118        constraint_violation_threshold: f32,
119    },
120    /// Averaged Gradient Episodic Memory (A-GEM)
121    AveragedGradientEpisodicMemory {
122        memory_strength: f32,
123        replay_batch_size: usize,
124    },
125    /// Meta-Experience Replay (MER)
126    MetaExperienceReplay {
127        beta: f32,
128        gamma: f32,
129        replay_steps: usize,
130    },
131    /// L2 Regularization (simple baseline)
132    L2Regularization { lambda: f32 },
133    /// Dropout-based approaches
134    VariationalContinualLearning {
135        kl_weight: f32,
136        prior_precision: f32,
137    },
138}
139
140/// Memory selection strategies for experience replay
141#[derive(Debug, Clone, Serialize, Deserialize)]
142pub enum MemorySelectionStrategy {
143    /// Random selection
144    Random,
145    /// Select most uncertain examples
146    Uncertainty,
147    /// Select most diverse examples
148    Diversity,
149    /// Gradient-based selection
150    Gradient,
151    /// Select examples with highest loss
152    HighestLoss,
153    /// Cluster-based selection
154    ClusterBased,
155    /// FIFO (First In, First Out)
156    FIFO,
157    /// Ring buffer
158    RingBuffer,
159}
160
161/// Learning rate scheduling for continual learning
162#[derive(Debug, Clone, Serialize, Deserialize)]
163pub enum LearningRateSchedule {
164    /// Constant learning rate
165    Constant { lr: f32 },
166    /// Exponential decay
167    ExponentialDecay { initial_lr: f32, decay_rate: f32 },
168    /// Step decay
169    StepDecay {
170        initial_lr: f32,
171        step_size: usize,
172        gamma: f32,
173    },
174    /// Cosine annealing
175    CosineAnnealing { initial_lr: f32, t_max: usize },
176    /// Warm restart
177    WarmRestart {
178        initial_lr: f32,
179        t_0: usize,
180        t_mult: usize,
181    },
182}
183
184/// Memory buffer for storing past experiences
185#[derive(Debug, Clone)]
186pub struct MemoryBuffer {
187    /// Stored examples (inputs)
188    pub inputs: Vec<Tensor>,
189    /// Stored targets
190    pub targets: Vec<Tensor>,
191    /// Task IDs for each example
192    pub task_ids: Vec<usize>,
193    /// Example priorities/weights
194    pub priorities: Vec<f32>,
195    /// Maximum buffer size
196    pub max_size: usize,
197    /// Current insertion pointer
198    pub insertion_ptr: usize,
199    /// Selection strategy
200    pub selection_strategy: MemorySelectionStrategy,
201}
202
203impl MemoryBuffer {
204    /// Create a new memory buffer
205    pub fn new(max_size: usize, selection_strategy: MemorySelectionStrategy) -> Self {
206        Self {
207            inputs: Vec::new(),
208            targets: Vec::new(),
209            task_ids: Vec::new(),
210            priorities: Vec::new(),
211            max_size,
212            insertion_ptr: 0,
213            selection_strategy,
214        }
215    }
216
217    /// Add a new example to the buffer
218    pub fn add_example(&mut self, input: Tensor, target: Tensor, task_id: usize, priority: f32) {
219        if self.inputs.len() < self.max_size {
220            // Buffer not full, just append
221            self.inputs.push(input);
222            self.targets.push(target);
223            self.task_ids.push(task_id);
224            self.priorities.push(priority);
225        } else {
226            // Buffer full, need to replace
227            match self.selection_strategy {
228                MemorySelectionStrategy::Random => {
229                    let idx = fastrand::usize(..self.max_size);
230                    self.inputs[idx] = input;
231                    self.targets[idx] = target;
232                    self.task_ids[idx] = task_id;
233                    self.priorities[idx] = priority;
234                },
235                MemorySelectionStrategy::FIFO | MemorySelectionStrategy::RingBuffer => {
236                    self.inputs[self.insertion_ptr] = input;
237                    self.targets[self.insertion_ptr] = target;
238                    self.task_ids[self.insertion_ptr] = task_id;
239                    self.priorities[self.insertion_ptr] = priority;
240                    self.insertion_ptr = (self.insertion_ptr + 1) % self.max_size;
241                },
242                _ => {
243                    // For other strategies, replace the least important example
244                    let min_idx = self
245                        .priorities
246                        .iter()
247                        .enumerate()
248                        .min_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal))
249                        .map(|(idx, _)| idx)
250                        .unwrap_or(0);
251
252                    if priority > self.priorities[min_idx] {
253                        self.inputs[min_idx] = input;
254                        self.targets[min_idx] = target;
255                        self.task_ids[min_idx] = task_id;
256                        self.priorities[min_idx] = priority;
257                    }
258                },
259            }
260        }
261    }
262
263    /// Sample a batch from the buffer
264    pub fn sample_batch(
265        &self,
266        batch_size: usize,
267    ) -> Result<(Vec<Tensor>, Vec<Tensor>, Vec<usize>)> {
268        if self.inputs.is_empty() {
269            return Ok((Vec::new(), Vec::new(), Vec::new()));
270        }
271
272        let sample_size = batch_size.min(self.inputs.len());
273        let mut indices = Vec::new();
274
275        match self.selection_strategy {
276            MemorySelectionStrategy::Random => {
277                for _ in 0..sample_size {
278                    indices.push(fastrand::usize(..self.inputs.len()));
279                }
280            },
281            _ => {
282                // For other strategies, sample proportional to priority
283                let total_priority: f32 = self.priorities.iter().sum();
284                for _ in 0..sample_size {
285                    let mut cumsum = 0.0;
286                    let threshold = fastrand::f32() * total_priority;
287                    for (i, &priority) in self.priorities.iter().enumerate() {
288                        cumsum += priority;
289                        if cumsum >= threshold {
290                            indices.push(i);
291                            break;
292                        }
293                    }
294                }
295            },
296        }
297
298        let inputs: Vec<Tensor> = indices.iter().map(|&i| self.inputs[i].clone()).collect();
299        let targets: Vec<Tensor> = indices.iter().map(|&i| self.targets[i].clone()).collect();
300        let task_ids: Vec<usize> = indices.iter().map(|&i| self.task_ids[i]).collect();
301
302        Ok((inputs, targets, task_ids))
303    }
304
305    /// Get examples from a specific task
306    pub fn get_task_examples(&self, task_id: usize) -> (Vec<Tensor>, Vec<Tensor>) {
307        let mut inputs = Vec::new();
308        let mut targets = Vec::new();
309
310        for (i, &tid) in self.task_ids.iter().enumerate() {
311            if tid == task_id {
312                inputs.push(self.inputs[i].clone());
313                targets.push(self.targets[i].clone());
314            }
315        }
316
317        (inputs, targets)
318    }
319
320    /// Clear the buffer
321    pub fn clear(&mut self) {
322        self.inputs.clear();
323        self.targets.clear();
324        self.task_ids.clear();
325        self.priorities.clear();
326        self.insertion_ptr = 0;
327    }
328
329    /// Get buffer size
330    pub fn size(&self) -> usize {
331        self.inputs.len()
332    }
333
334    /// Check if buffer is empty
335    pub fn is_empty(&self) -> bool {
336        self.inputs.is_empty()
337    }
338}
339
340/// Continual learning trainer
341pub struct ContinualLearningTrainer<M: Model> {
342    /// The model being trained
343    pub model: M,
344    /// Configuration
345    pub config: ContinualLearningConfig,
346    /// Memory buffer for experience replay
347    pub memory: MemoryBuffer,
348    /// Task-specific information
349    pub task_info: HashMap<usize, TaskInfo>,
350    /// Current task ID
351    pub current_task: Option<usize>,
352    /// Fisher information matrices (for EWC)
353    pub fisher_matrices: HashMap<String, Tensor>,
354    /// Optimal parameters (for EWC)
355    pub optimal_parameters: HashMap<String, Tensor>,
356    /// Training step counter
357    pub step_counter: usize,
358    /// Task detection state
359    pub task_detector: Option<TaskDetector>,
360}
361
362impl<M: Model<Input = Tensor, Output = Tensor>> ContinualLearningTrainer<M> {
363    /// Create a new continual learning trainer
364    pub fn new(model: M, config: ContinualLearningConfig) -> Result<Self> {
365        let memory = MemoryBuffer::new(config.memory_size, config.memory_selection.clone());
366
367        let task_detector = if config.automatic_task_detection {
368            Some(TaskDetector::new(config.task_detection_threshold))
369        } else {
370            None
371        };
372
373        Ok(Self {
374            model,
375            config,
376            memory,
377            task_info: HashMap::new(),
378            current_task: None,
379            fisher_matrices: HashMap::new(),
380            optimal_parameters: HashMap::new(),
381            step_counter: 0,
382            task_detector,
383        })
384    }
385
386    /// Start learning a new task
387    pub fn start_task(&mut self, task_id: usize) -> Result<()> {
388        // Save current task information if this is a task switch
389        if let Some(current_id) = self.current_task {
390            if current_id != task_id {
391                self.finalize_task(current_id)?;
392            }
393        }
394
395        self.current_task = Some(task_id);
396
397        // Initialize task info if new
398        self.task_info.entry(task_id).or_insert_with(|| TaskInfo::new(task_id));
399
400        // Apply strategy-specific initialization
401        match &self.config.strategy {
402            ContinualStrategy::ProgressiveNeuralNetworks { .. } => {
403                // Add new columns for progressive networks
404                self.add_progressive_columns(task_id)?;
405            },
406            ContinualStrategy::PackNet { .. } => {
407                // Prepare for pruning-based learning
408                self.prepare_packnet(task_id)?;
409            },
410            _ => {
411                // Most strategies don't require special initialization
412            },
413        }
414
415        Ok(())
416    }
417
418    /// Learn from a batch of data
419    pub fn learn_batch(
420        &mut self,
421        inputs: &[Tensor],
422        targets: &[Tensor],
423        task_id: Option<usize>,
424    ) -> Result<ContinualLearningOutput> {
425        let task_id = task_id
426            .or(self.current_task)
427            .ok_or_else(|| invalid_input("No task ID specified"))?;
428
429        // Detect task boundaries if enabled
430        if let Some(detector) = &mut self.task_detector {
431            if let Some(detected_task) = detector.detect_task_change(inputs, targets)? {
432                if detected_task != task_id {
433                    self.start_task(detected_task)?;
434                }
435            }
436        }
437
438        // Compute forward pass and loss
439        let outputs = self.model.forward(inputs[0].clone())?; // Simplified single input
440        let current_loss = self.compute_task_loss(&outputs, &targets[0])?;
441        let current_loss_for_output = current_loss.clone();
442
443        // Apply continual learning strategy
444        let total_loss = match &self.config.strategy {
445            ContinualStrategy::ElasticWeightConsolidation { lambda, .. } => {
446                let ewc_loss = self.compute_ewc_loss(*lambda)?;
447                current_loss.add(&ewc_loss)?
448            },
449            ContinualStrategy::LearningWithoutForgetting {
450                lambda,
451                temperature,
452            } => {
453                let distillation_loss = self.compute_lwf_loss(inputs, *lambda, *temperature)?;
454                current_loss.add(&distillation_loss)?
455            },
456            ContinualStrategy::ExperienceReplay {
457                memory_strength,
458                replay_batch_size,
459            } => {
460                let replay_loss = self.compute_replay_loss(*memory_strength, *replay_batch_size)?;
461                current_loss.add(&replay_loss)?
462            },
463            ContinualStrategy::GradientEpisodicMemory {
464                memory_strength, ..
465            } => self.compute_gem_loss(&current_loss, *memory_strength)?,
466            ContinualStrategy::L2Regularization { lambda } => {
467                let l2_loss = self.compute_l2_regularization(*lambda)?;
468                current_loss.add(&l2_loss)?
469            },
470            _ => current_loss,
471        };
472
473        // Store examples in memory if needed
474        if !matches!(
475            self.config.strategy,
476            ContinualStrategy::L2Regularization { .. }
477        ) {
478            for (input, target) in inputs.iter().zip(targets.iter()) {
479                let priority = self.compute_example_priority(input, target)?;
480                self.memory.add_example(input.clone(), target.clone(), task_id, priority);
481            }
482        }
483
484        // Update training step counter
485        self.step_counter += 1;
486
487        // Update task statistics
488        if let Some(task_info) = self.task_info.get_mut(&task_id) {
489            task_info.update_statistics(total_loss.to_scalar().unwrap_or(0.0));
490        }
491
492        let total_loss_clone = total_loss.clone();
493
494        Ok(ContinualLearningOutput {
495            total_loss: total_loss_clone.clone(),
496            task_loss: current_loss_for_output.clone(),
497            regularization_loss: total_loss_clone.sub(&current_loss_for_output)?,
498            task_id,
499            memory_usage: self.memory.size(),
500        })
501    }
502
503    /// Finalize learning for a task
504    pub fn finalize_task(&mut self, task_id: usize) -> Result<()> {
505        match self.config.strategy.clone() {
506            ContinualStrategy::ElasticWeightConsolidation { fisher_samples, .. }
507            | ContinualStrategy::OnlineElasticWeightConsolidation { fisher_samples, .. } => {
508                self.compute_fisher_information(task_id, fisher_samples)?;
509                self.save_optimal_parameters()?;
510            },
511            ContinualStrategy::PackNet {
512                prune_ratio,
513                retrain_epochs,
514            } => {
515                self.apply_packnet_pruning(prune_ratio)?;
516                self.retrain_after_pruning(retrain_epochs)?;
517            },
518            _ => {
519                // Most strategies don't require finalization
520            },
521        }
522
523        Ok(())
524    }
525
526    /// Compute task-specific loss
527    fn compute_task_loss(&self, outputs: &Tensor, targets: &Tensor) -> Result<Tensor> {
528        // Implement cross-entropy loss for classification tasks
529        let log_probs = outputs.softmax(-1)?.log()?;
530
531        // Check if targets are one-hot encoded or class indices
532        let targets_shape = targets.shape();
533        let outputs_shape = outputs.shape();
534
535        if targets_shape == outputs_shape {
536            // Targets are one-hot encoded
537            let element_wise = log_probs.mul(targets)?;
538            let sum_per_sample = element_wise.sum(Some(vec![outputs_shape.len() - 1]), false)?; // Sum across the last dimension
539            Ok(sum_per_sample.neg()?.mean()?)
540        } else {
541            // Targets are class indices - use simplified approach
542            // In a full implementation, we'd use proper gather operation
543            // For now, compute difference between predictions and one-hot targets
544            let batch_size = outputs_shape[0];
545            let num_classes = outputs_shape[outputs_shape.len() - 1];
546
547            // Create one-hot encoding manually (simplified)
548            let mut one_hot_data = vec![0.0f32; batch_size * num_classes];
549            let targets_data = targets.data()?;
550
551            for (i, &target_idx) in targets_data.iter().enumerate() {
552                if target_idx >= 0.0 && (target_idx as usize) < num_classes {
553                    one_hot_data[i * num_classes + target_idx as usize] = 1.0;
554                }
555            }
556
557            let one_hot_targets = Tensor::new(one_hot_data)?.reshape(&outputs_shape)?;
558            let element_wise = log_probs.mul(&one_hot_targets)?;
559            let sum_per_sample = element_wise.sum(Some(vec![outputs_shape.len() - 1]), false)?; // Sum across the last dimension
560            Ok(sum_per_sample.neg()?.mean()?)
561        }
562    }
563
564    /// Compute EWC regularization loss
565    fn compute_ewc_loss(&self, lambda: f32) -> Result<Tensor> {
566        let mut total_loss = Tensor::zeros(&[1])?;
567
568        // This is a simplified implementation
569        // In practice, you'd iterate through model parameters
570        for (param_name, fisher) in &self.fisher_matrices {
571            if let Some(optimal) = self.optimal_parameters.get(param_name) {
572                // Get current parameter (simplified)
573                let current_param = Tensor::zeros_like(optimal)?; // Placeholder
574                let diff = current_param.sub(optimal)?;
575                let squared_diff = diff.mul(&diff)?;
576                let weighted_diff = fisher.mul(&squared_diff)?;
577                total_loss = total_loss.add(&weighted_diff.sum(None, false)?)?;
578            }
579        }
580
581        total_loss.scalar_mul(lambda)
582    }
583
584    /// Compute Learning without Forgetting distillation loss
585    fn compute_lwf_loss(
586        &self,
587        _inputs: &[Tensor],
588        lambda: f32,
589        _temperature: f32,
590    ) -> Result<Tensor> {
591        // This would compute the distillation loss from previous tasks
592        // Simplified implementation
593        Tensor::zeros(&[1])?.scalar_mul(lambda)
594    }
595
596    /// Compute experience replay loss
597    fn compute_replay_loss(
598        &mut self,
599        memory_strength: f32,
600        replay_batch_size: usize,
601    ) -> Result<Tensor> {
602        if self.memory.is_empty() {
603            return Tensor::zeros(&[1]);
604        }
605
606        let (replay_inputs, replay_targets, _) = self.memory.sample_batch(replay_batch_size)?;
607
608        if replay_inputs.is_empty() {
609            return Tensor::zeros(&[1]);
610        }
611
612        // Compute loss on replay data
613        let replay_outputs = self.model.forward(replay_inputs[0].clone())?; // Simplified
614        let replay_loss = self.compute_task_loss(&replay_outputs, &replay_targets[0])?;
615
616        replay_loss.scalar_mul(memory_strength)
617    }
618
619    /// Compute GEM constraint loss
620    fn compute_gem_loss(&mut self, current_loss: &Tensor, memory_strength: f32) -> Result<Tensor> {
621        // GEM computes gradients on memory and projects current gradients
622        // This is a simplified implementation
623        current_loss.scalar_mul(memory_strength)
624    }
625
626    /// Compute L2 regularization loss
627    fn compute_l2_regularization(&self, lambda: f32) -> Result<Tensor> {
628        // Compute L2 norm of parameters
629        // This is a simplified implementation
630        Tensor::zeros(&[1])?.scalar_mul(lambda)
631    }
632
633    /// Compute example priority for memory storage
634    fn compute_example_priority(&self, input: &Tensor, target: &Tensor) -> Result<f32> {
635        match self.config.memory_selection {
636            MemorySelectionStrategy::Random => Ok(1.0),
637            MemorySelectionStrategy::Uncertainty => {
638                // Compute prediction uncertainty
639                let outputs = self.model.forward(input.clone())?;
640                let probs = outputs.softmax(-1)?;
641                let entropy = -(probs.clone().mul(&probs.log()?)?)
642                    .sum(Some(vec![1]), false)?
643                    .to_scalar()
644                    .unwrap_or(0.0);
645                Ok(entropy)
646            },
647            MemorySelectionStrategy::HighestLoss => {
648                let outputs = self.model.forward(input.clone())?;
649                let loss = self.compute_task_loss(&outputs, target)?;
650                Ok(loss.to_scalar().unwrap_or(0.0))
651            },
652            _ => Ok(1.0), // Default priority
653        }
654    }
655
656    /// Compute Fisher information for EWC
657    fn compute_fisher_information(&mut self, task_id: usize, num_samples: usize) -> Result<()> {
658        // Get examples from current task
659        let (task_inputs, task_targets) = self.memory.get_task_examples(task_id);
660
661        if task_inputs.is_empty() {
662            return Ok(());
663        }
664
665        // Sample examples for Fisher computation
666        let sample_size = num_samples.min(task_inputs.len());
667
668        // This is a simplified implementation
669        // In practice, you'd compute Fisher information for each parameter
670        for i in 0..sample_size {
671            let input = &task_inputs[i % task_inputs.len()];
672            let target = &task_targets[i % task_targets.len()];
673
674            // Compute gradients and accumulate Fisher information
675            let outputs = self.model.forward(input.clone())?;
676            let _loss = self.compute_task_loss(&outputs, target)?;
677
678            // Store Fisher information (simplified)
679            self.fisher_matrices.insert(
680                format!("param_{}", i),
681                Tensor::ones(&[10])?, // Placeholder
682            );
683        }
684
685        Ok(())
686    }
687
688    /// Save optimal parameters for EWC
689    fn save_optimal_parameters(&mut self) -> Result<()> {
690        // Save current model parameters as optimal
691        // This is a simplified implementation
692        self.optimal_parameters.insert(
693            "param_0".to_string(),
694            Tensor::zeros(&[10])?, // Placeholder
695        );
696        Ok(())
697    }
698
699    /// Add progressive network columns
700    fn add_progressive_columns(&mut self, _task_id: usize) -> Result<()> {
701        // Add new columns to the network for the new task
702        // This is a simplified implementation
703        Ok(())
704    }
705
706    /// Prepare for PackNet pruning
707    fn prepare_packnet(&mut self, _task_id: usize) -> Result<()> {
708        // Prepare the network for pruning-based continual learning
709        Ok(())
710    }
711
712    /// Apply PackNet pruning
713    fn apply_packnet_pruning(&mut self, _prune_ratio: f32) -> Result<()> {
714        // Prune the network and freeze pruned weights
715        Ok(())
716    }
717
718    /// Retrain after PackNet pruning
719    fn retrain_after_pruning(&mut self, _epochs: usize) -> Result<()> {
720        // Retrain the unpruned weights
721        Ok(())
722    }
723
724    /// Evaluate on all tasks
725    pub fn evaluate_all_tasks(&self) -> Result<HashMap<usize, TaskEvaluation>> {
726        let mut evaluations = HashMap::new();
727
728        for &task_id in self.task_info.keys() {
729            let (task_inputs, task_targets) = self.memory.get_task_examples(task_id);
730
731            if !task_inputs.is_empty() {
732                let evaluation = self.evaluate_task(&task_inputs, &task_targets, task_id)?;
733                evaluations.insert(task_id, evaluation);
734            }
735        }
736
737        Ok(evaluations)
738    }
739
740    /// Evaluate on a specific task
741    fn evaluate_task(
742        &self,
743        inputs: &[Tensor],
744        targets: &[Tensor],
745        task_id: usize,
746    ) -> Result<TaskEvaluation> {
747        let mut total_loss = 0.0;
748        let mut correct_predictions = 0;
749        let total_examples = inputs.len();
750
751        for (input, target) in inputs.iter().zip(targets.iter()) {
752            let outputs = self.model.forward(input.clone())?;
753            let loss = self.compute_task_loss(&outputs, target)?;
754            total_loss += loss.to_scalar().unwrap_or(0.0);
755
756            // Compute accuracy (simplified)
757            let predicted = Tensor::zeros(&[1])?; // Simplified placeholder - ideally should be argmax
758            let target_class = Tensor::zeros(&[1])?; // Simplified placeholder - ideally should be argmax
759            if predicted.to_scalar().unwrap_or(-1.0) == target_class.to_scalar().unwrap_or(-2.0) {
760                correct_predictions += 1;
761            }
762        }
763
764        Ok(TaskEvaluation {
765            task_id,
766            average_loss: total_loss / total_examples as f32,
767            accuracy: correct_predictions as f32 / total_examples as f32,
768            num_examples: total_examples,
769        })
770    }
771
772    /// Get continual learning metrics
773    pub fn get_metrics(&self) -> ContinualLearningMetrics {
774        let all_evaluations = self.evaluate_all_tasks().unwrap_or_default();
775
776        let average_accuracy = if !all_evaluations.is_empty() {
777            all_evaluations.values().map(|e| e.accuracy).sum::<f32>() / all_evaluations.len() as f32
778        } else {
779            0.0
780        };
781
782        let memory_efficiency = self.memory.size() as f32 / self.config.memory_size as f32;
783
784        ContinualLearningMetrics {
785            average_accuracy,
786            task_evaluations: all_evaluations,
787            memory_efficiency,
788            num_tasks_learned: self.task_info.len(),
789            current_task: self.current_task,
790        }
791    }
792}
793
794/// Information about a specific task
795#[derive(Debug, Clone)]
796pub struct TaskInfo {
797    pub task_id: usize,
798    pub start_step: usize,
799    pub num_examples_seen: usize,
800    pub average_loss: f32,
801    pub last_accuracy: f32,
802}
803
804impl TaskInfo {
805    pub fn new(task_id: usize) -> Self {
806        Self {
807            task_id,
808            start_step: 0,
809            num_examples_seen: 0,
810            average_loss: 0.0,
811            last_accuracy: 0.0,
812        }
813    }
814
815    pub fn update_statistics(&mut self, loss: f32) {
816        self.num_examples_seen += 1;
817        self.average_loss = (self.average_loss * (self.num_examples_seen - 1) as f32 + loss)
818            / self.num_examples_seen as f32;
819    }
820}
821
822/// Task detector for automatic task boundary detection
823pub struct TaskDetector {
824    #[allow(dead_code)]
825    threshold: f32,
826    #[allow(dead_code)]
827    recent_losses: Vec<f32>,
828    #[allow(dead_code)]
829    window_size: usize,
830}
831
832impl TaskDetector {
833    pub fn new(threshold: f32) -> Self {
834        Self {
835            threshold,
836            recent_losses: Vec::new(),
837            window_size: 100,
838        }
839    }
840
841    pub fn detect_task_change(
842        &mut self,
843        _inputs: &[Tensor],
844        _targets: &[Tensor],
845    ) -> Result<Option<usize>> {
846        // Simplified task detection based on loss spikes
847        // In practice, this would be more sophisticated
848        Ok(None)
849    }
850}
851
852/// Output from continual learning step
853#[derive(Debug, Clone)]
854pub struct ContinualLearningOutput {
855    pub total_loss: Tensor,
856    pub task_loss: Tensor,
857    pub regularization_loss: Tensor,
858    pub task_id: usize,
859    pub memory_usage: usize,
860}
861
862/// Evaluation results for a specific task
863#[derive(Debug, Clone)]
864pub struct TaskEvaluation {
865    pub task_id: usize,
866    pub average_loss: f32,
867    pub accuracy: f32,
868    pub num_examples: usize,
869}
870
871/// Overall continual learning metrics
872#[derive(Debug, Clone)]
873pub struct ContinualLearningMetrics {
874    pub average_accuracy: f32,
875    pub task_evaluations: HashMap<usize, TaskEvaluation>,
876    pub memory_efficiency: f32,
877    pub num_tasks_learned: usize,
878    pub current_task: Option<usize>,
879}
880
881/// Utilities for continual learning
882pub mod utils {
883    use super::*;
884
885    /// Create EWC configuration
886    pub fn ewc_config(
887        lambda: f32,
888        fisher_samples: usize,
889        memory_size: usize,
890    ) -> ContinualLearningConfig {
891        ContinualLearningConfig {
892            strategy: ContinualStrategy::ElasticWeightConsolidation {
893                lambda,
894                fisher_samples,
895            },
896            memory_size,
897            ..Default::default()
898        }
899    }
900
901    /// Create experience replay configuration
902    pub fn experience_replay_config(
903        memory_size: usize,
904        replay_batch_size: usize,
905    ) -> ContinualLearningConfig {
906        ContinualLearningConfig {
907            strategy: ContinualStrategy::ExperienceReplay {
908                memory_strength: 1.0,
909                replay_batch_size,
910            },
911            memory_size,
912            memory_selection: MemorySelectionStrategy::Random,
913            ..Default::default()
914        }
915    }
916
917    /// Create L2 regularization configuration
918    pub fn l2_regularization_config(lambda: f32) -> ContinualLearningConfig {
919        ContinualLearningConfig {
920            strategy: ContinualStrategy::L2Regularization { lambda },
921            memory_size: 0, // No memory needed for L2 regularization
922            ..Default::default()
923        }
924    }
925
926    /// Create progressive networks configuration
927    pub fn progressive_networks_config() -> ContinualLearningConfig {
928        ContinualLearningConfig {
929            strategy: ContinualStrategy::ProgressiveNeuralNetworks {
930                lateral_connections: true,
931                adapter_layers: true,
932            },
933            task_specific_heads: true,
934            ..Default::default()
935        }
936    }
937
938    /// Compute backward transfer (improvement on previous tasks)
939    pub fn compute_backward_transfer(
940        evaluations_before: &HashMap<usize, TaskEvaluation>,
941        evaluations_after: &HashMap<usize, TaskEvaluation>,
942    ) -> f32 {
943        let mut total_transfer = 0.0;
944        let mut num_tasks = 0;
945
946        for (&task_id, after_eval) in evaluations_after {
947            if let Some(before_eval) = evaluations_before.get(&task_id) {
948                total_transfer += after_eval.accuracy - before_eval.accuracy;
949                num_tasks += 1;
950            }
951        }
952
953        if num_tasks > 0 {
954            total_transfer / num_tasks as f32
955        } else {
956            0.0
957        }
958    }
959
960    /// Compute forward transfer (improvement on new tasks)
961    pub fn compute_forward_transfer(baseline_accuracy: f32, continual_accuracy: f32) -> f32 {
962        continual_accuracy - baseline_accuracy
963    }
964
965    /// Compute forgetting measure
966    pub fn compute_forgetting(
967        max_accuracies: &HashMap<usize, f32>,
968        final_accuracies: &HashMap<usize, f32>,
969    ) -> f32 {
970        let mut total_forgetting = 0.0;
971        let mut num_tasks = 0;
972
973        for (&task_id, &max_acc) in max_accuracies {
974            if let Some(&final_acc) = final_accuracies.get(&task_id) {
975                total_forgetting += max_acc - final_acc;
976                num_tasks += 1;
977            }
978        }
979
980        if num_tasks > 0 {
981            total_forgetting / num_tasks as f32
982        } else {
983            0.0
984        }
985    }
986}
987
988#[cfg(test)]
989mod tests {
990    use super::*;
991
992    #[test]
993    fn test_continual_learning_config_default() {
994        let config = ContinualLearningConfig::default();
995        assert_eq!(config.memory_size, 1000);
996        assert!(config.task_specific_heads);
997        assert!(!config.automatic_task_detection);
998
999        if let ContinualStrategy::ElasticWeightConsolidation {
1000            lambda,
1001            fisher_samples,
1002        } = config.strategy
1003        {
1004            assert_eq!(lambda, 0.4);
1005            assert_eq!(fisher_samples, 1000);
1006        } else {
1007            panic!("Expected EWC strategy");
1008        }
1009    }
1010
1011    #[test]
1012    fn test_memory_buffer() {
1013        let mut buffer = MemoryBuffer::new(3, MemorySelectionStrategy::Random);
1014        assert!(buffer.is_empty());
1015        assert_eq!(buffer.size(), 0);
1016
1017        // Add examples
1018        let input1 = Tensor::zeros(&[1, 10]).expect("operation failed");
1019        let target1 = Tensor::zeros(&[1]).expect("operation failed");
1020        buffer.add_example(input1, target1, 0, 1.0);
1021        assert_eq!(buffer.size(), 1);
1022
1023        let input2 = Tensor::ones(&[1, 10]).expect("operation failed");
1024        let target2 = Tensor::ones(&[1]).expect("operation failed");
1025        buffer.add_example(input2, target2, 1, 2.0);
1026        assert_eq!(buffer.size(), 2);
1027
1028        // Sample batch
1029        let (inputs, targets, task_ids) = buffer.sample_batch(2).expect("operation failed");
1030        assert_eq!(inputs.len(), 2);
1031        assert_eq!(targets.len(), 2);
1032        assert_eq!(task_ids.len(), 2);
1033    }
1034
1035    #[test]
1036    fn test_ewc_config() {
1037        let config = utils::ewc_config(0.5, 2000, 500);
1038        assert_eq!(config.memory_size, 500);
1039
1040        if let ContinualStrategy::ElasticWeightConsolidation {
1041            lambda,
1042            fisher_samples,
1043        } = config.strategy
1044        {
1045            assert_eq!(lambda, 0.5);
1046            assert_eq!(fisher_samples, 2000);
1047        } else {
1048            panic!("Expected EWC strategy");
1049        }
1050    }
1051
1052    #[test]
1053    fn test_experience_replay_config() {
1054        let config = utils::experience_replay_config(1000, 64);
1055        assert_eq!(config.memory_size, 1000);
1056
1057        if let ContinualStrategy::ExperienceReplay {
1058            memory_strength,
1059            replay_batch_size,
1060        } = config.strategy
1061        {
1062            assert_eq!(memory_strength, 1.0);
1063            assert_eq!(replay_batch_size, 64);
1064        } else {
1065            panic!("Expected ExperienceReplay strategy");
1066        }
1067    }
1068
1069    #[test]
1070    fn test_l2_regularization_config() {
1071        let config = utils::l2_regularization_config(0.01);
1072        assert_eq!(config.memory_size, 0);
1073
1074        if let ContinualStrategy::L2Regularization { lambda } = config.strategy {
1075            assert_eq!(lambda, 0.01);
1076        } else {
1077            panic!("Expected L2Regularization strategy");
1078        }
1079    }
1080
1081    #[test]
1082    fn test_task_info() {
1083        let mut info = TaskInfo::new(5);
1084        assert_eq!(info.task_id, 5);
1085        assert_eq!(info.num_examples_seen, 0);
1086
1087        info.update_statistics(0.5);
1088        assert_eq!(info.num_examples_seen, 1);
1089        assert_eq!(info.average_loss, 0.5);
1090
1091        info.update_statistics(1.0);
1092        assert_eq!(info.num_examples_seen, 2);
1093        assert_eq!(info.average_loss, 0.75);
1094    }
1095
1096    #[test]
1097    fn test_backward_transfer_computation() {
1098        let mut before = HashMap::new();
1099        before.insert(
1100            0,
1101            TaskEvaluation {
1102                task_id: 0,
1103                average_loss: 0.5,
1104                accuracy: 0.8,
1105                num_examples: 100,
1106            },
1107        );
1108        before.insert(
1109            1,
1110            TaskEvaluation {
1111                task_id: 1,
1112                average_loss: 0.6,
1113                accuracy: 0.7,
1114                num_examples: 100,
1115            },
1116        );
1117
1118        let mut after = HashMap::new();
1119        after.insert(
1120            0,
1121            TaskEvaluation {
1122                task_id: 0,
1123                average_loss: 0.4,
1124                accuracy: 0.85,
1125                num_examples: 100,
1126            },
1127        );
1128        after.insert(
1129            1,
1130            TaskEvaluation {
1131                task_id: 1,
1132                average_loss: 0.55,
1133                accuracy: 0.72,
1134                num_examples: 100,
1135            },
1136        );
1137
1138        let backward_transfer = utils::compute_backward_transfer(&before, &after);
1139        assert!((backward_transfer - 0.035).abs() < 1e-6); // (0.05 + 0.02) / 2
1140    }
1141
1142    #[test]
1143    fn test_forgetting_computation() {
1144        let mut max_accuracies = HashMap::new();
1145        max_accuracies.insert(0, 0.9);
1146        max_accuracies.insert(1, 0.85);
1147
1148        let mut final_accuracies = HashMap::new();
1149        final_accuracies.insert(0, 0.8);
1150        final_accuracies.insert(1, 0.75);
1151
1152        let forgetting = utils::compute_forgetting(&max_accuracies, &final_accuracies);
1153        assert!((forgetting - 0.1).abs() < 1e-6); // (0.1 + 0.1) / 2
1154    }
1155}