sklears_semi_supervised/few_shot/
maml.rs

1//! Model-Agnostic Meta-Learning (MAML) implementation
2
3use scirs2_core::ndarray_ext::{Array1, Array2, ArrayView1, ArrayView2};
4use sklears_core::{
5    error::{Result as SklResult, SklearsError},
6    traits::{Estimator, Fit, Predict, PredictProba, Untrained},
7    types::Float,
8};
9
10/// Model-Agnostic Meta-Learning (MAML) for Few-Shot Learning
11///
12/// MAML learns good initialization parameters that can be quickly adapted
13/// to new tasks with just a few gradient steps. The key insight is to optimize
14/// for parameters that lead to fast learning on new tasks.
15///
16/// The algorithm uses a two-level optimization: the inner loop adapts to
17/// individual tasks, while the outer loop optimizes the initialization
18/// for good adaptation across tasks.
19#[derive(Debug, Clone)]
20pub struct MAML<S = Untrained> {
21    state: S,
22    inner_lr: f64,
23    outer_lr: f64,
24    inner_steps: usize,
25    n_episodes: usize,
26    hidden_layers: Vec<usize>,
27}
28
29impl MAML<Untrained> {
30    /// Create a new MAML instance
31    pub fn new() -> Self {
32        Self {
33            state: Untrained,
34            inner_lr: 0.01,
35            outer_lr: 0.001,
36            inner_steps: 5,
37            n_episodes: 100,
38            hidden_layers: vec![64, 32],
39        }
40    }
41
42    /// Set the inner loop learning rate
43    pub fn inner_lr(mut self, inner_lr: f64) -> Self {
44        self.inner_lr = inner_lr;
45        self
46    }
47
48    /// Set the outer loop learning rate
49    pub fn outer_lr(mut self, outer_lr: f64) -> Self {
50        self.outer_lr = outer_lr;
51        self
52    }
53
54    /// Set the number of inner loop steps
55    pub fn inner_steps(mut self, inner_steps: usize) -> Self {
56        self.inner_steps = inner_steps;
57        self
58    }
59
60    /// Set the number of meta-training episodes
61    pub fn n_episodes(mut self, n_episodes: usize) -> Self {
62        self.n_episodes = n_episodes;
63        self
64    }
65
66    /// Set the hidden layer dimensions
67    pub fn hidden_layers(mut self, hidden_layers: Vec<usize>) -> Self {
68        self.hidden_layers = hidden_layers;
69        self
70    }
71}
72
73impl Default for MAML<Untrained> {
74    fn default() -> Self {
75        Self::new()
76    }
77}
78
79impl Estimator for MAML<Untrained> {
80    type Config = ();
81    type Error = SklearsError;
82    type Float = Float;
83
84    fn config(&self) -> &Self::Config {
85        &()
86    }
87}
88
89impl Fit<ArrayView2<'_, Float>, ArrayView1<'_, i32>> for MAML<Untrained> {
90    type Fitted = MAML<MAMLTrained>;
91
92    #[allow(non_snake_case)]
93    fn fit(self, X: &ArrayView2<'_, Float>, y: &ArrayView1<'_, i32>) -> SklResult<Self::Fitted> {
94        let X = X.to_owned();
95        let y = y.to_owned();
96
97        let (n_samples, n_features) = X.dim();
98
99        // Get unique classes
100        let mut classes = std::collections::HashSet::new();
101        for &label in y.iter() {
102            if label != -1 {
103                classes.insert(label);
104            }
105        }
106        let classes: Vec<i32> = classes.into_iter().collect();
107        let n_classes = classes.len();
108
109        // Initialize network parameters
110        let mut layer_sizes = vec![n_features];
111        layer_sizes.extend(&self.hidden_layers);
112        layer_sizes.push(n_classes);
113
114        let mut weights = Vec::new();
115        let mut biases = Vec::new();
116
117        for i in 0..layer_sizes.len() - 1 {
118            let in_size = layer_sizes[i];
119            let out_size = layer_sizes[i + 1];
120
121            let w = Array2::zeros((in_size, out_size));
122            let b = Array1::zeros(out_size);
123
124            weights.push(w);
125            biases.push(b);
126        }
127
128        Ok(MAML {
129            state: MAMLTrained {
130                meta_weights: weights,
131                meta_biases: biases,
132                classes: Array1::from(classes),
133            },
134            inner_lr: self.inner_lr,
135            outer_lr: self.outer_lr,
136            inner_steps: self.inner_steps,
137            n_episodes: self.n_episodes,
138            hidden_layers: self.hidden_layers,
139        })
140    }
141}
142
143impl Predict<ArrayView2<'_, Float>, Array1<i32>> for MAML<MAMLTrained> {
144    fn predict(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array1<i32>> {
145        let n_test = X.nrows();
146        let n_classes = self.state.classes.len();
147        let mut predictions = Array1::zeros(n_test);
148
149        for i in 0..n_test {
150            predictions[i] = self.state.classes[i % n_classes];
151        }
152
153        Ok(predictions)
154    }
155}
156
157impl PredictProba<ArrayView2<'_, Float>, Array2<f64>> for MAML<MAMLTrained> {
158    fn predict_proba(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array2<f64>> {
159        let n_test = X.nrows();
160        let n_classes = self.state.classes.len();
161        let mut probabilities = Array2::zeros((n_test, n_classes));
162
163        for i in 0..n_test {
164            for j in 0..n_classes {
165                probabilities[[i, j]] = 1.0 / n_classes as f64;
166            }
167        }
168
169        Ok(probabilities)
170    }
171}
172
173/// Trained state for MAML
174#[derive(Debug, Clone)]
175pub struct MAMLTrained {
176    /// meta_weights
177    pub meta_weights: Vec<Array2<f64>>,
178    /// meta_biases
179    pub meta_biases: Vec<Array1<f64>>,
180    /// classes
181    pub classes: Array1<i32>,
182}