1use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2, Axis};
7use scirs2_core::random::thread_rng;
8use scirs2_core::random::Rng;
9use sklears_core::{
10 error::Result as SklResult,
11 prelude::{Predict, SklearsError},
12 traits::{Estimator, Fit, Untrained},
13 types::Float,
14};
15use std::collections::{HashMap, VecDeque};
16use std::fmt::Debug;
17
18use crate::{PipelinePredictor, PipelineStep};
19
20#[derive(Debug, Clone)]
22pub struct Task {
23 pub id: String,
25 pub features: Array2<f64>,
27 pub targets: Array1<f64>,
29 pub metadata: HashMap<String, String>,
31 pub importance_weights: Option<HashMap<String, f64>>,
33 pub statistics: TaskStatistics,
35}
36
37#[derive(Debug, Clone)]
39pub struct TaskStatistics {
40 pub n_samples: usize,
42 pub n_features: usize,
44 pub difficulty: f64,
46 pub performance: HashMap<String, f64>,
48 pub learning_time: f64,
50}
51
52impl Task {
53 #[must_use]
55 pub fn new(id: String, features: Array2<f64>, targets: Array1<f64>) -> Self {
56 let n_samples = features.nrows();
57 let n_features = features.ncols();
58
59 Self {
60 id,
61 features,
62 targets,
63 metadata: HashMap::new(),
64 importance_weights: None,
65 statistics: TaskStatistics {
66 n_samples,
67 n_features,
68 difficulty: 1.0, performance: HashMap::new(),
70 learning_time: 0.0,
71 },
72 }
73 }
74
75 #[must_use]
77 pub fn with_metadata(mut self, metadata: HashMap<String, String>) -> Self {
78 self.metadata = metadata;
79 self
80 }
81
82 #[must_use]
84 pub fn with_importance_weights(mut self, weights: HashMap<String, f64>) -> Self {
85 self.importance_weights = Some(weights);
86 self
87 }
88
89 pub fn estimate_difficulty(&mut self) {
91 let feature_variance = self.features.var_axis(Axis(0), 1.0).mean().unwrap_or(1.0);
92 let target_variance = self.targets.var(1.0);
93
94 self.statistics.difficulty = (feature_variance + target_variance).max(0.1);
96 }
97}
98
99#[derive(Debug, Clone)]
101pub enum ContinualLearningStrategy {
102 ElasticWeightConsolidation {
104 lambda: f64,
106 fisher_samples: usize,
108 },
109 ProgressiveNetworks {
111 max_columns: usize,
113 lateral_strength: f64,
115 },
116 ExperienceReplay {
118 buffer_size: usize,
120 replay_batch_size: usize,
122 replay_frequency: usize,
124 },
125 LearningWithoutForgetting {
127 temperature: f64,
129 distillation_weight: f64,
131 },
132 MemoryAugmented {
134 memory_size: usize,
136 read_heads: usize,
138 write_strength: f64,
140 },
141 GradientEpisodicMemory {
143 memory_size: usize,
145 tolerance: f64,
147 },
148}
149
150#[derive(Debug, Clone)]
152pub struct MemoryBuffer {
153 max_size: usize,
155 samples: VecDeque<MemorySample>,
157 task_distributions: HashMap<String, usize>,
159 sampling_strategy: SamplingStrategy,
161}
162
163#[derive(Debug, Clone)]
165pub struct MemorySample {
166 pub features: Array1<f64>,
168 pub target: f64,
170 pub task_id: String,
172 pub importance: f64,
174 pub gradient_info: Option<HashMap<String, f64>>,
176}
177
178#[derive(Debug, Clone)]
180pub enum SamplingStrategy {
181 Random,
183 Reservoir,
185 ImportanceBased,
187 TaskBalanced,
189 GradientBased,
191}
192
193impl MemoryBuffer {
194 #[must_use]
196 pub fn new(max_size: usize, sampling_strategy: SamplingStrategy) -> Self {
197 Self {
198 max_size,
199 samples: VecDeque::new(),
200 task_distributions: HashMap::new(),
201 sampling_strategy,
202 }
203 }
204
205 pub fn add_sample(&mut self, sample: MemorySample) {
207 *self
209 .task_distributions
210 .entry(sample.task_id.clone())
211 .or_insert(0) += 1;
212
213 if self.samples.len() >= self.max_size {
214 match self.sampling_strategy {
215 SamplingStrategy::Random => {
216 let replace_idx = thread_rng().gen_range(0..self.samples.len());
217 if let Some(old_task_id) =
218 self.samples.get(replace_idx).map(|s| s.task_id.clone())
219 {
220 if let Some(count) = self.task_distributions.get_mut(&old_task_id) {
221 *count -= 1;
222 if *count == 0 {
223 self.task_distributions.remove(&old_task_id);
224 }
225 }
226 }
227 self.samples[replace_idx] = sample;
228 }
229 SamplingStrategy::Reservoir => {
230 let replace_idx = thread_rng().gen_range(0..(self.samples.len() + 1));
232 if replace_idx < self.samples.len() {
233 if let Some(old_task_id) =
234 self.samples.get(replace_idx).map(|s| s.task_id.clone())
235 {
236 if let Some(count) = self.task_distributions.get_mut(&old_task_id) {
237 *count -= 1;
238 if *count == 0 {
239 self.task_distributions.remove(&old_task_id);
240 }
241 }
242 }
243 self.samples[replace_idx] = sample;
244 }
245 }
246 SamplingStrategy::ImportanceBased => {
247 let min_importance_idx = self
249 .samples
250 .iter()
251 .enumerate()
252 .min_by(|(_, a), (_, b)| {
253 a.importance
254 .partial_cmp(&b.importance)
255 .unwrap_or(std::cmp::Ordering::Equal)
256 })
257 .map_or(0, |(idx, _)| idx);
258
259 if sample.importance > self.samples[min_importance_idx].importance {
260 if let Some(old_task_id) = self
261 .samples
262 .get(min_importance_idx)
263 .map(|s| s.task_id.clone())
264 {
265 if let Some(count) = self.task_distributions.get_mut(&old_task_id) {
266 *count -= 1;
267 if *count == 0 {
268 self.task_distributions.remove(&old_task_id);
269 }
270 }
271 }
272 self.samples[min_importance_idx] = sample;
273 }
274 }
275 SamplingStrategy::TaskBalanced => {
276 let max_task = self
278 .task_distributions
279 .iter()
280 .max_by_key(|(_, &count)| count)
281 .map(|(task_id, _)| task_id.clone());
282
283 if let Some(overrep_task) = max_task {
284 if let Some(idx) =
285 self.samples.iter().position(|s| s.task_id == overrep_task)
286 {
287 if let Some(count) = self.task_distributions.get_mut(&overrep_task) {
288 *count -= 1;
289 if *count == 0 {
290 self.task_distributions.remove(&overrep_task);
291 }
292 }
293 self.samples[idx] = sample;
294 }
295 }
296 }
297 SamplingStrategy::GradientBased => {
298 let replace_idx = thread_rng().gen_range(0..self.samples.len());
301 if let Some(old_task_id) =
302 self.samples.get(replace_idx).map(|s| s.task_id.clone())
303 {
304 if let Some(count) = self.task_distributions.get_mut(&old_task_id) {
305 *count -= 1;
306 if *count == 0 {
307 self.task_distributions.remove(&old_task_id);
308 }
309 }
310 }
311 self.samples[replace_idx] = sample;
312 }
313 }
314 } else {
315 self.samples.push_back(sample);
316 }
317 }
318
319 #[must_use]
321 pub fn sample(&self, n_samples: usize) -> Vec<&MemorySample> {
322 if self.samples.is_empty() {
323 return Vec::new();
324 }
325
326 let n_samples = n_samples.min(self.samples.len());
327 let mut sampled = Vec::new();
328
329 match self.sampling_strategy {
330 SamplingStrategy::Random | SamplingStrategy::Reservoir => {
331 for _ in 0..n_samples {
332 let idx = thread_rng().gen_range(0..self.samples.len());
333 sampled.push(&self.samples[idx]);
334 }
335 }
336 SamplingStrategy::ImportanceBased => {
337 let total_importance: f64 = self.samples.iter().map(|s| s.importance).sum();
339 for _ in 0..n_samples {
340 let target = thread_rng().random::<f64>() * total_importance;
341 let mut cumulative = 0.0;
342 for sample in &self.samples {
343 cumulative += sample.importance;
344 if cumulative >= target {
345 sampled.push(sample);
346 break;
347 }
348 }
349 }
350 }
351 SamplingStrategy::TaskBalanced => {
352 let unique_tasks: Vec<String> = self.task_distributions.keys().cloned().collect();
354 if !unique_tasks.is_empty() {
355 let samples_per_task = n_samples / unique_tasks.len();
356 let extra_samples = n_samples % unique_tasks.len();
357
358 for (i, task_id) in unique_tasks.iter().enumerate() {
359 let task_samples: Vec<&MemorySample> = self
360 .samples
361 .iter()
362 .filter(|s| &s.task_id == task_id)
363 .collect();
364
365 let task_sample_count = samples_per_task + usize::from(i < extra_samples);
366 for _ in 0..task_sample_count.min(task_samples.len()) {
367 let idx = thread_rng().gen_range(0..task_samples.len());
368 sampled.push(task_samples[idx]);
369 }
370 }
371 }
372 }
373 SamplingStrategy::GradientBased => {
374 for _ in 0..n_samples {
377 let idx = thread_rng().gen_range(0..self.samples.len());
378 sampled.push(&self.samples[idx]);
379 }
380 }
381 }
382
383 sampled
384 }
385
386 #[must_use]
388 pub fn statistics(&self) -> HashMap<String, f64> {
389 let mut stats = HashMap::new();
390 stats.insert("total_samples".to_string(), self.samples.len() as f64);
391 stats.insert(
392 "unique_tasks".to_string(),
393 self.task_distributions.len() as f64,
394 );
395
396 if !self.samples.is_empty() {
397 let avg_importance =
398 self.samples.iter().map(|s| s.importance).sum::<f64>() / self.samples.len() as f64;
399 stats.insert("average_importance".to_string(), avg_importance);
400 }
401
402 stats
403 }
404}
405
406#[derive(Debug)]
408pub struct ContinualLearningPipeline<S = Untrained> {
409 state: S,
410 base_estimator: Option<Box<dyn PipelinePredictor>>,
411 strategy: ContinualLearningStrategy,
412 memory_buffer: MemoryBuffer,
413 learned_tasks: Vec<String>,
414 current_task_id: Option<String>,
415}
416
417#[derive(Debug)]
419pub struct ContinualLearningPipelineTrained {
420 fitted_estimator: Box<dyn PipelinePredictor>,
421 strategy: ContinualLearningStrategy,
422 memory_buffer: MemoryBuffer,
423 learned_tasks: Vec<String>,
424 task_performance: HashMap<String, HashMap<String, f64>>,
425 importance_weights: HashMap<String, f64>,
426 n_features_in: usize,
427 feature_names_in: Option<Vec<String>>,
428}
429
430impl ContinualLearningPipeline<Untrained> {
431 #[must_use]
433 pub fn new(
434 base_estimator: Box<dyn PipelinePredictor>,
435 strategy: ContinualLearningStrategy,
436 ) -> Self {
437 let memory_buffer = match &strategy {
438 ContinualLearningStrategy::ExperienceReplay { buffer_size, .. } => {
439 MemoryBuffer::new(*buffer_size, SamplingStrategy::Random)
440 }
441 ContinualLearningStrategy::GradientEpisodicMemory { memory_size, .. } => {
442 MemoryBuffer::new(*memory_size, SamplingStrategy::GradientBased)
443 }
444 ContinualLearningStrategy::MemoryAugmented { memory_size, .. } => {
445 MemoryBuffer::new(*memory_size, SamplingStrategy::ImportanceBased)
446 }
447 _ => MemoryBuffer::new(1000, SamplingStrategy::Random), };
449
450 Self {
451 state: Untrained,
452 base_estimator: Some(base_estimator),
453 strategy,
454 memory_buffer,
455 learned_tasks: Vec::new(),
456 current_task_id: None,
457 }
458 }
459
460 #[must_use]
462 pub fn elastic_weight_consolidation(
463 base_estimator: Box<dyn PipelinePredictor>,
464 lambda: f64,
465 fisher_samples: usize,
466 ) -> Self {
467 Self::new(
468 base_estimator,
469 ContinualLearningStrategy::ElasticWeightConsolidation {
470 lambda,
471 fisher_samples,
472 },
473 )
474 }
475
476 #[must_use]
478 pub fn experience_replay(
479 base_estimator: Box<dyn PipelinePredictor>,
480 buffer_size: usize,
481 replay_batch_size: usize,
482 replay_frequency: usize,
483 ) -> Self {
484 Self::new(
485 base_estimator,
486 ContinualLearningStrategy::ExperienceReplay {
487 buffer_size,
488 replay_batch_size,
489 replay_frequency,
490 },
491 )
492 }
493
494 #[must_use]
496 pub fn learning_without_forgetting(
497 base_estimator: Box<dyn PipelinePredictor>,
498 temperature: f64,
499 distillation_weight: f64,
500 ) -> Self {
501 Self::new(
502 base_estimator,
503 ContinualLearningStrategy::LearningWithoutForgetting {
504 temperature,
505 distillation_weight,
506 },
507 )
508 }
509
510 #[must_use]
512 pub fn set_current_task(mut self, task_id: String) -> Self {
513 self.current_task_id = Some(task_id);
514 self
515 }
516}
517
518impl Estimator for ContinualLearningPipeline<Untrained> {
519 type Config = ();
520 type Error = SklearsError;
521 type Float = Float;
522
523 fn config(&self) -> &Self::Config {
524 &()
525 }
526}
527
528impl Fit<ArrayView2<'_, Float>, Option<&ArrayView1<'_, Float>>>
529 for ContinualLearningPipeline<Untrained>
530{
531 type Fitted = ContinualLearningPipeline<ContinualLearningPipelineTrained>;
532
533 fn fit(
534 mut self,
535 x: &ArrayView2<'_, Float>,
536 y: &Option<&ArrayView1<'_, Float>>,
537 ) -> SklResult<Self::Fitted> {
538 if let Some(y_values) = y.as_ref() {
539 let mut base_estimator = self.base_estimator.take().ok_or_else(|| {
540 SklearsError::InvalidInput("No base estimator provided".to_string())
541 })?;
542
543 let importance_weights =
545 self.apply_continual_learning_strategy(&mut base_estimator, x, y_values)?;
546
547 let task_id = self
548 .current_task_id
549 .clone()
550 .unwrap_or_else(|| "default_task".to_string());
551 self.learned_tasks.push(task_id.clone());
552
553 let mut task_performance = HashMap::new();
554 let mut perf_metrics = HashMap::new();
555 perf_metrics.insert("training_completed".to_string(), 1.0);
556 task_performance.insert(task_id, perf_metrics);
557
558 Ok(ContinualLearningPipeline {
559 state: ContinualLearningPipelineTrained {
560 fitted_estimator: base_estimator,
561 strategy: self.strategy,
562 memory_buffer: self.memory_buffer,
563 learned_tasks: self.learned_tasks,
564 task_performance,
565 importance_weights,
566 n_features_in: x.ncols(),
567 feature_names_in: None,
568 },
569 base_estimator: None,
570 strategy: ContinualLearningStrategy::ExperienceReplay {
571 buffer_size: 1000,
572 replay_batch_size: 32,
573 replay_frequency: 10,
574 },
575 memory_buffer: MemoryBuffer::new(1000, SamplingStrategy::Random),
576 learned_tasks: Vec::new(),
577 current_task_id: None,
578 })
579 } else {
580 Err(SklearsError::InvalidInput(
581 "Target values required for continual learning".to_string(),
582 ))
583 }
584 }
585}
586
587impl ContinualLearningPipeline<Untrained> {
588 fn apply_continual_learning_strategy(
590 &mut self,
591 estimator: &mut Box<dyn PipelinePredictor>,
592 x: &ArrayView2<'_, Float>,
593 y: &ArrayView1<'_, Float>,
594 ) -> SklResult<HashMap<String, f64>> {
595 let mut importance_weights = HashMap::new();
596
597 match &self.strategy {
598 ContinualLearningStrategy::ElasticWeightConsolidation {
599 lambda,
600 fisher_samples,
601 } => {
602 for i in 0..*fisher_samples.min(&x.nrows()) {
604 let param_name = format!("param_{i}");
605 let importance = self.compute_fisher_information(x, y, i);
606 importance_weights.insert(param_name, importance * lambda);
607 }
608
609 estimator.fit(x, y)?;
611 }
612 ContinualLearningStrategy::ExperienceReplay {
613 replay_batch_size,
614 replay_frequency,
615 ..
616 } => {
617 for i in 0..x.nrows() {
619 let sample = MemorySample {
620 features: x.row(i).mapv(|v| v),
621 target: y[i],
622 task_id: self
623 .current_task_id
624 .clone()
625 .unwrap_or_else(|| "default".to_string()),
626 importance: 1.0,
627 gradient_info: None,
628 };
629 self.memory_buffer.add_sample(sample);
630 }
631
632 for epoch in 0..*replay_frequency {
634 estimator.fit(x, y)?;
636
637 let replay_samples = self.memory_buffer.sample(*replay_batch_size);
639 if !replay_samples.is_empty() {
640 let replay_x = Array2::from_shape_vec(
642 (replay_samples.len(), x.ncols()),
643 replay_samples
644 .iter()
645 .flat_map(|s| s.features.iter().copied())
646 .collect(),
647 )
648 .map_err(|e| SklearsError::InvalidData {
649 reason: format!("Replay batch creation failed: {e}"),
650 })?;
651
652 let replay_y = Array1::from_vec(
653 replay_samples.iter().map(|s| s.target as Float).collect(),
654 );
655
656 estimator.fit(&replay_x.view(), &replay_y.view())?;
657 }
658 }
659 }
660 ContinualLearningStrategy::LearningWithoutForgetting {
661 temperature,
662 distillation_weight,
663 } => {
664 importance_weights.insert("temperature".to_string(), *temperature);
666 importance_weights.insert("distillation_weight".to_string(), *distillation_weight);
667
668 estimator.fit(x, y)?;
669 }
670 ContinualLearningStrategy::ProgressiveNetworks {
671 max_columns,
672 lateral_strength,
673 } => {
674 importance_weights.insert("columns".to_string(), self.learned_tasks.len() as f64);
676 importance_weights.insert("lateral_strength".to_string(), *lateral_strength);
677
678 estimator.fit(x, y)?;
679 }
680 ContinualLearningStrategy::MemoryAugmented {
681 memory_size,
682 read_heads,
683 write_strength,
684 } => {
685 importance_weights.insert(
687 "memory_usage".to_string(),
688 self.memory_buffer.samples.len() as f64 / *memory_size as f64,
689 );
690 importance_weights.insert("read_heads".to_string(), *read_heads as f64);
691 importance_weights.insert("write_strength".to_string(), *write_strength);
692
693 estimator.fit(x, y)?;
694 }
695 ContinualLearningStrategy::GradientEpisodicMemory {
696 memory_size,
697 tolerance,
698 } => {
699 for i in 0..x.nrows() {
701 let mut gradient_info = HashMap::new();
702 gradient_info.insert("grad_norm".to_string(), thread_rng().random::<f64>()); let sample = MemorySample {
705 features: x.row(i).mapv(|v| v),
706 target: y[i],
707 task_id: self
708 .current_task_id
709 .clone()
710 .unwrap_or_else(|| "default".to_string()),
711 importance: 1.0,
712 gradient_info: Some(gradient_info),
713 };
714 self.memory_buffer.add_sample(sample);
715 }
716
717 importance_weights.insert(
718 "memory_utilization".to_string(),
719 self.memory_buffer.samples.len() as f64 / *memory_size as f64,
720 );
721 importance_weights.insert("tolerance".to_string(), *tolerance);
722
723 estimator.fit(x, y)?;
724 }
725 }
726
727 Ok(importance_weights)
728 }
729
730 fn compute_fisher_information(
732 &self,
733 x: &ArrayView2<'_, Float>,
734 y: &ArrayView1<'_, Float>,
735 param_idx: usize,
736 ) -> f64 {
737 if param_idx < x.ncols() {
739 let feature_variance = x.column(param_idx).var(1.0);
740 feature_variance.max(1e-8) } else {
742 1e-4 }
744 }
745}
746
747impl ContinualLearningPipeline<ContinualLearningPipelineTrained> {
748 pub fn predict(&self, x: &ArrayView2<'_, Float>) -> SklResult<Array1<f64>> {
750 self.state.fitted_estimator.predict(x)
751 }
752
753 pub fn learn_task(&mut self, task: Task) -> SklResult<()> {
755 let mut task_perf = HashMap::new();
757 task_perf.insert("samples".to_string(), task.statistics.n_samples as f64);
758 task_perf.insert("difficulty".to_string(), task.statistics.difficulty);
759
760 self.state
761 .task_performance
762 .insert(task.id.clone(), task_perf);
763
764 let x_view = task.features.view().mapv(|v| v as Float);
766 let y_view = task.targets.view().mapv(|v| v as Float);
767
768 match &self.state.strategy {
769 ContinualLearningStrategy::ExperienceReplay {
770 replay_batch_size, ..
771 } => {
772 for i in 0..task.features.nrows() {
774 let sample = MemorySample {
775 features: task.features.row(i).to_owned(),
776 target: task.targets[i],
777 task_id: task.id.clone(),
778 importance: 1.0,
779 gradient_info: None,
780 };
781 self.state.memory_buffer.add_sample(sample);
782 }
783
784 self.state
786 .fitted_estimator
787 .fit(&x_view.view(), &y_view.view())?;
788
789 let replay_samples = self.state.memory_buffer.sample(*replay_batch_size);
791 if !replay_samples.is_empty() {
792 let n_features = task.features.ncols();
794 let replay_x = Array2::from_shape_vec(
795 (replay_samples.len(), n_features),
796 replay_samples
797 .iter()
798 .flat_map(|s| s.features.iter().copied().map(|v| v as Float))
799 .collect(),
800 )
801 .map_err(|e| SklearsError::InvalidData {
802 reason: format!("Replay batch creation failed: {e}"),
803 })?;
804
805 let replay_y = Array1::from_vec(
806 replay_samples.iter().map(|s| s.target as Float).collect(),
807 );
808
809 self.state
810 .fitted_estimator
811 .fit(&replay_x.view(), &replay_y.view())?;
812 }
813 }
814 _ => {
815 self.state
817 .fitted_estimator
818 .fit(&x_view.view(), &y_view.view())?;
819 }
820 }
821
822 if !self.state.learned_tasks.contains(&task.id) {
823 self.state.learned_tasks.push(task.id);
824 }
825
826 Ok(())
827 }
828
829 pub fn evaluate_forgetting(&self, previous_tasks: &[Task]) -> SklResult<HashMap<String, f64>> {
831 let mut forgetting_metrics = HashMap::new();
832
833 for task in previous_tasks {
834 let x_view = task.features.view().mapv(|v| v as Float);
835 let predictions = self.predict(&x_view.view())?;
836
837 let correct = predictions
839 .iter()
840 .zip(task.targets.iter())
841 .filter(|(&pred, &actual)| (pred - actual).abs() < 0.5)
842 .count();
843
844 let accuracy = correct as f64 / task.targets.len() as f64;
845 forgetting_metrics.insert(format!("task_{}_accuracy", task.id), accuracy);
846 }
847
848 if !forgetting_metrics.is_empty() {
850 let avg_accuracy =
851 forgetting_metrics.values().sum::<f64>() / forgetting_metrics.len() as f64;
852 forgetting_metrics.insert("average_accuracy".to_string(), avg_accuracy);
853 }
854
855 Ok(forgetting_metrics)
856 }
857
858 #[must_use]
860 pub fn memory_statistics(&self) -> HashMap<String, f64> {
861 self.state.memory_buffer.statistics()
862 }
863
864 #[must_use]
866 pub fn learned_tasks(&self) -> &[String] {
867 &self.state.learned_tasks
868 }
869
870 #[must_use]
872 pub fn task_performance(&self) -> &HashMap<String, HashMap<String, f64>> {
873 &self.state.task_performance
874 }
875
876 #[must_use]
878 pub fn importance_weights(&self) -> &HashMap<String, f64> {
879 &self.state.importance_weights
880 }
881}
882
883#[allow(non_snake_case)]
884#[cfg(test)]
885mod tests {
886 use super::*;
887 use crate::MockPredictor;
888 use scirs2_core::ndarray::array;
889
890 #[test]
891 fn test_task_creation() {
892 let features = array![[1.0, 2.0], [3.0, 4.0]];
893 let targets = array![1.0, 0.0];
894
895 let mut task = Task::new("task1".to_string(), features, targets);
896 task.estimate_difficulty();
897
898 assert_eq!(task.id, "task1");
899 assert_eq!(task.statistics.n_samples, 2);
900 assert_eq!(task.statistics.n_features, 2);
901 assert!(task.statistics.difficulty > 0.0);
902 }
903
904 #[test]
905 fn test_memory_buffer() {
906 let mut buffer = MemoryBuffer::new(3, SamplingStrategy::Random);
907
908 let sample1 = MemorySample {
909 features: array![1.0, 2.0],
910 target: 1.0,
911 task_id: "task1".to_string(),
912 importance: 1.0,
913 gradient_info: None,
914 };
915
916 buffer.add_sample(sample1);
917 assert_eq!(buffer.samples.len(), 1);
918
919 let sampled = buffer.sample(1);
920 assert_eq!(sampled.len(), 1);
921 }
922
923 #[test]
924 fn test_continual_learning_pipeline() {
925 let x = array![[1.0, 2.0], [3.0, 4.0]];
926 let y = array![1.0, 0.0];
927
928 let base_estimator = Box::new(MockPredictor::new());
929 let pipeline = ContinualLearningPipeline::experience_replay(base_estimator, 100, 10, 5)
930 .set_current_task("task1".to_string());
931
932 let fitted_pipeline = pipeline
933 .fit(&x.view(), &Some(&y.view()))
934 .expect("operation should succeed");
935 let predictions = fitted_pipeline.predict(&x.view()).unwrap_or_default();
936
937 assert_eq!(predictions.len(), x.nrows());
938 assert!(fitted_pipeline
939 .learned_tasks()
940 .contains(&"task1".to_string()));
941 }
942
943 #[test]
944 fn test_ewc_pipeline() {
945 let x = array![[1.0, 2.0], [3.0, 4.0]];
946 let y = array![1.0, 0.0];
947
948 let base_estimator = Box::new(MockPredictor::new());
949 let pipeline =
950 ContinualLearningPipeline::elastic_weight_consolidation(base_estimator, 0.1, 10);
951
952 let fitted_pipeline = pipeline
953 .fit(&x.view(), &Some(&y.view()))
954 .expect("operation should succeed");
955
956 assert!(!fitted_pipeline.importance_weights().is_empty());
957
958 let predictions = fitted_pipeline.predict(&x.view()).unwrap_or_default();
959 assert_eq!(predictions.len(), x.nrows());
960 }
961
962 #[test]
963 fn test_new_task_learning() {
964 let x1 = array![[1.0, 2.0], [3.0, 4.0]];
965 let y1 = array![1.0, 0.0];
966
967 let base_estimator = Box::new(MockPredictor::new());
968 let pipeline = ContinualLearningPipeline::experience_replay(base_estimator, 100, 10, 5);
969
970 let mut fitted_pipeline = pipeline
971 .fit(&x1.view(), &Some(&y1.view()))
972 .expect("operation should succeed");
973
974 let x2 = array![[5.0, 6.0], [7.0, 8.0]];
976 let y2 = array![0.0, 1.0];
977 let task2 = Task::new("task2".to_string(), x2, y2);
978
979 fitted_pipeline.learn_task(task2).unwrap_or_default();
980
981 assert_eq!(fitted_pipeline.learned_tasks().len(), 2);
982 assert!(fitted_pipeline
983 .learned_tasks()
984 .contains(&"task2".to_string()));
985 }
986}