1use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2, Axis};
7use scirs2_core::random::thread_rng;
8use scirs2_core::random::Rng;
9use sklears_core::{
10 error::Result as SklResult,
11 prelude::{Predict, SklearsError},
12 traits::{Estimator, Fit, Untrained},
13 types::Float,
14};
15use std::collections::{HashMap, VecDeque};
16use std::fmt::Debug;
17
18use crate::{PipelinePredictor, PipelineStep};
19
20#[derive(Debug, Clone)]
22pub struct Task {
23 pub id: String,
25 pub features: Array2<f64>,
27 pub targets: Array1<f64>,
29 pub metadata: HashMap<String, String>,
31 pub importance_weights: Option<HashMap<String, f64>>,
33 pub statistics: TaskStatistics,
35}
36
37#[derive(Debug, Clone)]
39pub struct TaskStatistics {
40 pub n_samples: usize,
42 pub n_features: usize,
44 pub difficulty: f64,
46 pub performance: HashMap<String, f64>,
48 pub learning_time: f64,
50}
51
52impl Task {
53 #[must_use]
55 pub fn new(id: String, features: Array2<f64>, targets: Array1<f64>) -> Self {
56 let n_samples = features.nrows();
57 let n_features = features.ncols();
58
59 Self {
60 id,
61 features,
62 targets,
63 metadata: HashMap::new(),
64 importance_weights: None,
65 statistics: TaskStatistics {
66 n_samples,
67 n_features,
68 difficulty: 1.0, performance: HashMap::new(),
70 learning_time: 0.0,
71 },
72 }
73 }
74
75 #[must_use]
77 pub fn with_metadata(mut self, metadata: HashMap<String, String>) -> Self {
78 self.metadata = metadata;
79 self
80 }
81
82 #[must_use]
84 pub fn with_importance_weights(mut self, weights: HashMap<String, f64>) -> Self {
85 self.importance_weights = Some(weights);
86 self
87 }
88
89 pub fn estimate_difficulty(&mut self) {
91 let feature_variance = self.features.var_axis(Axis(0), 1.0).mean().unwrap_or(1.0);
92 let target_variance = self.targets.var(1.0);
93
94 self.statistics.difficulty = (feature_variance + target_variance).max(0.1);
96 }
97}
98
99#[derive(Debug, Clone)]
101pub enum ContinualLearningStrategy {
102 ElasticWeightConsolidation {
104 lambda: f64,
106 fisher_samples: usize,
108 },
109 ProgressiveNetworks {
111 max_columns: usize,
113 lateral_strength: f64,
115 },
116 ExperienceReplay {
118 buffer_size: usize,
120 replay_batch_size: usize,
122 replay_frequency: usize,
124 },
125 LearningWithoutForgetting {
127 temperature: f64,
129 distillation_weight: f64,
131 },
132 MemoryAugmented {
134 memory_size: usize,
136 read_heads: usize,
138 write_strength: f64,
140 },
141 GradientEpisodicMemory {
143 memory_size: usize,
145 tolerance: f64,
147 },
148}
149
150#[derive(Debug, Clone)]
152pub struct MemoryBuffer {
153 max_size: usize,
155 samples: VecDeque<MemorySample>,
157 task_distributions: HashMap<String, usize>,
159 sampling_strategy: SamplingStrategy,
161}
162
163#[derive(Debug, Clone)]
165pub struct MemorySample {
166 pub features: Array1<f64>,
168 pub target: f64,
170 pub task_id: String,
172 pub importance: f64,
174 pub gradient_info: Option<HashMap<String, f64>>,
176}
177
178#[derive(Debug, Clone)]
180pub enum SamplingStrategy {
181 Random,
183 Reservoir,
185 ImportanceBased,
187 TaskBalanced,
189 GradientBased,
191}
192
193impl MemoryBuffer {
194 #[must_use]
196 pub fn new(max_size: usize, sampling_strategy: SamplingStrategy) -> Self {
197 Self {
198 max_size,
199 samples: VecDeque::new(),
200 task_distributions: HashMap::new(),
201 sampling_strategy,
202 }
203 }
204
205 pub fn add_sample(&mut self, sample: MemorySample) {
207 *self
209 .task_distributions
210 .entry(sample.task_id.clone())
211 .or_insert(0) += 1;
212
213 if self.samples.len() >= self.max_size {
214 match self.sampling_strategy {
215 SamplingStrategy::Random => {
216 let replace_idx = thread_rng().gen_range(0..self.samples.len());
217 if let Some(old_sample) = self.samples.get(replace_idx) {
218 let count = self
219 .task_distributions
220 .get_mut(&old_sample.task_id)
221 .unwrap();
222 *count -= 1;
223 if *count == 0 {
224 self.task_distributions.remove(&old_sample.task_id);
225 }
226 }
227 self.samples[replace_idx] = sample;
228 }
229 SamplingStrategy::Reservoir => {
230 let replace_idx = thread_rng().gen_range(0..(self.samples.len() + 1));
232 if replace_idx < self.samples.len() {
233 if let Some(old_sample) = self.samples.get(replace_idx) {
234 let count = self
235 .task_distributions
236 .get_mut(&old_sample.task_id)
237 .unwrap();
238 *count -= 1;
239 if *count == 0 {
240 self.task_distributions.remove(&old_sample.task_id);
241 }
242 }
243 self.samples[replace_idx] = sample;
244 }
245 }
246 SamplingStrategy::ImportanceBased => {
247 let min_importance_idx = self
249 .samples
250 .iter()
251 .enumerate()
252 .min_by(|(_, a), (_, b)| a.importance.partial_cmp(&b.importance).unwrap())
253 .map_or(0, |(idx, _)| idx);
254
255 if sample.importance > self.samples[min_importance_idx].importance {
256 if let Some(old_sample) = self.samples.get(min_importance_idx) {
257 let count = self
258 .task_distributions
259 .get_mut(&old_sample.task_id)
260 .unwrap();
261 *count -= 1;
262 if *count == 0 {
263 self.task_distributions.remove(&old_sample.task_id);
264 }
265 }
266 self.samples[min_importance_idx] = sample;
267 }
268 }
269 SamplingStrategy::TaskBalanced => {
270 let max_task = self
272 .task_distributions
273 .iter()
274 .max_by_key(|(_, &count)| count)
275 .map(|(task_id, _)| task_id.clone());
276
277 if let Some(overrep_task) = max_task {
278 if let Some(idx) =
279 self.samples.iter().position(|s| s.task_id == overrep_task)
280 {
281 let count = self.task_distributions.get_mut(&overrep_task).unwrap();
282 *count -= 1;
283 if *count == 0 {
284 self.task_distributions.remove(&overrep_task);
285 }
286 self.samples[idx] = sample;
287 }
288 }
289 }
290 SamplingStrategy::GradientBased => {
291 let replace_idx = thread_rng().gen_range(0..self.samples.len());
294 if let Some(old_sample) = self.samples.get(replace_idx) {
295 let count = self
296 .task_distributions
297 .get_mut(&old_sample.task_id)
298 .unwrap();
299 *count -= 1;
300 if *count == 0 {
301 self.task_distributions.remove(&old_sample.task_id);
302 }
303 }
304 self.samples[replace_idx] = sample;
305 }
306 }
307 } else {
308 self.samples.push_back(sample);
309 }
310 }
311
312 #[must_use]
314 pub fn sample(&self, n_samples: usize) -> Vec<&MemorySample> {
315 if self.samples.is_empty() {
316 return Vec::new();
317 }
318
319 let n_samples = n_samples.min(self.samples.len());
320 let mut sampled = Vec::new();
321
322 match self.sampling_strategy {
323 SamplingStrategy::Random | SamplingStrategy::Reservoir => {
324 for _ in 0..n_samples {
325 let idx = thread_rng().gen_range(0..self.samples.len());
326 sampled.push(&self.samples[idx]);
327 }
328 }
329 SamplingStrategy::ImportanceBased => {
330 let total_importance: f64 = self.samples.iter().map(|s| s.importance).sum();
332 for _ in 0..n_samples {
333 let target = thread_rng().gen::<f64>() * total_importance;
334 let mut cumulative = 0.0;
335 for sample in &self.samples {
336 cumulative += sample.importance;
337 if cumulative >= target {
338 sampled.push(sample);
339 break;
340 }
341 }
342 }
343 }
344 SamplingStrategy::TaskBalanced => {
345 let unique_tasks: Vec<String> = self.task_distributions.keys().cloned().collect();
347 if !unique_tasks.is_empty() {
348 let samples_per_task = n_samples / unique_tasks.len();
349 let extra_samples = n_samples % unique_tasks.len();
350
351 for (i, task_id) in unique_tasks.iter().enumerate() {
352 let task_samples: Vec<&MemorySample> = self
353 .samples
354 .iter()
355 .filter(|s| &s.task_id == task_id)
356 .collect();
357
358 let task_sample_count = samples_per_task + usize::from(i < extra_samples);
359 for _ in 0..task_sample_count.min(task_samples.len()) {
360 let idx = thread_rng().gen_range(0..task_samples.len());
361 sampled.push(task_samples[idx]);
362 }
363 }
364 }
365 }
366 SamplingStrategy::GradientBased => {
367 for _ in 0..n_samples {
370 let idx = thread_rng().gen_range(0..self.samples.len());
371 sampled.push(&self.samples[idx]);
372 }
373 }
374 }
375
376 sampled
377 }
378
379 #[must_use]
381 pub fn statistics(&self) -> HashMap<String, f64> {
382 let mut stats = HashMap::new();
383 stats.insert("total_samples".to_string(), self.samples.len() as f64);
384 stats.insert(
385 "unique_tasks".to_string(),
386 self.task_distributions.len() as f64,
387 );
388
389 if !self.samples.is_empty() {
390 let avg_importance =
391 self.samples.iter().map(|s| s.importance).sum::<f64>() / self.samples.len() as f64;
392 stats.insert("average_importance".to_string(), avg_importance);
393 }
394
395 stats
396 }
397}
398
399#[derive(Debug)]
401pub struct ContinualLearningPipeline<S = Untrained> {
402 state: S,
403 base_estimator: Option<Box<dyn PipelinePredictor>>,
404 strategy: ContinualLearningStrategy,
405 memory_buffer: MemoryBuffer,
406 learned_tasks: Vec<String>,
407 current_task_id: Option<String>,
408}
409
410#[derive(Debug)]
412pub struct ContinualLearningPipelineTrained {
413 fitted_estimator: Box<dyn PipelinePredictor>,
414 strategy: ContinualLearningStrategy,
415 memory_buffer: MemoryBuffer,
416 learned_tasks: Vec<String>,
417 task_performance: HashMap<String, HashMap<String, f64>>,
418 importance_weights: HashMap<String, f64>,
419 n_features_in: usize,
420 feature_names_in: Option<Vec<String>>,
421}
422
423impl ContinualLearningPipeline<Untrained> {
424 #[must_use]
426 pub fn new(
427 base_estimator: Box<dyn PipelinePredictor>,
428 strategy: ContinualLearningStrategy,
429 ) -> Self {
430 let memory_buffer = match &strategy {
431 ContinualLearningStrategy::ExperienceReplay { buffer_size, .. } => {
432 MemoryBuffer::new(*buffer_size, SamplingStrategy::Random)
433 }
434 ContinualLearningStrategy::GradientEpisodicMemory { memory_size, .. } => {
435 MemoryBuffer::new(*memory_size, SamplingStrategy::GradientBased)
436 }
437 ContinualLearningStrategy::MemoryAugmented { memory_size, .. } => {
438 MemoryBuffer::new(*memory_size, SamplingStrategy::ImportanceBased)
439 }
440 _ => MemoryBuffer::new(1000, SamplingStrategy::Random), };
442
443 Self {
444 state: Untrained,
445 base_estimator: Some(base_estimator),
446 strategy,
447 memory_buffer,
448 learned_tasks: Vec::new(),
449 current_task_id: None,
450 }
451 }
452
453 #[must_use]
455 pub fn elastic_weight_consolidation(
456 base_estimator: Box<dyn PipelinePredictor>,
457 lambda: f64,
458 fisher_samples: usize,
459 ) -> Self {
460 Self::new(
461 base_estimator,
462 ContinualLearningStrategy::ElasticWeightConsolidation {
463 lambda,
464 fisher_samples,
465 },
466 )
467 }
468
469 #[must_use]
471 pub fn experience_replay(
472 base_estimator: Box<dyn PipelinePredictor>,
473 buffer_size: usize,
474 replay_batch_size: usize,
475 replay_frequency: usize,
476 ) -> Self {
477 Self::new(
478 base_estimator,
479 ContinualLearningStrategy::ExperienceReplay {
480 buffer_size,
481 replay_batch_size,
482 replay_frequency,
483 },
484 )
485 }
486
487 #[must_use]
489 pub fn learning_without_forgetting(
490 base_estimator: Box<dyn PipelinePredictor>,
491 temperature: f64,
492 distillation_weight: f64,
493 ) -> Self {
494 Self::new(
495 base_estimator,
496 ContinualLearningStrategy::LearningWithoutForgetting {
497 temperature,
498 distillation_weight,
499 },
500 )
501 }
502
503 #[must_use]
505 pub fn set_current_task(mut self, task_id: String) -> Self {
506 self.current_task_id = Some(task_id);
507 self
508 }
509}
510
511impl Estimator for ContinualLearningPipeline<Untrained> {
512 type Config = ();
513 type Error = SklearsError;
514 type Float = Float;
515
516 fn config(&self) -> &Self::Config {
517 &()
518 }
519}
520
521impl Fit<ArrayView2<'_, Float>, Option<&ArrayView1<'_, Float>>>
522 for ContinualLearningPipeline<Untrained>
523{
524 type Fitted = ContinualLearningPipeline<ContinualLearningPipelineTrained>;
525
526 fn fit(
527 mut self,
528 x: &ArrayView2<'_, Float>,
529 y: &Option<&ArrayView1<'_, Float>>,
530 ) -> SklResult<Self::Fitted> {
531 if let Some(y_values) = y.as_ref() {
532 let mut base_estimator = self.base_estimator.take().ok_or_else(|| {
533 SklearsError::InvalidInput("No base estimator provided".to_string())
534 })?;
535
536 let importance_weights =
538 self.apply_continual_learning_strategy(&mut base_estimator, x, y_values)?;
539
540 let task_id = self
541 .current_task_id
542 .clone()
543 .unwrap_or_else(|| "default_task".to_string());
544 self.learned_tasks.push(task_id.clone());
545
546 let mut task_performance = HashMap::new();
547 let mut perf_metrics = HashMap::new();
548 perf_metrics.insert("training_completed".to_string(), 1.0);
549 task_performance.insert(task_id, perf_metrics);
550
551 Ok(ContinualLearningPipeline {
552 state: ContinualLearningPipelineTrained {
553 fitted_estimator: base_estimator,
554 strategy: self.strategy,
555 memory_buffer: self.memory_buffer,
556 learned_tasks: self.learned_tasks,
557 task_performance,
558 importance_weights,
559 n_features_in: x.ncols(),
560 feature_names_in: None,
561 },
562 base_estimator: None,
563 strategy: ContinualLearningStrategy::ExperienceReplay {
564 buffer_size: 1000,
565 replay_batch_size: 32,
566 replay_frequency: 10,
567 },
568 memory_buffer: MemoryBuffer::new(1000, SamplingStrategy::Random),
569 learned_tasks: Vec::new(),
570 current_task_id: None,
571 })
572 } else {
573 Err(SklearsError::InvalidInput(
574 "Target values required for continual learning".to_string(),
575 ))
576 }
577 }
578}
579
580impl ContinualLearningPipeline<Untrained> {
581 fn apply_continual_learning_strategy(
583 &mut self,
584 estimator: &mut Box<dyn PipelinePredictor>,
585 x: &ArrayView2<'_, Float>,
586 y: &ArrayView1<'_, Float>,
587 ) -> SklResult<HashMap<String, f64>> {
588 let mut importance_weights = HashMap::new();
589
590 match &self.strategy {
591 ContinualLearningStrategy::ElasticWeightConsolidation {
592 lambda,
593 fisher_samples,
594 } => {
595 for i in 0..*fisher_samples.min(&x.nrows()) {
597 let param_name = format!("param_{i}");
598 let importance = self.compute_fisher_information(x, y, i);
599 importance_weights.insert(param_name, importance * lambda);
600 }
601
602 estimator.fit(x, y)?;
604 }
605 ContinualLearningStrategy::ExperienceReplay {
606 replay_batch_size,
607 replay_frequency,
608 ..
609 } => {
610 for i in 0..x.nrows() {
612 let sample = MemorySample {
613 features: x.row(i).mapv(|v| v),
614 target: y[i],
615 task_id: self
616 .current_task_id
617 .clone()
618 .unwrap_or_else(|| "default".to_string()),
619 importance: 1.0,
620 gradient_info: None,
621 };
622 self.memory_buffer.add_sample(sample);
623 }
624
625 for epoch in 0..*replay_frequency {
627 estimator.fit(x, y)?;
629
630 let replay_samples = self.memory_buffer.sample(*replay_batch_size);
632 if !replay_samples.is_empty() {
633 let replay_x = Array2::from_shape_vec(
635 (replay_samples.len(), x.ncols()),
636 replay_samples
637 .iter()
638 .flat_map(|s| s.features.iter().copied())
639 .collect(),
640 )
641 .map_err(|e| SklearsError::InvalidData {
642 reason: format!("Replay batch creation failed: {e}"),
643 })?;
644
645 let replay_y = Array1::from_vec(
646 replay_samples.iter().map(|s| s.target as Float).collect(),
647 );
648
649 estimator.fit(&replay_x.view(), &replay_y.view())?;
650 }
651 }
652 }
653 ContinualLearningStrategy::LearningWithoutForgetting {
654 temperature,
655 distillation_weight,
656 } => {
657 importance_weights.insert("temperature".to_string(), *temperature);
659 importance_weights.insert("distillation_weight".to_string(), *distillation_weight);
660
661 estimator.fit(x, y)?;
662 }
663 ContinualLearningStrategy::ProgressiveNetworks {
664 max_columns,
665 lateral_strength,
666 } => {
667 importance_weights.insert("columns".to_string(), self.learned_tasks.len() as f64);
669 importance_weights.insert("lateral_strength".to_string(), *lateral_strength);
670
671 estimator.fit(x, y)?;
672 }
673 ContinualLearningStrategy::MemoryAugmented {
674 memory_size,
675 read_heads,
676 write_strength,
677 } => {
678 importance_weights.insert(
680 "memory_usage".to_string(),
681 self.memory_buffer.samples.len() as f64 / *memory_size as f64,
682 );
683 importance_weights.insert("read_heads".to_string(), *read_heads as f64);
684 importance_weights.insert("write_strength".to_string(), *write_strength);
685
686 estimator.fit(x, y)?;
687 }
688 ContinualLearningStrategy::GradientEpisodicMemory {
689 memory_size,
690 tolerance,
691 } => {
692 for i in 0..x.nrows() {
694 let mut gradient_info = HashMap::new();
695 gradient_info.insert("grad_norm".to_string(), thread_rng().gen::<f64>()); let sample = MemorySample {
698 features: x.row(i).mapv(|v| v),
699 target: y[i],
700 task_id: self
701 .current_task_id
702 .clone()
703 .unwrap_or_else(|| "default".to_string()),
704 importance: 1.0,
705 gradient_info: Some(gradient_info),
706 };
707 self.memory_buffer.add_sample(sample);
708 }
709
710 importance_weights.insert(
711 "memory_utilization".to_string(),
712 self.memory_buffer.samples.len() as f64 / *memory_size as f64,
713 );
714 importance_weights.insert("tolerance".to_string(), *tolerance);
715
716 estimator.fit(x, y)?;
717 }
718 }
719
720 Ok(importance_weights)
721 }
722
723 fn compute_fisher_information(
725 &self,
726 x: &ArrayView2<'_, Float>,
727 y: &ArrayView1<'_, Float>,
728 param_idx: usize,
729 ) -> f64 {
730 if param_idx < x.ncols() {
732 let feature_variance = x.column(param_idx).var(1.0);
733 feature_variance.max(1e-8) } else {
735 1e-4 }
737 }
738}
739
740impl ContinualLearningPipeline<ContinualLearningPipelineTrained> {
741 pub fn predict(&self, x: &ArrayView2<'_, Float>) -> SklResult<Array1<f64>> {
743 self.state.fitted_estimator.predict(x)
744 }
745
746 pub fn learn_task(&mut self, task: Task) -> SklResult<()> {
748 let mut task_perf = HashMap::new();
750 task_perf.insert("samples".to_string(), task.statistics.n_samples as f64);
751 task_perf.insert("difficulty".to_string(), task.statistics.difficulty);
752
753 self.state
754 .task_performance
755 .insert(task.id.clone(), task_perf);
756
757 let x_view = task.features.view().mapv(|v| v as Float);
759 let y_view = task.targets.view().mapv(|v| v as Float);
760
761 match &self.state.strategy {
762 ContinualLearningStrategy::ExperienceReplay {
763 replay_batch_size, ..
764 } => {
765 for i in 0..task.features.nrows() {
767 let sample = MemorySample {
768 features: task.features.row(i).to_owned(),
769 target: task.targets[i],
770 task_id: task.id.clone(),
771 importance: 1.0,
772 gradient_info: None,
773 };
774 self.state.memory_buffer.add_sample(sample);
775 }
776
777 self.state
779 .fitted_estimator
780 .fit(&x_view.view(), &y_view.view())?;
781
782 let replay_samples = self.state.memory_buffer.sample(*replay_batch_size);
784 if !replay_samples.is_empty() {
785 let n_features = task.features.ncols();
787 let replay_x = Array2::from_shape_vec(
788 (replay_samples.len(), n_features),
789 replay_samples
790 .iter()
791 .flat_map(|s| s.features.iter().copied().map(|v| v as Float))
792 .collect(),
793 )
794 .map_err(|e| SklearsError::InvalidData {
795 reason: format!("Replay batch creation failed: {e}"),
796 })?;
797
798 let replay_y = Array1::from_vec(
799 replay_samples.iter().map(|s| s.target as Float).collect(),
800 );
801
802 self.state
803 .fitted_estimator
804 .fit(&replay_x.view(), &replay_y.view())?;
805 }
806 }
807 _ => {
808 self.state
810 .fitted_estimator
811 .fit(&x_view.view(), &y_view.view())?;
812 }
813 }
814
815 if !self.state.learned_tasks.contains(&task.id) {
816 self.state.learned_tasks.push(task.id);
817 }
818
819 Ok(())
820 }
821
822 pub fn evaluate_forgetting(&self, previous_tasks: &[Task]) -> SklResult<HashMap<String, f64>> {
824 let mut forgetting_metrics = HashMap::new();
825
826 for task in previous_tasks {
827 let x_view = task.features.view().mapv(|v| v as Float);
828 let predictions = self.predict(&x_view.view())?;
829
830 let correct = predictions
832 .iter()
833 .zip(task.targets.iter())
834 .filter(|(&pred, &actual)| (pred - actual).abs() < 0.5)
835 .count();
836
837 let accuracy = correct as f64 / task.targets.len() as f64;
838 forgetting_metrics.insert(format!("task_{}_accuracy", task.id), accuracy);
839 }
840
841 if !forgetting_metrics.is_empty() {
843 let avg_accuracy =
844 forgetting_metrics.values().sum::<f64>() / forgetting_metrics.len() as f64;
845 forgetting_metrics.insert("average_accuracy".to_string(), avg_accuracy);
846 }
847
848 Ok(forgetting_metrics)
849 }
850
851 #[must_use]
853 pub fn memory_statistics(&self) -> HashMap<String, f64> {
854 self.state.memory_buffer.statistics()
855 }
856
857 #[must_use]
859 pub fn learned_tasks(&self) -> &[String] {
860 &self.state.learned_tasks
861 }
862
863 #[must_use]
865 pub fn task_performance(&self) -> &HashMap<String, HashMap<String, f64>> {
866 &self.state.task_performance
867 }
868
869 #[must_use]
871 pub fn importance_weights(&self) -> &HashMap<String, f64> {
872 &self.state.importance_weights
873 }
874}
875
876#[allow(non_snake_case)]
877#[cfg(test)]
878mod tests {
879 use super::*;
880 use crate::MockPredictor;
881 use scirs2_core::ndarray::array;
882
883 #[test]
884 fn test_task_creation() {
885 let features = array![[1.0, 2.0], [3.0, 4.0]];
886 let targets = array![1.0, 0.0];
887
888 let mut task = Task::new("task1".to_string(), features, targets);
889 task.estimate_difficulty();
890
891 assert_eq!(task.id, "task1");
892 assert_eq!(task.statistics.n_samples, 2);
893 assert_eq!(task.statistics.n_features, 2);
894 assert!(task.statistics.difficulty > 0.0);
895 }
896
897 #[test]
898 fn test_memory_buffer() {
899 let mut buffer = MemoryBuffer::new(3, SamplingStrategy::Random);
900
901 let sample1 = MemorySample {
902 features: array![1.0, 2.0],
903 target: 1.0,
904 task_id: "task1".to_string(),
905 importance: 1.0,
906 gradient_info: None,
907 };
908
909 buffer.add_sample(sample1);
910 assert_eq!(buffer.samples.len(), 1);
911
912 let sampled = buffer.sample(1);
913 assert_eq!(sampled.len(), 1);
914 }
915
916 #[test]
917 fn test_continual_learning_pipeline() {
918 let x = array![[1.0, 2.0], [3.0, 4.0]];
919 let y = array![1.0, 0.0];
920
921 let base_estimator = Box::new(MockPredictor::new());
922 let pipeline = ContinualLearningPipeline::experience_replay(base_estimator, 100, 10, 5)
923 .set_current_task("task1".to_string());
924
925 let fitted_pipeline = pipeline.fit(&x.view(), &Some(&y.view())).unwrap();
926 let predictions = fitted_pipeline.predict(&x.view()).unwrap();
927
928 assert_eq!(predictions.len(), x.nrows());
929 assert!(fitted_pipeline
930 .learned_tasks()
931 .contains(&"task1".to_string()));
932 }
933
934 #[test]
935 fn test_ewc_pipeline() {
936 let x = array![[1.0, 2.0], [3.0, 4.0]];
937 let y = array![1.0, 0.0];
938
939 let base_estimator = Box::new(MockPredictor::new());
940 let pipeline =
941 ContinualLearningPipeline::elastic_weight_consolidation(base_estimator, 0.1, 10);
942
943 let fitted_pipeline = pipeline.fit(&x.view(), &Some(&y.view())).unwrap();
944
945 assert!(!fitted_pipeline.importance_weights().is_empty());
946
947 let predictions = fitted_pipeline.predict(&x.view()).unwrap();
948 assert_eq!(predictions.len(), x.nrows());
949 }
950
951 #[test]
952 fn test_new_task_learning() {
953 let x1 = array![[1.0, 2.0], [3.0, 4.0]];
954 let y1 = array![1.0, 0.0];
955
956 let base_estimator = Box::new(MockPredictor::new());
957 let pipeline = ContinualLearningPipeline::experience_replay(base_estimator, 100, 10, 5);
958
959 let mut fitted_pipeline = pipeline.fit(&x1.view(), &Some(&y1.view())).unwrap();
960
961 let x2 = array![[5.0, 6.0], [7.0, 8.0]];
963 let y2 = array![0.0, 1.0];
964 let task2 = Task::new("task2".to_string(), x2, y2);
965
966 fitted_pipeline.learn_task(task2).unwrap();
967
968 assert_eq!(fitted_pipeline.learned_tasks().len(), 2);
969 assert!(fitted_pipeline
970 .learned_tasks()
971 .contains(&"task2".to_string()));
972 }
973}