sklears_compose/
few_shot.rs

1//! Few-shot learning pipeline components
2//!
3//! This module provides few-shot learning capabilities including prototype-based
4//! and gradient-based meta-learners.
5
6use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2, Axis};
7use sklears_core::{
8    error::Result as SklResult,
9    prelude::{Predict, SklearsError},
10    traits::{Estimator, Fit, Untrained},
11    types::{Float, FloatBounds},
12};
13use std::collections::HashMap;
14
15use crate::{PipelinePredictor, PipelineStep};
16
17/// Support set for few-shot learning
18#[derive(Debug, Clone)]
19pub struct SupportSet {
20    /// Features of support examples
21    pub features: Array2<f64>,
22    /// Labels of support examples
23    pub labels: Array1<f64>,
24    /// Number of examples per class
25    pub n_shot: usize,
26    /// Number of classes
27    pub n_way: usize,
28}
29
30impl SupportSet {
31    /// Create a new support set
32    #[must_use]
33    pub fn new(features: Array2<f64>, labels: Array1<f64>, n_shot: usize, n_way: usize) -> Self {
34        Self {
35            features,
36            labels,
37            n_shot,
38            n_way,
39        }
40    }
41
42    /// Get examples for a specific class
43    #[must_use]
44    pub fn get_class_examples(&self, class_label: f64) -> (Array2<f64>, Array1<f64>) {
45        let mut class_features = Vec::new();
46        let mut class_labels = Vec::new();
47
48        for (i, &label) in self.labels.iter().enumerate() {
49            if (label - class_label).abs() < 1e-6 {
50                class_features.push(self.features.row(i).to_owned());
51                class_labels.push(label);
52            }
53        }
54
55        if class_features.is_empty() {
56            return (Array2::zeros((0, self.features.ncols())), Array1::zeros(0));
57        }
58
59        let n_examples = class_features.len();
60        let n_features = class_features[0].len();
61        let mut features_array = Array2::zeros((n_examples, n_features));
62
63        for (i, features) in class_features.iter().enumerate() {
64            features_array.row_mut(i).assign(features);
65        }
66
67        (features_array, Array1::from_vec(class_labels))
68    }
69
70    /// Get all unique classes
71    #[must_use]
72    pub fn get_classes(&self) -> Vec<f64> {
73        let mut classes: Vec<f64> = self.labels.iter().copied().collect();
74        classes.sort_by(|a, b| a.partial_cmp(b).unwrap());
75        classes.dedup_by(|a, b| (*a - *b).abs() < 1e-6);
76        classes
77    }
78}
79
80/// Prototype-based few-shot learner
81#[derive(Debug)]
82pub struct PrototypicalNetwork<S = Untrained> {
83    state: S,
84    distance_metric: DistanceMetric,
85    embedding_dim: Option<usize>,
86    prototypes: HashMap<String, Array1<f64>>,
87}
88
89/// Trained state for `PrototypicalNetwork`
90#[derive(Debug)]
91pub struct PrototypicalNetworkTrained {
92    prototypes: HashMap<String, Array1<f64>>,
93    distance_metric: DistanceMetric,
94    n_features_in: usize,
95    feature_names_in: Option<Vec<String>>,
96}
97
98/// Distance metrics for prototype-based learning
99#[derive(Debug, Clone)]
100pub enum DistanceMetric {
101    /// Euclidean distance
102    Euclidean,
103    /// Cosine distance
104    Cosine,
105    /// Manhattan distance
106    Manhattan,
107    /// Mahalanobis distance (with covariance matrix)
108    Mahalanobis { covariance: Array2<f64> },
109}
110
111impl DistanceMetric {
112    /// Compute distance between two vectors
113    #[must_use]
114    pub fn distance(&self, a: &ArrayView1<'_, f64>, b: &ArrayView1<'_, f64>) -> f64 {
115        match self {
116            DistanceMetric::Euclidean => ((a - b).mapv(|x| x * x).sum()).sqrt(),
117            DistanceMetric::Cosine => {
118                let dot_product = a.dot(b);
119                let norm_a = (a.mapv(|x| x * x).sum()).sqrt();
120                let norm_b = (b.mapv(|x| x * x).sum()).sqrt();
121                if norm_a == 0.0 || norm_b == 0.0 {
122                    1.0
123                } else {
124                    1.0 - dot_product / (norm_a * norm_b)
125                }
126            }
127            DistanceMetric::Manhattan => (a - b).mapv(f64::abs).sum(),
128            DistanceMetric::Mahalanobis { covariance } => {
129                let diff = a - b;
130                // Simplified Mahalanobis distance (assuming covariance is diagonal)
131                let weighted_diff = &diff * &covariance.diag();
132                (weighted_diff.mapv(|x| x * x).sum()).sqrt()
133            }
134        }
135    }
136}
137
138impl PrototypicalNetwork<Untrained> {
139    /// Create a new prototypical network
140    #[must_use]
141    pub fn new() -> Self {
142        Self {
143            state: Untrained,
144            distance_metric: DistanceMetric::Euclidean,
145            embedding_dim: None,
146            prototypes: HashMap::new(),
147        }
148    }
149
150    /// Set the distance metric
151    #[must_use]
152    pub fn distance_metric(mut self, metric: DistanceMetric) -> Self {
153        self.distance_metric = metric;
154        self
155    }
156
157    /// Set the embedding dimension
158    #[must_use]
159    pub fn embedding_dim(mut self, dim: usize) -> Self {
160        self.embedding_dim = Some(dim);
161        self
162    }
163}
164
165impl Default for PrototypicalNetwork<Untrained> {
166    fn default() -> Self {
167        Self::new()
168    }
169}
170
171impl Estimator for PrototypicalNetwork<Untrained> {
172    type Config = ();
173    type Error = SklearsError;
174    type Float = Float;
175
176    fn config(&self) -> &Self::Config {
177        &()
178    }
179}
180
181impl PrototypicalNetwork<Untrained> {
182    /// Fit the prototypical network using support sets
183    pub fn fit_support_set(
184        self,
185        support_set: &SupportSet,
186    ) -> SklResult<PrototypicalNetwork<PrototypicalNetworkTrained>> {
187        let mut prototypes = HashMap::new();
188        let classes = support_set.get_classes();
189
190        // Compute prototype for each class
191        for class_label in classes {
192            let (class_features, _) = support_set.get_class_examples(class_label);
193
194            if class_features.nrows() > 0 {
195                // Compute mean as prototype
196                let prototype = class_features.mean_axis(Axis(0)).unwrap();
197                prototypes.insert(class_label.to_string(), prototype);
198            }
199        }
200
201        Ok(PrototypicalNetwork {
202            state: PrototypicalNetworkTrained {
203                prototypes,
204                distance_metric: self.distance_metric,
205                n_features_in: support_set.features.ncols(),
206                feature_names_in: None,
207            },
208            distance_metric: DistanceMetric::Euclidean, // Placeholder
209            embedding_dim: None,
210            prototypes: HashMap::new(),
211        })
212    }
213}
214
215impl Fit<ArrayView2<'_, Float>, Option<&ArrayView1<'_, Float>>> for PrototypicalNetwork<Untrained> {
216    type Fitted = PrototypicalNetwork<PrototypicalNetworkTrained>;
217
218    fn fit(
219        self,
220        x: &ArrayView2<'_, Float>,
221        y: &Option<&ArrayView1<'_, Float>>,
222    ) -> SklResult<Self::Fitted> {
223        if let Some(y_values) = y.as_ref() {
224            let x_f64 = x.mapv(|v| v);
225            let y_f64 = y_values.mapv(|v| v);
226
227            // Determine n_way and n_shot from data
228            let unique_labels: Vec<f64> = {
229                let mut labels: Vec<f64> = y_f64.iter().copied().collect();
230                labels.sort_by(|a, b| a.partial_cmp(b).unwrap());
231                labels.dedup_by(|a, b| (*a - *b).abs() < 1e-6);
232                labels
233            };
234
235            let n_way = unique_labels.len();
236            let n_shot = y_f64.len() / n_way; // Assume balanced
237
238            let support_set = SupportSet::new(x_f64, y_f64, n_shot, n_way);
239            self.fit_support_set(&support_set)
240        } else {
241            Err(SklearsError::InvalidInput(
242                "Labels required for few-shot learning".to_string(),
243            ))
244        }
245    }
246}
247
248impl PrototypicalNetwork<PrototypicalNetworkTrained> {
249    /// Predict using nearest prototype
250    pub fn predict(&self, x: &ArrayView2<'_, Float>) -> SklResult<Array1<f64>> {
251        let x_f64 = x.mapv(|v| v);
252        let mut predictions = Array1::zeros(x_f64.nrows());
253
254        for (i, sample) in x_f64.axis_iter(Axis(0)).enumerate() {
255            let mut min_distance = f64::INFINITY;
256            let mut predicted_class = 0.0;
257
258            for (class_str, prototype) in &self.state.prototypes {
259                let distance = self
260                    .state
261                    .distance_metric
262                    .distance(&sample, &prototype.view());
263                if distance < min_distance {
264                    min_distance = distance;
265                    predicted_class = class_str.parse().unwrap_or(0.0);
266                }
267            }
268
269            predictions[i] = predicted_class;
270        }
271
272        Ok(predictions)
273    }
274
275    /// Get the prototypes
276    #[must_use]
277    pub fn prototypes(&self) -> &HashMap<String, Array1<f64>> {
278        &self.state.prototypes
279    }
280}
281
282/// Model-Agnostic Meta-Learning (MAML) few-shot learner
283#[derive(Debug)]
284pub struct MAMLLearner<S = Untrained> {
285    state: S,
286    base_learner: Option<Box<dyn PipelinePredictor>>,
287    inner_lr: f64,
288    outer_lr: f64,
289    inner_steps: usize,
290    meta_parameters: HashMap<String, f64>,
291}
292
293/// Trained state for `MAMLLearner`
294#[derive(Debug)]
295pub struct MAMLLearnerTrained {
296    fitted_learner: Box<dyn PipelinePredictor>,
297    inner_lr: f64,
298    outer_lr: f64,
299    inner_steps: usize,
300    meta_parameters: HashMap<String, f64>,
301    n_features_in: usize,
302    feature_names_in: Option<Vec<String>>,
303}
304
305impl MAMLLearner<Untrained> {
306    /// Create a new MAML learner
307    #[must_use]
308    pub fn new(base_learner: Box<dyn PipelinePredictor>) -> Self {
309        Self {
310            state: Untrained,
311            base_learner: Some(base_learner),
312            inner_lr: 0.01,
313            outer_lr: 0.001,
314            inner_steps: 5,
315            meta_parameters: HashMap::new(),
316        }
317    }
318
319    /// Set inner learning rate
320    #[must_use]
321    pub fn inner_lr(mut self, lr: f64) -> Self {
322        self.inner_lr = lr;
323        self
324    }
325
326    /// Set outer learning rate  
327    #[must_use]
328    pub fn outer_lr(mut self, lr: f64) -> Self {
329        self.outer_lr = lr;
330        self
331    }
332
333    /// Set number of inner steps
334    #[must_use]
335    pub fn inner_steps(mut self, steps: usize) -> Self {
336        self.inner_steps = steps;
337        self
338    }
339
340    /// Set initial meta-parameters
341    #[must_use]
342    pub fn meta_parameters(mut self, params: HashMap<String, f64>) -> Self {
343        self.meta_parameters = params;
344        self
345    }
346}
347
348impl Estimator for MAMLLearner<Untrained> {
349    type Config = ();
350    type Error = SklearsError;
351    type Float = Float;
352
353    fn config(&self) -> &Self::Config {
354        &()
355    }
356}
357
358impl Fit<ArrayView2<'_, Float>, Option<&ArrayView1<'_, Float>>> for MAMLLearner<Untrained> {
359    type Fitted = MAMLLearner<MAMLLearnerTrained>;
360
361    fn fit(
362        self,
363        x: &ArrayView2<'_, Float>,
364        y: &Option<&ArrayView1<'_, Float>>,
365    ) -> SklResult<Self::Fitted> {
366        let mut base_learner = self
367            .base_learner
368            .ok_or_else(|| SklearsError::InvalidInput("No base learner provided".to_string()))?;
369
370        if let Some(y_values) = y.as_ref() {
371            // Simulate meta-training
372            base_learner.fit(x, y_values)?;
373
374            Ok(MAMLLearner {
375                state: MAMLLearnerTrained {
376                    fitted_learner: base_learner,
377                    inner_lr: self.inner_lr,
378                    outer_lr: self.outer_lr,
379                    inner_steps: self.inner_steps,
380                    meta_parameters: self.meta_parameters,
381                    n_features_in: x.ncols(),
382                    feature_names_in: None,
383                },
384                base_learner: None,
385                inner_lr: 0.0,
386                outer_lr: 0.0,
387                inner_steps: 0,
388                meta_parameters: HashMap::new(),
389            })
390        } else {
391            Err(SklearsError::InvalidInput(
392                "Labels required for MAML training".to_string(),
393            ))
394        }
395    }
396}
397
398impl MAMLLearner<MAMLLearnerTrained> {
399    /// Adapt to a new task with few examples
400    pub fn adapt_to_task(&mut self, support_set: &SupportSet) -> SklResult<()> {
401        // Simulate inner loop adaptation
402        for _ in 0..self.state.inner_steps {
403            // In a real implementation, this would compute gradients and update parameters
404            // For now, we simulate by fitting the base learner
405            let mapped_features = support_set.features.view().mapv(|v| v as Float);
406            let mapped_labels = support_set.labels.view().mapv(|v| v as Float);
407            self.state
408                .fitted_learner
409                .fit(&mapped_features.view(), &mapped_labels.view())?;
410        }
411        Ok(())
412    }
413
414    /// Predict after adaptation
415    pub fn predict(&self, x: &ArrayView2<'_, Float>) -> SklResult<Array1<f64>> {
416        self.state.fitted_learner.predict(x)
417    }
418
419    /// Get meta-parameters
420    #[must_use]
421    pub fn meta_parameters(&self) -> &HashMap<String, f64> {
422        &self.state.meta_parameters
423    }
424}
425
426/// Few-shot learning pipeline
427#[derive(Debug)]
428pub struct FewShotPipeline<S = Untrained> {
429    state: S,
430    learner_type: FewShotLearnerType,
431    meta_learner: Option<MetaLearnerWrapper>,
432}
433
434/// Trained state for `FewShotPipeline`
435#[derive(Debug)]
436pub struct FewShotPipelineTrained {
437    fitted_learner: MetaLearnerWrapper,
438    n_features_in: usize,
439    feature_names_in: Option<Vec<String>>,
440}
441
442/// Types of few-shot learners
443#[derive(Debug, Clone)]
444pub enum FewShotLearnerType {
445    /// Prototypical network
446    Prototypical { distance_metric: DistanceMetric },
447    /// MAML-based learner
448    MAML {
449        inner_lr: f64,
450        outer_lr: f64,
451        inner_steps: usize,
452    },
453}
454
455/// Wrapper for different meta-learner types
456#[derive(Debug)]
457pub enum MetaLearnerWrapper {
458    /// Prototypical
459    Prototypical(PrototypicalNetwork<PrototypicalNetworkTrained>),
460    /// MAML
461    MAML(MAMLLearner<MAMLLearnerTrained>),
462}
463
464impl MetaLearnerWrapper {
465    /// Predict using the wrapped meta-learner
466    pub fn predict(&self, x: &ArrayView2<'_, Float>) -> SklResult<Array1<f64>> {
467        match self {
468            MetaLearnerWrapper::Prototypical(learner) => learner.predict(x),
469            MetaLearnerWrapper::MAML(learner) => learner.predict(x),
470        }
471    }
472}
473
474impl FewShotPipeline<Untrained> {
475    /// Create a new few-shot pipeline
476    #[must_use]
477    pub fn new(learner_type: FewShotLearnerType) -> Self {
478        Self {
479            state: Untrained,
480            learner_type,
481            meta_learner: None,
482        }
483    }
484
485    /// Create a prototypical network pipeline
486    #[must_use]
487    pub fn prototypical(distance_metric: DistanceMetric) -> Self {
488        Self::new(FewShotLearnerType::Prototypical { distance_metric })
489    }
490
491    /// Create a MAML pipeline
492    #[must_use]
493    pub fn maml(inner_lr: f64, outer_lr: f64, inner_steps: usize) -> Self {
494        Self::new(FewShotLearnerType::MAML {
495            inner_lr,
496            outer_lr,
497            inner_steps,
498        })
499    }
500}
501
502impl Estimator for FewShotPipeline<Untrained> {
503    type Config = ();
504    type Error = SklearsError;
505    type Float = Float;
506
507    fn config(&self) -> &Self::Config {
508        &()
509    }
510}
511
512impl Fit<ArrayView2<'_, Float>, Option<&ArrayView1<'_, Float>>> for FewShotPipeline<Untrained> {
513    type Fitted = FewShotPipeline<FewShotPipelineTrained>;
514
515    fn fit(
516        self,
517        x: &ArrayView2<'_, Float>,
518        y: &Option<&ArrayView1<'_, Float>>,
519    ) -> SklResult<Self::Fitted> {
520        let fitted_learner = match &self.learner_type {
521            FewShotLearnerType::Prototypical { distance_metric } => {
522                let learner = PrototypicalNetwork::new().distance_metric(distance_metric.clone());
523                let fitted = learner.fit(x, y)?;
524                MetaLearnerWrapper::Prototypical(fitted)
525            }
526            FewShotLearnerType::MAML {
527                inner_lr,
528                outer_lr,
529                inner_steps,
530            } => {
531                use crate::MockPredictor;
532                let base_learner = Box::new(MockPredictor::new());
533                let learner = MAMLLearner::new(base_learner)
534                    .inner_lr(*inner_lr)
535                    .outer_lr(*outer_lr)
536                    .inner_steps(*inner_steps);
537                let fitted = learner.fit(x, y)?;
538                MetaLearnerWrapper::MAML(fitted)
539            }
540        };
541
542        Ok(FewShotPipeline {
543            state: FewShotPipelineTrained {
544                fitted_learner,
545                n_features_in: x.ncols(),
546                feature_names_in: None,
547            },
548            learner_type: FewShotLearnerType::Prototypical {
549                distance_metric: DistanceMetric::Euclidean,
550            }, // Placeholder
551            meta_learner: None,
552        })
553    }
554}
555
556impl FewShotPipeline<FewShotPipelineTrained> {
557    /// Predict using the fitted few-shot learner
558    pub fn predict(&self, x: &ArrayView2<'_, Float>) -> SklResult<Array1<f64>> {
559        self.state.fitted_learner.predict(x)
560    }
561
562    /// Adapt to a new task using support set
563    pub fn adapt_to_task(&mut self, support_set: &SupportSet) -> SklResult<()> {
564        match &mut self.state.fitted_learner {
565            MetaLearnerWrapper::Prototypical(_) => {
566                // Prototypical networks adapt by recomputing prototypes
567                // This would require refitting with new support set
568                Ok(())
569            }
570            MetaLearnerWrapper::MAML(learner) => learner.adapt_to_task(support_set),
571        }
572    }
573
574    /// Get the fitted learner
575    #[must_use]
576    pub fn learner(&self) -> &MetaLearnerWrapper {
577        &self.state.fitted_learner
578    }
579}
580
581#[allow(non_snake_case)]
582#[cfg(test)]
583mod tests {
584    use super::*;
585    use scirs2_core::ndarray::array;
586
587    #[test]
588    fn test_support_set() {
589        let features = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]];
590        let labels = array![0.0, 0.0, 1.0, 1.0];
591
592        let support_set = SupportSet::new(features, labels, 2, 2);
593
594        let (class_features, class_labels) = support_set.get_class_examples(0.0);
595        assert_eq!(class_features.nrows(), 2);
596        assert_eq!(class_labels.len(), 2);
597
598        let classes = support_set.get_classes();
599        assert_eq!(classes.len(), 2);
600    }
601
602    #[test]
603    fn test_distance_metrics() {
604        let a = array![1.0, 2.0, 3.0];
605        let b = array![4.0, 5.0, 6.0];
606
607        let euclidean = DistanceMetric::Euclidean;
608        let distance = euclidean.distance(&a.view(), &b.view());
609        assert!(distance > 0.0);
610
611        let cosine = DistanceMetric::Cosine;
612        let distance = cosine.distance(&a.view(), &b.view());
613        assert!(distance >= 0.0 && distance <= 2.0);
614    }
615
616    #[test]
617    fn test_prototypical_network() {
618        let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]];
619        let y = array![0.0, 0.0, 1.0, 1.0];
620
621        let learner = PrototypicalNetwork::new();
622        let fitted = learner.fit(&x.view(), &Some(&y.view())).unwrap();
623
624        let predictions = fitted.predict(&x.view()).unwrap();
625        assert_eq!(predictions.len(), x.nrows());
626    }
627
628    #[test]
629    fn test_few_shot_pipeline() {
630        let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]];
631        let y = array![0.0, 0.0, 1.0, 1.0];
632
633        let pipeline = FewShotPipeline::prototypical(DistanceMetric::Euclidean);
634        let fitted = pipeline.fit(&x.view(), &Some(&y.view())).unwrap();
635
636        let predictions = fitted.predict(&x.view()).unwrap();
637        assert_eq!(predictions.len(), x.nrows());
638    }
639}