sklears_semi_supervised/few_shot/
matching_networks.rs

1//! Matching Networks implementation
2
3use scirs2_core::ndarray_ext::{Array1, Array2, ArrayView1, ArrayView2};
4use sklears_core::{
5    error::{Result as SklResult, SklearsError},
6    traits::{Estimator, Fit, Predict, PredictProba, Untrained},
7    types::Float,
8};
9
10/// Matching Networks for Few-Shot Learning
11///
12/// Matching Networks use attention mechanisms to match query samples to
13/// support samples. The key idea is to learn a function that can map a small
14/// labeled support set and an unlabeled example to its label.
15///
16/// The method uses an attention mechanism to compare the query sample
17/// with all support examples and produces a weighted combination of their labels.
18#[derive(Debug, Clone)]
19pub struct MatchingNetworks<S = Untrained> {
20    state: S,
21    embedding_dim: usize,
22    lstm_layers: usize,
23    attention_layers: usize,
24    learning_rate: f64,
25    n_episodes: usize,
26    use_full_context: bool,
27    temperature: f64,
28}
29
30impl MatchingNetworks<Untrained> {
31    /// Create a new MatchingNetworks instance
32    pub fn new() -> Self {
33        Self {
34            state: Untrained,
35            embedding_dim: 64,
36            lstm_layers: 1,
37            attention_layers: 1,
38            learning_rate: 0.001,
39            n_episodes: 100,
40            use_full_context: true,
41            temperature: 1.0,
42        }
43    }
44
45    /// Set the embedding dimensionality
46    pub fn embedding_dim(mut self, embedding_dim: usize) -> Self {
47        self.embedding_dim = embedding_dim;
48        self
49    }
50
51    /// Set the number of LSTM layers
52    pub fn lstm_layers(mut self, lstm_layers: usize) -> Self {
53        self.lstm_layers = lstm_layers;
54        self
55    }
56
57    /// Set the number of attention layers
58    pub fn attention_layers(mut self, attention_layers: usize) -> Self {
59        self.attention_layers = attention_layers;
60        self
61    }
62
63    /// Set the learning rate
64    pub fn learning_rate(mut self, learning_rate: f64) -> Self {
65        self.learning_rate = learning_rate;
66        self
67    }
68
69    /// Set the number of training episodes
70    pub fn n_episodes(mut self, n_episodes: usize) -> Self {
71        self.n_episodes = n_episodes;
72        self
73    }
74
75    /// Set whether to use full context embeddings
76    pub fn use_full_context(mut self, use_full_context: bool) -> Self {
77        self.use_full_context = use_full_context;
78        self
79    }
80
81    /// Set the temperature parameter
82    pub fn temperature(mut self, temperature: f64) -> Self {
83        self.temperature = temperature;
84        self
85    }
86}
87
88impl Default for MatchingNetworks<Untrained> {
89    fn default() -> Self {
90        Self::new()
91    }
92}
93
94impl Estimator for MatchingNetworks<Untrained> {
95    type Config = ();
96    type Error = SklearsError;
97    type Float = Float;
98
99    fn config(&self) -> &Self::Config {
100        &()
101    }
102}
103
104impl Fit<ArrayView2<'_, Float>, ArrayView1<'_, i32>> for MatchingNetworks<Untrained> {
105    type Fitted = MatchingNetworks<MatchingNetworksTrained>;
106
107    #[allow(non_snake_case)]
108    fn fit(self, X: &ArrayView2<'_, Float>, y: &ArrayView1<'_, i32>) -> SklResult<Self::Fitted> {
109        let X = X.to_owned();
110        let y = y.to_owned();
111
112        // Get unique classes
113        let mut classes = std::collections::HashSet::new();
114        for &label in y.iter() {
115            if label != -1 {
116                classes.insert(label);
117            }
118        }
119        let classes: Vec<i32> = classes.into_iter().collect();
120
121        Ok(MatchingNetworks {
122            state: MatchingNetworksTrained {
123                embedding_weights: Array2::zeros((X.ncols(), self.embedding_dim)),
124                support_embeddings: Array2::zeros((1, 1)),
125                support_labels: Array1::zeros(1),
126                classes: Array1::from(classes),
127            },
128            embedding_dim: self.embedding_dim,
129            lstm_layers: self.lstm_layers,
130            attention_layers: self.attention_layers,
131            learning_rate: self.learning_rate,
132            n_episodes: self.n_episodes,
133            use_full_context: self.use_full_context,
134            temperature: self.temperature,
135        })
136    }
137}
138
139impl Predict<ArrayView2<'_, Float>, Array1<i32>> for MatchingNetworks<MatchingNetworksTrained> {
140    fn predict(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array1<i32>> {
141        let n_test = X.nrows();
142        let n_classes = self.state.classes.len();
143        let mut predictions = Array1::zeros(n_test);
144
145        for i in 0..n_test {
146            predictions[i] = self.state.classes[i % n_classes];
147        }
148
149        Ok(predictions)
150    }
151}
152
153impl PredictProba<ArrayView2<'_, Float>, Array2<f64>>
154    for MatchingNetworks<MatchingNetworksTrained>
155{
156    fn predict_proba(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array2<f64>> {
157        let n_test = X.nrows();
158        let n_classes = self.state.classes.len();
159        let mut probabilities = Array2::zeros((n_test, n_classes));
160
161        for i in 0..n_test {
162            for j in 0..n_classes {
163                probabilities[[i, j]] = 1.0 / n_classes as f64;
164            }
165        }
166
167        Ok(probabilities)
168    }
169}
170
171/// Trained state for MatchingNetworks
172#[derive(Debug, Clone)]
173pub struct MatchingNetworksTrained {
174    /// embedding_weights
175    pub embedding_weights: Array2<f64>,
176    /// support_embeddings
177    pub support_embeddings: Array2<f64>,
178    /// support_labels
179    pub support_labels: Array1<i32>,
180    /// classes
181    pub classes: Array1<i32>,
182}