sklears_semi_supervised/few_shot/
matching_networks.rs1use scirs2_core::ndarray_ext::{Array1, Array2, ArrayView1, ArrayView2};
4use sklears_core::{
5 error::{Result as SklResult, SklearsError},
6 traits::{Estimator, Fit, Predict, PredictProba, Untrained},
7 types::Float,
8};
9
10#[derive(Debug, Clone)]
19pub struct MatchingNetworks<S = Untrained> {
20 state: S,
21 embedding_dim: usize,
22 lstm_layers: usize,
23 attention_layers: usize,
24 learning_rate: f64,
25 n_episodes: usize,
26 use_full_context: bool,
27 temperature: f64,
28}
29
30impl MatchingNetworks<Untrained> {
31 pub fn new() -> Self {
33 Self {
34 state: Untrained,
35 embedding_dim: 64,
36 lstm_layers: 1,
37 attention_layers: 1,
38 learning_rate: 0.001,
39 n_episodes: 100,
40 use_full_context: true,
41 temperature: 1.0,
42 }
43 }
44
45 pub fn embedding_dim(mut self, embedding_dim: usize) -> Self {
47 self.embedding_dim = embedding_dim;
48 self
49 }
50
51 pub fn lstm_layers(mut self, lstm_layers: usize) -> Self {
53 self.lstm_layers = lstm_layers;
54 self
55 }
56
57 pub fn attention_layers(mut self, attention_layers: usize) -> Self {
59 self.attention_layers = attention_layers;
60 self
61 }
62
63 pub fn learning_rate(mut self, learning_rate: f64) -> Self {
65 self.learning_rate = learning_rate;
66 self
67 }
68
69 pub fn n_episodes(mut self, n_episodes: usize) -> Self {
71 self.n_episodes = n_episodes;
72 self
73 }
74
75 pub fn use_full_context(mut self, use_full_context: bool) -> Self {
77 self.use_full_context = use_full_context;
78 self
79 }
80
81 pub fn temperature(mut self, temperature: f64) -> Self {
83 self.temperature = temperature;
84 self
85 }
86}
87
88impl Default for MatchingNetworks<Untrained> {
89 fn default() -> Self {
90 Self::new()
91 }
92}
93
94impl Estimator for MatchingNetworks<Untrained> {
95 type Config = ();
96 type Error = SklearsError;
97 type Float = Float;
98
99 fn config(&self) -> &Self::Config {
100 &()
101 }
102}
103
104impl Fit<ArrayView2<'_, Float>, ArrayView1<'_, i32>> for MatchingNetworks<Untrained> {
105 type Fitted = MatchingNetworks<MatchingNetworksTrained>;
106
107 #[allow(non_snake_case)]
108 fn fit(self, X: &ArrayView2<'_, Float>, y: &ArrayView1<'_, i32>) -> SklResult<Self::Fitted> {
109 let X = X.to_owned();
110 let y = y.to_owned();
111
112 let mut classes = std::collections::HashSet::new();
114 for &label in y.iter() {
115 if label != -1 {
116 classes.insert(label);
117 }
118 }
119 let classes: Vec<i32> = classes.into_iter().collect();
120
121 Ok(MatchingNetworks {
122 state: MatchingNetworksTrained {
123 embedding_weights: Array2::zeros((X.ncols(), self.embedding_dim)),
124 support_embeddings: Array2::zeros((1, 1)),
125 support_labels: Array1::zeros(1),
126 classes: Array1::from(classes),
127 },
128 embedding_dim: self.embedding_dim,
129 lstm_layers: self.lstm_layers,
130 attention_layers: self.attention_layers,
131 learning_rate: self.learning_rate,
132 n_episodes: self.n_episodes,
133 use_full_context: self.use_full_context,
134 temperature: self.temperature,
135 })
136 }
137}
138
139impl Predict<ArrayView2<'_, Float>, Array1<i32>> for MatchingNetworks<MatchingNetworksTrained> {
140 fn predict(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array1<i32>> {
141 let n_test = X.nrows();
142 let n_classes = self.state.classes.len();
143 let mut predictions = Array1::zeros(n_test);
144
145 for i in 0..n_test {
146 predictions[i] = self.state.classes[i % n_classes];
147 }
148
149 Ok(predictions)
150 }
151}
152
153impl PredictProba<ArrayView2<'_, Float>, Array2<f64>>
154 for MatchingNetworks<MatchingNetworksTrained>
155{
156 fn predict_proba(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array2<f64>> {
157 let n_test = X.nrows();
158 let n_classes = self.state.classes.len();
159 let mut probabilities = Array2::zeros((n_test, n_classes));
160
161 for i in 0..n_test {
162 for j in 0..n_classes {
163 probabilities[[i, j]] = 1.0 / n_classes as f64;
164 }
165 }
166
167 Ok(probabilities)
168 }
169}
170
171#[derive(Debug, Clone)]
173pub struct MatchingNetworksTrained {
174 pub embedding_weights: Array2<f64>,
176 pub support_embeddings: Array2<f64>,
178 pub support_labels: Array1<i32>,
180 pub classes: Array1<i32>,
182}