Skip to main content

sklears_compose/
continual_learning.rs

1//! Continual learning pipeline components
2//!
3//! This module provides continual learning capabilities including catastrophic forgetting
4//! prevention, memory-based approaches, and progressive learning strategies.
5
6use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2, Axis};
7use scirs2_core::random::thread_rng;
8use scirs2_core::random::Rng;
9use sklears_core::{
10    error::Result as SklResult,
11    prelude::{Predict, SklearsError},
12    traits::{Estimator, Fit, Untrained},
13    types::Float,
14};
15use std::collections::{HashMap, VecDeque};
16use std::fmt::Debug;
17
18use crate::{PipelinePredictor, PipelineStep};
19
20/// Task representation for continual learning
21#[derive(Debug, Clone)]
22pub struct Task {
23    /// Unique task identifier
24    pub id: String,
25    /// Task features
26    pub features: Array2<f64>,
27    /// Task targets
28    pub targets: Array1<f64>,
29    /// Task metadata
30    pub metadata: HashMap<String, String>,
31    /// Task importance weights
32    pub importance_weights: Option<HashMap<String, f64>>,
33    /// Task learning statistics
34    pub statistics: TaskStatistics,
35}
36
37/// Statistics for a learning task
38#[derive(Debug, Clone)]
39pub struct TaskStatistics {
40    /// Number of training samples
41    pub n_samples: usize,
42    /// Number of features
43    pub n_features: usize,
44    /// Task difficulty (estimated)
45    pub difficulty: f64,
46    /// Performance metrics
47    pub performance: HashMap<String, f64>,
48    /// Learning time
49    pub learning_time: f64,
50}
51
52impl Task {
53    /// Create a new task
54    #[must_use]
55    pub fn new(id: String, features: Array2<f64>, targets: Array1<f64>) -> Self {
56        let n_samples = features.nrows();
57        let n_features = features.ncols();
58
59        Self {
60            id,
61            features,
62            targets,
63            metadata: HashMap::new(),
64            importance_weights: None,
65            statistics: TaskStatistics {
66                n_samples,
67                n_features,
68                difficulty: 1.0, // Default difficulty
69                performance: HashMap::new(),
70                learning_time: 0.0,
71            },
72        }
73    }
74
75    /// Set task metadata
76    #[must_use]
77    pub fn with_metadata(mut self, metadata: HashMap<String, String>) -> Self {
78        self.metadata = metadata;
79        self
80    }
81
82    /// Set importance weights
83    #[must_use]
84    pub fn with_importance_weights(mut self, weights: HashMap<String, f64>) -> Self {
85        self.importance_weights = Some(weights);
86        self
87    }
88
89    /// Estimate task difficulty based on data characteristics
90    pub fn estimate_difficulty(&mut self) {
91        let feature_variance = self.features.var_axis(Axis(0), 1.0).mean().unwrap_or(1.0);
92        let target_variance = self.targets.var(1.0);
93
94        // Simple difficulty estimation based on variance
95        self.statistics.difficulty = (feature_variance + target_variance).max(0.1);
96    }
97}
98
99/// Continual learning strategy
100#[derive(Debug, Clone)]
101pub enum ContinualLearningStrategy {
102    /// Elastic Weight Consolidation (EWC)
103    ElasticWeightConsolidation {
104        /// Regularization strength
105        lambda: f64,
106        /// Fisher information estimation samples
107        fisher_samples: usize,
108    },
109    /// Progressive Neural Networks
110    ProgressiveNetworks {
111        /// Maximum number of parallel columns
112        max_columns: usize,
113        /// Lateral connection strength
114        lateral_strength: f64,
115    },
116    /// Experience Replay
117    ExperienceReplay {
118        /// Memory buffer size
119        buffer_size: usize,
120        /// Replay batch size
121        replay_batch_size: usize,
122        /// Replay frequency
123        replay_frequency: usize,
124    },
125    /// Learning without Forgetting (`LwF`)
126    LearningWithoutForgetting {
127        /// Distillation temperature
128        temperature: f64,
129        /// Distillation weight
130        distillation_weight: f64,
131    },
132    /// Memory-Augmented Networks
133    MemoryAugmented {
134        /// Memory size
135        memory_size: usize,
136        /// Memory read heads
137        read_heads: usize,
138        /// Memory write strength
139        write_strength: f64,
140    },
141    /// Gradient Episodic Memory (GEM)
142    GradientEpisodicMemory {
143        /// Memory buffer size
144        memory_size: usize,
145        /// Inequality constraint tolerance
146        tolerance: f64,
147    },
148}
149
150/// Memory buffer for continual learning
151#[derive(Debug, Clone)]
152pub struct MemoryBuffer {
153    /// Maximum buffer size
154    max_size: usize,
155    /// Stored samples
156    samples: VecDeque<MemorySample>,
157    /// Task distributions
158    task_distributions: HashMap<String, usize>,
159    /// Sampling strategy
160    sampling_strategy: SamplingStrategy,
161}
162
163/// Memory sample with task information
164#[derive(Debug, Clone)]
165pub struct MemorySample {
166    /// Sample features
167    pub features: Array1<f64>,
168    /// Sample target
169    pub target: f64,
170    /// Source task ID
171    pub task_id: String,
172    /// Sample importance
173    pub importance: f64,
174    /// Gradient information (for GEM)
175    pub gradient_info: Option<HashMap<String, f64>>,
176}
177
178/// Sampling strategy for memory buffer
179#[derive(Debug, Clone)]
180pub enum SamplingStrategy {
181    /// Random sampling
182    Random,
183    /// Reservoir sampling
184    Reservoir,
185    /// Importance-based sampling
186    ImportanceBased,
187    /// Task-balanced sampling
188    TaskBalanced,
189    /// Gradient-based sampling (for GEM)
190    GradientBased,
191}
192
193impl MemoryBuffer {
194    /// Create a new memory buffer
195    #[must_use]
196    pub fn new(max_size: usize, sampling_strategy: SamplingStrategy) -> Self {
197        Self {
198            max_size,
199            samples: VecDeque::new(),
200            task_distributions: HashMap::new(),
201            sampling_strategy,
202        }
203    }
204
205    /// Add a sample to the buffer
206    pub fn add_sample(&mut self, sample: MemorySample) {
207        // Update task distribution
208        *self
209            .task_distributions
210            .entry(sample.task_id.clone())
211            .or_insert(0) += 1;
212
213        if self.samples.len() >= self.max_size {
214            match self.sampling_strategy {
215                SamplingStrategy::Random => {
216                    let replace_idx = thread_rng().gen_range(0..self.samples.len());
217                    if let Some(old_task_id) =
218                        self.samples.get(replace_idx).map(|s| s.task_id.clone())
219                    {
220                        if let Some(count) = self.task_distributions.get_mut(&old_task_id) {
221                            *count -= 1;
222                            if *count == 0 {
223                                self.task_distributions.remove(&old_task_id);
224                            }
225                        }
226                    }
227                    self.samples[replace_idx] = sample;
228                }
229                SamplingStrategy::Reservoir => {
230                    // Standard reservoir sampling
231                    let replace_idx = thread_rng().gen_range(0..(self.samples.len() + 1));
232                    if replace_idx < self.samples.len() {
233                        if let Some(old_task_id) =
234                            self.samples.get(replace_idx).map(|s| s.task_id.clone())
235                        {
236                            if let Some(count) = self.task_distributions.get_mut(&old_task_id) {
237                                *count -= 1;
238                                if *count == 0 {
239                                    self.task_distributions.remove(&old_task_id);
240                                }
241                            }
242                        }
243                        self.samples[replace_idx] = sample;
244                    }
245                }
246                SamplingStrategy::ImportanceBased => {
247                    // Replace sample with lowest importance
248                    let min_importance_idx = self
249                        .samples
250                        .iter()
251                        .enumerate()
252                        .min_by(|(_, a), (_, b)| {
253                            a.importance
254                                .partial_cmp(&b.importance)
255                                .unwrap_or(std::cmp::Ordering::Equal)
256                        })
257                        .map_or(0, |(idx, _)| idx);
258
259                    if sample.importance > self.samples[min_importance_idx].importance {
260                        if let Some(old_task_id) = self
261                            .samples
262                            .get(min_importance_idx)
263                            .map(|s| s.task_id.clone())
264                        {
265                            if let Some(count) = self.task_distributions.get_mut(&old_task_id) {
266                                *count -= 1;
267                                if *count == 0 {
268                                    self.task_distributions.remove(&old_task_id);
269                                }
270                            }
271                        }
272                        self.samples[min_importance_idx] = sample;
273                    }
274                }
275                SamplingStrategy::TaskBalanced => {
276                    // Replace from overrepresented task
277                    let max_task = self
278                        .task_distributions
279                        .iter()
280                        .max_by_key(|(_, &count)| count)
281                        .map(|(task_id, _)| task_id.clone());
282
283                    if let Some(overrep_task) = max_task {
284                        if let Some(idx) =
285                            self.samples.iter().position(|s| s.task_id == overrep_task)
286                        {
287                            if let Some(count) = self.task_distributions.get_mut(&overrep_task) {
288                                *count -= 1;
289                                if *count == 0 {
290                                    self.task_distributions.remove(&overrep_task);
291                                }
292                            }
293                            self.samples[idx] = sample;
294                        }
295                    }
296                }
297                SamplingStrategy::GradientBased => {
298                    // For GEM, replace based on gradient diversity
299                    // Simplified: replace random sample
300                    let replace_idx = thread_rng().gen_range(0..self.samples.len());
301                    if let Some(old_task_id) =
302                        self.samples.get(replace_idx).map(|s| s.task_id.clone())
303                    {
304                        if let Some(count) = self.task_distributions.get_mut(&old_task_id) {
305                            *count -= 1;
306                            if *count == 0 {
307                                self.task_distributions.remove(&old_task_id);
308                            }
309                        }
310                    }
311                    self.samples[replace_idx] = sample;
312                }
313            }
314        } else {
315            self.samples.push_back(sample);
316        }
317    }
318
319    /// Sample from the buffer
320    #[must_use]
321    pub fn sample(&self, n_samples: usize) -> Vec<&MemorySample> {
322        if self.samples.is_empty() {
323            return Vec::new();
324        }
325
326        let n_samples = n_samples.min(self.samples.len());
327        let mut sampled = Vec::new();
328
329        match self.sampling_strategy {
330            SamplingStrategy::Random | SamplingStrategy::Reservoir => {
331                for _ in 0..n_samples {
332                    let idx = thread_rng().gen_range(0..self.samples.len());
333                    sampled.push(&self.samples[idx]);
334                }
335            }
336            SamplingStrategy::ImportanceBased => {
337                // Sample based on importance weights
338                let total_importance: f64 = self.samples.iter().map(|s| s.importance).sum();
339                for _ in 0..n_samples {
340                    let target = thread_rng().random::<f64>() * total_importance;
341                    let mut cumulative = 0.0;
342                    for sample in &self.samples {
343                        cumulative += sample.importance;
344                        if cumulative >= target {
345                            sampled.push(sample);
346                            break;
347                        }
348                    }
349                }
350            }
351            SamplingStrategy::TaskBalanced => {
352                // Ensure balanced representation across tasks
353                let unique_tasks: Vec<String> = self.task_distributions.keys().cloned().collect();
354                if !unique_tasks.is_empty() {
355                    let samples_per_task = n_samples / unique_tasks.len();
356                    let extra_samples = n_samples % unique_tasks.len();
357
358                    for (i, task_id) in unique_tasks.iter().enumerate() {
359                        let task_samples: Vec<&MemorySample> = self
360                            .samples
361                            .iter()
362                            .filter(|s| &s.task_id == task_id)
363                            .collect();
364
365                        let task_sample_count = samples_per_task + usize::from(i < extra_samples);
366                        for _ in 0..task_sample_count.min(task_samples.len()) {
367                            let idx = thread_rng().gen_range(0..task_samples.len());
368                            sampled.push(task_samples[idx]);
369                        }
370                    }
371                }
372            }
373            SamplingStrategy::GradientBased => {
374                // For GEM, sample based on gradient information
375                // Simplified: random sampling for now
376                for _ in 0..n_samples {
377                    let idx = thread_rng().gen_range(0..self.samples.len());
378                    sampled.push(&self.samples[idx]);
379                }
380            }
381        }
382
383        sampled
384    }
385
386    /// Get buffer statistics
387    #[must_use]
388    pub fn statistics(&self) -> HashMap<String, f64> {
389        let mut stats = HashMap::new();
390        stats.insert("total_samples".to_string(), self.samples.len() as f64);
391        stats.insert(
392            "unique_tasks".to_string(),
393            self.task_distributions.len() as f64,
394        );
395
396        if !self.samples.is_empty() {
397            let avg_importance =
398                self.samples.iter().map(|s| s.importance).sum::<f64>() / self.samples.len() as f64;
399            stats.insert("average_importance".to_string(), avg_importance);
400        }
401
402        stats
403    }
404}
405
406/// Continual learning pipeline
407#[derive(Debug)]
408pub struct ContinualLearningPipeline<S = Untrained> {
409    state: S,
410    base_estimator: Option<Box<dyn PipelinePredictor>>,
411    strategy: ContinualLearningStrategy,
412    memory_buffer: MemoryBuffer,
413    learned_tasks: Vec<String>,
414    current_task_id: Option<String>,
415}
416
417/// Trained state for `ContinualLearningPipeline`
418#[derive(Debug)]
419pub struct ContinualLearningPipelineTrained {
420    fitted_estimator: Box<dyn PipelinePredictor>,
421    strategy: ContinualLearningStrategy,
422    memory_buffer: MemoryBuffer,
423    learned_tasks: Vec<String>,
424    task_performance: HashMap<String, HashMap<String, f64>>,
425    importance_weights: HashMap<String, f64>,
426    n_features_in: usize,
427    feature_names_in: Option<Vec<String>>,
428}
429
430impl ContinualLearningPipeline<Untrained> {
431    /// Create a new continual learning pipeline
432    #[must_use]
433    pub fn new(
434        base_estimator: Box<dyn PipelinePredictor>,
435        strategy: ContinualLearningStrategy,
436    ) -> Self {
437        let memory_buffer = match &strategy {
438            ContinualLearningStrategy::ExperienceReplay { buffer_size, .. } => {
439                MemoryBuffer::new(*buffer_size, SamplingStrategy::Random)
440            }
441            ContinualLearningStrategy::GradientEpisodicMemory { memory_size, .. } => {
442                MemoryBuffer::new(*memory_size, SamplingStrategy::GradientBased)
443            }
444            ContinualLearningStrategy::MemoryAugmented { memory_size, .. } => {
445                MemoryBuffer::new(*memory_size, SamplingStrategy::ImportanceBased)
446            }
447            _ => MemoryBuffer::new(1000, SamplingStrategy::Random), // Default
448        };
449
450        Self {
451            state: Untrained,
452            base_estimator: Some(base_estimator),
453            strategy,
454            memory_buffer,
455            learned_tasks: Vec::new(),
456            current_task_id: None,
457        }
458    }
459
460    /// Create EWC pipeline
461    #[must_use]
462    pub fn elastic_weight_consolidation(
463        base_estimator: Box<dyn PipelinePredictor>,
464        lambda: f64,
465        fisher_samples: usize,
466    ) -> Self {
467        Self::new(
468            base_estimator,
469            ContinualLearningStrategy::ElasticWeightConsolidation {
470                lambda,
471                fisher_samples,
472            },
473        )
474    }
475
476    /// Create experience replay pipeline
477    #[must_use]
478    pub fn experience_replay(
479        base_estimator: Box<dyn PipelinePredictor>,
480        buffer_size: usize,
481        replay_batch_size: usize,
482        replay_frequency: usize,
483    ) -> Self {
484        Self::new(
485            base_estimator,
486            ContinualLearningStrategy::ExperienceReplay {
487                buffer_size,
488                replay_batch_size,
489                replay_frequency,
490            },
491        )
492    }
493
494    /// Create Learning without Forgetting pipeline
495    #[must_use]
496    pub fn learning_without_forgetting(
497        base_estimator: Box<dyn PipelinePredictor>,
498        temperature: f64,
499        distillation_weight: f64,
500    ) -> Self {
501        Self::new(
502            base_estimator,
503            ContinualLearningStrategy::LearningWithoutForgetting {
504                temperature,
505                distillation_weight,
506            },
507        )
508    }
509
510    /// Set current task ID
511    #[must_use]
512    pub fn set_current_task(mut self, task_id: String) -> Self {
513        self.current_task_id = Some(task_id);
514        self
515    }
516}
517
518impl Estimator for ContinualLearningPipeline<Untrained> {
519    type Config = ();
520    type Error = SklearsError;
521    type Float = Float;
522
523    fn config(&self) -> &Self::Config {
524        &()
525    }
526}
527
528impl Fit<ArrayView2<'_, Float>, Option<&ArrayView1<'_, Float>>>
529    for ContinualLearningPipeline<Untrained>
530{
531    type Fitted = ContinualLearningPipeline<ContinualLearningPipelineTrained>;
532
533    fn fit(
534        mut self,
535        x: &ArrayView2<'_, Float>,
536        y: &Option<&ArrayView1<'_, Float>>,
537    ) -> SklResult<Self::Fitted> {
538        if let Some(y_values) = y.as_ref() {
539            let mut base_estimator = self.base_estimator.take().ok_or_else(|| {
540                SklearsError::InvalidInput("No base estimator provided".to_string())
541            })?;
542
543            // Apply continual learning strategy
544            let importance_weights =
545                self.apply_continual_learning_strategy(&mut base_estimator, x, y_values)?;
546
547            let task_id = self
548                .current_task_id
549                .clone()
550                .unwrap_or_else(|| "default_task".to_string());
551            self.learned_tasks.push(task_id.clone());
552
553            let mut task_performance = HashMap::new();
554            let mut perf_metrics = HashMap::new();
555            perf_metrics.insert("training_completed".to_string(), 1.0);
556            task_performance.insert(task_id, perf_metrics);
557
558            Ok(ContinualLearningPipeline {
559                state: ContinualLearningPipelineTrained {
560                    fitted_estimator: base_estimator,
561                    strategy: self.strategy,
562                    memory_buffer: self.memory_buffer,
563                    learned_tasks: self.learned_tasks,
564                    task_performance,
565                    importance_weights,
566                    n_features_in: x.ncols(),
567                    feature_names_in: None,
568                },
569                base_estimator: None,
570                strategy: ContinualLearningStrategy::ExperienceReplay {
571                    buffer_size: 1000,
572                    replay_batch_size: 32,
573                    replay_frequency: 10,
574                },
575                memory_buffer: MemoryBuffer::new(1000, SamplingStrategy::Random),
576                learned_tasks: Vec::new(),
577                current_task_id: None,
578            })
579        } else {
580            Err(SklearsError::InvalidInput(
581                "Target values required for continual learning".to_string(),
582            ))
583        }
584    }
585}
586
587impl ContinualLearningPipeline<Untrained> {
588    /// Apply continual learning strategy
589    fn apply_continual_learning_strategy(
590        &mut self,
591        estimator: &mut Box<dyn PipelinePredictor>,
592        x: &ArrayView2<'_, Float>,
593        y: &ArrayView1<'_, Float>,
594    ) -> SklResult<HashMap<String, f64>> {
595        let mut importance_weights = HashMap::new();
596
597        match &self.strategy {
598            ContinualLearningStrategy::ElasticWeightConsolidation {
599                lambda,
600                fisher_samples,
601            } => {
602                // Simulate EWC by computing importance weights
603                for i in 0..*fisher_samples.min(&x.nrows()) {
604                    let param_name = format!("param_{i}");
605                    let importance = self.compute_fisher_information(x, y, i);
606                    importance_weights.insert(param_name, importance * lambda);
607                }
608
609                // Fit with regularization (simulated)
610                estimator.fit(x, y)?;
611            }
612            ContinualLearningStrategy::ExperienceReplay {
613                replay_batch_size,
614                replay_frequency,
615                ..
616            } => {
617                // Store current samples in memory
618                for i in 0..x.nrows() {
619                    let sample = MemorySample {
620                        features: x.row(i).mapv(|v| v),
621                        target: y[i],
622                        task_id: self
623                            .current_task_id
624                            .clone()
625                            .unwrap_or_else(|| "default".to_string()),
626                        importance: 1.0,
627                        gradient_info: None,
628                    };
629                    self.memory_buffer.add_sample(sample);
630                }
631
632                // Train with replay
633                for epoch in 0..*replay_frequency {
634                    // Train on current data
635                    estimator.fit(x, y)?;
636
637                    // Replay from memory
638                    let replay_samples = self.memory_buffer.sample(*replay_batch_size);
639                    if !replay_samples.is_empty() {
640                        // Create replay batch (simplified)
641                        let replay_x = Array2::from_shape_vec(
642                            (replay_samples.len(), x.ncols()),
643                            replay_samples
644                                .iter()
645                                .flat_map(|s| s.features.iter().copied())
646                                .collect(),
647                        )
648                        .map_err(|e| SklearsError::InvalidData {
649                            reason: format!("Replay batch creation failed: {e}"),
650                        })?;
651
652                        let replay_y = Array1::from_vec(
653                            replay_samples.iter().map(|s| s.target as Float).collect(),
654                        );
655
656                        estimator.fit(&replay_x.view(), &replay_y.view())?;
657                    }
658                }
659            }
660            ContinualLearningStrategy::LearningWithoutForgetting {
661                temperature,
662                distillation_weight,
663            } => {
664                // Simulate LwF by storing distillation info
665                importance_weights.insert("temperature".to_string(), *temperature);
666                importance_weights.insert("distillation_weight".to_string(), *distillation_weight);
667
668                estimator.fit(x, y)?;
669            }
670            ContinualLearningStrategy::ProgressiveNetworks {
671                max_columns,
672                lateral_strength,
673            } => {
674                // Simulate progressive networks
675                importance_weights.insert("columns".to_string(), self.learned_tasks.len() as f64);
676                importance_weights.insert("lateral_strength".to_string(), *lateral_strength);
677
678                estimator.fit(x, y)?;
679            }
680            ContinualLearningStrategy::MemoryAugmented {
681                memory_size,
682                read_heads,
683                write_strength,
684            } => {
685                // Simulate memory-augmented networks
686                importance_weights.insert(
687                    "memory_usage".to_string(),
688                    self.memory_buffer.samples.len() as f64 / *memory_size as f64,
689                );
690                importance_weights.insert("read_heads".to_string(), *read_heads as f64);
691                importance_weights.insert("write_strength".to_string(), *write_strength);
692
693                estimator.fit(x, y)?;
694            }
695            ContinualLearningStrategy::GradientEpisodicMemory {
696                memory_size,
697                tolerance,
698            } => {
699                // Store samples with gradient information
700                for i in 0..x.nrows() {
701                    let mut gradient_info = HashMap::new();
702                    gradient_info.insert("grad_norm".to_string(), thread_rng().random::<f64>()); // Placeholder
703
704                    let sample = MemorySample {
705                        features: x.row(i).mapv(|v| v),
706                        target: y[i],
707                        task_id: self
708                            .current_task_id
709                            .clone()
710                            .unwrap_or_else(|| "default".to_string()),
711                        importance: 1.0,
712                        gradient_info: Some(gradient_info),
713                    };
714                    self.memory_buffer.add_sample(sample);
715                }
716
717                importance_weights.insert(
718                    "memory_utilization".to_string(),
719                    self.memory_buffer.samples.len() as f64 / *memory_size as f64,
720                );
721                importance_weights.insert("tolerance".to_string(), *tolerance);
722
723                estimator.fit(x, y)?;
724            }
725        }
726
727        Ok(importance_weights)
728    }
729
730    /// Compute Fisher information approximation
731    fn compute_fisher_information(
732        &self,
733        x: &ArrayView2<'_, Float>,
734        y: &ArrayView1<'_, Float>,
735        param_idx: usize,
736    ) -> f64 {
737        // Simplified Fisher information computation
738        if param_idx < x.ncols() {
739            let feature_variance = x.column(param_idx).var(1.0);
740            feature_variance.max(1e-8) // Avoid zero importance
741        } else {
742            1e-4 // Default small importance
743        }
744    }
745}
746
747impl ContinualLearningPipeline<ContinualLearningPipelineTrained> {
748    /// Predict using the fitted continual learning model
749    pub fn predict(&self, x: &ArrayView2<'_, Float>) -> SklResult<Array1<f64>> {
750        self.state.fitted_estimator.predict(x)
751    }
752
753    /// Learn a new task
754    pub fn learn_task(&mut self, task: Task) -> SklResult<()> {
755        // Update current task performance
756        let mut task_perf = HashMap::new();
757        task_perf.insert("samples".to_string(), task.statistics.n_samples as f64);
758        task_perf.insert("difficulty".to_string(), task.statistics.difficulty);
759
760        self.state
761            .task_performance
762            .insert(task.id.clone(), task_perf);
763
764        // Apply continual learning for new task
765        let x_view = task.features.view().mapv(|v| v as Float);
766        let y_view = task.targets.view().mapv(|v| v as Float);
767
768        match &self.state.strategy {
769            ContinualLearningStrategy::ExperienceReplay {
770                replay_batch_size, ..
771            } => {
772                // Store new task samples
773                for i in 0..task.features.nrows() {
774                    let sample = MemorySample {
775                        features: task.features.row(i).to_owned(),
776                        target: task.targets[i],
777                        task_id: task.id.clone(),
778                        importance: 1.0,
779                        gradient_info: None,
780                    };
781                    self.state.memory_buffer.add_sample(sample);
782                }
783
784                // Train with replay
785                self.state
786                    .fitted_estimator
787                    .fit(&x_view.view(), &y_view.view())?;
788
789                // Replay from memory
790                let replay_samples = self.state.memory_buffer.sample(*replay_batch_size);
791                if !replay_samples.is_empty() {
792                    // Create replay batch
793                    let n_features = task.features.ncols();
794                    let replay_x = Array2::from_shape_vec(
795                        (replay_samples.len(), n_features),
796                        replay_samples
797                            .iter()
798                            .flat_map(|s| s.features.iter().copied().map(|v| v as Float))
799                            .collect(),
800                    )
801                    .map_err(|e| SklearsError::InvalidData {
802                        reason: format!("Replay batch creation failed: {e}"),
803                    })?;
804
805                    let replay_y = Array1::from_vec(
806                        replay_samples.iter().map(|s| s.target as Float).collect(),
807                    );
808
809                    self.state
810                        .fitted_estimator
811                        .fit(&replay_x.view(), &replay_y.view())?;
812                }
813            }
814            _ => {
815                // Default: just train on new task
816                self.state
817                    .fitted_estimator
818                    .fit(&x_view.view(), &y_view.view())?;
819            }
820        }
821
822        if !self.state.learned_tasks.contains(&task.id) {
823            self.state.learned_tasks.push(task.id);
824        }
825
826        Ok(())
827    }
828
829    /// Evaluate catastrophic forgetting
830    pub fn evaluate_forgetting(&self, previous_tasks: &[Task]) -> SklResult<HashMap<String, f64>> {
831        let mut forgetting_metrics = HashMap::new();
832
833        for task in previous_tasks {
834            let x_view = task.features.view().mapv(|v| v as Float);
835            let predictions = self.predict(&x_view.view())?;
836
837            // Simple accuracy computation
838            let correct = predictions
839                .iter()
840                .zip(task.targets.iter())
841                .filter(|(&pred, &actual)| (pred - actual).abs() < 0.5)
842                .count();
843
844            let accuracy = correct as f64 / task.targets.len() as f64;
845            forgetting_metrics.insert(format!("task_{}_accuracy", task.id), accuracy);
846        }
847
848        // Compute average forgetting
849        if !forgetting_metrics.is_empty() {
850            let avg_accuracy =
851                forgetting_metrics.values().sum::<f64>() / forgetting_metrics.len() as f64;
852            forgetting_metrics.insert("average_accuracy".to_string(), avg_accuracy);
853        }
854
855        Ok(forgetting_metrics)
856    }
857
858    /// Get memory buffer statistics
859    #[must_use]
860    pub fn memory_statistics(&self) -> HashMap<String, f64> {
861        self.state.memory_buffer.statistics()
862    }
863
864    /// Get learned tasks
865    #[must_use]
866    pub fn learned_tasks(&self) -> &[String] {
867        &self.state.learned_tasks
868    }
869
870    /// Get task performance
871    #[must_use]
872    pub fn task_performance(&self) -> &HashMap<String, HashMap<String, f64>> {
873        &self.state.task_performance
874    }
875
876    /// Get importance weights
877    #[must_use]
878    pub fn importance_weights(&self) -> &HashMap<String, f64> {
879        &self.state.importance_weights
880    }
881}
882
883#[allow(non_snake_case)]
884#[cfg(test)]
885mod tests {
886    use super::*;
887    use crate::MockPredictor;
888    use scirs2_core::ndarray::array;
889
890    #[test]
891    fn test_task_creation() {
892        let features = array![[1.0, 2.0], [3.0, 4.0]];
893        let targets = array![1.0, 0.0];
894
895        let mut task = Task::new("task1".to_string(), features, targets);
896        task.estimate_difficulty();
897
898        assert_eq!(task.id, "task1");
899        assert_eq!(task.statistics.n_samples, 2);
900        assert_eq!(task.statistics.n_features, 2);
901        assert!(task.statistics.difficulty > 0.0);
902    }
903
904    #[test]
905    fn test_memory_buffer() {
906        let mut buffer = MemoryBuffer::new(3, SamplingStrategy::Random);
907
908        let sample1 = MemorySample {
909            features: array![1.0, 2.0],
910            target: 1.0,
911            task_id: "task1".to_string(),
912            importance: 1.0,
913            gradient_info: None,
914        };
915
916        buffer.add_sample(sample1);
917        assert_eq!(buffer.samples.len(), 1);
918
919        let sampled = buffer.sample(1);
920        assert_eq!(sampled.len(), 1);
921    }
922
923    #[test]
924    fn test_continual_learning_pipeline() {
925        let x = array![[1.0, 2.0], [3.0, 4.0]];
926        let y = array![1.0, 0.0];
927
928        let base_estimator = Box::new(MockPredictor::new());
929        let pipeline = ContinualLearningPipeline::experience_replay(base_estimator, 100, 10, 5)
930            .set_current_task("task1".to_string());
931
932        let fitted_pipeline = pipeline
933            .fit(&x.view(), &Some(&y.view()))
934            .expect("operation should succeed");
935        let predictions = fitted_pipeline.predict(&x.view()).unwrap_or_default();
936
937        assert_eq!(predictions.len(), x.nrows());
938        assert!(fitted_pipeline
939            .learned_tasks()
940            .contains(&"task1".to_string()));
941    }
942
943    #[test]
944    fn test_ewc_pipeline() {
945        let x = array![[1.0, 2.0], [3.0, 4.0]];
946        let y = array![1.0, 0.0];
947
948        let base_estimator = Box::new(MockPredictor::new());
949        let pipeline =
950            ContinualLearningPipeline::elastic_weight_consolidation(base_estimator, 0.1, 10);
951
952        let fitted_pipeline = pipeline
953            .fit(&x.view(), &Some(&y.view()))
954            .expect("operation should succeed");
955
956        assert!(!fitted_pipeline.importance_weights().is_empty());
957
958        let predictions = fitted_pipeline.predict(&x.view()).unwrap_or_default();
959        assert_eq!(predictions.len(), x.nrows());
960    }
961
962    #[test]
963    fn test_new_task_learning() {
964        let x1 = array![[1.0, 2.0], [3.0, 4.0]];
965        let y1 = array![1.0, 0.0];
966
967        let base_estimator = Box::new(MockPredictor::new());
968        let pipeline = ContinualLearningPipeline::experience_replay(base_estimator, 100, 10, 5);
969
970        let mut fitted_pipeline = pipeline
971            .fit(&x1.view(), &Some(&y1.view()))
972            .expect("operation should succeed");
973
974        // Learn new task
975        let x2 = array![[5.0, 6.0], [7.0, 8.0]];
976        let y2 = array![0.0, 1.0];
977        let task2 = Task::new("task2".to_string(), x2, y2);
978
979        fitted_pipeline.learn_task(task2).unwrap_or_default();
980
981        assert_eq!(fitted_pipeline.learned_tasks().len(), 2);
982        assert!(fitted_pipeline
983            .learned_tasks()
984            .contains(&"task2".to_string()));
985    }
986}