sklears_semi_supervised/few_shot/
prototypical_networks.rs

1//! Prototypical Networks implementation
2
3use scirs2_core::ndarray_ext::{Array1, Array2, ArrayView1, ArrayView2, Axis};
4use sklears_core::{
5    error::{Result as SklResult, SklearsError},
6    traits::{Estimator, Fit, Predict, PredictProba, Untrained},
7    types::Float,
8};
9use std::collections::HashMap;
10
11/// Prototypical Networks for Few-Shot Learning
12///
13/// Prototypical Networks learn a metric space where classification can be performed
14/// by computing distances to prototype representations of each class. The prototypes
15/// are the mean of the support examples for each class in an embedding space.
16///
17/// The method is particularly effective for few-shot learning scenarios where
18/// only a few labeled examples are available per class.
19///
20/// # Parameters
21///
22/// * `embedding_dim` - Dimensionality of the embedding space
23/// * `hidden_layers` - Hidden layer dimensions for the embedding network
24/// * `distance_metric` - Distance metric to use ('euclidean', 'cosine', 'manhattan')
25/// * `learning_rate` - Learning rate for embedding network training
26/// * `n_episodes` - Number of training episodes
27/// * `n_way` - Number of classes per episode
28/// * `n_shot` - Number of support examples per class
29/// * `n_query` - Number of query examples per class
30///
31/// # Examples
32///
33/// ```rust,ignore
34/// use sklears_semi_supervised::PrototypicalNetworks;
35/// use sklears_core::traits::{Predict, Fit};
36///
37///
38/// let X = array![
39///     [1.0, 2.0], [2.0, 3.0], [3.0, 4.0], [4.0, 5.0],
40///     [1.1, 2.1], [2.1, 3.1], [3.1, 4.1], [4.1, 5.1]
41/// ];
42/// let y = array![0, 1, 0, 1, 0, 1, 0, 1];
43///
44/// let proto_net = PrototypicalNetworks::new()
45///     .embedding_dim(32)
46///     .n_way(2)
47///     .n_shot(1)
48///     .n_query(3);
49/// let fitted = proto_net.fit(&X.view(), &y.view()).unwrap();
50/// let predictions = fitted.predict(&X.view()).unwrap();
51/// ```
52#[derive(Debug, Clone)]
53pub struct PrototypicalNetworks<S = Untrained> {
54    state: S,
55    embedding_dim: usize,
56    hidden_layers: Vec<usize>,
57    distance_metric: String,
58    learning_rate: f64,
59    n_episodes: usize,
60    n_way: usize,
61    n_shot: usize,
62    n_query: usize,
63    temperature: f64,
64}
65
66impl PrototypicalNetworks<Untrained> {
67    /// Create a new PrototypicalNetworks instance
68    pub fn new() -> Self {
69        Self {
70            state: Untrained,
71            embedding_dim: 64,
72            hidden_layers: vec![128, 64],
73            distance_metric: "euclidean".to_string(),
74            learning_rate: 0.001,
75            n_episodes: 100,
76            n_way: 5,
77            n_shot: 1,
78            n_query: 15,
79            temperature: 1.0,
80        }
81    }
82
83    /// Set the embedding dimensionality
84    pub fn embedding_dim(mut self, embedding_dim: usize) -> Self {
85        self.embedding_dim = embedding_dim;
86        self
87    }
88
89    /// Set the hidden layer dimensions
90    pub fn hidden_layers(mut self, hidden_layers: Vec<usize>) -> Self {
91        self.hidden_layers = hidden_layers;
92        self
93    }
94
95    /// Set the distance metric
96    pub fn distance_metric(mut self, metric: String) -> Self {
97        self.distance_metric = metric;
98        self
99    }
100
101    /// Set the learning rate
102    pub fn learning_rate(mut self, learning_rate: f64) -> Self {
103        self.learning_rate = learning_rate;
104        self
105    }
106
107    /// Set the number of training episodes
108    pub fn n_episodes(mut self, n_episodes: usize) -> Self {
109        self.n_episodes = n_episodes;
110        self
111    }
112
113    /// Set the number of classes per episode (N-way)
114    pub fn n_way(mut self, n_way: usize) -> Self {
115        self.n_way = n_way;
116        self
117    }
118
119    /// Set the number of support examples per class (N-shot)
120    pub fn n_shot(mut self, n_shot: usize) -> Self {
121        self.n_shot = n_shot;
122        self
123    }
124
125    /// Set the number of query examples per class
126    pub fn n_query(mut self, n_query: usize) -> Self {
127        self.n_query = n_query;
128        self
129    }
130
131    /// Set the temperature parameter for softmax
132    pub fn temperature(mut self, temperature: f64) -> Self {
133        self.temperature = temperature;
134        self
135    }
136
137    /// Compute embedding for input data
138    fn compute_embedding(
139        &self,
140        X: &Array2<f64>,
141        weights: &[Array2<f64>],
142        biases: &[Array1<f64>],
143    ) -> Array2<f64> {
144        let mut current = X.clone();
145
146        for (i, (w, b)) in weights.iter().zip(biases.iter()).enumerate() {
147            current = current.dot(w);
148
149            // Add bias
150            for mut row in current.axis_iter_mut(Axis(0)) {
151                for (j, &bias_val) in b.iter().enumerate() {
152                    row[j] += bias_val;
153                }
154            }
155
156            // Apply ReLU activation (except for last layer)
157            if i < weights.len() - 1 {
158                current.mapv_inplace(|x| x.max(0.0));
159            }
160        }
161
162        current
163    }
164
165    /// Compute distance between embeddings
166    fn compute_distance(&self, a: &Array1<f64>, b: &Array1<f64>) -> f64 {
167        match self.distance_metric.as_str() {
168            "euclidean" => {
169                let diff = a - b;
170                diff.mapv(|x| x * x).sum().sqrt()
171            }
172            "cosine" => {
173                let dot_product = a.dot(b);
174                let norm_a = a.mapv(|x| x * x).sum().sqrt();
175                let norm_b = b.mapv(|x| x * x).sum().sqrt();
176                1.0 - (dot_product / (norm_a * norm_b))
177            }
178            "manhattan" => {
179                let diff = a - b;
180                diff.mapv(|x| x.abs()).sum()
181            }
182            _ => {
183                // Default to euclidean
184                let diff = a - b;
185                diff.mapv(|x| x * x).sum().sqrt()
186            }
187        }
188    }
189
190    /// Compute prototypes for each class
191    fn compute_prototypes(
192        &self,
193        support_embeddings: &Array2<f64>,
194        support_labels: &Array1<i32>,
195        classes: &[i32],
196    ) -> Array2<f64> {
197        let n_classes = classes.len();
198        let embedding_dim = support_embeddings.ncols();
199        let mut prototypes = Array2::zeros((n_classes, embedding_dim));
200
201        for (class_idx, &class_label) in classes.iter().enumerate() {
202            let mut class_embeddings = Vec::new();
203
204            for (sample_idx, &label) in support_labels.iter().enumerate() {
205                if label == class_label {
206                    class_embeddings.push(support_embeddings.row(sample_idx).to_owned());
207                }
208            }
209
210            if !class_embeddings.is_empty() {
211                // Compute mean embedding as prototype
212                for dim in 0..embedding_dim {
213                    let mean_val: f64 = class_embeddings.iter().map(|emb| emb[dim]).sum::<f64>()
214                        / class_embeddings.len() as f64;
215                    prototypes[[class_idx, dim]] = mean_val;
216                }
217            }
218        }
219
220        prototypes
221    }
222
223    /// Apply softmax to distances
224    fn softmax_distances(&self, distances: &Array1<f64>) -> Array1<f64> {
225        let scaled_distances = distances.mapv(|d| -d / self.temperature);
226        let max_dist = scaled_distances
227            .iter()
228            .fold(f64::NEG_INFINITY, |acc, &x| acc.max(x));
229
230        let exp_distances = scaled_distances.mapv(|d| (d - max_dist).exp());
231        let sum_exp = exp_distances.sum();
232
233        exp_distances.mapv(|x| x / sum_exp)
234    }
235
236    /// Sample episode for training
237    #[allow(clippy::type_complexity)]
238    fn sample_episode(
239        &self,
240        X: &Array2<f64>,
241        y: &Array1<i32>,
242        classes: &[i32],
243    ) -> SklResult<(Array2<f64>, Array1<i32>, Array2<f64>, Array1<i32>)> {
244        let n_samples = X.nrows();
245        let n_features = X.ncols();
246
247        // Group samples by class
248        let mut class_samples: HashMap<i32, Vec<usize>> = HashMap::new();
249        for (i, &label) in y.iter().enumerate() {
250            class_samples.entry(label).or_default().push(i);
251        }
252
253        // Check if we have enough samples per class
254        for &class_label in classes {
255            if let Some(samples) = class_samples.get(&class_label) {
256                if samples.len() < self.n_shot + self.n_query {
257                    return Err(SklearsError::InvalidInput(format!(
258                        "Not enough samples for class {}: need {}, have {}",
259                        class_label,
260                        self.n_shot + self.n_query,
261                        samples.len()
262                    )));
263                }
264            } else {
265                return Err(SklearsError::InvalidInput(format!(
266                    "Class {} not found in data",
267                    class_label
268                )));
269            }
270        }
271
272        // Sample support and query sets
273        let total_support = self.n_way * self.n_shot;
274        let total_query = self.n_way * self.n_query;
275
276        let mut support_X = Array2::zeros((total_support, n_features));
277        let mut support_y = Array1::zeros(total_support);
278        let mut query_X = Array2::zeros((total_query, n_features));
279        let mut query_y = Array1::zeros(total_query);
280
281        let mut support_idx = 0;
282        let mut query_idx = 0;
283
284        for (class_idx, &class_label) in classes.iter().take(self.n_way).enumerate() {
285            if let Some(samples) = class_samples.get(&class_label) {
286                // Randomly sample from available samples (simplified - just take first few)
287                let selected_samples: Vec<usize> = samples
288                    .iter()
289                    .take(self.n_shot + self.n_query)
290                    .cloned()
291                    .collect();
292
293                #[allow(clippy::needless_range_loop)]
294                // Support set
295                for i in 0..self.n_shot {
296                    let sample_idx = selected_samples[i];
297                    support_X.row_mut(support_idx).assign(&X.row(sample_idx));
298                    support_y[support_idx] = class_idx as i32; // Use episode-specific class indices
299                    support_idx += 1;
300                }
301
302                #[allow(clippy::needless_range_loop)]
303                // Query set
304                for i in self.n_shot..self.n_shot + self.n_query {
305                    let sample_idx = selected_samples[i];
306                    query_X.row_mut(query_idx).assign(&X.row(sample_idx));
307                    query_y[query_idx] = class_idx as i32; // Use episode-specific class indices
308                    query_idx += 1;
309                }
310            }
311        }
312
313        Ok((support_X, support_y, query_X, query_y))
314    }
315}
316
317impl Default for PrototypicalNetworks<Untrained> {
318    fn default() -> Self {
319        Self::new()
320    }
321}
322
323impl Estimator for PrototypicalNetworks<Untrained> {
324    type Config = ();
325    type Error = SklearsError;
326    type Float = Float;
327
328    fn config(&self) -> &Self::Config {
329        &()
330    }
331}
332
333impl Fit<ArrayView2<'_, Float>, ArrayView1<'_, i32>> for PrototypicalNetworks<Untrained> {
334    type Fitted = PrototypicalNetworks<PrototypicalNetworksTrained>;
335
336    #[allow(non_snake_case)]
337    fn fit(self, X: &ArrayView2<'_, Float>, y: &ArrayView1<'_, i32>) -> SklResult<Self::Fitted> {
338        let X = X.to_owned();
339        let y = y.to_owned();
340
341        let (n_samples, n_features) = X.dim();
342
343        // Get unique classes
344        let mut classes = std::collections::HashSet::new();
345        for &label in y.iter() {
346            if label != -1 {
347                classes.insert(label);
348            }
349        }
350        let classes: Vec<i32> = classes.into_iter().collect();
351
352        if classes.len() < self.n_way {
353            return Err(SklearsError::InvalidInput(format!(
354                "Need at least {} classes for {}-way classification, found {}",
355                self.n_way,
356                self.n_way,
357                classes.len()
358            )));
359        }
360
361        // Initialize embedding network weights
362        let mut layer_sizes = vec![n_features];
363        layer_sizes.extend(&self.hidden_layers);
364        layer_sizes.push(self.embedding_dim);
365
366        let mut weights = Vec::new();
367        let mut biases = Vec::new();
368
369        for i in 0..layer_sizes.len() - 1 {
370            let in_size = layer_sizes[i];
371            let out_size = layer_sizes[i + 1];
372
373            // Xavier initialization
374            let scale = (2.0 / (in_size + out_size) as f64).sqrt();
375            let mut w = Array2::zeros((in_size, out_size));
376            let b = Array1::zeros(out_size);
377
378            // Simple initialization
379            for i in 0..in_size {
380                for j in 0..out_size {
381                    w[[i, j]] = scale * ((i + j) as f64 * 0.1).sin();
382                }
383            }
384
385            weights.push(w);
386            biases.push(b);
387        }
388
389        // Episodic training
390        for episode in 0..self.n_episodes {
391            // Sample an episode
392            let episode_classes: Vec<i32> = classes.iter().take(self.n_way).cloned().collect();
393
394            let (support_X, support_y, query_X, query_y) =
395                self.sample_episode(&X, &y, &episode_classes)?;
396
397            // Forward pass: compute embeddings
398            let support_embeddings = self.compute_embedding(&support_X, &weights, &biases);
399            let query_embeddings = self.compute_embedding(&query_X, &weights, &biases);
400
401            // Compute prototypes
402            let episode_class_indices: Vec<i32> = (0..self.n_way as i32).collect();
403            let prototypes =
404                self.compute_prototypes(&support_embeddings, &support_y, &episode_class_indices);
405
406            // Compute distances and probabilities for query set
407            let n_query_samples = query_embeddings.nrows();
408            let mut total_loss = 0.0;
409
410            for query_idx in 0..n_query_samples {
411                let query_embedding = query_embeddings.row(query_idx);
412                let true_class = query_y[query_idx] as usize;
413
414                // Skip if true_class is out of bounds
415                if true_class >= self.n_way {
416                    continue;
417                }
418
419                // Compute distances to all prototypes
420                let mut distances = Array1::zeros(self.n_way);
421                for class_idx in 0..self.n_way {
422                    let prototype = prototypes.row(class_idx);
423                    distances[class_idx] =
424                        self.compute_distance(&query_embedding.to_owned(), &prototype.to_owned());
425                }
426
427                // Convert to probabilities
428                let probabilities = self.softmax_distances(&distances);
429
430                // Cross-entropy loss
431                let prob = probabilities[true_class].max(1e-10);
432                total_loss -= prob.ln();
433
434                // Simple gradient update (simplified for demonstration)
435                let lr = self.learning_rate / (episode + 1) as f64;
436
437                // Update last layer weights (simplified gradient)
438                if let (Some(last_w), Some(last_b)) = (weights.last_mut(), biases.last_mut()) {
439                    let max_features = query_X.ncols().min(last_w.nrows());
440                    for i in 0..max_features {
441                        for j in 0..last_w.ncols() {
442                            let grad_w =
443                                (probabilities[true_class] - 1.0) * query_X[[query_idx, i]];
444                            last_w[[i, j]] -= lr * grad_w;
445                        }
446                    }
447
448                    for j in 0..last_b.len() {
449                        let grad_b = probabilities[true_class] - 1.0;
450                        last_b[j] -= lr * grad_b;
451                    }
452                }
453            }
454
455            // Print training progress occasionally
456            if episode % 20 == 0 {
457                let avg_loss = total_loss / n_query_samples as f64;
458                // Could log progress here if needed
459                let _ = avg_loss; // Suppress unused variable warning
460            }
461        }
462
463        Ok(PrototypicalNetworks {
464            state: PrototypicalNetworksTrained {
465                weights,
466                biases,
467                classes: Array1::from(classes),
468                prototypes: Array2::zeros((1, 1)), // Will be computed during prediction
469            },
470            embedding_dim: self.embedding_dim,
471            hidden_layers: self.hidden_layers,
472            distance_metric: self.distance_metric,
473            learning_rate: self.learning_rate,
474            n_episodes: self.n_episodes,
475            n_way: self.n_way,
476            n_shot: self.n_shot,
477            n_query: self.n_query,
478            temperature: self.temperature,
479        })
480    }
481}
482
483impl Predict<ArrayView2<'_, Float>, Array1<i32>>
484    for PrototypicalNetworks<PrototypicalNetworksTrained>
485{
486    fn predict(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array1<i32>> {
487        let probabilities = self.predict_proba(X)?;
488        let n_test = X.nrows();
489        let mut predictions = Array1::zeros(n_test);
490
491        for i in 0..n_test {
492            let max_idx = probabilities
493                .row(i)
494                .iter()
495                .enumerate()
496                .max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
497                .unwrap()
498                .0;
499            predictions[i] = self.state.classes[max_idx];
500        }
501
502        Ok(predictions)
503    }
504}
505
506impl PredictProba<ArrayView2<'_, Float>, Array2<f64>>
507    for PrototypicalNetworks<PrototypicalNetworksTrained>
508{
509    #[allow(non_snake_case)]
510    fn predict_proba(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array2<f64>> {
511        let X = X.to_owned();
512        let n_test = X.nrows();
513        let n_classes = self.state.classes.len();
514
515        // For prediction, we need support examples to compute prototypes
516        // This is a limitation of the current implementation - in practice,
517        // we would store representative prototypes from training
518
519        // For now, return uniform probabilities as a placeholder
520        let mut probabilities = Array2::zeros((n_test, n_classes));
521        for i in 0..n_test {
522            for j in 0..n_classes {
523                probabilities[[i, j]] = 1.0 / n_classes as f64;
524            }
525        }
526
527        Ok(probabilities)
528    }
529}
530
531/// Trained state for PrototypicalNetworks
532#[derive(Debug, Clone)]
533pub struct PrototypicalNetworksTrained {
534    /// weights
535    pub weights: Vec<Array2<f64>>,
536    /// biases
537    pub biases: Vec<Array1<f64>>,
538    /// classes
539    pub classes: Array1<i32>,
540    /// prototypes
541    pub prototypes: Array2<f64>,
542}