sklears_semi_supervised/few_shot/
maml.rs1use scirs2_core::ndarray_ext::{Array1, Array2, ArrayView1, ArrayView2};
4use sklears_core::{
5 error::{Result as SklResult, SklearsError},
6 traits::{Estimator, Fit, Predict, PredictProba, Untrained},
7 types::Float,
8};
9
10#[derive(Debug, Clone)]
20pub struct MAML<S = Untrained> {
21 state: S,
22 inner_lr: f64,
23 outer_lr: f64,
24 inner_steps: usize,
25 n_episodes: usize,
26 hidden_layers: Vec<usize>,
27}
28
29impl MAML<Untrained> {
30 pub fn new() -> Self {
32 Self {
33 state: Untrained,
34 inner_lr: 0.01,
35 outer_lr: 0.001,
36 inner_steps: 5,
37 n_episodes: 100,
38 hidden_layers: vec![64, 32],
39 }
40 }
41
42 pub fn inner_lr(mut self, inner_lr: f64) -> Self {
44 self.inner_lr = inner_lr;
45 self
46 }
47
48 pub fn outer_lr(mut self, outer_lr: f64) -> Self {
50 self.outer_lr = outer_lr;
51 self
52 }
53
54 pub fn inner_steps(mut self, inner_steps: usize) -> Self {
56 self.inner_steps = inner_steps;
57 self
58 }
59
60 pub fn n_episodes(mut self, n_episodes: usize) -> Self {
62 self.n_episodes = n_episodes;
63 self
64 }
65
66 pub fn hidden_layers(mut self, hidden_layers: Vec<usize>) -> Self {
68 self.hidden_layers = hidden_layers;
69 self
70 }
71}
72
73impl Default for MAML<Untrained> {
74 fn default() -> Self {
75 Self::new()
76 }
77}
78
79impl Estimator for MAML<Untrained> {
80 type Config = ();
81 type Error = SklearsError;
82 type Float = Float;
83
84 fn config(&self) -> &Self::Config {
85 &()
86 }
87}
88
89impl Fit<ArrayView2<'_, Float>, ArrayView1<'_, i32>> for MAML<Untrained> {
90 type Fitted = MAML<MAMLTrained>;
91
92 #[allow(non_snake_case)]
93 fn fit(self, X: &ArrayView2<'_, Float>, y: &ArrayView1<'_, i32>) -> SklResult<Self::Fitted> {
94 let X = X.to_owned();
95 let y = y.to_owned();
96
97 let (n_samples, n_features) = X.dim();
98
99 let mut classes = std::collections::HashSet::new();
101 for &label in y.iter() {
102 if label != -1 {
103 classes.insert(label);
104 }
105 }
106 let classes: Vec<i32> = classes.into_iter().collect();
107 let n_classes = classes.len();
108
109 let mut layer_sizes = vec![n_features];
111 layer_sizes.extend(&self.hidden_layers);
112 layer_sizes.push(n_classes);
113
114 let mut weights = Vec::new();
115 let mut biases = Vec::new();
116
117 for i in 0..layer_sizes.len() - 1 {
118 let in_size = layer_sizes[i];
119 let out_size = layer_sizes[i + 1];
120
121 let w = Array2::zeros((in_size, out_size));
122 let b = Array1::zeros(out_size);
123
124 weights.push(w);
125 biases.push(b);
126 }
127
128 Ok(MAML {
129 state: MAMLTrained {
130 meta_weights: weights,
131 meta_biases: biases,
132 classes: Array1::from(classes),
133 },
134 inner_lr: self.inner_lr,
135 outer_lr: self.outer_lr,
136 inner_steps: self.inner_steps,
137 n_episodes: self.n_episodes,
138 hidden_layers: self.hidden_layers,
139 })
140 }
141}
142
143impl Predict<ArrayView2<'_, Float>, Array1<i32>> for MAML<MAMLTrained> {
144 fn predict(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array1<i32>> {
145 let n_test = X.nrows();
146 let n_classes = self.state.classes.len();
147 let mut predictions = Array1::zeros(n_test);
148
149 for i in 0..n_test {
150 predictions[i] = self.state.classes[i % n_classes];
151 }
152
153 Ok(predictions)
154 }
155}
156
157impl PredictProba<ArrayView2<'_, Float>, Array2<f64>> for MAML<MAMLTrained> {
158 fn predict_proba(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array2<f64>> {
159 let n_test = X.nrows();
160 let n_classes = self.state.classes.len();
161 let mut probabilities = Array2::zeros((n_test, n_classes));
162
163 for i in 0..n_test {
164 for j in 0..n_classes {
165 probabilities[[i, j]] = 1.0 / n_classes as f64;
166 }
167 }
168
169 Ok(probabilities)
170 }
171}
172
173#[derive(Debug, Clone)]
175pub struct MAMLTrained {
176 pub meta_weights: Vec<Array2<f64>>,
178 pub meta_biases: Vec<Array1<f64>>,
180 pub classes: Array1<i32>,
182}