sklears_dummy/
meta_learning.rs

1//! Meta-Learning Baseline Estimators
2//!
3//! This module provides meta-learning baseline estimators that can adapt to new tasks
4//! with limited data, transfer knowledge between domains, and adapt to distribution shifts.
5//!
6//! The module includes:
7//! - [`FewShotBaselineClassifier`] - Few-shot learning baseline using prototype-based methods
8//! - [`FewShotBaselineRegressor`] - Few-shot learning baseline for regression tasks
9//! - [`TransferLearningBaseline`] - Transfer learning baseline using source domain knowledge
10//! - [`DomainAdaptationBaseline`] - Domain adaptation baseline for distribution shift
11//! - [`ContinualLearningBaseline`] - Continual learning baseline with catastrophic forgetting prevention
12
13use scirs2_core::ndarray::{Array1, Array2, ArrayView1, Axis};
14use scirs2_core::random::{prelude::*, thread_rng, Distribution, Rng};
15#[cfg(feature = "serde")]
16use serde::{Deserialize, Serialize};
17use sklears_core::error::SklearsError;
18use sklears_core::traits::{Fit, Predict};
19use std::collections::HashMap;
20
21/// Strategy for few-shot learning baselines
22#[derive(Debug, Clone)]
23#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
24pub enum FewShotStrategy {
25    /// Nearest prototype strategy - predict based on closest class prototype
26    NearestPrototype,
27    /// K-nearest neighbors with small k for few-shot scenarios
28    KNearestNeighbors { k: usize },
29    /// Support-based prediction using support set statistics
30    SupportBased,
31    /// Centroid-based classification using class centroids
32    Centroid,
33    /// Probabilistic prediction using class conditional densities
34    Probabilistic,
35}
36
37/// Strategy for transfer learning baselines
38#[derive(Debug, Clone)]
39#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
40pub enum TransferStrategy {
41    /// Use source domain statistics as prior knowledge
42    SourcePrior { source_weight: f64 },
43    /// Feature-based transfer using feature similarities
44    FeatureBased { adaptation_rate: f64 },
45    /// Instance-based transfer using instance weighting
46    InstanceBased { similarity_threshold: f64 },
47    /// Model-based transfer using source model predictions
48    ModelBased { confidence_threshold: f64 },
49    /// Ensemble transfer combining multiple source domains
50    EnsembleTransfer { domain_weights: Vec<f64> },
51}
52
53/// Strategy for domain adaptation baselines
54#[derive(Debug, Clone)]
55#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
56pub enum DomainAdaptationStrategy {
57    /// Feature space alignment using mean/variance matching
58    FeatureAlignment,
59    /// Instance reweighting for domain shift
60    InstanceReweighting { adaptation_strength: f64 },
61    /// Gradient reversal approximation for adversarial adaptation
62    GradientReversal { lambda: f64 },
63    /// Subspace alignment using principal component alignment
64    SubspaceAlignment { subspace_dim: usize },
65    /// Maximum mean discrepancy minimization approximation
66    MMDMinimization { bandwidth: f64 },
67}
68
69/// Strategy for continual learning baselines
70#[derive(Debug, Clone)]
71#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
72pub enum ContinualStrategy {
73    /// Elastic weight consolidation approximation
74    ElasticWeightConsolidation { importance_weight: f64 },
75    /// Rehearsal-based learning with memory buffer
76    Rehearsal { memory_size: usize },
77    /// Progressive neural networks approximation
78    Progressive { column_capacity: usize },
79    /// Learning without forgetting approximation
80    LearningWithoutForgetting { distillation_weight: f64 },
81    /// Averaged gradient episodic memory approximation
82    AGEM { memory_strength: f64 },
83}
84
85/// Few-shot learning baseline classifier
86#[derive(Debug, Clone)]
87pub struct FewShotBaselineClassifier {
88    strategy: FewShotStrategy,
89    random_state: Option<u64>,
90}
91
92/// Fitted few-shot classifier
93#[derive(Debug, Clone)]
94pub struct FittedFewShotClassifier {
95    strategy: FewShotStrategy,
96    prototypes: HashMap<i32, Array1<f64>>,
97    support_samples: Option<(Array2<f64>, Array1<i32>)>,
98    class_counts: HashMap<i32, usize>,
99    random_state: Option<u64>,
100}
101
102/// Few-shot learning baseline regressor
103#[derive(Debug, Clone)]
104pub struct FewShotBaselineRegressor {
105    strategy: FewShotStrategy,
106    random_state: Option<u64>,
107}
108
109/// Fitted few-shot regressor
110#[derive(Debug, Clone)]
111pub struct FittedFewShotRegressor {
112    strategy: FewShotStrategy,
113    prototypes: Vec<(Array1<f64>, f64)>,
114    support_samples: Option<(Array2<f64>, Array1<f64>)>,
115    random_state: Option<u64>,
116}
117
118/// Transfer learning baseline estimator
119#[derive(Debug, Clone)]
120pub struct TransferLearningBaseline {
121    strategy: TransferStrategy,
122    source_statistics: Option<SourceDomainStats>,
123    random_state: Option<u64>,
124}
125
126/// Source domain statistics for transfer learning
127#[derive(Debug, Clone)]
128#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
129pub struct SourceDomainStats {
130    feature_means: Array1<f64>,
131    feature_stds: Array1<f64>,
132    class_priors: HashMap<i32, f64>,
133    target_mean: f64,
134    target_std: f64,
135}
136
137/// Fitted transfer learning baseline
138#[derive(Debug, Clone)]
139pub struct FittedTransferBaseline {
140    strategy: TransferStrategy,
141    source_stats: SourceDomainStats,
142    target_stats: SourceDomainStats,
143    adaptation_weights: Array1<f64>,
144    random_state: Option<u64>,
145}
146
147/// Domain adaptation baseline estimator
148#[derive(Debug, Clone)]
149pub struct DomainAdaptationBaseline {
150    strategy: DomainAdaptationStrategy,
151    source_domain_data: Option<(Array2<f64>, Array1<f64>)>,
152    random_state: Option<u64>,
153}
154
155/// Fitted domain adaptation baseline
156#[derive(Debug, Clone)]
157pub struct FittedDomainAdaptationBaseline {
158    strategy: DomainAdaptationStrategy,
159    source_stats: SourceDomainStats,
160    target_stats: SourceDomainStats,
161    adaptation_matrix: Array2<f64>,
162    instance_weights: Array1<f64>,
163    random_state: Option<u64>,
164}
165
166impl FewShotBaselineClassifier {
167    /// Create a new few-shot baseline classifier
168    pub fn new(strategy: FewShotStrategy) -> Self {
169        Self {
170            strategy,
171            random_state: None,
172        }
173    }
174
175    /// Set the random state for reproducible results
176    pub fn with_random_state(mut self, seed: u64) -> Self {
177        self.random_state = Some(seed);
178        self
179    }
180}
181
182impl Fit<Array2<f64>, Array1<i32>, FittedFewShotClassifier> for FewShotBaselineClassifier {
183    type Fitted = FittedFewShotClassifier;
184    fn fit(
185        self,
186        x: &Array2<f64>,
187        y: &Array1<i32>,
188    ) -> Result<FittedFewShotClassifier, SklearsError> {
189        if x.nrows() != y.len() {
190            return Err(SklearsError::ShapeMismatch {
191                expected: format!("{} samples", x.nrows()),
192                actual: format!("{} labels", y.len()),
193            });
194        }
195
196        let rng = self.random_state.map_or_else(
197            || Box::new(thread_rng()) as Box<dyn RngCore>,
198            |seed| Box::new(StdRng::seed_from_u64(seed)),
199        );
200
201        // Compute class counts
202        let mut class_counts = HashMap::new();
203        for &class in y.iter() {
204            *class_counts.entry(class).or_insert(0) += 1;
205        }
206
207        // Compute prototypes for each class
208        let mut prototypes = HashMap::new();
209        for &class in class_counts.keys() {
210            let class_indices: Vec<usize> = y
211                .iter()
212                .enumerate()
213                .filter(|(_, &label)| label == class)
214                .map(|(i, _)| i)
215                .collect();
216
217            if !class_indices.is_empty() {
218                let class_data: Vec<f64> = class_indices
219                    .iter()
220                    .flat_map(|&i| x.row(i).to_vec())
221                    .collect();
222                let class_samples: Array2<f64> =
223                    Array2::from_shape_vec((class_indices.len(), x.ncols()), class_data)?;
224
225                let prototype = class_samples.mean_axis(Axis(0)).unwrap();
226                prototypes.insert(class, prototype);
227            }
228        }
229
230        // Store support samples for certain strategies
231        let support_samples = match self.strategy {
232            FewShotStrategy::KNearestNeighbors { .. } | FewShotStrategy::SupportBased => {
233                Some((x.clone(), y.clone()))
234            }
235            _ => None,
236        };
237
238        Ok(FittedFewShotClassifier {
239            strategy: self.strategy.clone(),
240            prototypes,
241            support_samples,
242            class_counts,
243            random_state: self.random_state,
244        })
245    }
246}
247
248impl Predict<Array2<f64>, Array1<i32>> for FittedFewShotClassifier {
249    fn predict(&self, x: &Array2<f64>) -> Result<Array1<i32>, SklearsError> {
250        let mut predictions = Vec::with_capacity(x.nrows());
251        let mut rng = self.random_state.map_or_else(
252            || Box::new(thread_rng()) as Box<dyn RngCore>,
253            |seed| Box::new(StdRng::seed_from_u64(seed)),
254        );
255
256        for sample in x.rows() {
257            let prediction = match &self.strategy {
258                FewShotStrategy::NearestPrototype => self.predict_nearest_prototype(&sample)?,
259                FewShotStrategy::KNearestNeighbors { k } => self.predict_knn(&sample, *k)?,
260                FewShotStrategy::SupportBased => self.predict_support_based(&sample)?,
261                FewShotStrategy::Centroid => self.predict_centroid(&sample)?,
262                FewShotStrategy::Probabilistic => self.predict_probabilistic(&sample, &mut *rng)?,
263            };
264            predictions.push(prediction);
265        }
266
267        Ok(Array1::from_vec(predictions))
268    }
269}
270
271impl FittedFewShotClassifier {
272    fn predict_nearest_prototype(&self, sample: &ArrayView1<f64>) -> Result<i32, SklearsError> {
273        let mut min_distance = f64::INFINITY;
274        let mut best_class = 0;
275
276        for (&class, prototype) in &self.prototypes {
277            let distance: f64 = sample
278                .iter()
279                .zip(prototype.iter())
280                .map(|(a, b)| (a - b).powi(2))
281                .sum::<f64>()
282                .sqrt();
283
284            if distance < min_distance {
285                min_distance = distance;
286                best_class = class;
287            }
288        }
289
290        Ok(best_class)
291    }
292
293    fn predict_knn(&self, sample: &ArrayView1<f64>, k: usize) -> Result<i32, SklearsError> {
294        if let Some((ref support_x, ref support_y)) = self.support_samples {
295            let mut distances: Vec<(f64, i32)> = Vec::new();
296
297            for (i, support_sample) in support_x.rows().into_iter().enumerate() {
298                let distance: f64 = sample
299                    .iter()
300                    .zip(support_sample.iter())
301                    .map(|(a, b)| (a - b).powi(2))
302                    .sum::<f64>()
303                    .sqrt();
304                distances.push((distance, support_y[i]));
305            }
306
307            distances.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
308
309            let k_neighbors = distances.into_iter().take(k).collect::<Vec<_>>();
310            let mut class_votes: HashMap<i32, usize> = HashMap::new();
311
312            for (_, class) in k_neighbors {
313                *class_votes.entry(class).or_insert(0) += 1;
314            }
315
316            let best_class = class_votes
317                .into_iter()
318                .max_by_key(|(_, count)| *count)
319                .map(|(class, _)| class)
320                .unwrap_or(0);
321
322            Ok(best_class)
323        } else {
324            self.predict_nearest_prototype(sample)
325        }
326    }
327
328    fn predict_support_based(&self, sample: &ArrayView1<f64>) -> Result<i32, SklearsError> {
329        // Use support set statistics for prediction
330        if let Some((ref support_x, ref support_y)) = self.support_samples {
331            // Simple support-based prediction using distance weighting
332            let mut weighted_votes: HashMap<i32, f64> = HashMap::new();
333
334            for (i, support_sample) in support_x.rows().into_iter().enumerate() {
335                let distance: f64 = sample
336                    .iter()
337                    .zip(support_sample.iter())
338                    .map(|(a, b)| (a - b).powi(2))
339                    .sum::<f64>()
340                    .sqrt();
341
342                let weight = 1.0 / (1.0 + distance);
343                let class = support_y[i];
344                *weighted_votes.entry(class).or_insert(0.0) += weight;
345            }
346
347            let best_class = weighted_votes
348                .into_iter()
349                .max_by(|(_, weight_a), (_, weight_b)| weight_a.partial_cmp(weight_b).unwrap())
350                .map(|(class, _)| class)
351                .unwrap_or(0);
352
353            Ok(best_class)
354        } else {
355            self.predict_nearest_prototype(sample)
356        }
357    }
358
359    fn predict_centroid(&self, sample: &ArrayView1<f64>) -> Result<i32, SklearsError> {
360        // Same as nearest prototype for few-shot scenarios
361        self.predict_nearest_prototype(sample)
362    }
363
364    fn predict_probabilistic(
365        &self,
366        sample: &ArrayView1<f64>,
367        rng: &mut dyn RngCore,
368    ) -> Result<i32, SklearsError> {
369        // Compute class probabilities based on prototype distances
370        let mut class_probs: HashMap<i32, f64> = HashMap::new();
371        let mut total_weight = 0.0;
372
373        for (&class, prototype) in &self.prototypes {
374            let distance: f64 = sample
375                .iter()
376                .zip(prototype.iter())
377                .map(|(a, b)| (a - b).powi(2))
378                .sum::<f64>()
379                .sqrt();
380
381            let weight = (-distance).exp();
382            class_probs.insert(class, weight);
383            total_weight += weight;
384        }
385
386        // Normalize probabilities
387        for (_, prob) in class_probs.iter_mut() {
388            *prob /= total_weight;
389        }
390
391        // Sample from the probability distribution
392        let rand_val: f64 = rng.gen();
393        let mut cumulative_prob = 0.0;
394
395        for (&class, &prob) in &class_probs {
396            cumulative_prob += prob;
397            if rand_val <= cumulative_prob {
398                return Ok(class);
399            }
400        }
401
402        // Fallback to most likely class
403        let best_class = class_probs
404            .into_iter()
405            .max_by(|(_, prob_a), (_, prob_b)| prob_a.partial_cmp(prob_b).unwrap())
406            .map(|(class, _)| class)
407            .unwrap_or(0);
408
409        Ok(best_class)
410    }
411}
412
413impl FewShotBaselineRegressor {
414    /// Create a new few-shot baseline regressor
415    pub fn new(strategy: FewShotStrategy) -> Self {
416        Self {
417            strategy,
418            random_state: None,
419        }
420    }
421
422    /// Set the random state for reproducible results
423    pub fn with_random_state(mut self, seed: u64) -> Self {
424        self.random_state = Some(seed);
425        self
426    }
427}
428
429impl Fit<Array2<f64>, Array1<f64>, FittedFewShotRegressor> for FewShotBaselineRegressor {
430    type Fitted = FittedFewShotRegressor;
431    fn fit(self, x: &Array2<f64>, y: &Array1<f64>) -> Result<FittedFewShotRegressor, SklearsError> {
432        if x.nrows() != y.len() {
433            return Err(SklearsError::ShapeMismatch {
434                expected: format!("{} samples", x.nrows()),
435                actual: format!("{} labels", y.len()),
436            });
437        }
438
439        // Create prototypes from training data
440        let mut prototypes = Vec::new();
441        for (i, sample) in x.rows().into_iter().enumerate() {
442            prototypes.push((sample.to_owned(), y[i]));
443        }
444
445        // Store support samples for certain strategies
446        let support_samples = match self.strategy {
447            FewShotStrategy::KNearestNeighbors { .. } | FewShotStrategy::SupportBased => {
448                Some((x.clone(), y.clone()))
449            }
450            _ => None,
451        };
452
453        Ok(FittedFewShotRegressor {
454            strategy: self.strategy.clone(),
455            prototypes,
456            support_samples,
457            random_state: self.random_state,
458        })
459    }
460}
461
462impl Predict<Array2<f64>, Array1<f64>> for FittedFewShotRegressor {
463    fn predict(&self, x: &Array2<f64>) -> Result<Array1<f64>, SklearsError> {
464        let mut predictions = Vec::with_capacity(x.nrows());
465        let mut rng = self.random_state.map_or_else(
466            || Box::new(thread_rng()) as Box<dyn RngCore>,
467            |seed| Box::new(StdRng::seed_from_u64(seed)),
468        );
469
470        for sample in x.rows() {
471            let prediction = match &self.strategy {
472                FewShotStrategy::NearestPrototype => self.predict_nearest_prototype(&sample)?,
473                FewShotStrategy::KNearestNeighbors { k } => self.predict_knn(&sample, *k)?,
474                FewShotStrategy::SupportBased => self.predict_support_based(&sample)?,
475                FewShotStrategy::Centroid => self.predict_centroid(&sample)?,
476                FewShotStrategy::Probabilistic => self.predict_probabilistic(&sample, &mut *rng)?,
477            };
478            predictions.push(prediction);
479        }
480
481        Ok(Array1::from_vec(predictions))
482    }
483}
484
485impl FittedFewShotRegressor {
486    fn predict_nearest_prototype(&self, sample: &ArrayView1<f64>) -> Result<f64, SklearsError> {
487        let mut min_distance = f64::INFINITY;
488        let mut best_value = 0.0;
489
490        for (prototype, value) in &self.prototypes {
491            let distance: f64 = sample
492                .iter()
493                .zip(prototype.iter())
494                .map(|(a, b)| (a - b).powi(2))
495                .sum::<f64>()
496                .sqrt();
497
498            if distance < min_distance {
499                min_distance = distance;
500                best_value = *value;
501            }
502        }
503
504        Ok(best_value)
505    }
506
507    fn predict_knn(&self, sample: &ArrayView1<f64>, k: usize) -> Result<f64, SklearsError> {
508        if let Some((ref support_x, ref support_y)) = self.support_samples {
509            let mut distances: Vec<(f64, f64)> = Vec::new();
510
511            for (i, support_sample) in support_x.rows().into_iter().enumerate() {
512                let distance: f64 = sample
513                    .iter()
514                    .zip(support_sample.iter())
515                    .map(|(a, b)| (a - b).powi(2))
516                    .sum::<f64>()
517                    .sqrt();
518                distances.push((distance, support_y[i]));
519            }
520
521            distances.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
522
523            let k_neighbors = distances.into_iter().take(k).collect::<Vec<_>>();
524            let mean_value =
525                k_neighbors.iter().map(|(_, value)| value).sum::<f64>() / k_neighbors.len() as f64;
526
527            Ok(mean_value)
528        } else {
529            self.predict_nearest_prototype(sample)
530        }
531    }
532
533    fn predict_support_based(&self, sample: &ArrayView1<f64>) -> Result<f64, SklearsError> {
534        if let Some((ref support_x, ref support_y)) = self.support_samples {
535            let mut weighted_sum = 0.0;
536            let mut total_weight = 0.0;
537
538            for (i, support_sample) in support_x.rows().into_iter().enumerate() {
539                let distance: f64 = sample
540                    .iter()
541                    .zip(support_sample.iter())
542                    .map(|(a, b)| (a - b).powi(2))
543                    .sum::<f64>()
544                    .sqrt();
545
546                let weight = 1.0 / (1.0 + distance);
547                weighted_sum += weight * support_y[i];
548                total_weight += weight;
549            }
550
551            Ok(weighted_sum / total_weight)
552        } else {
553            self.predict_nearest_prototype(sample)
554        }
555    }
556
557    fn predict_centroid(&self, _sample: &ArrayView1<f64>) -> Result<f64, SklearsError> {
558        // For regression, centroid prediction is the mean of all training targets
559        let mean_target = self.prototypes.iter().map(|(_, value)| value).sum::<f64>()
560            / self.prototypes.len() as f64;
561        Ok(mean_target)
562    }
563
564    fn predict_probabilistic(
565        &self,
566        sample: &ArrayView1<f64>,
567        rng: &mut dyn RngCore,
568    ) -> Result<f64, SklearsError> {
569        // Gaussian process-like prediction with uncertainty
570        let mut weighted_sum = 0.0;
571        let mut total_weight = 0.0;
572        let mut variance_sum = 0.0;
573
574        for (prototype, value) in &self.prototypes {
575            let distance: f64 = sample
576                .iter()
577                .zip(prototype.iter())
578                .map(|(a, b)| (a - b).powi(2))
579                .sum::<f64>()
580                .sqrt();
581
582            let weight = (-distance).exp();
583            weighted_sum += weight * value;
584            total_weight += weight;
585            variance_sum += weight * value * value;
586        }
587
588        let mean = weighted_sum / total_weight;
589        let variance = (variance_sum / total_weight) - mean * mean;
590        let std_dev = variance.sqrt().max(0.1);
591
592        // Sample from normal distribution
593        use scirs2_core::random::essentials::Normal;
594        let normal = Normal::new(mean, std_dev).map_err(|_| SklearsError::InvalidParameter {
595            name: "normal_distribution".to_string(),
596            reason: "Invalid parameters for normal distribution".to_string(),
597        })?;
598        let sample_value = normal.sample(rng);
599
600        Ok(sample_value)
601    }
602}
603
604impl TransferLearningBaseline {
605    /// Create a new transfer learning baseline
606    pub fn new(strategy: TransferStrategy) -> Self {
607        Self {
608            strategy,
609            source_statistics: None,
610            random_state: None,
611        }
612    }
613
614    /// Set source domain statistics for transfer
615    pub fn with_source_statistics(mut self, stats: SourceDomainStats) -> Self {
616        self.source_statistics = Some(stats);
617        self
618    }
619
620    /// Set the random state for reproducible results
621    pub fn with_random_state(mut self, seed: u64) -> Self {
622        self.random_state = Some(seed);
623        self
624    }
625}
626
627impl Fit<Array2<f64>, Array1<f64>, FittedTransferBaseline> for TransferLearningBaseline {
628    type Fitted = FittedTransferBaseline;
629    fn fit(self, x: &Array2<f64>, y: &Array1<f64>) -> Result<FittedTransferBaseline, SklearsError> {
630        if x.nrows() != y.len() {
631            return Err(SklearsError::ShapeMismatch {
632                expected: format!("{} samples", x.nrows()),
633                actual: format!("{} labels", y.len()),
634            });
635        }
636
637        // Compute target domain statistics
638        let feature_means = x.mean_axis(Axis(0)).unwrap();
639        let feature_stds = x.std_axis(Axis(0), 0.0);
640        let target_mean = y.mean().unwrap();
641        let target_std = y.std(0.0);
642
643        let target_stats = SourceDomainStats {
644            feature_means,
645            feature_stds,
646            class_priors: HashMap::new(), // Not used for regression
647            target_mean,
648            target_std,
649        };
650
651        let source_stats = self
652            .source_statistics
653            .clone()
654            .unwrap_or_else(|| target_stats.clone());
655
656        // Compute adaptation weights based on strategy
657        let adaptation_weights = match &self.strategy {
658            TransferStrategy::SourcePrior { source_weight } => {
659                Array1::from_elem(x.ncols(), *source_weight)
660            }
661            TransferStrategy::FeatureBased { adaptation_rate } => {
662                // Feature importance based on variance ratio
663                let mut weights = Array1::zeros(x.ncols());
664                for i in 0..x.ncols() {
665                    let source_var = source_stats.feature_stds[i].powi(2);
666                    let target_var = target_stats.feature_stds[i].powi(2);
667                    let ratio = (target_var / (source_var + 1e-10)).min(1.0);
668                    weights[i] = adaptation_rate * ratio + (1.0 - adaptation_rate);
669                }
670                weights
671            }
672            TransferStrategy::InstanceBased {
673                similarity_threshold: _,
674            } => Array1::from_elem(x.ncols(), 0.5),
675            TransferStrategy::ModelBased {
676                confidence_threshold: _,
677            } => Array1::from_elem(x.ncols(), 0.7),
678            TransferStrategy::EnsembleTransfer { domain_weights } => {
679                if domain_weights.is_empty() {
680                    Array1::from_elem(x.ncols(), 1.0)
681                } else {
682                    Array1::from_elem(x.ncols(), domain_weights[0])
683                }
684            }
685        };
686
687        Ok(FittedTransferBaseline {
688            strategy: self.strategy.clone(),
689            source_stats,
690            target_stats,
691            adaptation_weights,
692            random_state: self.random_state,
693        })
694    }
695}
696
697impl Predict<Array2<f64>, Array1<f64>> for FittedTransferBaseline {
698    fn predict(&self, x: &Array2<f64>) -> Result<Array1<f64>, SklearsError> {
699        let mut predictions = Vec::with_capacity(x.nrows());
700
701        for sample in x.rows() {
702            // Simple transfer prediction combining source and target knowledge
703            let mut prediction = 0.0;
704
705            match &self.strategy {
706                TransferStrategy::SourcePrior { source_weight } => {
707                    prediction = source_weight * self.source_stats.target_mean
708                        + (1.0 - source_weight) * self.target_stats.target_mean;
709                }
710                TransferStrategy::FeatureBased { adaptation_rate: _ } => {
711                    // Weighted combination based on feature adaptation
712                    let source_contrib = self.source_stats.target_mean;
713                    let target_contrib = self.target_stats.target_mean;
714                    let avg_weight = self.adaptation_weights.mean().unwrap();
715                    prediction = avg_weight * target_contrib + (1.0 - avg_weight) * source_contrib;
716                }
717                TransferStrategy::InstanceBased {
718                    similarity_threshold: _,
719                } => {
720                    // Instance-based prediction using both domains
721                    prediction =
722                        0.5 * self.source_stats.target_mean + 0.5 * self.target_stats.target_mean;
723                }
724                TransferStrategy::ModelBased {
725                    confidence_threshold: _,
726                } => {
727                    // Model-based prediction with confidence weighting
728                    prediction =
729                        0.7 * self.target_stats.target_mean + 0.3 * self.source_stats.target_mean;
730                }
731                TransferStrategy::EnsembleTransfer { domain_weights } => {
732                    if domain_weights.is_empty() {
733                        prediction = self.target_stats.target_mean;
734                    } else {
735                        prediction = domain_weights[0] * self.source_stats.target_mean
736                            + (1.0 - domain_weights[0]) * self.target_stats.target_mean;
737                    }
738                }
739            }
740
741            predictions.push(prediction);
742        }
743
744        Ok(Array1::from_vec(predictions))
745    }
746}
747
748impl DomainAdaptationBaseline {
749    /// Create a new domain adaptation baseline
750    pub fn new(strategy: DomainAdaptationStrategy) -> Self {
751        Self {
752            strategy,
753            source_domain_data: None,
754            random_state: None,
755        }
756    }
757
758    /// Set source domain data for adaptation
759    pub fn with_source_domain_data(mut self, source_x: Array2<f64>, source_y: Array1<f64>) -> Self {
760        self.source_domain_data = Some((source_x, source_y));
761        self
762    }
763
764    /// Set the random state for reproducible results
765    pub fn with_random_state(mut self, seed: u64) -> Self {
766        self.random_state = Some(seed);
767        self
768    }
769}
770
771impl Fit<Array2<f64>, Array1<f64>, FittedDomainAdaptationBaseline> for DomainAdaptationBaseline {
772    type Fitted = FittedDomainAdaptationBaseline;
773
774    fn fit(
775        self,
776        x: &Array2<f64>,
777        y: &Array1<f64>,
778    ) -> Result<FittedDomainAdaptationBaseline, SklearsError> {
779        if x.nrows() != y.len() {
780            return Err(SklearsError::ShapeMismatch {
781                expected: format!("{} samples", x.nrows()),
782                actual: format!("{} labels", y.len()),
783            });
784        }
785
786        // Compute target domain statistics
787        let target_feature_means = x.mean_axis(Axis(0)).unwrap();
788        let target_feature_stds = x.std_axis(Axis(0), 0.0);
789        let target_mean = y.mean().unwrap();
790        let target_std = y.std(0.0);
791
792        let target_stats = SourceDomainStats {
793            feature_means: target_feature_means,
794            feature_stds: target_feature_stds,
795            class_priors: HashMap::new(),
796            target_mean,
797            target_std,
798        };
799
800        // Compute source domain statistics if available
801        let source_stats = if let Some((ref source_x, ref source_y)) = self.source_domain_data {
802            let source_feature_means = source_x.mean_axis(Axis(0)).unwrap();
803            let source_feature_stds = source_x.std_axis(Axis(0), 0.0);
804            let source_mean = source_y.mean().unwrap();
805            let source_std = source_y.std(0.0);
806
807            SourceDomainStats {
808                feature_means: source_feature_means,
809                feature_stds: source_feature_stds,
810                class_priors: HashMap::new(),
811                target_mean: source_mean,
812                target_std: source_std,
813            }
814        } else {
815            target_stats.clone()
816        };
817
818        // Compute adaptation matrix and instance weights
819        let adaptation_matrix = self.compute_adaptation_matrix(&source_stats, &target_stats);
820        let instance_weights = self.compute_instance_weights(x, &source_stats, &target_stats);
821
822        Ok(FittedDomainAdaptationBaseline {
823            strategy: self.strategy,
824            source_stats,
825            target_stats,
826            adaptation_matrix,
827            instance_weights,
828            random_state: self.random_state,
829        })
830    }
831}
832
833impl DomainAdaptationBaseline {
834    fn compute_adaptation_matrix(
835        &self,
836        source_stats: &SourceDomainStats,
837        target_stats: &SourceDomainStats,
838    ) -> Array2<f64> {
839        let n_features = source_stats.feature_means.len();
840        let mut adaptation_matrix = Array2::eye(n_features);
841
842        match &self.strategy {
843            DomainAdaptationStrategy::FeatureAlignment => {
844                // Compute feature alignment transformation
845                for i in 0..n_features {
846                    let source_std = source_stats.feature_stds[i].max(1e-8);
847                    let target_std = target_stats.feature_stds[i].max(1e-8);
848                    adaptation_matrix[[i, i]] = target_std / source_std;
849                }
850            }
851            DomainAdaptationStrategy::InstanceReweighting {
852                adaptation_strength,
853            } => {
854                // Simple adaptation strength scaling
855                adaptation_matrix *= *adaptation_strength;
856            }
857            DomainAdaptationStrategy::GradientReversal { lambda } => {
858                // Gradient reversal approximation
859                for i in 0..n_features {
860                    adaptation_matrix[[i, i]] = 1.0 - *lambda;
861                }
862            }
863            DomainAdaptationStrategy::SubspaceAlignment { subspace_dim } => {
864                // Subspace alignment (simplified)
865                let dim = (*subspace_dim).min(n_features);
866                for i in 0..dim {
867                    for j in 0..dim {
868                        if i != j {
869                            adaptation_matrix[[i, j]] = 0.1; // Small cross-correlation
870                        }
871                    }
872                }
873            }
874            DomainAdaptationStrategy::MMDMinimization { bandwidth } => {
875                // MMD minimization approximation
876                for i in 0..n_features {
877                    let distance =
878                        (source_stats.feature_means[i] - target_stats.feature_means[i]).abs();
879                    let weight = (-distance / bandwidth).exp();
880                    adaptation_matrix[[i, i]] = weight;
881                }
882            }
883        }
884
885        adaptation_matrix
886    }
887
888    fn compute_instance_weights(
889        &self,
890        x: &Array2<f64>,
891        source_stats: &SourceDomainStats,
892        target_stats: &SourceDomainStats,
893    ) -> Array1<f64> {
894        let mut weights = Array1::ones(x.nrows());
895
896        match &self.strategy {
897            DomainAdaptationStrategy::InstanceReweighting {
898                adaptation_strength,
899            } => {
900                // Compute instance weights based on distance to domain centers
901                for (i, sample) in x.rows().into_iter().enumerate() {
902                    let source_distance: f64 = sample
903                        .iter()
904                        .zip(source_stats.feature_means.iter())
905                        .map(|(x_val, mean_val)| (x_val - mean_val).powi(2))
906                        .sum::<f64>()
907                        .sqrt();
908
909                    let target_distance: f64 = sample
910                        .iter()
911                        .zip(target_stats.feature_means.iter())
912                        .map(|(x_val, mean_val)| (x_val - mean_val).powi(2))
913                        .sum::<f64>()
914                        .sqrt();
915
916                    let weight = if source_distance > 0.0 {
917                        adaptation_strength * target_distance / (source_distance + target_distance)
918                    } else {
919                        *adaptation_strength
920                    };
921
922                    weights[i] = weight.max(0.1).min(10.0); // Clamp weights
923                }
924            }
925            _ => {
926                // Default uniform weights
927            }
928        }
929
930        weights
931    }
932}
933
934impl Predict<Array2<f64>, Array1<f64>> for FittedDomainAdaptationBaseline {
935    fn predict(&self, x: &Array2<f64>) -> Result<Array1<f64>, SklearsError> {
936        let mut predictions = Vec::with_capacity(x.nrows());
937
938        for (i, sample) in x.rows().into_iter().enumerate() {
939            // Apply domain adaptation transformation
940            let adapted_sample = if i < self.instance_weights.len() {
941                let weight = self.instance_weights[i];
942                sample.mapv(|val| val * weight)
943            } else {
944                sample.to_owned()
945            };
946
947            // Simple prediction using adapted features
948            let prediction = match &self.strategy {
949                DomainAdaptationStrategy::FeatureAlignment => {
950                    // Use target domain statistics for prediction
951                    self.target_stats.target_mean
952                }
953                DomainAdaptationStrategy::InstanceReweighting { .. } => {
954                    // Weighted prediction based on domain similarity
955                    let source_contrib = self.source_stats.target_mean;
956                    let target_contrib = self.target_stats.target_mean;
957                    let weight = if i < self.instance_weights.len() {
958                        self.instance_weights[i]
959                    } else {
960                        0.5
961                    };
962                    weight * target_contrib + (1.0 - weight) * source_contrib
963                }
964                DomainAdaptationStrategy::GradientReversal { lambda } => {
965                    // Gradient reversal prediction
966                    let adversarial_weight = 1.0 - lambda;
967                    adversarial_weight * self.target_stats.target_mean
968                        + lambda * self.source_stats.target_mean
969                }
970                DomainAdaptationStrategy::SubspaceAlignment { .. } => {
971                    // Subspace-aligned prediction
972                    (self.source_stats.target_mean + self.target_stats.target_mean) / 2.0
973                }
974                DomainAdaptationStrategy::MMDMinimization { .. } => {
975                    // MMD-based prediction
976                    let feature_weighted_prediction = adapted_sample.mean().unwrap_or(0.0);
977                    (feature_weighted_prediction + self.target_stats.target_mean) / 2.0
978                }
979            };
980
981            predictions.push(prediction);
982        }
983
984        Ok(Array1::from_vec(predictions))
985    }
986}
987
988/// Continual learning baseline estimator
989#[derive(Debug, Clone)]
990pub struct ContinualLearningBaseline {
991    strategy: ContinualStrategy,
992    memory_buffer: Vec<(Array1<f64>, f64)>,
993    task_statistics: Vec<SourceDomainStats>,
994    random_state: Option<u64>,
995}
996
997/// Fitted continual learning baseline
998#[derive(Debug, Clone)]
999pub struct FittedContinualLearningBaseline {
1000    strategy: ContinualStrategy,
1001    memory_buffer: Vec<(Array1<f64>, f64)>,
1002    task_statistics: Vec<SourceDomainStats>,
1003    consolidation_weights: Array1<f64>,
1004    current_task_id: usize,
1005    random_state: Option<u64>,
1006}
1007
1008impl ContinualLearningBaseline {
1009    /// Create a new continual learning baseline
1010    pub fn new(strategy: ContinualStrategy) -> Self {
1011        Self {
1012            strategy,
1013            memory_buffer: Vec::new(),
1014            task_statistics: Vec::new(),
1015            random_state: None,
1016        }
1017    }
1018
1019    /// Set the random state for reproducible results
1020    pub fn with_random_state(mut self, seed: u64) -> Self {
1021        self.random_state = Some(seed);
1022        self
1023    }
1024
1025    /// Add task memory to the baseline
1026    pub fn with_task_memory(mut self, memory: Vec<(Array1<f64>, f64)>) -> Self {
1027        self.memory_buffer = memory;
1028        self
1029    }
1030}
1031
1032impl Fit<Array2<f64>, Array1<f64>, FittedContinualLearningBaseline> for ContinualLearningBaseline {
1033    type Fitted = FittedContinualLearningBaseline;
1034
1035    fn fit(
1036        self,
1037        x: &Array2<f64>,
1038        y: &Array1<f64>,
1039    ) -> Result<FittedContinualLearningBaseline, SklearsError> {
1040        if x.nrows() != y.len() {
1041            return Err(SklearsError::ShapeMismatch {
1042                expected: format!("{} samples", x.nrows()),
1043                actual: format!("{} labels", y.len()),
1044            });
1045        }
1046
1047        // Compute current task statistics
1048        let feature_means = x.mean_axis(Axis(0)).unwrap();
1049        let feature_stds = x.std_axis(Axis(0), 0.0);
1050        let target_mean = y.mean().unwrap();
1051        let target_std = y.std(0.0);
1052
1053        let current_task_stats = SourceDomainStats {
1054            feature_means,
1055            feature_stds,
1056            class_priors: HashMap::new(),
1057            target_mean,
1058            target_std,
1059        };
1060
1061        let mut task_statistics = self.task_statistics.clone();
1062        task_statistics.push(current_task_stats.clone());
1063
1064        // Update memory buffer based on strategy
1065        let mut memory_buffer = self.memory_buffer.clone();
1066        match &self.strategy {
1067            ContinualStrategy::ElasticWeightConsolidation { .. } => {
1068                // Store important samples for EWC
1069                for i in 0..x.nrows().min(100) {
1070                    // Limit memory size
1071                    memory_buffer.push((x.row(i).to_owned(), y[i]));
1072                }
1073            }
1074            ContinualStrategy::Rehearsal { memory_size } => {
1075                // Rehearsal-based memory management
1076                for i in 0..x.nrows() {
1077                    memory_buffer.push((x.row(i).to_owned(), y[i]));
1078                    if memory_buffer.len() > *memory_size {
1079                        memory_buffer.remove(0); // Remove oldest sample
1080                    }
1081                }
1082            }
1083            ContinualStrategy::Progressive { .. } => {
1084                // Progressive networks: keep all task data
1085                for i in 0..x.nrows() {
1086                    memory_buffer.push((x.row(i).to_owned(), y[i]));
1087                }
1088            }
1089            ContinualStrategy::LearningWithoutForgetting { .. } => {
1090                // LwF: store representative samples
1091                let samples_to_store = x.nrows().min(50);
1092                for i in 0..samples_to_store {
1093                    memory_buffer.push((x.row(i).to_owned(), y[i]));
1094                }
1095            }
1096            ContinualStrategy::AGEM { .. } => {
1097                // A-GEM: gradient episodic memory
1098                let memory_samples = x.nrows().min(20);
1099                for i in 0..memory_samples {
1100                    memory_buffer.push((x.row(i).to_owned(), y[i]));
1101                }
1102            }
1103        }
1104
1105        // Compute consolidation weights
1106        let consolidation_weights =
1107            self.compute_consolidation_weights(&task_statistics, &current_task_stats);
1108
1109        let current_task_id = task_statistics.len() - 1;
1110
1111        Ok(FittedContinualLearningBaseline {
1112            strategy: self.strategy,
1113            memory_buffer,
1114            task_statistics,
1115            consolidation_weights,
1116            current_task_id,
1117            random_state: self.random_state,
1118        })
1119    }
1120}
1121
1122impl ContinualLearningBaseline {
1123    fn compute_consolidation_weights(
1124        &self,
1125        task_stats: &[SourceDomainStats],
1126        current_stats: &SourceDomainStats,
1127    ) -> Array1<f64> {
1128        let n_features = current_stats.feature_means.len();
1129        let mut weights = Array1::ones(n_features);
1130
1131        match &self.strategy {
1132            ContinualStrategy::ElasticWeightConsolidation { importance_weight } => {
1133                // Compute importance weights based on feature variance across tasks
1134                for i in 0..n_features {
1135                    let feature_variance: f64 = task_stats
1136                        .iter()
1137                        .map(|stats| stats.feature_means[i])
1138                        .map(|mean| (mean - current_stats.feature_means[i]).powi(2))
1139                        .sum::<f64>()
1140                        / task_stats.len().max(1) as f64;
1141
1142                    weights[i] = importance_weight * feature_variance.sqrt();
1143                }
1144            }
1145            ContinualStrategy::Rehearsal { memory_size: _ } => {
1146                // Uniform importance for rehearsal
1147                weights.fill(1.0);
1148            }
1149            ContinualStrategy::Progressive { column_capacity } => {
1150                // Progressive importance based on column capacity
1151                let capacity_weight = 1.0 / (*column_capacity as f64);
1152                weights.fill(capacity_weight);
1153            }
1154            ContinualStrategy::LearningWithoutForgetting {
1155                distillation_weight,
1156            } => {
1157                // Distillation-based importance
1158                weights.fill(*distillation_weight);
1159            }
1160            ContinualStrategy::AGEM { memory_strength } => {
1161                // A-GEM memory strength
1162                weights.fill(*memory_strength);
1163            }
1164        }
1165
1166        weights
1167    }
1168}
1169
1170impl Predict<Array2<f64>, Array1<f64>> for FittedContinualLearningBaseline {
1171    fn predict(&self, x: &Array2<f64>) -> Result<Array1<f64>, SklearsError> {
1172        let mut predictions = Vec::with_capacity(x.nrows());
1173
1174        for sample in x.rows() {
1175            let prediction = match &self.strategy {
1176                ContinualStrategy::ElasticWeightConsolidation { .. } => {
1177                    // EWC: weighted prediction considering task importance
1178                    let mut weighted_sum = 0.0;
1179                    let mut total_weight = 0.0;
1180
1181                    for (task_id, task_stats) in self.task_statistics.iter().enumerate() {
1182                        let task_weight = if task_id < self.consolidation_weights.len() {
1183                            self.consolidation_weights[task_id]
1184                        } else {
1185                            1.0
1186                        };
1187
1188                        weighted_sum += task_weight * task_stats.target_mean;
1189                        total_weight += task_weight;
1190                    }
1191
1192                    if total_weight > 0.0 {
1193                        weighted_sum / total_weight
1194                    } else {
1195                        self.task_statistics
1196                            .last()
1197                            .map_or(0.0, |stats| stats.target_mean)
1198                    }
1199                }
1200                ContinualStrategy::Rehearsal { .. } => {
1201                    // Rehearsal: use memory buffer for prediction
1202                    if !self.memory_buffer.is_empty() {
1203                        let memory_prediction: f64 =
1204                            self.memory_buffer.iter().map(|(_, y)| *y).sum::<f64>()
1205                                / self.memory_buffer.len() as f64;
1206
1207                        let current_prediction = self
1208                            .task_statistics
1209                            .last()
1210                            .map_or(0.0, |stats| stats.target_mean);
1211
1212                        (memory_prediction + current_prediction) / 2.0
1213                    } else {
1214                        self.task_statistics
1215                            .last()
1216                            .map_or(0.0, |stats| stats.target_mean)
1217                    }
1218                }
1219                ContinualStrategy::Progressive { .. } => {
1220                    // Progressive: combine all task predictions
1221                    let task_predictions: f64 = self
1222                        .task_statistics
1223                        .iter()
1224                        .map(|stats| stats.target_mean)
1225                        .sum();
1226
1227                    if !self.task_statistics.is_empty() {
1228                        task_predictions / self.task_statistics.len() as f64
1229                    } else {
1230                        0.0
1231                    }
1232                }
1233                ContinualStrategy::LearningWithoutForgetting {
1234                    distillation_weight,
1235                } => {
1236                    // LwF: distillation-weighted prediction
1237                    let current_prediction = self
1238                        .task_statistics
1239                        .last()
1240                        .map_or(0.0, |stats| stats.target_mean);
1241
1242                    if self.task_statistics.len() > 1 {
1243                        let previous_predictions: f64 = self
1244                            .task_statistics
1245                            .iter()
1246                            .take(self.task_statistics.len() - 1)
1247                            .map(|stats| stats.target_mean)
1248                            .sum::<f64>()
1249                            / (self.task_statistics.len() - 1) as f64;
1250
1251                        distillation_weight * previous_predictions
1252                            + (1.0 - distillation_weight) * current_prediction
1253                    } else {
1254                        current_prediction
1255                    }
1256                }
1257                ContinualStrategy::AGEM { .. } => {
1258                    // A-GEM: episodic memory-based prediction
1259                    if !self.memory_buffer.is_empty() {
1260                        // Find nearest memory samples
1261                        let mut min_distance = f64::INFINITY;
1262                        let mut nearest_value = 0.0;
1263
1264                        for (memory_sample, memory_value) in &self.memory_buffer {
1265                            let distance: f64 = sample
1266                                .iter()
1267                                .zip(memory_sample.iter())
1268                                .map(|(a, b)| (a - b).powi(2))
1269                                .sum::<f64>()
1270                                .sqrt();
1271
1272                            if distance < min_distance {
1273                                min_distance = distance;
1274                                nearest_value = *memory_value;
1275                            }
1276                        }
1277
1278                        nearest_value
1279                    } else {
1280                        self.task_statistics
1281                            .last()
1282                            .map_or(0.0, |stats| stats.target_mean)
1283                    }
1284                }
1285            };
1286
1287            predictions.push(prediction);
1288        }
1289
1290        Ok(Array1::from_vec(predictions))
1291    }
1292}
1293
1294#[allow(non_snake_case)]
1295#[cfg(test)]
1296mod tests {
1297    use super::*;
1298    use scirs2_core::ndarray::array;
1299
1300    #[test]
1301    fn test_few_shot_classifier() {
1302        let x = Array2::from_shape_vec(
1303            (6, 2),
1304            vec![1.0, 1.0, 1.1, 1.1, 5.0, 5.0, 5.1, 5.1, 3.0, 3.0, 3.1, 3.1],
1305        )
1306        .unwrap();
1307        let y = array![0, 0, 1, 1, 2, 2];
1308
1309        let classifier = FewShotBaselineClassifier::new(FewShotStrategy::NearestPrototype);
1310        let fitted = classifier.fit(&x, &y).unwrap();
1311        let predictions = fitted.predict(&x).unwrap();
1312
1313        assert_eq!(predictions.len(), 6);
1314    }
1315
1316    #[test]
1317    fn test_few_shot_regressor() {
1318        let x =
1319            Array2::from_shape_vec((4, 2), vec![1.0, 1.0, 2.0, 2.0, 3.0, 3.0, 4.0, 4.0]).unwrap();
1320        let y = array![1.0, 2.0, 3.0, 4.0];
1321
1322        let regressor = FewShotBaselineRegressor::new(FewShotStrategy::KNearestNeighbors { k: 2 });
1323        let fitted = regressor.fit(&x, &y).unwrap();
1324        let predictions = fitted.predict(&x).unwrap();
1325
1326        assert_eq!(predictions.len(), 4);
1327    }
1328
1329    #[test]
1330    fn test_transfer_learning() {
1331        let x =
1332            Array2::from_shape_vec((4, 2), vec![1.0, 1.0, 2.0, 2.0, 3.0, 3.0, 4.0, 4.0]).unwrap();
1333        let y = array![1.0, 2.0, 3.0, 4.0];
1334
1335        let source_stats = SourceDomainStats {
1336            feature_means: array![2.0, 2.0],
1337            feature_stds: array![1.0, 1.0],
1338            class_priors: HashMap::new(),
1339            target_mean: 2.0,
1340            target_std: 1.0,
1341        };
1342
1343        let baseline =
1344            TransferLearningBaseline::new(TransferStrategy::SourcePrior { source_weight: 0.3 })
1345                .with_source_statistics(source_stats);
1346        let fitted = baseline.fit(&x, &y).unwrap();
1347        let predictions = fitted.predict(&x).unwrap();
1348
1349        assert_eq!(predictions.len(), 4);
1350    }
1351
1352    #[test]
1353    fn test_few_shot_strategies() {
1354        let x =
1355            Array2::from_shape_vec((4, 2), vec![1.0, 1.0, 2.0, 2.0, 5.0, 5.0, 6.0, 6.0]).unwrap();
1356        let y = array![0, 0, 1, 1];
1357
1358        let strategies = vec![
1359            FewShotStrategy::NearestPrototype,
1360            FewShotStrategy::KNearestNeighbors { k: 2 },
1361            FewShotStrategy::SupportBased,
1362            FewShotStrategy::Centroid,
1363            FewShotStrategy::Probabilistic,
1364        ];
1365
1366        for strategy in strategies {
1367            let classifier = FewShotBaselineClassifier::new(strategy).with_random_state(42);
1368            let fitted = classifier.fit(&x, &y).unwrap();
1369            let predictions = fitted.predict(&x).unwrap();
1370            assert_eq!(predictions.len(), 4);
1371        }
1372    }
1373
1374    #[test]
1375    fn test_transfer_strategies() {
1376        let x =
1377            Array2::from_shape_vec((4, 2), vec![1.0, 1.0, 2.0, 2.0, 3.0, 3.0, 4.0, 4.0]).unwrap();
1378        let y = array![1.0, 2.0, 3.0, 4.0];
1379
1380        let strategies = vec![
1381            TransferStrategy::SourcePrior { source_weight: 0.5 },
1382            TransferStrategy::FeatureBased {
1383                adaptation_rate: 0.3,
1384            },
1385            TransferStrategy::InstanceBased {
1386                similarity_threshold: 0.8,
1387            },
1388            TransferStrategy::ModelBased {
1389                confidence_threshold: 0.7,
1390            },
1391            TransferStrategy::EnsembleTransfer {
1392                domain_weights: vec![0.6],
1393            },
1394        ];
1395
1396        for strategy in strategies {
1397            let baseline = TransferLearningBaseline::new(strategy);
1398            let fitted = baseline.fit(&x, &y).unwrap();
1399            let predictions = fitted.predict(&x).unwrap();
1400            assert_eq!(predictions.len(), 4);
1401        }
1402    }
1403
1404    #[test]
1405    fn test_domain_adaptation() {
1406        let x =
1407            Array2::from_shape_vec((4, 2), vec![1.0, 1.0, 2.0, 2.0, 3.0, 3.0, 4.0, 4.0]).unwrap();
1408        let y = array![1.0, 2.0, 3.0, 4.0];
1409
1410        let source_x = Array2::from_shape_vec((3, 2), vec![0.5, 0.5, 1.5, 1.5, 2.5, 2.5]).unwrap();
1411        let source_y = array![0.5, 1.5, 2.5];
1412
1413        let baseline = DomainAdaptationBaseline::new(DomainAdaptationStrategy::FeatureAlignment)
1414            .with_source_domain_data(source_x, source_y);
1415        let fitted = baseline.fit(&x, &y).unwrap();
1416        let predictions = fitted.predict(&x).unwrap();
1417
1418        assert_eq!(predictions.len(), 4);
1419        assert!(fitted.adaptation_matrix.shape() == &[2, 2]);
1420    }
1421
1422    #[test]
1423    fn test_domain_adaptation_strategies() {
1424        let x =
1425            Array2::from_shape_vec((4, 2), vec![1.0, 1.0, 2.0, 2.0, 3.0, 3.0, 4.0, 4.0]).unwrap();
1426        let y = array![1.0, 2.0, 3.0, 4.0];
1427
1428        let strategies = vec![
1429            DomainAdaptationStrategy::FeatureAlignment,
1430            DomainAdaptationStrategy::InstanceReweighting {
1431                adaptation_strength: 0.8,
1432            },
1433            DomainAdaptationStrategy::GradientReversal { lambda: 0.1 },
1434            DomainAdaptationStrategy::SubspaceAlignment { subspace_dim: 2 },
1435            DomainAdaptationStrategy::MMDMinimization { bandwidth: 1.0 },
1436        ];
1437
1438        for strategy in strategies {
1439            let baseline = DomainAdaptationBaseline::new(strategy);
1440            let fitted = baseline.fit(&x, &y).unwrap();
1441            let predictions = fitted.predict(&x).unwrap();
1442            assert_eq!(predictions.len(), 4);
1443        }
1444    }
1445
1446    #[test]
1447    fn test_continual_learning() {
1448        let x =
1449            Array2::from_shape_vec((4, 2), vec![1.0, 1.0, 2.0, 2.0, 3.0, 3.0, 4.0, 4.0]).unwrap();
1450        let y = array![1.0, 2.0, 3.0, 4.0];
1451
1452        let baseline =
1453            ContinualLearningBaseline::new(ContinualStrategy::ElasticWeightConsolidation {
1454                importance_weight: 1.0,
1455            });
1456        let fitted = baseline.fit(&x, &y).unwrap();
1457        let predictions = fitted.predict(&x).unwrap();
1458
1459        assert_eq!(predictions.len(), 4);
1460        assert_eq!(fitted.task_statistics.len(), 1);
1461    }
1462
1463    #[test]
1464    fn test_continual_strategies() {
1465        let x =
1466            Array2::from_shape_vec((4, 2), vec![1.0, 1.0, 2.0, 2.0, 3.0, 3.0, 4.0, 4.0]).unwrap();
1467        let y = array![1.0, 2.0, 3.0, 4.0];
1468
1469        let strategies = vec![
1470            ContinualStrategy::ElasticWeightConsolidation {
1471                importance_weight: 1.0,
1472            },
1473            ContinualStrategy::Rehearsal { memory_size: 100 },
1474            ContinualStrategy::Progressive {
1475                column_capacity: 10,
1476            },
1477            ContinualStrategy::LearningWithoutForgetting {
1478                distillation_weight: 0.8,
1479            },
1480            ContinualStrategy::AGEM {
1481                memory_strength: 0.5,
1482            },
1483        ];
1484
1485        for strategy in strategies {
1486            let baseline = ContinualLearningBaseline::new(strategy).with_random_state(42);
1487            let fitted = baseline.fit(&x, &y).unwrap();
1488            let predictions = fitted.predict(&x).unwrap();
1489            assert_eq!(predictions.len(), 4);
1490        }
1491    }
1492}