Skip to main content

tensorlogic_train/
few_shot.rs

1//! Few-shot learning utilities for learning from limited examples.
2//!
3//! This module provides infrastructure for few-shot learning scenarios where
4//! models must learn to generalize from only a small number of labeled examples.
5//!
6//! # Overview
7//!
8//! Few-shot learning addresses the challenge of learning new tasks with minimal
9//! training data. This module implements:
10//!
11//! - **Episode sampling**: N-way K-shot task generation
12//! - **Prototypical networks**: Learn metric space for prototype-based classification
13//! - **Matching networks**: Attention-based matching between query and support sets
14//! - **Support set management**: Efficient storage and retrieval of support examples
15//! - **Distance metrics**: Various similarity functions for few-shot learning
16//!
17//! # Key Concepts
18//!
19//! - **Support set**: Small set of labeled examples used for adaptation
20//! - **Query set**: Examples to classify based on the support set
21//! - **N-way K-shot**: Task with N classes, K examples per class
22//! - **Episode**: Single few-shot task instance during training
23//!
24//! # Examples
25//!
26//! ```rust
27//! use tensorlogic_train::{
28//!     EpisodeSampler, PrototypicalDistance, DistanceMetric, ShotType
29//! };
30//! use scirs2_core::ndarray::Array2;
31//!
32//! # fn main() -> Result<(), Box<dyn std::error::Error>> {
33//! // Create a 5-way 1-shot episode sampler
34//! let sampler = EpisodeSampler::new(5, ShotType::OneShot, 15);
35//!
36//! // Use prototypical distance for classification
37//! let distance = PrototypicalDistance::euclidean();
38//!
39//! # Ok(())
40//! # }
41//! ```
42
43use crate::{TrainError, TrainResult};
44use scirs2_core::ndarray::{Array1, Array2, ArrayView1, Axis};
45
46/// Type of shot configuration for few-shot learning.
47#[derive(Debug, Clone, Copy, PartialEq, Eq)]
48pub enum ShotType {
49    /// One example per class (1-shot).
50    OneShot,
51    /// Few examples per class (typically 5-shot).
52    FewShot(usize),
53    /// Custom number of examples per class.
54    Custom(usize),
55}
56
57impl ShotType {
58    /// Get the number of shots.
59    pub fn k(&self) -> usize {
60        match self {
61            ShotType::OneShot => 1,
62            ShotType::FewShot(k) => *k,
63            ShotType::Custom(k) => *k,
64        }
65    }
66}
67
68/// Distance metric for few-shot learning.
69#[derive(Debug, Clone, Copy, PartialEq, Eq)]
70pub enum DistanceMetric {
71    /// Euclidean distance (L2 norm).
72    Euclidean,
73    /// Cosine similarity (normalized dot product).
74    Cosine,
75    /// Manhattan distance (L1 norm).
76    Manhattan,
77    /// Squared Euclidean distance.
78    SquaredEuclidean,
79}
80
81impl DistanceMetric {
82    /// Compute distance between two vectors.
83    pub fn compute(&self, a: &ArrayView1<f64>, b: &ArrayView1<f64>) -> f64 {
84        match self {
85            DistanceMetric::Euclidean => {
86                let diff = a.to_owned() - b.to_owned();
87                diff.dot(&diff).sqrt()
88            }
89            DistanceMetric::Cosine => {
90                let dot = a.dot(b);
91                let norm_a = a.dot(a).sqrt();
92                let norm_b = b.dot(b).sqrt();
93                if norm_a == 0.0 || norm_b == 0.0 {
94                    0.0
95                } else {
96                    1.0 - (dot / (norm_a * norm_b))
97                }
98            }
99            DistanceMetric::Manhattan => {
100                let diff = a.to_owned() - b.to_owned();
101                diff.iter().map(|x| x.abs()).sum()
102            }
103            DistanceMetric::SquaredEuclidean => {
104                let diff = a.to_owned() - b.to_owned();
105                diff.dot(&diff)
106            }
107        }
108    }
109}
110
111/// Support set for few-shot learning.
112///
113/// Contains labeled examples used for classification or regression.
114#[derive(Debug, Clone)]
115pub struct SupportSet {
116    /// Feature vectors for support examples.
117    pub features: Array2<f64>,
118    /// Labels for support examples (class indices).
119    pub labels: Array1<usize>,
120    /// Number of classes.
121    pub num_classes: usize,
122}
123
124impl SupportSet {
125    /// Create a new support set.
126    ///
127    /// # Arguments
128    /// * `features` - Feature matrix (n_examples × n_features)
129    /// * `labels` - Class labels (n_examples,)
130    ///
131    /// # Returns
132    /// New support set
133    pub fn new(features: Array2<f64>, labels: Array1<usize>) -> TrainResult<Self> {
134        if features.nrows() != labels.len() {
135            return Err(TrainError::InvalidParameter(format!(
136                "Feature rows ({}) must match label count ({})",
137                features.nrows(),
138                labels.len()
139            )));
140        }
141
142        let num_classes = labels.iter().max().copied().unwrap_or(0) + 1;
143
144        Ok(Self {
145            features,
146            labels,
147            num_classes,
148        })
149    }
150
151    /// Get examples for a specific class.
152    pub fn get_class_examples(&self, class_id: usize) -> Array2<f64> {
153        let indices: Vec<usize> = self
154            .labels
155            .iter()
156            .enumerate()
157            .filter(|(_, &label)| label == class_id)
158            .map(|(idx, _)| idx)
159            .collect();
160
161        if indices.is_empty() {
162            return Array2::zeros((0, self.features.ncols()));
163        }
164
165        let mut result = Array2::zeros((indices.len(), self.features.ncols()));
166        for (i, &idx) in indices.iter().enumerate() {
167            result.row_mut(i).assign(&self.features.row(idx));
168        }
169        result
170    }
171
172    /// Get number of support examples.
173    pub fn size(&self) -> usize {
174        self.features.nrows()
175    }
176}
177
178/// Prototypical distance calculator for few-shot learning.
179///
180/// Computes distances between query examples and class prototypes.
181/// Prototypes are computed as the mean of support examples for each class.
182#[derive(Debug, Clone)]
183pub struct PrototypicalDistance {
184    /// Distance metric to use.
185    metric: DistanceMetric,
186    /// Class prototypes (computed from support set).
187    prototypes: Option<Array2<f64>>,
188}
189
190impl PrototypicalDistance {
191    /// Create with Euclidean distance.
192    pub fn euclidean() -> Self {
193        Self {
194            metric: DistanceMetric::Euclidean,
195            prototypes: None,
196        }
197    }
198
199    /// Create with cosine distance.
200    pub fn cosine() -> Self {
201        Self {
202            metric: DistanceMetric::Cosine,
203            prototypes: None,
204        }
205    }
206
207    /// Create with custom distance metric.
208    pub fn new(metric: DistanceMetric) -> Self {
209        Self {
210            metric,
211            prototypes: None,
212        }
213    }
214
215    /// Compute prototypes from support set.
216    ///
217    /// # Arguments
218    /// * `support` - Support set with labeled examples
219    pub fn compute_prototypes(&mut self, support: &SupportSet) {
220        let mut prototypes = Array2::zeros((support.num_classes, support.features.ncols()));
221
222        for class_id in 0..support.num_classes {
223            let class_examples = support.get_class_examples(class_id);
224            if class_examples.nrows() > 0 {
225                let prototype = class_examples
226                    .mean_axis(Axis(0))
227                    .expect("mean_axis on non-empty class examples");
228                prototypes.row_mut(class_id).assign(&prototype);
229            }
230        }
231
232        self.prototypes = Some(prototypes);
233    }
234
235    /// Compute distances from query to all prototypes.
236    ///
237    /// # Arguments
238    /// * `query` - Query feature vector
239    ///
240    /// # Returns
241    /// Distance to each prototype (class)
242    pub fn compute_distances(&self, query: &ArrayView1<f64>) -> TrainResult<Array1<f64>> {
243        let prototypes = self
244            .prototypes
245            .as_ref()
246            .ok_or_else(|| TrainError::Other("Prototypes not computed".to_string()))?;
247
248        let mut distances = Array1::zeros(prototypes.nrows());
249        for (i, prototype) in prototypes.axis_iter(Axis(0)).enumerate() {
250            distances[i] = self.metric.compute(query, &prototype);
251        }
252
253        Ok(distances)
254    }
255
256    /// Predict class for query example.
257    ///
258    /// # Arguments
259    /// * `query` - Query feature vector
260    ///
261    /// # Returns
262    /// Predicted class (nearest prototype)
263    pub fn predict(&self, query: &ArrayView1<f64>) -> TrainResult<usize> {
264        let distances = self.compute_distances(query)?;
265
266        // Find minimum distance
267        let mut min_idx = 0;
268        let mut min_dist = distances[0];
269        for (i, &dist) in distances.iter().enumerate() {
270            if dist < min_dist {
271                min_dist = dist;
272                min_idx = i;
273            }
274        }
275
276        Ok(min_idx)
277    }
278
279    /// Predict probabilities using softmax over negative distances.
280    ///
281    /// # Arguments
282    /// * `query` - Query feature vector
283    /// * `temperature` - Temperature for softmax (default 1.0)
284    ///
285    /// # Returns
286    /// Probability distribution over classes
287    pub fn predict_proba(
288        &self,
289        query: &ArrayView1<f64>,
290        temperature: f64,
291    ) -> TrainResult<Array1<f64>> {
292        let distances = self.compute_distances(query)?;
293
294        // Convert to logits (negative distances)
295        let logits = distances.mapv(|d| -d / temperature);
296
297        // Softmax
298        let max_logit = logits.iter().copied().fold(f64::NEG_INFINITY, f64::max);
299        let exp_logits = logits.mapv(|x| (x - max_logit).exp());
300        let sum_exp = exp_logits.sum();
301        let probs = exp_logits.mapv(|x| x / sum_exp);
302
303        Ok(probs)
304    }
305}
306
307/// Episode sampler for N-way K-shot tasks.
308///
309/// Generates episodes for episodic training in few-shot learning.
310#[derive(Debug, Clone)]
311pub struct EpisodeSampler {
312    /// Number of classes per episode (N-way).
313    n_way: usize,
314    /// Number of shots per class (K-shot).
315    shot_type: ShotType,
316    /// Number of query examples per class.
317    n_query: usize,
318}
319
320impl EpisodeSampler {
321    /// Create a new episode sampler.
322    ///
323    /// # Arguments
324    /// * `n_way` - Number of classes per episode
325    /// * `shot_type` - Number of shots per class
326    /// * `n_query` - Number of query examples per class
327    pub fn new(n_way: usize, shot_type: ShotType, n_query: usize) -> Self {
328        Self {
329            n_way,
330            shot_type,
331            n_query,
332        }
333    }
334
335    /// Get total support examples per episode.
336    pub fn support_size(&self) -> usize {
337        self.n_way * self.shot_type.k()
338    }
339
340    /// Get total query examples per episode.
341    pub fn query_size(&self) -> usize {
342        self.n_way * self.n_query
343    }
344
345    /// Get episode description.
346    pub fn description(&self) -> String {
347        format!(
348            "{}-way {}-shot (query: {} per class)",
349            self.n_way,
350            self.shot_type.k(),
351            self.n_query
352        )
353    }
354}
355
356/// Matching network for few-shot learning.
357///
358/// Uses attention mechanism to match query examples to support examples.
359#[derive(Debug, Clone)]
360pub struct MatchingNetwork {
361    /// Distance metric for similarity.
362    metric: DistanceMetric,
363    /// Support set.
364    support: Option<SupportSet>,
365}
366
367impl MatchingNetwork {
368    /// Create a new matching network.
369    pub fn new(metric: DistanceMetric) -> Self {
370        Self {
371            metric,
372            support: None,
373        }
374    }
375
376    /// Set support set.
377    pub fn set_support(&mut self, support: SupportSet) {
378        self.support = Some(support);
379    }
380
381    /// Compute attention weights between query and all support examples.
382    ///
383    /// # Arguments
384    /// * `query` - Query feature vector
385    ///
386    /// # Returns
387    /// Attention weights for each support example
388    pub fn compute_attention(&self, query: &ArrayView1<f64>) -> TrainResult<Array1<f64>> {
389        let support = self
390            .support
391            .as_ref()
392            .ok_or_else(|| TrainError::Other("Support set not set".to_string()))?;
393
394        let n_support = support.size();
395        let mut similarities = Array1::zeros(n_support);
396
397        // Compute similarities
398        for i in 0..n_support {
399            let support_example = support.features.row(i);
400            similarities[i] = -self.metric.compute(query, &support_example);
401        }
402
403        // Softmax to get attention weights
404        let max_sim = similarities
405            .iter()
406            .copied()
407            .fold(f64::NEG_INFINITY, f64::max);
408        let exp_sims = similarities.mapv(|x| (x - max_sim).exp());
409        let sum_exp = exp_sims.sum();
410        let weights = exp_sims.mapv(|x| x / sum_exp);
411
412        Ok(weights)
413    }
414
415    /// Predict class using attention-weighted voting.
416    ///
417    /// # Arguments
418    /// * `query` - Query feature vector
419    ///
420    /// # Returns
421    /// Predicted class probabilities
422    pub fn predict_proba(&self, query: &ArrayView1<f64>) -> TrainResult<Array1<f64>> {
423        let support = self
424            .support
425            .as_ref()
426            .ok_or_else(|| TrainError::Other("Support set not set".to_string()))?;
427
428        let attention = self.compute_attention(query)?;
429        let mut class_probs = Array1::zeros(support.num_classes);
430
431        // Weighted voting
432        for (i, &weight) in attention.iter().enumerate() {
433            let label = support.labels[i];
434            class_probs[label] += weight;
435        }
436
437        Ok(class_probs)
438    }
439
440    /// Predict class label.
441    pub fn predict(&self, query: &ArrayView1<f64>) -> TrainResult<usize> {
442        let probs = self.predict_proba(query)?;
443        let mut max_idx = 0;
444        let mut max_prob = probs[0];
445        for (i, &prob) in probs.iter().enumerate() {
446            if prob > max_prob {
447                max_prob = prob;
448                max_idx = i;
449            }
450        }
451        Ok(max_idx)
452    }
453}
454
455/// Few-shot accuracy evaluator.
456#[derive(Debug, Clone, Default)]
457pub struct FewShotAccuracy {
458    correct: usize,
459    total: usize,
460}
461
462impl FewShotAccuracy {
463    /// Create a new accuracy tracker.
464    pub fn new() -> Self {
465        Self {
466            correct: 0,
467            total: 0,
468        }
469    }
470
471    /// Update with prediction.
472    pub fn update(&mut self, predicted: usize, actual: usize) {
473        self.total += 1;
474        if predicted == actual {
475            self.correct += 1;
476        }
477    }
478
479    /// Get current accuracy.
480    pub fn accuracy(&self) -> f64 {
481        if self.total == 0 {
482            0.0
483        } else {
484            self.correct as f64 / self.total as f64
485        }
486    }
487
488    /// Reset counters.
489    pub fn reset(&mut self) {
490        self.correct = 0;
491        self.total = 0;
492    }
493
494    /// Get counts.
495    pub fn counts(&self) -> (usize, usize) {
496        (self.correct, self.total)
497    }
498}
499
500#[cfg(test)]
501mod tests {
502    use super::*;
503    use approx::assert_relative_eq;
504
505    #[test]
506    fn test_shot_type() {
507        assert_eq!(ShotType::OneShot.k(), 1);
508        assert_eq!(ShotType::FewShot(5).k(), 5);
509        assert_eq!(ShotType::Custom(10).k(), 10);
510    }
511
512    #[test]
513    fn test_euclidean_distance() {
514        let a = Array1::from_vec(vec![1.0, 2.0, 3.0]);
515        let b = Array1::from_vec(vec![4.0, 5.0, 6.0]);
516
517        let dist = DistanceMetric::Euclidean.compute(&a.view(), &b.view());
518        assert_relative_eq!(dist, 5.196152, epsilon = 1e-5);
519    }
520
521    #[test]
522    fn test_cosine_distance() {
523        let a = Array1::from_vec(vec![1.0, 0.0]);
524        let b = Array1::from_vec(vec![0.0, 1.0]);
525
526        let dist = DistanceMetric::Cosine.compute(&a.view(), &b.view());
527        assert_relative_eq!(dist, 1.0, epsilon = 1e-5);
528    }
529
530    #[test]
531    fn test_support_set_creation() {
532        let features = Array2::from_shape_vec((4, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0])
533            .expect("unwrap");
534        let labels = Array1::from_vec(vec![0, 0, 1, 1]);
535
536        let support = SupportSet::new(features, labels).expect("unwrap");
537        assert_eq!(support.size(), 4);
538        assert_eq!(support.num_classes, 2);
539    }
540
541    #[test]
542    fn test_support_set_get_class_examples() {
543        let features = Array2::from_shape_vec((4, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0])
544            .expect("unwrap");
545        let labels = Array1::from_vec(vec![0, 0, 1, 1]);
546
547        let support = SupportSet::new(features, labels).expect("unwrap");
548        let class_0 = support.get_class_examples(0);
549
550        assert_eq!(class_0.nrows(), 2);
551        assert_eq!(class_0[[0, 0]], 1.0);
552        assert_eq!(class_0[[1, 0]], 3.0);
553    }
554
555    #[test]
556    fn test_prototypical_distance() {
557        let features = Array2::from_shape_vec((4, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0])
558            .expect("unwrap");
559        let labels = Array1::from_vec(vec![0, 0, 1, 1]);
560        let support = SupportSet::new(features, labels).expect("unwrap");
561
562        let mut proto = PrototypicalDistance::euclidean();
563        proto.compute_prototypes(&support);
564
565        let query = Array1::from_vec(vec![2.0, 3.0]);
566        let prediction = proto.predict(&query.view()).expect("unwrap");
567
568        assert_eq!(prediction, 0); // Closer to class 0
569    }
570
571    #[test]
572    fn test_prototypical_predict_proba() {
573        let features = Array2::from_shape_vec((4, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0])
574            .expect("unwrap");
575        let labels = Array1::from_vec(vec![0, 0, 1, 1]);
576        let support = SupportSet::new(features, labels).expect("unwrap");
577
578        let mut proto = PrototypicalDistance::euclidean();
579        proto.compute_prototypes(&support);
580
581        let query = Array1::from_vec(vec![2.0, 3.0]);
582        let probs = proto.predict_proba(&query.view(), 1.0).expect("unwrap");
583
584        assert_eq!(probs.len(), 2);
585        assert!(probs[0] > probs[1]); // Higher probability for class 0
586        assert_relative_eq!(probs.sum(), 1.0, epsilon = 1e-10);
587    }
588
589    #[test]
590    fn test_episode_sampler() {
591        let sampler = EpisodeSampler::new(5, ShotType::OneShot, 15);
592
593        assert_eq!(sampler.support_size(), 5); // 5 classes × 1 shot
594        assert_eq!(sampler.query_size(), 75); // 5 classes × 15 queries
595        assert!(sampler.description().contains("5-way"));
596    }
597
598    #[test]
599    fn test_matching_network() {
600        let features = Array2::from_shape_vec((4, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0])
601            .expect("unwrap");
602        let labels = Array1::from_vec(vec![0, 0, 1, 1]);
603        let support = SupportSet::new(features, labels).expect("unwrap");
604
605        let mut matcher = MatchingNetwork::new(DistanceMetric::Euclidean);
606        matcher.set_support(support);
607
608        let query = Array1::from_vec(vec![2.0, 3.0]);
609        let prediction = matcher.predict(&query.view()).expect("unwrap");
610
611        assert_eq!(prediction, 0); // Should predict class 0
612    }
613
614    #[test]
615    fn test_matching_network_attention() {
616        let features = Array2::from_shape_vec((4, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0])
617            .expect("unwrap");
618        let labels = Array1::from_vec(vec![0, 0, 1, 1]);
619        let support = SupportSet::new(features, labels).expect("unwrap");
620
621        let mut matcher = MatchingNetwork::new(DistanceMetric::Euclidean);
622        matcher.set_support(support);
623
624        let query = Array1::from_vec(vec![2.0, 3.0]);
625        let attention = matcher.compute_attention(&query.view()).expect("unwrap");
626
627        assert_eq!(attention.len(), 4);
628        assert_relative_eq!(attention.sum(), 1.0, epsilon = 1e-10);
629    }
630
631    #[test]
632    fn test_few_shot_accuracy() {
633        let mut acc = FewShotAccuracy::new();
634
635        acc.update(0, 0); // Correct
636        acc.update(1, 1); // Correct
637        acc.update(1, 0); // Wrong
638
639        assert_eq!(acc.accuracy(), 2.0 / 3.0);
640        assert_eq!(acc.counts(), (2, 3));
641
642        acc.reset();
643        assert_eq!(acc.accuracy(), 0.0);
644    }
645
646    #[test]
647    fn test_manhattan_distance() {
648        let a = Array1::from_vec(vec![1.0, 2.0, 3.0]);
649        let b = Array1::from_vec(vec![4.0, 5.0, 6.0]);
650
651        let dist = DistanceMetric::Manhattan.compute(&a.view(), &b.view());
652        assert_eq!(dist, 9.0);
653    }
654
655    #[test]
656    fn test_squared_euclidean_distance() {
657        let a = Array1::from_vec(vec![1.0, 2.0]);
658        let b = Array1::from_vec(vec![4.0, 6.0]);
659
660        let dist = DistanceMetric::SquaredEuclidean.compute(&a.view(), &b.view());
661        assert_eq!(dist, 25.0); // (3^2 + 4^2) = 9 + 16 = 25
662    }
663}