1use serde::{Deserialize, Serialize};
40use std::collections::HashMap;
41use trustformers_core::{errors::invalid_input, tensor::Tensor, traits::Model, Result};
42
43#[derive(Debug, Clone, Serialize, Deserialize)]
45pub struct ContinualLearningConfig {
46 pub strategy: ContinualStrategy,
48 pub memory_size: usize,
50 pub memory_selection: MemorySelectionStrategy,
52 pub task_specific_heads: bool,
54 pub max_tasks: usize,
56 pub learning_rate_schedule: LearningRateSchedule,
58 pub evaluation_frequency: usize,
60 pub automatic_task_detection: bool,
62 pub task_detection_threshold: f32,
64}
65
66impl Default for ContinualLearningConfig {
67 fn default() -> Self {
68 Self {
69 strategy: ContinualStrategy::ElasticWeightConsolidation {
70 lambda: 0.4,
71 fisher_samples: 1000,
72 },
73 memory_size: 1000,
74 memory_selection: MemorySelectionStrategy::Random,
75 task_specific_heads: true,
76 max_tasks: 10,
77 learning_rate_schedule: LearningRateSchedule::Constant { lr: 1e-4 },
78 evaluation_frequency: 1000,
79 automatic_task_detection: false,
80 task_detection_threshold: 0.8,
81 }
82 }
83}
84
85#[derive(Debug, Clone, Serialize, Deserialize)]
87pub enum ContinualStrategy {
88 ElasticWeightConsolidation { lambda: f32, fisher_samples: usize },
90 OnlineElasticWeightConsolidation {
92 lambda: f32,
93 gamma: f32,
94 fisher_samples: usize,
95 },
96 SynapticIntelligence { c: f32, xi: f32 },
98 LearningWithoutForgetting { lambda: f32, temperature: f32 },
100 ProgressiveNeuralNetworks {
102 lateral_connections: bool,
103 adapter_layers: bool,
104 },
105 PackNet {
107 prune_ratio: f32,
108 retrain_epochs: usize,
109 },
110 ExperienceReplay {
112 memory_strength: f32,
113 replay_batch_size: usize,
114 },
115 GradientEpisodicMemory {
117 memory_strength: f32,
118 constraint_violation_threshold: f32,
119 },
120 AveragedGradientEpisodicMemory {
122 memory_strength: f32,
123 replay_batch_size: usize,
124 },
125 MetaExperienceReplay {
127 beta: f32,
128 gamma: f32,
129 replay_steps: usize,
130 },
131 L2Regularization { lambda: f32 },
133 VariationalContinualLearning {
135 kl_weight: f32,
136 prior_precision: f32,
137 },
138}
139
140#[derive(Debug, Clone, Serialize, Deserialize)]
142pub enum MemorySelectionStrategy {
143 Random,
145 Uncertainty,
147 Diversity,
149 Gradient,
151 HighestLoss,
153 ClusterBased,
155 FIFO,
157 RingBuffer,
159}
160
161#[derive(Debug, Clone, Serialize, Deserialize)]
163pub enum LearningRateSchedule {
164 Constant { lr: f32 },
166 ExponentialDecay { initial_lr: f32, decay_rate: f32 },
168 StepDecay {
170 initial_lr: f32,
171 step_size: usize,
172 gamma: f32,
173 },
174 CosineAnnealing { initial_lr: f32, t_max: usize },
176 WarmRestart {
178 initial_lr: f32,
179 t_0: usize,
180 t_mult: usize,
181 },
182}
183
184#[derive(Debug, Clone)]
186pub struct MemoryBuffer {
187 pub inputs: Vec<Tensor>,
189 pub targets: Vec<Tensor>,
191 pub task_ids: Vec<usize>,
193 pub priorities: Vec<f32>,
195 pub max_size: usize,
197 pub insertion_ptr: usize,
199 pub selection_strategy: MemorySelectionStrategy,
201}
202
203impl MemoryBuffer {
204 pub fn new(max_size: usize, selection_strategy: MemorySelectionStrategy) -> Self {
206 Self {
207 inputs: Vec::new(),
208 targets: Vec::new(),
209 task_ids: Vec::new(),
210 priorities: Vec::new(),
211 max_size,
212 insertion_ptr: 0,
213 selection_strategy,
214 }
215 }
216
217 pub fn add_example(&mut self, input: Tensor, target: Tensor, task_id: usize, priority: f32) {
219 if self.inputs.len() < self.max_size {
220 self.inputs.push(input);
222 self.targets.push(target);
223 self.task_ids.push(task_id);
224 self.priorities.push(priority);
225 } else {
226 match self.selection_strategy {
228 MemorySelectionStrategy::Random => {
229 let idx = fastrand::usize(..self.max_size);
230 self.inputs[idx] = input;
231 self.targets[idx] = target;
232 self.task_ids[idx] = task_id;
233 self.priorities[idx] = priority;
234 },
235 MemorySelectionStrategy::FIFO | MemorySelectionStrategy::RingBuffer => {
236 self.inputs[self.insertion_ptr] = input;
237 self.targets[self.insertion_ptr] = target;
238 self.task_ids[self.insertion_ptr] = task_id;
239 self.priorities[self.insertion_ptr] = priority;
240 self.insertion_ptr = (self.insertion_ptr + 1) % self.max_size;
241 },
242 _ => {
243 let min_idx = self
245 .priorities
246 .iter()
247 .enumerate()
248 .min_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal))
249 .map(|(idx, _)| idx)
250 .unwrap_or(0);
251
252 if priority > self.priorities[min_idx] {
253 self.inputs[min_idx] = input;
254 self.targets[min_idx] = target;
255 self.task_ids[min_idx] = task_id;
256 self.priorities[min_idx] = priority;
257 }
258 },
259 }
260 }
261 }
262
263 pub fn sample_batch(
265 &self,
266 batch_size: usize,
267 ) -> Result<(Vec<Tensor>, Vec<Tensor>, Vec<usize>)> {
268 if self.inputs.is_empty() {
269 return Ok((Vec::new(), Vec::new(), Vec::new()));
270 }
271
272 let sample_size = batch_size.min(self.inputs.len());
273 let mut indices = Vec::new();
274
275 match self.selection_strategy {
276 MemorySelectionStrategy::Random => {
277 for _ in 0..sample_size {
278 indices.push(fastrand::usize(..self.inputs.len()));
279 }
280 },
281 _ => {
282 let total_priority: f32 = self.priorities.iter().sum();
284 for _ in 0..sample_size {
285 let mut cumsum = 0.0;
286 let threshold = fastrand::f32() * total_priority;
287 for (i, &priority) in self.priorities.iter().enumerate() {
288 cumsum += priority;
289 if cumsum >= threshold {
290 indices.push(i);
291 break;
292 }
293 }
294 }
295 },
296 }
297
298 let inputs: Vec<Tensor> = indices.iter().map(|&i| self.inputs[i].clone()).collect();
299 let targets: Vec<Tensor> = indices.iter().map(|&i| self.targets[i].clone()).collect();
300 let task_ids: Vec<usize> = indices.iter().map(|&i| self.task_ids[i]).collect();
301
302 Ok((inputs, targets, task_ids))
303 }
304
305 pub fn get_task_examples(&self, task_id: usize) -> (Vec<Tensor>, Vec<Tensor>) {
307 let mut inputs = Vec::new();
308 let mut targets = Vec::new();
309
310 for (i, &tid) in self.task_ids.iter().enumerate() {
311 if tid == task_id {
312 inputs.push(self.inputs[i].clone());
313 targets.push(self.targets[i].clone());
314 }
315 }
316
317 (inputs, targets)
318 }
319
320 pub fn clear(&mut self) {
322 self.inputs.clear();
323 self.targets.clear();
324 self.task_ids.clear();
325 self.priorities.clear();
326 self.insertion_ptr = 0;
327 }
328
329 pub fn size(&self) -> usize {
331 self.inputs.len()
332 }
333
334 pub fn is_empty(&self) -> bool {
336 self.inputs.is_empty()
337 }
338}
339
340pub struct ContinualLearningTrainer<M: Model> {
342 pub model: M,
344 pub config: ContinualLearningConfig,
346 pub memory: MemoryBuffer,
348 pub task_info: HashMap<usize, TaskInfo>,
350 pub current_task: Option<usize>,
352 pub fisher_matrices: HashMap<String, Tensor>,
354 pub optimal_parameters: HashMap<String, Tensor>,
356 pub step_counter: usize,
358 pub task_detector: Option<TaskDetector>,
360}
361
362impl<M: Model<Input = Tensor, Output = Tensor>> ContinualLearningTrainer<M> {
363 pub fn new(model: M, config: ContinualLearningConfig) -> Result<Self> {
365 let memory = MemoryBuffer::new(config.memory_size, config.memory_selection.clone());
366
367 let task_detector = if config.automatic_task_detection {
368 Some(TaskDetector::new(config.task_detection_threshold))
369 } else {
370 None
371 };
372
373 Ok(Self {
374 model,
375 config,
376 memory,
377 task_info: HashMap::new(),
378 current_task: None,
379 fisher_matrices: HashMap::new(),
380 optimal_parameters: HashMap::new(),
381 step_counter: 0,
382 task_detector,
383 })
384 }
385
386 pub fn start_task(&mut self, task_id: usize) -> Result<()> {
388 if let Some(current_id) = self.current_task {
390 if current_id != task_id {
391 self.finalize_task(current_id)?;
392 }
393 }
394
395 self.current_task = Some(task_id);
396
397 self.task_info.entry(task_id).or_insert_with(|| TaskInfo::new(task_id));
399
400 match &self.config.strategy {
402 ContinualStrategy::ProgressiveNeuralNetworks { .. } => {
403 self.add_progressive_columns(task_id)?;
405 },
406 ContinualStrategy::PackNet { .. } => {
407 self.prepare_packnet(task_id)?;
409 },
410 _ => {
411 },
413 }
414
415 Ok(())
416 }
417
418 pub fn learn_batch(
420 &mut self,
421 inputs: &[Tensor],
422 targets: &[Tensor],
423 task_id: Option<usize>,
424 ) -> Result<ContinualLearningOutput> {
425 let task_id = task_id
426 .or(self.current_task)
427 .ok_or_else(|| invalid_input("No task ID specified"))?;
428
429 if let Some(detector) = &mut self.task_detector {
431 if let Some(detected_task) = detector.detect_task_change(inputs, targets)? {
432 if detected_task != task_id {
433 self.start_task(detected_task)?;
434 }
435 }
436 }
437
438 let outputs = self.model.forward(inputs[0].clone())?; let current_loss = self.compute_task_loss(&outputs, &targets[0])?;
441 let current_loss_for_output = current_loss.clone();
442
443 let total_loss = match &self.config.strategy {
445 ContinualStrategy::ElasticWeightConsolidation { lambda, .. } => {
446 let ewc_loss = self.compute_ewc_loss(*lambda)?;
447 current_loss.add(&ewc_loss)?
448 },
449 ContinualStrategy::LearningWithoutForgetting {
450 lambda,
451 temperature,
452 } => {
453 let distillation_loss = self.compute_lwf_loss(inputs, *lambda, *temperature)?;
454 current_loss.add(&distillation_loss)?
455 },
456 ContinualStrategy::ExperienceReplay {
457 memory_strength,
458 replay_batch_size,
459 } => {
460 let replay_loss = self.compute_replay_loss(*memory_strength, *replay_batch_size)?;
461 current_loss.add(&replay_loss)?
462 },
463 ContinualStrategy::GradientEpisodicMemory {
464 memory_strength, ..
465 } => self.compute_gem_loss(¤t_loss, *memory_strength)?,
466 ContinualStrategy::L2Regularization { lambda } => {
467 let l2_loss = self.compute_l2_regularization(*lambda)?;
468 current_loss.add(&l2_loss)?
469 },
470 _ => current_loss,
471 };
472
473 if !matches!(
475 self.config.strategy,
476 ContinualStrategy::L2Regularization { .. }
477 ) {
478 for (input, target) in inputs.iter().zip(targets.iter()) {
479 let priority = self.compute_example_priority(input, target)?;
480 self.memory.add_example(input.clone(), target.clone(), task_id, priority);
481 }
482 }
483
484 self.step_counter += 1;
486
487 if let Some(task_info) = self.task_info.get_mut(&task_id) {
489 task_info.update_statistics(total_loss.to_scalar().unwrap_or(0.0));
490 }
491
492 let total_loss_clone = total_loss.clone();
493
494 Ok(ContinualLearningOutput {
495 total_loss: total_loss_clone.clone(),
496 task_loss: current_loss_for_output.clone(),
497 regularization_loss: total_loss_clone.sub(¤t_loss_for_output)?,
498 task_id,
499 memory_usage: self.memory.size(),
500 })
501 }
502
503 pub fn finalize_task(&mut self, task_id: usize) -> Result<()> {
505 match self.config.strategy.clone() {
506 ContinualStrategy::ElasticWeightConsolidation { fisher_samples, .. }
507 | ContinualStrategy::OnlineElasticWeightConsolidation { fisher_samples, .. } => {
508 self.compute_fisher_information(task_id, fisher_samples)?;
509 self.save_optimal_parameters()?;
510 },
511 ContinualStrategy::PackNet {
512 prune_ratio,
513 retrain_epochs,
514 } => {
515 self.apply_packnet_pruning(prune_ratio)?;
516 self.retrain_after_pruning(retrain_epochs)?;
517 },
518 _ => {
519 },
521 }
522
523 Ok(())
524 }
525
526 fn compute_task_loss(&self, outputs: &Tensor, targets: &Tensor) -> Result<Tensor> {
528 let log_probs = outputs.softmax(-1)?.log()?;
530
531 let targets_shape = targets.shape();
533 let outputs_shape = outputs.shape();
534
535 if targets_shape == outputs_shape {
536 let element_wise = log_probs.mul(targets)?;
538 let sum_per_sample = element_wise.sum(Some(vec![outputs_shape.len() - 1]), false)?; Ok(sum_per_sample.neg()?.mean()?)
540 } else {
541 let batch_size = outputs_shape[0];
545 let num_classes = outputs_shape[outputs_shape.len() - 1];
546
547 let mut one_hot_data = vec![0.0f32; batch_size * num_classes];
549 let targets_data = targets.data()?;
550
551 for (i, &target_idx) in targets_data.iter().enumerate() {
552 if target_idx >= 0.0 && (target_idx as usize) < num_classes {
553 one_hot_data[i * num_classes + target_idx as usize] = 1.0;
554 }
555 }
556
557 let one_hot_targets = Tensor::new(one_hot_data)?.reshape(&outputs_shape)?;
558 let element_wise = log_probs.mul(&one_hot_targets)?;
559 let sum_per_sample = element_wise.sum(Some(vec![outputs_shape.len() - 1]), false)?; Ok(sum_per_sample.neg()?.mean()?)
561 }
562 }
563
564 fn compute_ewc_loss(&self, lambda: f32) -> Result<Tensor> {
566 let mut total_loss = Tensor::zeros(&[1])?;
567
568 for (param_name, fisher) in &self.fisher_matrices {
571 if let Some(optimal) = self.optimal_parameters.get(param_name) {
572 let current_param = Tensor::zeros_like(optimal)?; let diff = current_param.sub(optimal)?;
575 let squared_diff = diff.mul(&diff)?;
576 let weighted_diff = fisher.mul(&squared_diff)?;
577 total_loss = total_loss.add(&weighted_diff.sum(None, false)?)?;
578 }
579 }
580
581 total_loss.scalar_mul(lambda)
582 }
583
584 fn compute_lwf_loss(
586 &self,
587 _inputs: &[Tensor],
588 lambda: f32,
589 _temperature: f32,
590 ) -> Result<Tensor> {
591 Tensor::zeros(&[1])?.scalar_mul(lambda)
594 }
595
596 fn compute_replay_loss(
598 &mut self,
599 memory_strength: f32,
600 replay_batch_size: usize,
601 ) -> Result<Tensor> {
602 if self.memory.is_empty() {
603 return Tensor::zeros(&[1]);
604 }
605
606 let (replay_inputs, replay_targets, _) = self.memory.sample_batch(replay_batch_size)?;
607
608 if replay_inputs.is_empty() {
609 return Tensor::zeros(&[1]);
610 }
611
612 let replay_outputs = self.model.forward(replay_inputs[0].clone())?; let replay_loss = self.compute_task_loss(&replay_outputs, &replay_targets[0])?;
615
616 replay_loss.scalar_mul(memory_strength)
617 }
618
619 fn compute_gem_loss(&mut self, current_loss: &Tensor, memory_strength: f32) -> Result<Tensor> {
621 current_loss.scalar_mul(memory_strength)
624 }
625
626 fn compute_l2_regularization(&self, lambda: f32) -> Result<Tensor> {
628 Tensor::zeros(&[1])?.scalar_mul(lambda)
631 }
632
633 fn compute_example_priority(&self, input: &Tensor, target: &Tensor) -> Result<f32> {
635 match self.config.memory_selection {
636 MemorySelectionStrategy::Random => Ok(1.0),
637 MemorySelectionStrategy::Uncertainty => {
638 let outputs = self.model.forward(input.clone())?;
640 let probs = outputs.softmax(-1)?;
641 let entropy = -(probs.clone().mul(&probs.log()?)?)
642 .sum(Some(vec![1]), false)?
643 .to_scalar()
644 .unwrap_or(0.0);
645 Ok(entropy)
646 },
647 MemorySelectionStrategy::HighestLoss => {
648 let outputs = self.model.forward(input.clone())?;
649 let loss = self.compute_task_loss(&outputs, target)?;
650 Ok(loss.to_scalar().unwrap_or(0.0))
651 },
652 _ => Ok(1.0), }
654 }
655
656 fn compute_fisher_information(&mut self, task_id: usize, num_samples: usize) -> Result<()> {
658 let (task_inputs, task_targets) = self.memory.get_task_examples(task_id);
660
661 if task_inputs.is_empty() {
662 return Ok(());
663 }
664
665 let sample_size = num_samples.min(task_inputs.len());
667
668 for i in 0..sample_size {
671 let input = &task_inputs[i % task_inputs.len()];
672 let target = &task_targets[i % task_targets.len()];
673
674 let outputs = self.model.forward(input.clone())?;
676 let _loss = self.compute_task_loss(&outputs, target)?;
677
678 self.fisher_matrices.insert(
680 format!("param_{}", i),
681 Tensor::ones(&[10])?, );
683 }
684
685 Ok(())
686 }
687
688 fn save_optimal_parameters(&mut self) -> Result<()> {
690 self.optimal_parameters.insert(
693 "param_0".to_string(),
694 Tensor::zeros(&[10])?, );
696 Ok(())
697 }
698
699 fn add_progressive_columns(&mut self, _task_id: usize) -> Result<()> {
701 Ok(())
704 }
705
706 fn prepare_packnet(&mut self, _task_id: usize) -> Result<()> {
708 Ok(())
710 }
711
712 fn apply_packnet_pruning(&mut self, _prune_ratio: f32) -> Result<()> {
714 Ok(())
716 }
717
718 fn retrain_after_pruning(&mut self, _epochs: usize) -> Result<()> {
720 Ok(())
722 }
723
724 pub fn evaluate_all_tasks(&self) -> Result<HashMap<usize, TaskEvaluation>> {
726 let mut evaluations = HashMap::new();
727
728 for &task_id in self.task_info.keys() {
729 let (task_inputs, task_targets) = self.memory.get_task_examples(task_id);
730
731 if !task_inputs.is_empty() {
732 let evaluation = self.evaluate_task(&task_inputs, &task_targets, task_id)?;
733 evaluations.insert(task_id, evaluation);
734 }
735 }
736
737 Ok(evaluations)
738 }
739
740 fn evaluate_task(
742 &self,
743 inputs: &[Tensor],
744 targets: &[Tensor],
745 task_id: usize,
746 ) -> Result<TaskEvaluation> {
747 let mut total_loss = 0.0;
748 let mut correct_predictions = 0;
749 let total_examples = inputs.len();
750
751 for (input, target) in inputs.iter().zip(targets.iter()) {
752 let outputs = self.model.forward(input.clone())?;
753 let loss = self.compute_task_loss(&outputs, target)?;
754 total_loss += loss.to_scalar().unwrap_or(0.0);
755
756 let predicted = Tensor::zeros(&[1])?; let target_class = Tensor::zeros(&[1])?; if predicted.to_scalar().unwrap_or(-1.0) == target_class.to_scalar().unwrap_or(-2.0) {
760 correct_predictions += 1;
761 }
762 }
763
764 Ok(TaskEvaluation {
765 task_id,
766 average_loss: total_loss / total_examples as f32,
767 accuracy: correct_predictions as f32 / total_examples as f32,
768 num_examples: total_examples,
769 })
770 }
771
772 pub fn get_metrics(&self) -> ContinualLearningMetrics {
774 let all_evaluations = self.evaluate_all_tasks().unwrap_or_default();
775
776 let average_accuracy = if !all_evaluations.is_empty() {
777 all_evaluations.values().map(|e| e.accuracy).sum::<f32>() / all_evaluations.len() as f32
778 } else {
779 0.0
780 };
781
782 let memory_efficiency = self.memory.size() as f32 / self.config.memory_size as f32;
783
784 ContinualLearningMetrics {
785 average_accuracy,
786 task_evaluations: all_evaluations,
787 memory_efficiency,
788 num_tasks_learned: self.task_info.len(),
789 current_task: self.current_task,
790 }
791 }
792}
793
794#[derive(Debug, Clone)]
796pub struct TaskInfo {
797 pub task_id: usize,
798 pub start_step: usize,
799 pub num_examples_seen: usize,
800 pub average_loss: f32,
801 pub last_accuracy: f32,
802}
803
804impl TaskInfo {
805 pub fn new(task_id: usize) -> Self {
806 Self {
807 task_id,
808 start_step: 0,
809 num_examples_seen: 0,
810 average_loss: 0.0,
811 last_accuracy: 0.0,
812 }
813 }
814
815 pub fn update_statistics(&mut self, loss: f32) {
816 self.num_examples_seen += 1;
817 self.average_loss = (self.average_loss * (self.num_examples_seen - 1) as f32 + loss)
818 / self.num_examples_seen as f32;
819 }
820}
821
822pub struct TaskDetector {
824 #[allow(dead_code)]
825 threshold: f32,
826 #[allow(dead_code)]
827 recent_losses: Vec<f32>,
828 #[allow(dead_code)]
829 window_size: usize,
830}
831
832impl TaskDetector {
833 pub fn new(threshold: f32) -> Self {
834 Self {
835 threshold,
836 recent_losses: Vec::new(),
837 window_size: 100,
838 }
839 }
840
841 pub fn detect_task_change(
842 &mut self,
843 _inputs: &[Tensor],
844 _targets: &[Tensor],
845 ) -> Result<Option<usize>> {
846 Ok(None)
849 }
850}
851
852#[derive(Debug, Clone)]
854pub struct ContinualLearningOutput {
855 pub total_loss: Tensor,
856 pub task_loss: Tensor,
857 pub regularization_loss: Tensor,
858 pub task_id: usize,
859 pub memory_usage: usize,
860}
861
862#[derive(Debug, Clone)]
864pub struct TaskEvaluation {
865 pub task_id: usize,
866 pub average_loss: f32,
867 pub accuracy: f32,
868 pub num_examples: usize,
869}
870
871#[derive(Debug, Clone)]
873pub struct ContinualLearningMetrics {
874 pub average_accuracy: f32,
875 pub task_evaluations: HashMap<usize, TaskEvaluation>,
876 pub memory_efficiency: f32,
877 pub num_tasks_learned: usize,
878 pub current_task: Option<usize>,
879}
880
881pub mod utils {
883 use super::*;
884
885 pub fn ewc_config(
887 lambda: f32,
888 fisher_samples: usize,
889 memory_size: usize,
890 ) -> ContinualLearningConfig {
891 ContinualLearningConfig {
892 strategy: ContinualStrategy::ElasticWeightConsolidation {
893 lambda,
894 fisher_samples,
895 },
896 memory_size,
897 ..Default::default()
898 }
899 }
900
901 pub fn experience_replay_config(
903 memory_size: usize,
904 replay_batch_size: usize,
905 ) -> ContinualLearningConfig {
906 ContinualLearningConfig {
907 strategy: ContinualStrategy::ExperienceReplay {
908 memory_strength: 1.0,
909 replay_batch_size,
910 },
911 memory_size,
912 memory_selection: MemorySelectionStrategy::Random,
913 ..Default::default()
914 }
915 }
916
917 pub fn l2_regularization_config(lambda: f32) -> ContinualLearningConfig {
919 ContinualLearningConfig {
920 strategy: ContinualStrategy::L2Regularization { lambda },
921 memory_size: 0, ..Default::default()
923 }
924 }
925
926 pub fn progressive_networks_config() -> ContinualLearningConfig {
928 ContinualLearningConfig {
929 strategy: ContinualStrategy::ProgressiveNeuralNetworks {
930 lateral_connections: true,
931 adapter_layers: true,
932 },
933 task_specific_heads: true,
934 ..Default::default()
935 }
936 }
937
938 pub fn compute_backward_transfer(
940 evaluations_before: &HashMap<usize, TaskEvaluation>,
941 evaluations_after: &HashMap<usize, TaskEvaluation>,
942 ) -> f32 {
943 let mut total_transfer = 0.0;
944 let mut num_tasks = 0;
945
946 for (&task_id, after_eval) in evaluations_after {
947 if let Some(before_eval) = evaluations_before.get(&task_id) {
948 total_transfer += after_eval.accuracy - before_eval.accuracy;
949 num_tasks += 1;
950 }
951 }
952
953 if num_tasks > 0 {
954 total_transfer / num_tasks as f32
955 } else {
956 0.0
957 }
958 }
959
960 pub fn compute_forward_transfer(baseline_accuracy: f32, continual_accuracy: f32) -> f32 {
962 continual_accuracy - baseline_accuracy
963 }
964
965 pub fn compute_forgetting(
967 max_accuracies: &HashMap<usize, f32>,
968 final_accuracies: &HashMap<usize, f32>,
969 ) -> f32 {
970 let mut total_forgetting = 0.0;
971 let mut num_tasks = 0;
972
973 for (&task_id, &max_acc) in max_accuracies {
974 if let Some(&final_acc) = final_accuracies.get(&task_id) {
975 total_forgetting += max_acc - final_acc;
976 num_tasks += 1;
977 }
978 }
979
980 if num_tasks > 0 {
981 total_forgetting / num_tasks as f32
982 } else {
983 0.0
984 }
985 }
986}
987
988#[cfg(test)]
989mod tests {
990 use super::*;
991
992 #[test]
993 fn test_continual_learning_config_default() {
994 let config = ContinualLearningConfig::default();
995 assert_eq!(config.memory_size, 1000);
996 assert!(config.task_specific_heads);
997 assert!(!config.automatic_task_detection);
998
999 if let ContinualStrategy::ElasticWeightConsolidation {
1000 lambda,
1001 fisher_samples,
1002 } = config.strategy
1003 {
1004 assert_eq!(lambda, 0.4);
1005 assert_eq!(fisher_samples, 1000);
1006 } else {
1007 panic!("Expected EWC strategy");
1008 }
1009 }
1010
1011 #[test]
1012 fn test_memory_buffer() {
1013 let mut buffer = MemoryBuffer::new(3, MemorySelectionStrategy::Random);
1014 assert!(buffer.is_empty());
1015 assert_eq!(buffer.size(), 0);
1016
1017 let input1 = Tensor::zeros(&[1, 10]).expect("operation failed");
1019 let target1 = Tensor::zeros(&[1]).expect("operation failed");
1020 buffer.add_example(input1, target1, 0, 1.0);
1021 assert_eq!(buffer.size(), 1);
1022
1023 let input2 = Tensor::ones(&[1, 10]).expect("operation failed");
1024 let target2 = Tensor::ones(&[1]).expect("operation failed");
1025 buffer.add_example(input2, target2, 1, 2.0);
1026 assert_eq!(buffer.size(), 2);
1027
1028 let (inputs, targets, task_ids) = buffer.sample_batch(2).expect("operation failed");
1030 assert_eq!(inputs.len(), 2);
1031 assert_eq!(targets.len(), 2);
1032 assert_eq!(task_ids.len(), 2);
1033 }
1034
1035 #[test]
1036 fn test_ewc_config() {
1037 let config = utils::ewc_config(0.5, 2000, 500);
1038 assert_eq!(config.memory_size, 500);
1039
1040 if let ContinualStrategy::ElasticWeightConsolidation {
1041 lambda,
1042 fisher_samples,
1043 } = config.strategy
1044 {
1045 assert_eq!(lambda, 0.5);
1046 assert_eq!(fisher_samples, 2000);
1047 } else {
1048 panic!("Expected EWC strategy");
1049 }
1050 }
1051
1052 #[test]
1053 fn test_experience_replay_config() {
1054 let config = utils::experience_replay_config(1000, 64);
1055 assert_eq!(config.memory_size, 1000);
1056
1057 if let ContinualStrategy::ExperienceReplay {
1058 memory_strength,
1059 replay_batch_size,
1060 } = config.strategy
1061 {
1062 assert_eq!(memory_strength, 1.0);
1063 assert_eq!(replay_batch_size, 64);
1064 } else {
1065 panic!("Expected ExperienceReplay strategy");
1066 }
1067 }
1068
1069 #[test]
1070 fn test_l2_regularization_config() {
1071 let config = utils::l2_regularization_config(0.01);
1072 assert_eq!(config.memory_size, 0);
1073
1074 if let ContinualStrategy::L2Regularization { lambda } = config.strategy {
1075 assert_eq!(lambda, 0.01);
1076 } else {
1077 panic!("Expected L2Regularization strategy");
1078 }
1079 }
1080
1081 #[test]
1082 fn test_task_info() {
1083 let mut info = TaskInfo::new(5);
1084 assert_eq!(info.task_id, 5);
1085 assert_eq!(info.num_examples_seen, 0);
1086
1087 info.update_statistics(0.5);
1088 assert_eq!(info.num_examples_seen, 1);
1089 assert_eq!(info.average_loss, 0.5);
1090
1091 info.update_statistics(1.0);
1092 assert_eq!(info.num_examples_seen, 2);
1093 assert_eq!(info.average_loss, 0.75);
1094 }
1095
1096 #[test]
1097 fn test_backward_transfer_computation() {
1098 let mut before = HashMap::new();
1099 before.insert(
1100 0,
1101 TaskEvaluation {
1102 task_id: 0,
1103 average_loss: 0.5,
1104 accuracy: 0.8,
1105 num_examples: 100,
1106 },
1107 );
1108 before.insert(
1109 1,
1110 TaskEvaluation {
1111 task_id: 1,
1112 average_loss: 0.6,
1113 accuracy: 0.7,
1114 num_examples: 100,
1115 },
1116 );
1117
1118 let mut after = HashMap::new();
1119 after.insert(
1120 0,
1121 TaskEvaluation {
1122 task_id: 0,
1123 average_loss: 0.4,
1124 accuracy: 0.85,
1125 num_examples: 100,
1126 },
1127 );
1128 after.insert(
1129 1,
1130 TaskEvaluation {
1131 task_id: 1,
1132 average_loss: 0.55,
1133 accuracy: 0.72,
1134 num_examples: 100,
1135 },
1136 );
1137
1138 let backward_transfer = utils::compute_backward_transfer(&before, &after);
1139 assert!((backward_transfer - 0.035).abs() < 1e-6); }
1141
1142 #[test]
1143 fn test_forgetting_computation() {
1144 let mut max_accuracies = HashMap::new();
1145 max_accuracies.insert(0, 0.9);
1146 max_accuracies.insert(1, 0.85);
1147
1148 let mut final_accuracies = HashMap::new();
1149 final_accuracies.insert(0, 0.8);
1150 final_accuracies.insert(1, 0.75);
1151
1152 let forgetting = utils::compute_forgetting(&max_accuracies, &final_accuracies);
1153 assert!((forgetting - 0.1).abs() < 1e-6); }
1155}