sklears_compose/
continual_learning.rs

1//! Continual learning pipeline components
2//!
3//! This module provides continual learning capabilities including catastrophic forgetting
4//! prevention, memory-based approaches, and progressive learning strategies.
5
6use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2, Axis};
7use scirs2_core::random::thread_rng;
8use scirs2_core::random::Rng;
9use sklears_core::{
10    error::Result as SklResult,
11    prelude::{Predict, SklearsError},
12    traits::{Estimator, Fit, Untrained},
13    types::Float,
14};
15use std::collections::{HashMap, VecDeque};
16use std::fmt::Debug;
17
18use crate::{PipelinePredictor, PipelineStep};
19
20/// Task representation for continual learning
21#[derive(Debug, Clone)]
22pub struct Task {
23    /// Unique task identifier
24    pub id: String,
25    /// Task features
26    pub features: Array2<f64>,
27    /// Task targets
28    pub targets: Array1<f64>,
29    /// Task metadata
30    pub metadata: HashMap<String, String>,
31    /// Task importance weights
32    pub importance_weights: Option<HashMap<String, f64>>,
33    /// Task learning statistics
34    pub statistics: TaskStatistics,
35}
36
37/// Statistics for a learning task
38#[derive(Debug, Clone)]
39pub struct TaskStatistics {
40    /// Number of training samples
41    pub n_samples: usize,
42    /// Number of features
43    pub n_features: usize,
44    /// Task difficulty (estimated)
45    pub difficulty: f64,
46    /// Performance metrics
47    pub performance: HashMap<String, f64>,
48    /// Learning time
49    pub learning_time: f64,
50}
51
52impl Task {
53    /// Create a new task
54    #[must_use]
55    pub fn new(id: String, features: Array2<f64>, targets: Array1<f64>) -> Self {
56        let n_samples = features.nrows();
57        let n_features = features.ncols();
58
59        Self {
60            id,
61            features,
62            targets,
63            metadata: HashMap::new(),
64            importance_weights: None,
65            statistics: TaskStatistics {
66                n_samples,
67                n_features,
68                difficulty: 1.0, // Default difficulty
69                performance: HashMap::new(),
70                learning_time: 0.0,
71            },
72        }
73    }
74
75    /// Set task metadata
76    #[must_use]
77    pub fn with_metadata(mut self, metadata: HashMap<String, String>) -> Self {
78        self.metadata = metadata;
79        self
80    }
81
82    /// Set importance weights
83    #[must_use]
84    pub fn with_importance_weights(mut self, weights: HashMap<String, f64>) -> Self {
85        self.importance_weights = Some(weights);
86        self
87    }
88
89    /// Estimate task difficulty based on data characteristics
90    pub fn estimate_difficulty(&mut self) {
91        let feature_variance = self.features.var_axis(Axis(0), 1.0).mean().unwrap_or(1.0);
92        let target_variance = self.targets.var(1.0);
93
94        // Simple difficulty estimation based on variance
95        self.statistics.difficulty = (feature_variance + target_variance).max(0.1);
96    }
97}
98
99/// Continual learning strategy
100#[derive(Debug, Clone)]
101pub enum ContinualLearningStrategy {
102    /// Elastic Weight Consolidation (EWC)
103    ElasticWeightConsolidation {
104        /// Regularization strength
105        lambda: f64,
106        /// Fisher information estimation samples
107        fisher_samples: usize,
108    },
109    /// Progressive Neural Networks
110    ProgressiveNetworks {
111        /// Maximum number of parallel columns
112        max_columns: usize,
113        /// Lateral connection strength
114        lateral_strength: f64,
115    },
116    /// Experience Replay
117    ExperienceReplay {
118        /// Memory buffer size
119        buffer_size: usize,
120        /// Replay batch size
121        replay_batch_size: usize,
122        /// Replay frequency
123        replay_frequency: usize,
124    },
125    /// Learning without Forgetting (`LwF`)
126    LearningWithoutForgetting {
127        /// Distillation temperature
128        temperature: f64,
129        /// Distillation weight
130        distillation_weight: f64,
131    },
132    /// Memory-Augmented Networks
133    MemoryAugmented {
134        /// Memory size
135        memory_size: usize,
136        /// Memory read heads
137        read_heads: usize,
138        /// Memory write strength
139        write_strength: f64,
140    },
141    /// Gradient Episodic Memory (GEM)
142    GradientEpisodicMemory {
143        /// Memory buffer size
144        memory_size: usize,
145        /// Inequality constraint tolerance
146        tolerance: f64,
147    },
148}
149
150/// Memory buffer for continual learning
151#[derive(Debug, Clone)]
152pub struct MemoryBuffer {
153    /// Maximum buffer size
154    max_size: usize,
155    /// Stored samples
156    samples: VecDeque<MemorySample>,
157    /// Task distributions
158    task_distributions: HashMap<String, usize>,
159    /// Sampling strategy
160    sampling_strategy: SamplingStrategy,
161}
162
163/// Memory sample with task information
164#[derive(Debug, Clone)]
165pub struct MemorySample {
166    /// Sample features
167    pub features: Array1<f64>,
168    /// Sample target
169    pub target: f64,
170    /// Source task ID
171    pub task_id: String,
172    /// Sample importance
173    pub importance: f64,
174    /// Gradient information (for GEM)
175    pub gradient_info: Option<HashMap<String, f64>>,
176}
177
178/// Sampling strategy for memory buffer
179#[derive(Debug, Clone)]
180pub enum SamplingStrategy {
181    /// Random sampling
182    Random,
183    /// Reservoir sampling
184    Reservoir,
185    /// Importance-based sampling
186    ImportanceBased,
187    /// Task-balanced sampling
188    TaskBalanced,
189    /// Gradient-based sampling (for GEM)
190    GradientBased,
191}
192
193impl MemoryBuffer {
194    /// Create a new memory buffer
195    #[must_use]
196    pub fn new(max_size: usize, sampling_strategy: SamplingStrategy) -> Self {
197        Self {
198            max_size,
199            samples: VecDeque::new(),
200            task_distributions: HashMap::new(),
201            sampling_strategy,
202        }
203    }
204
205    /// Add a sample to the buffer
206    pub fn add_sample(&mut self, sample: MemorySample) {
207        // Update task distribution
208        *self
209            .task_distributions
210            .entry(sample.task_id.clone())
211            .or_insert(0) += 1;
212
213        if self.samples.len() >= self.max_size {
214            match self.sampling_strategy {
215                SamplingStrategy::Random => {
216                    let replace_idx = thread_rng().gen_range(0..self.samples.len());
217                    if let Some(old_sample) = self.samples.get(replace_idx) {
218                        let count = self
219                            .task_distributions
220                            .get_mut(&old_sample.task_id)
221                            .unwrap();
222                        *count -= 1;
223                        if *count == 0 {
224                            self.task_distributions.remove(&old_sample.task_id);
225                        }
226                    }
227                    self.samples[replace_idx] = sample;
228                }
229                SamplingStrategy::Reservoir => {
230                    // Standard reservoir sampling
231                    let replace_idx = thread_rng().gen_range(0..(self.samples.len() + 1));
232                    if replace_idx < self.samples.len() {
233                        if let Some(old_sample) = self.samples.get(replace_idx) {
234                            let count = self
235                                .task_distributions
236                                .get_mut(&old_sample.task_id)
237                                .unwrap();
238                            *count -= 1;
239                            if *count == 0 {
240                                self.task_distributions.remove(&old_sample.task_id);
241                            }
242                        }
243                        self.samples[replace_idx] = sample;
244                    }
245                }
246                SamplingStrategy::ImportanceBased => {
247                    // Replace sample with lowest importance
248                    let min_importance_idx = self
249                        .samples
250                        .iter()
251                        .enumerate()
252                        .min_by(|(_, a), (_, b)| a.importance.partial_cmp(&b.importance).unwrap())
253                        .map_or(0, |(idx, _)| idx);
254
255                    if sample.importance > self.samples[min_importance_idx].importance {
256                        if let Some(old_sample) = self.samples.get(min_importance_idx) {
257                            let count = self
258                                .task_distributions
259                                .get_mut(&old_sample.task_id)
260                                .unwrap();
261                            *count -= 1;
262                            if *count == 0 {
263                                self.task_distributions.remove(&old_sample.task_id);
264                            }
265                        }
266                        self.samples[min_importance_idx] = sample;
267                    }
268                }
269                SamplingStrategy::TaskBalanced => {
270                    // Replace from overrepresented task
271                    let max_task = self
272                        .task_distributions
273                        .iter()
274                        .max_by_key(|(_, &count)| count)
275                        .map(|(task_id, _)| task_id.clone());
276
277                    if let Some(overrep_task) = max_task {
278                        if let Some(idx) =
279                            self.samples.iter().position(|s| s.task_id == overrep_task)
280                        {
281                            let count = self.task_distributions.get_mut(&overrep_task).unwrap();
282                            *count -= 1;
283                            if *count == 0 {
284                                self.task_distributions.remove(&overrep_task);
285                            }
286                            self.samples[idx] = sample;
287                        }
288                    }
289                }
290                SamplingStrategy::GradientBased => {
291                    // For GEM, replace based on gradient diversity
292                    // Simplified: replace random sample
293                    let replace_idx = thread_rng().gen_range(0..self.samples.len());
294                    if let Some(old_sample) = self.samples.get(replace_idx) {
295                        let count = self
296                            .task_distributions
297                            .get_mut(&old_sample.task_id)
298                            .unwrap();
299                        *count -= 1;
300                        if *count == 0 {
301                            self.task_distributions.remove(&old_sample.task_id);
302                        }
303                    }
304                    self.samples[replace_idx] = sample;
305                }
306            }
307        } else {
308            self.samples.push_back(sample);
309        }
310    }
311
312    /// Sample from the buffer
313    #[must_use]
314    pub fn sample(&self, n_samples: usize) -> Vec<&MemorySample> {
315        if self.samples.is_empty() {
316            return Vec::new();
317        }
318
319        let n_samples = n_samples.min(self.samples.len());
320        let mut sampled = Vec::new();
321
322        match self.sampling_strategy {
323            SamplingStrategy::Random | SamplingStrategy::Reservoir => {
324                for _ in 0..n_samples {
325                    let idx = thread_rng().gen_range(0..self.samples.len());
326                    sampled.push(&self.samples[idx]);
327                }
328            }
329            SamplingStrategy::ImportanceBased => {
330                // Sample based on importance weights
331                let total_importance: f64 = self.samples.iter().map(|s| s.importance).sum();
332                for _ in 0..n_samples {
333                    let target = thread_rng().gen::<f64>() * total_importance;
334                    let mut cumulative = 0.0;
335                    for sample in &self.samples {
336                        cumulative += sample.importance;
337                        if cumulative >= target {
338                            sampled.push(sample);
339                            break;
340                        }
341                    }
342                }
343            }
344            SamplingStrategy::TaskBalanced => {
345                // Ensure balanced representation across tasks
346                let unique_tasks: Vec<String> = self.task_distributions.keys().cloned().collect();
347                if !unique_tasks.is_empty() {
348                    let samples_per_task = n_samples / unique_tasks.len();
349                    let extra_samples = n_samples % unique_tasks.len();
350
351                    for (i, task_id) in unique_tasks.iter().enumerate() {
352                        let task_samples: Vec<&MemorySample> = self
353                            .samples
354                            .iter()
355                            .filter(|s| &s.task_id == task_id)
356                            .collect();
357
358                        let task_sample_count = samples_per_task + usize::from(i < extra_samples);
359                        for _ in 0..task_sample_count.min(task_samples.len()) {
360                            let idx = thread_rng().gen_range(0..task_samples.len());
361                            sampled.push(task_samples[idx]);
362                        }
363                    }
364                }
365            }
366            SamplingStrategy::GradientBased => {
367                // For GEM, sample based on gradient information
368                // Simplified: random sampling for now
369                for _ in 0..n_samples {
370                    let idx = thread_rng().gen_range(0..self.samples.len());
371                    sampled.push(&self.samples[idx]);
372                }
373            }
374        }
375
376        sampled
377    }
378
379    /// Get buffer statistics
380    #[must_use]
381    pub fn statistics(&self) -> HashMap<String, f64> {
382        let mut stats = HashMap::new();
383        stats.insert("total_samples".to_string(), self.samples.len() as f64);
384        stats.insert(
385            "unique_tasks".to_string(),
386            self.task_distributions.len() as f64,
387        );
388
389        if !self.samples.is_empty() {
390            let avg_importance =
391                self.samples.iter().map(|s| s.importance).sum::<f64>() / self.samples.len() as f64;
392            stats.insert("average_importance".to_string(), avg_importance);
393        }
394
395        stats
396    }
397}
398
399/// Continual learning pipeline
400#[derive(Debug)]
401pub struct ContinualLearningPipeline<S = Untrained> {
402    state: S,
403    base_estimator: Option<Box<dyn PipelinePredictor>>,
404    strategy: ContinualLearningStrategy,
405    memory_buffer: MemoryBuffer,
406    learned_tasks: Vec<String>,
407    current_task_id: Option<String>,
408}
409
410/// Trained state for `ContinualLearningPipeline`
411#[derive(Debug)]
412pub struct ContinualLearningPipelineTrained {
413    fitted_estimator: Box<dyn PipelinePredictor>,
414    strategy: ContinualLearningStrategy,
415    memory_buffer: MemoryBuffer,
416    learned_tasks: Vec<String>,
417    task_performance: HashMap<String, HashMap<String, f64>>,
418    importance_weights: HashMap<String, f64>,
419    n_features_in: usize,
420    feature_names_in: Option<Vec<String>>,
421}
422
423impl ContinualLearningPipeline<Untrained> {
424    /// Create a new continual learning pipeline
425    #[must_use]
426    pub fn new(
427        base_estimator: Box<dyn PipelinePredictor>,
428        strategy: ContinualLearningStrategy,
429    ) -> Self {
430        let memory_buffer = match &strategy {
431            ContinualLearningStrategy::ExperienceReplay { buffer_size, .. } => {
432                MemoryBuffer::new(*buffer_size, SamplingStrategy::Random)
433            }
434            ContinualLearningStrategy::GradientEpisodicMemory { memory_size, .. } => {
435                MemoryBuffer::new(*memory_size, SamplingStrategy::GradientBased)
436            }
437            ContinualLearningStrategy::MemoryAugmented { memory_size, .. } => {
438                MemoryBuffer::new(*memory_size, SamplingStrategy::ImportanceBased)
439            }
440            _ => MemoryBuffer::new(1000, SamplingStrategy::Random), // Default
441        };
442
443        Self {
444            state: Untrained,
445            base_estimator: Some(base_estimator),
446            strategy,
447            memory_buffer,
448            learned_tasks: Vec::new(),
449            current_task_id: None,
450        }
451    }
452
453    /// Create EWC pipeline
454    #[must_use]
455    pub fn elastic_weight_consolidation(
456        base_estimator: Box<dyn PipelinePredictor>,
457        lambda: f64,
458        fisher_samples: usize,
459    ) -> Self {
460        Self::new(
461            base_estimator,
462            ContinualLearningStrategy::ElasticWeightConsolidation {
463                lambda,
464                fisher_samples,
465            },
466        )
467    }
468
469    /// Create experience replay pipeline
470    #[must_use]
471    pub fn experience_replay(
472        base_estimator: Box<dyn PipelinePredictor>,
473        buffer_size: usize,
474        replay_batch_size: usize,
475        replay_frequency: usize,
476    ) -> Self {
477        Self::new(
478            base_estimator,
479            ContinualLearningStrategy::ExperienceReplay {
480                buffer_size,
481                replay_batch_size,
482                replay_frequency,
483            },
484        )
485    }
486
487    /// Create Learning without Forgetting pipeline
488    #[must_use]
489    pub fn learning_without_forgetting(
490        base_estimator: Box<dyn PipelinePredictor>,
491        temperature: f64,
492        distillation_weight: f64,
493    ) -> Self {
494        Self::new(
495            base_estimator,
496            ContinualLearningStrategy::LearningWithoutForgetting {
497                temperature,
498                distillation_weight,
499            },
500        )
501    }
502
503    /// Set current task ID
504    #[must_use]
505    pub fn set_current_task(mut self, task_id: String) -> Self {
506        self.current_task_id = Some(task_id);
507        self
508    }
509}
510
511impl Estimator for ContinualLearningPipeline<Untrained> {
512    type Config = ();
513    type Error = SklearsError;
514    type Float = Float;
515
516    fn config(&self) -> &Self::Config {
517        &()
518    }
519}
520
521impl Fit<ArrayView2<'_, Float>, Option<&ArrayView1<'_, Float>>>
522    for ContinualLearningPipeline<Untrained>
523{
524    type Fitted = ContinualLearningPipeline<ContinualLearningPipelineTrained>;
525
526    fn fit(
527        mut self,
528        x: &ArrayView2<'_, Float>,
529        y: &Option<&ArrayView1<'_, Float>>,
530    ) -> SklResult<Self::Fitted> {
531        if let Some(y_values) = y.as_ref() {
532            let mut base_estimator = self.base_estimator.take().ok_or_else(|| {
533                SklearsError::InvalidInput("No base estimator provided".to_string())
534            })?;
535
536            // Apply continual learning strategy
537            let importance_weights =
538                self.apply_continual_learning_strategy(&mut base_estimator, x, y_values)?;
539
540            let task_id = self
541                .current_task_id
542                .clone()
543                .unwrap_or_else(|| "default_task".to_string());
544            self.learned_tasks.push(task_id.clone());
545
546            let mut task_performance = HashMap::new();
547            let mut perf_metrics = HashMap::new();
548            perf_metrics.insert("training_completed".to_string(), 1.0);
549            task_performance.insert(task_id, perf_metrics);
550
551            Ok(ContinualLearningPipeline {
552                state: ContinualLearningPipelineTrained {
553                    fitted_estimator: base_estimator,
554                    strategy: self.strategy,
555                    memory_buffer: self.memory_buffer,
556                    learned_tasks: self.learned_tasks,
557                    task_performance,
558                    importance_weights,
559                    n_features_in: x.ncols(),
560                    feature_names_in: None,
561                },
562                base_estimator: None,
563                strategy: ContinualLearningStrategy::ExperienceReplay {
564                    buffer_size: 1000,
565                    replay_batch_size: 32,
566                    replay_frequency: 10,
567                },
568                memory_buffer: MemoryBuffer::new(1000, SamplingStrategy::Random),
569                learned_tasks: Vec::new(),
570                current_task_id: None,
571            })
572        } else {
573            Err(SklearsError::InvalidInput(
574                "Target values required for continual learning".to_string(),
575            ))
576        }
577    }
578}
579
580impl ContinualLearningPipeline<Untrained> {
581    /// Apply continual learning strategy
582    fn apply_continual_learning_strategy(
583        &mut self,
584        estimator: &mut Box<dyn PipelinePredictor>,
585        x: &ArrayView2<'_, Float>,
586        y: &ArrayView1<'_, Float>,
587    ) -> SklResult<HashMap<String, f64>> {
588        let mut importance_weights = HashMap::new();
589
590        match &self.strategy {
591            ContinualLearningStrategy::ElasticWeightConsolidation {
592                lambda,
593                fisher_samples,
594            } => {
595                // Simulate EWC by computing importance weights
596                for i in 0..*fisher_samples.min(&x.nrows()) {
597                    let param_name = format!("param_{i}");
598                    let importance = self.compute_fisher_information(x, y, i);
599                    importance_weights.insert(param_name, importance * lambda);
600                }
601
602                // Fit with regularization (simulated)
603                estimator.fit(x, y)?;
604            }
605            ContinualLearningStrategy::ExperienceReplay {
606                replay_batch_size,
607                replay_frequency,
608                ..
609            } => {
610                // Store current samples in memory
611                for i in 0..x.nrows() {
612                    let sample = MemorySample {
613                        features: x.row(i).mapv(|v| v),
614                        target: y[i],
615                        task_id: self
616                            .current_task_id
617                            .clone()
618                            .unwrap_or_else(|| "default".to_string()),
619                        importance: 1.0,
620                        gradient_info: None,
621                    };
622                    self.memory_buffer.add_sample(sample);
623                }
624
625                // Train with replay
626                for epoch in 0..*replay_frequency {
627                    // Train on current data
628                    estimator.fit(x, y)?;
629
630                    // Replay from memory
631                    let replay_samples = self.memory_buffer.sample(*replay_batch_size);
632                    if !replay_samples.is_empty() {
633                        // Create replay batch (simplified)
634                        let replay_x = Array2::from_shape_vec(
635                            (replay_samples.len(), x.ncols()),
636                            replay_samples
637                                .iter()
638                                .flat_map(|s| s.features.iter().copied())
639                                .collect(),
640                        )
641                        .map_err(|e| SklearsError::InvalidData {
642                            reason: format!("Replay batch creation failed: {e}"),
643                        })?;
644
645                        let replay_y = Array1::from_vec(
646                            replay_samples.iter().map(|s| s.target as Float).collect(),
647                        );
648
649                        estimator.fit(&replay_x.view(), &replay_y.view())?;
650                    }
651                }
652            }
653            ContinualLearningStrategy::LearningWithoutForgetting {
654                temperature,
655                distillation_weight,
656            } => {
657                // Simulate LwF by storing distillation info
658                importance_weights.insert("temperature".to_string(), *temperature);
659                importance_weights.insert("distillation_weight".to_string(), *distillation_weight);
660
661                estimator.fit(x, y)?;
662            }
663            ContinualLearningStrategy::ProgressiveNetworks {
664                max_columns,
665                lateral_strength,
666            } => {
667                // Simulate progressive networks
668                importance_weights.insert("columns".to_string(), self.learned_tasks.len() as f64);
669                importance_weights.insert("lateral_strength".to_string(), *lateral_strength);
670
671                estimator.fit(x, y)?;
672            }
673            ContinualLearningStrategy::MemoryAugmented {
674                memory_size,
675                read_heads,
676                write_strength,
677            } => {
678                // Simulate memory-augmented networks
679                importance_weights.insert(
680                    "memory_usage".to_string(),
681                    self.memory_buffer.samples.len() as f64 / *memory_size as f64,
682                );
683                importance_weights.insert("read_heads".to_string(), *read_heads as f64);
684                importance_weights.insert("write_strength".to_string(), *write_strength);
685
686                estimator.fit(x, y)?;
687            }
688            ContinualLearningStrategy::GradientEpisodicMemory {
689                memory_size,
690                tolerance,
691            } => {
692                // Store samples with gradient information
693                for i in 0..x.nrows() {
694                    let mut gradient_info = HashMap::new();
695                    gradient_info.insert("grad_norm".to_string(), thread_rng().gen::<f64>()); // Placeholder
696
697                    let sample = MemorySample {
698                        features: x.row(i).mapv(|v| v),
699                        target: y[i],
700                        task_id: self
701                            .current_task_id
702                            .clone()
703                            .unwrap_or_else(|| "default".to_string()),
704                        importance: 1.0,
705                        gradient_info: Some(gradient_info),
706                    };
707                    self.memory_buffer.add_sample(sample);
708                }
709
710                importance_weights.insert(
711                    "memory_utilization".to_string(),
712                    self.memory_buffer.samples.len() as f64 / *memory_size as f64,
713                );
714                importance_weights.insert("tolerance".to_string(), *tolerance);
715
716                estimator.fit(x, y)?;
717            }
718        }
719
720        Ok(importance_weights)
721    }
722
723    /// Compute Fisher information approximation
724    fn compute_fisher_information(
725        &self,
726        x: &ArrayView2<'_, Float>,
727        y: &ArrayView1<'_, Float>,
728        param_idx: usize,
729    ) -> f64 {
730        // Simplified Fisher information computation
731        if param_idx < x.ncols() {
732            let feature_variance = x.column(param_idx).var(1.0);
733            feature_variance.max(1e-8) // Avoid zero importance
734        } else {
735            1e-4 // Default small importance
736        }
737    }
738}
739
740impl ContinualLearningPipeline<ContinualLearningPipelineTrained> {
741    /// Predict using the fitted continual learning model
742    pub fn predict(&self, x: &ArrayView2<'_, Float>) -> SklResult<Array1<f64>> {
743        self.state.fitted_estimator.predict(x)
744    }
745
746    /// Learn a new task
747    pub fn learn_task(&mut self, task: Task) -> SklResult<()> {
748        // Update current task performance
749        let mut task_perf = HashMap::new();
750        task_perf.insert("samples".to_string(), task.statistics.n_samples as f64);
751        task_perf.insert("difficulty".to_string(), task.statistics.difficulty);
752
753        self.state
754            .task_performance
755            .insert(task.id.clone(), task_perf);
756
757        // Apply continual learning for new task
758        let x_view = task.features.view().mapv(|v| v as Float);
759        let y_view = task.targets.view().mapv(|v| v as Float);
760
761        match &self.state.strategy {
762            ContinualLearningStrategy::ExperienceReplay {
763                replay_batch_size, ..
764            } => {
765                // Store new task samples
766                for i in 0..task.features.nrows() {
767                    let sample = MemorySample {
768                        features: task.features.row(i).to_owned(),
769                        target: task.targets[i],
770                        task_id: task.id.clone(),
771                        importance: 1.0,
772                        gradient_info: None,
773                    };
774                    self.state.memory_buffer.add_sample(sample);
775                }
776
777                // Train with replay
778                self.state
779                    .fitted_estimator
780                    .fit(&x_view.view(), &y_view.view())?;
781
782                // Replay from memory
783                let replay_samples = self.state.memory_buffer.sample(*replay_batch_size);
784                if !replay_samples.is_empty() {
785                    // Create replay batch
786                    let n_features = task.features.ncols();
787                    let replay_x = Array2::from_shape_vec(
788                        (replay_samples.len(), n_features),
789                        replay_samples
790                            .iter()
791                            .flat_map(|s| s.features.iter().copied().map(|v| v as Float))
792                            .collect(),
793                    )
794                    .map_err(|e| SklearsError::InvalidData {
795                        reason: format!("Replay batch creation failed: {e}"),
796                    })?;
797
798                    let replay_y = Array1::from_vec(
799                        replay_samples.iter().map(|s| s.target as Float).collect(),
800                    );
801
802                    self.state
803                        .fitted_estimator
804                        .fit(&replay_x.view(), &replay_y.view())?;
805                }
806            }
807            _ => {
808                // Default: just train on new task
809                self.state
810                    .fitted_estimator
811                    .fit(&x_view.view(), &y_view.view())?;
812            }
813        }
814
815        if !self.state.learned_tasks.contains(&task.id) {
816            self.state.learned_tasks.push(task.id);
817        }
818
819        Ok(())
820    }
821
822    /// Evaluate catastrophic forgetting
823    pub fn evaluate_forgetting(&self, previous_tasks: &[Task]) -> SklResult<HashMap<String, f64>> {
824        let mut forgetting_metrics = HashMap::new();
825
826        for task in previous_tasks {
827            let x_view = task.features.view().mapv(|v| v as Float);
828            let predictions = self.predict(&x_view.view())?;
829
830            // Simple accuracy computation
831            let correct = predictions
832                .iter()
833                .zip(task.targets.iter())
834                .filter(|(&pred, &actual)| (pred - actual).abs() < 0.5)
835                .count();
836
837            let accuracy = correct as f64 / task.targets.len() as f64;
838            forgetting_metrics.insert(format!("task_{}_accuracy", task.id), accuracy);
839        }
840
841        // Compute average forgetting
842        if !forgetting_metrics.is_empty() {
843            let avg_accuracy =
844                forgetting_metrics.values().sum::<f64>() / forgetting_metrics.len() as f64;
845            forgetting_metrics.insert("average_accuracy".to_string(), avg_accuracy);
846        }
847
848        Ok(forgetting_metrics)
849    }
850
851    /// Get memory buffer statistics
852    #[must_use]
853    pub fn memory_statistics(&self) -> HashMap<String, f64> {
854        self.state.memory_buffer.statistics()
855    }
856
857    /// Get learned tasks
858    #[must_use]
859    pub fn learned_tasks(&self) -> &[String] {
860        &self.state.learned_tasks
861    }
862
863    /// Get task performance
864    #[must_use]
865    pub fn task_performance(&self) -> &HashMap<String, HashMap<String, f64>> {
866        &self.state.task_performance
867    }
868
869    /// Get importance weights
870    #[must_use]
871    pub fn importance_weights(&self) -> &HashMap<String, f64> {
872        &self.state.importance_weights
873    }
874}
875
876#[allow(non_snake_case)]
877#[cfg(test)]
878mod tests {
879    use super::*;
880    use crate::MockPredictor;
881    use scirs2_core::ndarray::array;
882
883    #[test]
884    fn test_task_creation() {
885        let features = array![[1.0, 2.0], [3.0, 4.0]];
886        let targets = array![1.0, 0.0];
887
888        let mut task = Task::new("task1".to_string(), features, targets);
889        task.estimate_difficulty();
890
891        assert_eq!(task.id, "task1");
892        assert_eq!(task.statistics.n_samples, 2);
893        assert_eq!(task.statistics.n_features, 2);
894        assert!(task.statistics.difficulty > 0.0);
895    }
896
897    #[test]
898    fn test_memory_buffer() {
899        let mut buffer = MemoryBuffer::new(3, SamplingStrategy::Random);
900
901        let sample1 = MemorySample {
902            features: array![1.0, 2.0],
903            target: 1.0,
904            task_id: "task1".to_string(),
905            importance: 1.0,
906            gradient_info: None,
907        };
908
909        buffer.add_sample(sample1);
910        assert_eq!(buffer.samples.len(), 1);
911
912        let sampled = buffer.sample(1);
913        assert_eq!(sampled.len(), 1);
914    }
915
916    #[test]
917    fn test_continual_learning_pipeline() {
918        let x = array![[1.0, 2.0], [3.0, 4.0]];
919        let y = array![1.0, 0.0];
920
921        let base_estimator = Box::new(MockPredictor::new());
922        let pipeline = ContinualLearningPipeline::experience_replay(base_estimator, 100, 10, 5)
923            .set_current_task("task1".to_string());
924
925        let fitted_pipeline = pipeline.fit(&x.view(), &Some(&y.view())).unwrap();
926        let predictions = fitted_pipeline.predict(&x.view()).unwrap();
927
928        assert_eq!(predictions.len(), x.nrows());
929        assert!(fitted_pipeline
930            .learned_tasks()
931            .contains(&"task1".to_string()));
932    }
933
934    #[test]
935    fn test_ewc_pipeline() {
936        let x = array![[1.0, 2.0], [3.0, 4.0]];
937        let y = array![1.0, 0.0];
938
939        let base_estimator = Box::new(MockPredictor::new());
940        let pipeline =
941            ContinualLearningPipeline::elastic_weight_consolidation(base_estimator, 0.1, 10);
942
943        let fitted_pipeline = pipeline.fit(&x.view(), &Some(&y.view())).unwrap();
944
945        assert!(!fitted_pipeline.importance_weights().is_empty());
946
947        let predictions = fitted_pipeline.predict(&x.view()).unwrap();
948        assert_eq!(predictions.len(), x.nrows());
949    }
950
951    #[test]
952    fn test_new_task_learning() {
953        let x1 = array![[1.0, 2.0], [3.0, 4.0]];
954        let y1 = array![1.0, 0.0];
955
956        let base_estimator = Box::new(MockPredictor::new());
957        let pipeline = ContinualLearningPipeline::experience_replay(base_estimator, 100, 10, 5);
958
959        let mut fitted_pipeline = pipeline.fit(&x1.view(), &Some(&y1.view())).unwrap();
960
961        // Learn new task
962        let x2 = array![[5.0, 6.0], [7.0, 8.0]];
963        let y2 = array![0.0, 1.0];
964        let task2 = Task::new("task2".to_string(), x2, y2);
965
966        fitted_pipeline.learn_task(task2).unwrap();
967
968        assert_eq!(fitted_pipeline.learned_tasks().len(), 2);
969        assert!(fitted_pipeline
970            .learned_tasks()
971            .contains(&"task2".to_string()));
972    }
973}