Skip to main content

tensorlogic_train/
few_shot.rs

1//! Few-shot learning utilities for learning from limited examples.
2//!
3//! This module provides infrastructure for few-shot learning scenarios where
4//! models must learn to generalize from only a small number of labeled examples.
5//!
6//! # Overview
7//!
8//! Few-shot learning addresses the challenge of learning new tasks with minimal
9//! training data. This module implements:
10//!
11//! - **Episode sampling**: N-way K-shot task generation
12//! - **Prototypical networks**: Learn metric space for prototype-based classification
13//! - **Matching networks**: Attention-based matching between query and support sets
14//! - **Support set management**: Efficient storage and retrieval of support examples
15//! - **Distance metrics**: Various similarity functions for few-shot learning
16//!
17//! # Key Concepts
18//!
19//! - **Support set**: Small set of labeled examples used for adaptation
20//! - **Query set**: Examples to classify based on the support set
21//! - **N-way K-shot**: Task with N classes, K examples per class
22//! - **Episode**: Single few-shot task instance during training
23//!
24//! # Examples
25//!
26//! ```rust
27//! use tensorlogic_train::{
28//!     EpisodeSampler, PrototypicalDistance, DistanceMetric, ShotType
29//! };
30//! use scirs2_core::ndarray::Array2;
31//!
32//! # fn main() -> Result<(), Box<dyn std::error::Error>> {
33//! // Create a 5-way 1-shot episode sampler
34//! let sampler = EpisodeSampler::new(5, ShotType::OneShot, 15);
35//!
36//! // Use prototypical distance for classification
37//! let distance = PrototypicalDistance::euclidean();
38//!
39//! # Ok(())
40//! # }
41//! ```
42
43use crate::{TrainError, TrainResult};
44use scirs2_core::ndarray::{Array1, Array2, ArrayView1, Axis};
45
46/// Type of shot configuration for few-shot learning.
47#[derive(Debug, Clone, Copy, PartialEq, Eq)]
48pub enum ShotType {
49    /// One example per class (1-shot).
50    OneShot,
51    /// Few examples per class (typically 5-shot).
52    FewShot(usize),
53    /// Custom number of examples per class.
54    Custom(usize),
55}
56
57impl ShotType {
58    /// Get the number of shots.
59    pub fn k(&self) -> usize {
60        match self {
61            ShotType::OneShot => 1,
62            ShotType::FewShot(k) => *k,
63            ShotType::Custom(k) => *k,
64        }
65    }
66}
67
68/// Distance metric for few-shot learning.
69#[derive(Debug, Clone, Copy, PartialEq, Eq)]
70pub enum DistanceMetric {
71    /// Euclidean distance (L2 norm).
72    Euclidean,
73    /// Cosine similarity (normalized dot product).
74    Cosine,
75    /// Manhattan distance (L1 norm).
76    Manhattan,
77    /// Squared Euclidean distance.
78    SquaredEuclidean,
79}
80
81impl DistanceMetric {
82    /// Compute distance between two vectors.
83    pub fn compute(&self, a: &ArrayView1<f64>, b: &ArrayView1<f64>) -> f64 {
84        match self {
85            DistanceMetric::Euclidean => {
86                let diff = a.to_owned() - b.to_owned();
87                diff.dot(&diff).sqrt()
88            }
89            DistanceMetric::Cosine => {
90                let dot = a.dot(b);
91                let norm_a = a.dot(a).sqrt();
92                let norm_b = b.dot(b).sqrt();
93                if norm_a == 0.0 || norm_b == 0.0 {
94                    0.0
95                } else {
96                    1.0 - (dot / (norm_a * norm_b))
97                }
98            }
99            DistanceMetric::Manhattan => {
100                let diff = a.to_owned() - b.to_owned();
101                diff.iter().map(|x| x.abs()).sum()
102            }
103            DistanceMetric::SquaredEuclidean => {
104                let diff = a.to_owned() - b.to_owned();
105                diff.dot(&diff)
106            }
107        }
108    }
109}
110
111/// Support set for few-shot learning.
112///
113/// Contains labeled examples used for classification or regression.
114#[derive(Debug, Clone)]
115pub struct SupportSet {
116    /// Feature vectors for support examples.
117    pub features: Array2<f64>,
118    /// Labels for support examples (class indices).
119    pub labels: Array1<usize>,
120    /// Number of classes.
121    pub num_classes: usize,
122}
123
124impl SupportSet {
125    /// Create a new support set.
126    ///
127    /// # Arguments
128    /// * `features` - Feature matrix (n_examples × n_features)
129    /// * `labels` - Class labels (n_examples,)
130    ///
131    /// # Returns
132    /// New support set
133    pub fn new(features: Array2<f64>, labels: Array1<usize>) -> TrainResult<Self> {
134        if features.nrows() != labels.len() {
135            return Err(TrainError::InvalidParameter(format!(
136                "Feature rows ({}) must match label count ({})",
137                features.nrows(),
138                labels.len()
139            )));
140        }
141
142        let num_classes = labels.iter().max().copied().unwrap_or(0) + 1;
143
144        Ok(Self {
145            features,
146            labels,
147            num_classes,
148        })
149    }
150
151    /// Get examples for a specific class.
152    pub fn get_class_examples(&self, class_id: usize) -> Array2<f64> {
153        let indices: Vec<usize> = self
154            .labels
155            .iter()
156            .enumerate()
157            .filter(|(_, &label)| label == class_id)
158            .map(|(idx, _)| idx)
159            .collect();
160
161        if indices.is_empty() {
162            return Array2::zeros((0, self.features.ncols()));
163        }
164
165        let mut result = Array2::zeros((indices.len(), self.features.ncols()));
166        for (i, &idx) in indices.iter().enumerate() {
167            result.row_mut(i).assign(&self.features.row(idx));
168        }
169        result
170    }
171
172    /// Get number of support examples.
173    pub fn size(&self) -> usize {
174        self.features.nrows()
175    }
176}
177
178/// Prototypical distance calculator for few-shot learning.
179///
180/// Computes distances between query examples and class prototypes.
181/// Prototypes are computed as the mean of support examples for each class.
182#[derive(Debug, Clone)]
183pub struct PrototypicalDistance {
184    /// Distance metric to use.
185    metric: DistanceMetric,
186    /// Class prototypes (computed from support set).
187    prototypes: Option<Array2<f64>>,
188}
189
190impl PrototypicalDistance {
191    /// Create with Euclidean distance.
192    pub fn euclidean() -> Self {
193        Self {
194            metric: DistanceMetric::Euclidean,
195            prototypes: None,
196        }
197    }
198
199    /// Create with cosine distance.
200    pub fn cosine() -> Self {
201        Self {
202            metric: DistanceMetric::Cosine,
203            prototypes: None,
204        }
205    }
206
207    /// Create with custom distance metric.
208    pub fn new(metric: DistanceMetric) -> Self {
209        Self {
210            metric,
211            prototypes: None,
212        }
213    }
214
215    /// Compute prototypes from support set.
216    ///
217    /// # Arguments
218    /// * `support` - Support set with labeled examples
219    pub fn compute_prototypes(&mut self, support: &SupportSet) {
220        let mut prototypes = Array2::zeros((support.num_classes, support.features.ncols()));
221
222        for class_id in 0..support.num_classes {
223            let class_examples = support.get_class_examples(class_id);
224            if class_examples.nrows() > 0 {
225                let prototype = class_examples.mean_axis(Axis(0)).unwrap();
226                prototypes.row_mut(class_id).assign(&prototype);
227            }
228        }
229
230        self.prototypes = Some(prototypes);
231    }
232
233    /// Compute distances from query to all prototypes.
234    ///
235    /// # Arguments
236    /// * `query` - Query feature vector
237    ///
238    /// # Returns
239    /// Distance to each prototype (class)
240    pub fn compute_distances(&self, query: &ArrayView1<f64>) -> TrainResult<Array1<f64>> {
241        let prototypes = self
242            .prototypes
243            .as_ref()
244            .ok_or_else(|| TrainError::Other("Prototypes not computed".to_string()))?;
245
246        let mut distances = Array1::zeros(prototypes.nrows());
247        for (i, prototype) in prototypes.axis_iter(Axis(0)).enumerate() {
248            distances[i] = self.metric.compute(query, &prototype);
249        }
250
251        Ok(distances)
252    }
253
254    /// Predict class for query example.
255    ///
256    /// # Arguments
257    /// * `query` - Query feature vector
258    ///
259    /// # Returns
260    /// Predicted class (nearest prototype)
261    pub fn predict(&self, query: &ArrayView1<f64>) -> TrainResult<usize> {
262        let distances = self.compute_distances(query)?;
263
264        // Find minimum distance
265        let mut min_idx = 0;
266        let mut min_dist = distances[0];
267        for (i, &dist) in distances.iter().enumerate() {
268            if dist < min_dist {
269                min_dist = dist;
270                min_idx = i;
271            }
272        }
273
274        Ok(min_idx)
275    }
276
277    /// Predict probabilities using softmax over negative distances.
278    ///
279    /// # Arguments
280    /// * `query` - Query feature vector
281    /// * `temperature` - Temperature for softmax (default 1.0)
282    ///
283    /// # Returns
284    /// Probability distribution over classes
285    pub fn predict_proba(
286        &self,
287        query: &ArrayView1<f64>,
288        temperature: f64,
289    ) -> TrainResult<Array1<f64>> {
290        let distances = self.compute_distances(query)?;
291
292        // Convert to logits (negative distances)
293        let logits = distances.mapv(|d| -d / temperature);
294
295        // Softmax
296        let max_logit = logits.iter().copied().fold(f64::NEG_INFINITY, f64::max);
297        let exp_logits = logits.mapv(|x| (x - max_logit).exp());
298        let sum_exp = exp_logits.sum();
299        let probs = exp_logits.mapv(|x| x / sum_exp);
300
301        Ok(probs)
302    }
303}
304
305/// Episode sampler for N-way K-shot tasks.
306///
307/// Generates episodes for episodic training in few-shot learning.
308#[derive(Debug, Clone)]
309pub struct EpisodeSampler {
310    /// Number of classes per episode (N-way).
311    n_way: usize,
312    /// Number of shots per class (K-shot).
313    shot_type: ShotType,
314    /// Number of query examples per class.
315    n_query: usize,
316}
317
318impl EpisodeSampler {
319    /// Create a new episode sampler.
320    ///
321    /// # Arguments
322    /// * `n_way` - Number of classes per episode
323    /// * `shot_type` - Number of shots per class
324    /// * `n_query` - Number of query examples per class
325    pub fn new(n_way: usize, shot_type: ShotType, n_query: usize) -> Self {
326        Self {
327            n_way,
328            shot_type,
329            n_query,
330        }
331    }
332
333    /// Get total support examples per episode.
334    pub fn support_size(&self) -> usize {
335        self.n_way * self.shot_type.k()
336    }
337
338    /// Get total query examples per episode.
339    pub fn query_size(&self) -> usize {
340        self.n_way * self.n_query
341    }
342
343    /// Get episode description.
344    pub fn description(&self) -> String {
345        format!(
346            "{}-way {}-shot (query: {} per class)",
347            self.n_way,
348            self.shot_type.k(),
349            self.n_query
350        )
351    }
352}
353
354/// Matching network for few-shot learning.
355///
356/// Uses attention mechanism to match query examples to support examples.
357#[derive(Debug, Clone)]
358pub struct MatchingNetwork {
359    /// Distance metric for similarity.
360    metric: DistanceMetric,
361    /// Support set.
362    support: Option<SupportSet>,
363}
364
365impl MatchingNetwork {
366    /// Create a new matching network.
367    pub fn new(metric: DistanceMetric) -> Self {
368        Self {
369            metric,
370            support: None,
371        }
372    }
373
374    /// Set support set.
375    pub fn set_support(&mut self, support: SupportSet) {
376        self.support = Some(support);
377    }
378
379    /// Compute attention weights between query and all support examples.
380    ///
381    /// # Arguments
382    /// * `query` - Query feature vector
383    ///
384    /// # Returns
385    /// Attention weights for each support example
386    pub fn compute_attention(&self, query: &ArrayView1<f64>) -> TrainResult<Array1<f64>> {
387        let support = self
388            .support
389            .as_ref()
390            .ok_or_else(|| TrainError::Other("Support set not set".to_string()))?;
391
392        let n_support = support.size();
393        let mut similarities = Array1::zeros(n_support);
394
395        // Compute similarities
396        for i in 0..n_support {
397            let support_example = support.features.row(i);
398            similarities[i] = -self.metric.compute(query, &support_example);
399        }
400
401        // Softmax to get attention weights
402        let max_sim = similarities
403            .iter()
404            .copied()
405            .fold(f64::NEG_INFINITY, f64::max);
406        let exp_sims = similarities.mapv(|x| (x - max_sim).exp());
407        let sum_exp = exp_sims.sum();
408        let weights = exp_sims.mapv(|x| x / sum_exp);
409
410        Ok(weights)
411    }
412
413    /// Predict class using attention-weighted voting.
414    ///
415    /// # Arguments
416    /// * `query` - Query feature vector
417    ///
418    /// # Returns
419    /// Predicted class probabilities
420    pub fn predict_proba(&self, query: &ArrayView1<f64>) -> TrainResult<Array1<f64>> {
421        let support = self
422            .support
423            .as_ref()
424            .ok_or_else(|| TrainError::Other("Support set not set".to_string()))?;
425
426        let attention = self.compute_attention(query)?;
427        let mut class_probs = Array1::zeros(support.num_classes);
428
429        // Weighted voting
430        for (i, &weight) in attention.iter().enumerate() {
431            let label = support.labels[i];
432            class_probs[label] += weight;
433        }
434
435        Ok(class_probs)
436    }
437
438    /// Predict class label.
439    pub fn predict(&self, query: &ArrayView1<f64>) -> TrainResult<usize> {
440        let probs = self.predict_proba(query)?;
441        let mut max_idx = 0;
442        let mut max_prob = probs[0];
443        for (i, &prob) in probs.iter().enumerate() {
444            if prob > max_prob {
445                max_prob = prob;
446                max_idx = i;
447            }
448        }
449        Ok(max_idx)
450    }
451}
452
453/// Few-shot accuracy evaluator.
454#[derive(Debug, Clone, Default)]
455pub struct FewShotAccuracy {
456    correct: usize,
457    total: usize,
458}
459
460impl FewShotAccuracy {
461    /// Create a new accuracy tracker.
462    pub fn new() -> Self {
463        Self {
464            correct: 0,
465            total: 0,
466        }
467    }
468
469    /// Update with prediction.
470    pub fn update(&mut self, predicted: usize, actual: usize) {
471        self.total += 1;
472        if predicted == actual {
473            self.correct += 1;
474        }
475    }
476
477    /// Get current accuracy.
478    pub fn accuracy(&self) -> f64 {
479        if self.total == 0 {
480            0.0
481        } else {
482            self.correct as f64 / self.total as f64
483        }
484    }
485
486    /// Reset counters.
487    pub fn reset(&mut self) {
488        self.correct = 0;
489        self.total = 0;
490    }
491
492    /// Get counts.
493    pub fn counts(&self) -> (usize, usize) {
494        (self.correct, self.total)
495    }
496}
497
498#[cfg(test)]
499mod tests {
500    use super::*;
501    use approx::assert_relative_eq;
502
503    #[test]
504    fn test_shot_type() {
505        assert_eq!(ShotType::OneShot.k(), 1);
506        assert_eq!(ShotType::FewShot(5).k(), 5);
507        assert_eq!(ShotType::Custom(10).k(), 10);
508    }
509
510    #[test]
511    fn test_euclidean_distance() {
512        let a = Array1::from_vec(vec![1.0, 2.0, 3.0]);
513        let b = Array1::from_vec(vec![4.0, 5.0, 6.0]);
514
515        let dist = DistanceMetric::Euclidean.compute(&a.view(), &b.view());
516        assert_relative_eq!(dist, 5.196152, epsilon = 1e-5);
517    }
518
519    #[test]
520    fn test_cosine_distance() {
521        let a = Array1::from_vec(vec![1.0, 0.0]);
522        let b = Array1::from_vec(vec![0.0, 1.0]);
523
524        let dist = DistanceMetric::Cosine.compute(&a.view(), &b.view());
525        assert_relative_eq!(dist, 1.0, epsilon = 1e-5);
526    }
527
528    #[test]
529    fn test_support_set_creation() {
530        let features =
531            Array2::from_shape_vec((4, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]).unwrap();
532        let labels = Array1::from_vec(vec![0, 0, 1, 1]);
533
534        let support = SupportSet::new(features, labels).unwrap();
535        assert_eq!(support.size(), 4);
536        assert_eq!(support.num_classes, 2);
537    }
538
539    #[test]
540    fn test_support_set_get_class_examples() {
541        let features =
542            Array2::from_shape_vec((4, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]).unwrap();
543        let labels = Array1::from_vec(vec![0, 0, 1, 1]);
544
545        let support = SupportSet::new(features, labels).unwrap();
546        let class_0 = support.get_class_examples(0);
547
548        assert_eq!(class_0.nrows(), 2);
549        assert_eq!(class_0[[0, 0]], 1.0);
550        assert_eq!(class_0[[1, 0]], 3.0);
551    }
552
553    #[test]
554    fn test_prototypical_distance() {
555        let features =
556            Array2::from_shape_vec((4, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]).unwrap();
557        let labels = Array1::from_vec(vec![0, 0, 1, 1]);
558        let support = SupportSet::new(features, labels).unwrap();
559
560        let mut proto = PrototypicalDistance::euclidean();
561        proto.compute_prototypes(&support);
562
563        let query = Array1::from_vec(vec![2.0, 3.0]);
564        let prediction = proto.predict(&query.view()).unwrap();
565
566        assert_eq!(prediction, 0); // Closer to class 0
567    }
568
569    #[test]
570    fn test_prototypical_predict_proba() {
571        let features =
572            Array2::from_shape_vec((4, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]).unwrap();
573        let labels = Array1::from_vec(vec![0, 0, 1, 1]);
574        let support = SupportSet::new(features, labels).unwrap();
575
576        let mut proto = PrototypicalDistance::euclidean();
577        proto.compute_prototypes(&support);
578
579        let query = Array1::from_vec(vec![2.0, 3.0]);
580        let probs = proto.predict_proba(&query.view(), 1.0).unwrap();
581
582        assert_eq!(probs.len(), 2);
583        assert!(probs[0] > probs[1]); // Higher probability for class 0
584        assert_relative_eq!(probs.sum(), 1.0, epsilon = 1e-10);
585    }
586
587    #[test]
588    fn test_episode_sampler() {
589        let sampler = EpisodeSampler::new(5, ShotType::OneShot, 15);
590
591        assert_eq!(sampler.support_size(), 5); // 5 classes × 1 shot
592        assert_eq!(sampler.query_size(), 75); // 5 classes × 15 queries
593        assert!(sampler.description().contains("5-way"));
594    }
595
596    #[test]
597    fn test_matching_network() {
598        let features =
599            Array2::from_shape_vec((4, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]).unwrap();
600        let labels = Array1::from_vec(vec![0, 0, 1, 1]);
601        let support = SupportSet::new(features, labels).unwrap();
602
603        let mut matcher = MatchingNetwork::new(DistanceMetric::Euclidean);
604        matcher.set_support(support);
605
606        let query = Array1::from_vec(vec![2.0, 3.0]);
607        let prediction = matcher.predict(&query.view()).unwrap();
608
609        assert_eq!(prediction, 0); // Should predict class 0
610    }
611
612    #[test]
613    fn test_matching_network_attention() {
614        let features =
615            Array2::from_shape_vec((4, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]).unwrap();
616        let labels = Array1::from_vec(vec![0, 0, 1, 1]);
617        let support = SupportSet::new(features, labels).unwrap();
618
619        let mut matcher = MatchingNetwork::new(DistanceMetric::Euclidean);
620        matcher.set_support(support);
621
622        let query = Array1::from_vec(vec![2.0, 3.0]);
623        let attention = matcher.compute_attention(&query.view()).unwrap();
624
625        assert_eq!(attention.len(), 4);
626        assert_relative_eq!(attention.sum(), 1.0, epsilon = 1e-10);
627    }
628
629    #[test]
630    fn test_few_shot_accuracy() {
631        let mut acc = FewShotAccuracy::new();
632
633        acc.update(0, 0); // Correct
634        acc.update(1, 1); // Correct
635        acc.update(1, 0); // Wrong
636
637        assert_eq!(acc.accuracy(), 2.0 / 3.0);
638        assert_eq!(acc.counts(), (2, 3));
639
640        acc.reset();
641        assert_eq!(acc.accuracy(), 0.0);
642    }
643
644    #[test]
645    fn test_manhattan_distance() {
646        let a = Array1::from_vec(vec![1.0, 2.0, 3.0]);
647        let b = Array1::from_vec(vec![4.0, 5.0, 6.0]);
648
649        let dist = DistanceMetric::Manhattan.compute(&a.view(), &b.view());
650        assert_eq!(dist, 9.0);
651    }
652
653    #[test]
654    fn test_squared_euclidean_distance() {
655        let a = Array1::from_vec(vec![1.0, 2.0]);
656        let b = Array1::from_vec(vec![4.0, 6.0]);
657
658        let dist = DistanceMetric::SquaredEuclidean.compute(&a.view(), &b.view());
659        assert_eq!(dist, 25.0); // (3^2 + 4^2) = 9 + 16 = 25
660    }
661}