1use crate::{TrainError, TrainResult};
44use scirs2_core::ndarray::{Array1, Array2, ArrayView1, Axis};
45
46#[derive(Debug, Clone, Copy, PartialEq, Eq)]
48pub enum ShotType {
49 OneShot,
51 FewShot(usize),
53 Custom(usize),
55}
56
57impl ShotType {
58 pub fn k(&self) -> usize {
60 match self {
61 ShotType::OneShot => 1,
62 ShotType::FewShot(k) => *k,
63 ShotType::Custom(k) => *k,
64 }
65 }
66}
67
68#[derive(Debug, Clone, Copy, PartialEq, Eq)]
70pub enum DistanceMetric {
71 Euclidean,
73 Cosine,
75 Manhattan,
77 SquaredEuclidean,
79}
80
81impl DistanceMetric {
82 pub fn compute(&self, a: &ArrayView1<f64>, b: &ArrayView1<f64>) -> f64 {
84 match self {
85 DistanceMetric::Euclidean => {
86 let diff = a.to_owned() - b.to_owned();
87 diff.dot(&diff).sqrt()
88 }
89 DistanceMetric::Cosine => {
90 let dot = a.dot(b);
91 let norm_a = a.dot(a).sqrt();
92 let norm_b = b.dot(b).sqrt();
93 if norm_a == 0.0 || norm_b == 0.0 {
94 0.0
95 } else {
96 1.0 - (dot / (norm_a * norm_b))
97 }
98 }
99 DistanceMetric::Manhattan => {
100 let diff = a.to_owned() - b.to_owned();
101 diff.iter().map(|x| x.abs()).sum()
102 }
103 DistanceMetric::SquaredEuclidean => {
104 let diff = a.to_owned() - b.to_owned();
105 diff.dot(&diff)
106 }
107 }
108 }
109}
110
111#[derive(Debug, Clone)]
115pub struct SupportSet {
116 pub features: Array2<f64>,
118 pub labels: Array1<usize>,
120 pub num_classes: usize,
122}
123
124impl SupportSet {
125 pub fn new(features: Array2<f64>, labels: Array1<usize>) -> TrainResult<Self> {
134 if features.nrows() != labels.len() {
135 return Err(TrainError::InvalidParameter(format!(
136 "Feature rows ({}) must match label count ({})",
137 features.nrows(),
138 labels.len()
139 )));
140 }
141
142 let num_classes = labels.iter().max().copied().unwrap_or(0) + 1;
143
144 Ok(Self {
145 features,
146 labels,
147 num_classes,
148 })
149 }
150
151 pub fn get_class_examples(&self, class_id: usize) -> Array2<f64> {
153 let indices: Vec<usize> = self
154 .labels
155 .iter()
156 .enumerate()
157 .filter(|(_, &label)| label == class_id)
158 .map(|(idx, _)| idx)
159 .collect();
160
161 if indices.is_empty() {
162 return Array2::zeros((0, self.features.ncols()));
163 }
164
165 let mut result = Array2::zeros((indices.len(), self.features.ncols()));
166 for (i, &idx) in indices.iter().enumerate() {
167 result.row_mut(i).assign(&self.features.row(idx));
168 }
169 result
170 }
171
172 pub fn size(&self) -> usize {
174 self.features.nrows()
175 }
176}
177
178#[derive(Debug, Clone)]
183pub struct PrototypicalDistance {
184 metric: DistanceMetric,
186 prototypes: Option<Array2<f64>>,
188}
189
190impl PrototypicalDistance {
191 pub fn euclidean() -> Self {
193 Self {
194 metric: DistanceMetric::Euclidean,
195 prototypes: None,
196 }
197 }
198
199 pub fn cosine() -> Self {
201 Self {
202 metric: DistanceMetric::Cosine,
203 prototypes: None,
204 }
205 }
206
207 pub fn new(metric: DistanceMetric) -> Self {
209 Self {
210 metric,
211 prototypes: None,
212 }
213 }
214
215 pub fn compute_prototypes(&mut self, support: &SupportSet) {
220 let mut prototypes = Array2::zeros((support.num_classes, support.features.ncols()));
221
222 for class_id in 0..support.num_classes {
223 let class_examples = support.get_class_examples(class_id);
224 if class_examples.nrows() > 0 {
225 let prototype = class_examples.mean_axis(Axis(0)).unwrap();
226 prototypes.row_mut(class_id).assign(&prototype);
227 }
228 }
229
230 self.prototypes = Some(prototypes);
231 }
232
233 pub fn compute_distances(&self, query: &ArrayView1<f64>) -> TrainResult<Array1<f64>> {
241 let prototypes = self
242 .prototypes
243 .as_ref()
244 .ok_or_else(|| TrainError::Other("Prototypes not computed".to_string()))?;
245
246 let mut distances = Array1::zeros(prototypes.nrows());
247 for (i, prototype) in prototypes.axis_iter(Axis(0)).enumerate() {
248 distances[i] = self.metric.compute(query, &prototype);
249 }
250
251 Ok(distances)
252 }
253
254 pub fn predict(&self, query: &ArrayView1<f64>) -> TrainResult<usize> {
262 let distances = self.compute_distances(query)?;
263
264 let mut min_idx = 0;
266 let mut min_dist = distances[0];
267 for (i, &dist) in distances.iter().enumerate() {
268 if dist < min_dist {
269 min_dist = dist;
270 min_idx = i;
271 }
272 }
273
274 Ok(min_idx)
275 }
276
277 pub fn predict_proba(
286 &self,
287 query: &ArrayView1<f64>,
288 temperature: f64,
289 ) -> TrainResult<Array1<f64>> {
290 let distances = self.compute_distances(query)?;
291
292 let logits = distances.mapv(|d| -d / temperature);
294
295 let max_logit = logits.iter().copied().fold(f64::NEG_INFINITY, f64::max);
297 let exp_logits = logits.mapv(|x| (x - max_logit).exp());
298 let sum_exp = exp_logits.sum();
299 let probs = exp_logits.mapv(|x| x / sum_exp);
300
301 Ok(probs)
302 }
303}
304
305#[derive(Debug, Clone)]
309pub struct EpisodeSampler {
310 n_way: usize,
312 shot_type: ShotType,
314 n_query: usize,
316}
317
318impl EpisodeSampler {
319 pub fn new(n_way: usize, shot_type: ShotType, n_query: usize) -> Self {
326 Self {
327 n_way,
328 shot_type,
329 n_query,
330 }
331 }
332
333 pub fn support_size(&self) -> usize {
335 self.n_way * self.shot_type.k()
336 }
337
338 pub fn query_size(&self) -> usize {
340 self.n_way * self.n_query
341 }
342
343 pub fn description(&self) -> String {
345 format!(
346 "{}-way {}-shot (query: {} per class)",
347 self.n_way,
348 self.shot_type.k(),
349 self.n_query
350 )
351 }
352}
353
354#[derive(Debug, Clone)]
358pub struct MatchingNetwork {
359 metric: DistanceMetric,
361 support: Option<SupportSet>,
363}
364
365impl MatchingNetwork {
366 pub fn new(metric: DistanceMetric) -> Self {
368 Self {
369 metric,
370 support: None,
371 }
372 }
373
374 pub fn set_support(&mut self, support: SupportSet) {
376 self.support = Some(support);
377 }
378
379 pub fn compute_attention(&self, query: &ArrayView1<f64>) -> TrainResult<Array1<f64>> {
387 let support = self
388 .support
389 .as_ref()
390 .ok_or_else(|| TrainError::Other("Support set not set".to_string()))?;
391
392 let n_support = support.size();
393 let mut similarities = Array1::zeros(n_support);
394
395 for i in 0..n_support {
397 let support_example = support.features.row(i);
398 similarities[i] = -self.metric.compute(query, &support_example);
399 }
400
401 let max_sim = similarities
403 .iter()
404 .copied()
405 .fold(f64::NEG_INFINITY, f64::max);
406 let exp_sims = similarities.mapv(|x| (x - max_sim).exp());
407 let sum_exp = exp_sims.sum();
408 let weights = exp_sims.mapv(|x| x / sum_exp);
409
410 Ok(weights)
411 }
412
413 pub fn predict_proba(&self, query: &ArrayView1<f64>) -> TrainResult<Array1<f64>> {
421 let support = self
422 .support
423 .as_ref()
424 .ok_or_else(|| TrainError::Other("Support set not set".to_string()))?;
425
426 let attention = self.compute_attention(query)?;
427 let mut class_probs = Array1::zeros(support.num_classes);
428
429 for (i, &weight) in attention.iter().enumerate() {
431 let label = support.labels[i];
432 class_probs[label] += weight;
433 }
434
435 Ok(class_probs)
436 }
437
438 pub fn predict(&self, query: &ArrayView1<f64>) -> TrainResult<usize> {
440 let probs = self.predict_proba(query)?;
441 let mut max_idx = 0;
442 let mut max_prob = probs[0];
443 for (i, &prob) in probs.iter().enumerate() {
444 if prob > max_prob {
445 max_prob = prob;
446 max_idx = i;
447 }
448 }
449 Ok(max_idx)
450 }
451}
452
453#[derive(Debug, Clone, Default)]
455pub struct FewShotAccuracy {
456 correct: usize,
457 total: usize,
458}
459
460impl FewShotAccuracy {
461 pub fn new() -> Self {
463 Self {
464 correct: 0,
465 total: 0,
466 }
467 }
468
469 pub fn update(&mut self, predicted: usize, actual: usize) {
471 self.total += 1;
472 if predicted == actual {
473 self.correct += 1;
474 }
475 }
476
477 pub fn accuracy(&self) -> f64 {
479 if self.total == 0 {
480 0.0
481 } else {
482 self.correct as f64 / self.total as f64
483 }
484 }
485
486 pub fn reset(&mut self) {
488 self.correct = 0;
489 self.total = 0;
490 }
491
492 pub fn counts(&self) -> (usize, usize) {
494 (self.correct, self.total)
495 }
496}
497
498#[cfg(test)]
499mod tests {
500 use super::*;
501 use approx::assert_relative_eq;
502
503 #[test]
504 fn test_shot_type() {
505 assert_eq!(ShotType::OneShot.k(), 1);
506 assert_eq!(ShotType::FewShot(5).k(), 5);
507 assert_eq!(ShotType::Custom(10).k(), 10);
508 }
509
510 #[test]
511 fn test_euclidean_distance() {
512 let a = Array1::from_vec(vec![1.0, 2.0, 3.0]);
513 let b = Array1::from_vec(vec![4.0, 5.0, 6.0]);
514
515 let dist = DistanceMetric::Euclidean.compute(&a.view(), &b.view());
516 assert_relative_eq!(dist, 5.196152, epsilon = 1e-5);
517 }
518
519 #[test]
520 fn test_cosine_distance() {
521 let a = Array1::from_vec(vec![1.0, 0.0]);
522 let b = Array1::from_vec(vec![0.0, 1.0]);
523
524 let dist = DistanceMetric::Cosine.compute(&a.view(), &b.view());
525 assert_relative_eq!(dist, 1.0, epsilon = 1e-5);
526 }
527
528 #[test]
529 fn test_support_set_creation() {
530 let features =
531 Array2::from_shape_vec((4, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]).unwrap();
532 let labels = Array1::from_vec(vec![0, 0, 1, 1]);
533
534 let support = SupportSet::new(features, labels).unwrap();
535 assert_eq!(support.size(), 4);
536 assert_eq!(support.num_classes, 2);
537 }
538
539 #[test]
540 fn test_support_set_get_class_examples() {
541 let features =
542 Array2::from_shape_vec((4, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]).unwrap();
543 let labels = Array1::from_vec(vec![0, 0, 1, 1]);
544
545 let support = SupportSet::new(features, labels).unwrap();
546 let class_0 = support.get_class_examples(0);
547
548 assert_eq!(class_0.nrows(), 2);
549 assert_eq!(class_0[[0, 0]], 1.0);
550 assert_eq!(class_0[[1, 0]], 3.0);
551 }
552
553 #[test]
554 fn test_prototypical_distance() {
555 let features =
556 Array2::from_shape_vec((4, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]).unwrap();
557 let labels = Array1::from_vec(vec![0, 0, 1, 1]);
558 let support = SupportSet::new(features, labels).unwrap();
559
560 let mut proto = PrototypicalDistance::euclidean();
561 proto.compute_prototypes(&support);
562
563 let query = Array1::from_vec(vec![2.0, 3.0]);
564 let prediction = proto.predict(&query.view()).unwrap();
565
566 assert_eq!(prediction, 0); }
568
569 #[test]
570 fn test_prototypical_predict_proba() {
571 let features =
572 Array2::from_shape_vec((4, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]).unwrap();
573 let labels = Array1::from_vec(vec![0, 0, 1, 1]);
574 let support = SupportSet::new(features, labels).unwrap();
575
576 let mut proto = PrototypicalDistance::euclidean();
577 proto.compute_prototypes(&support);
578
579 let query = Array1::from_vec(vec![2.0, 3.0]);
580 let probs = proto.predict_proba(&query.view(), 1.0).unwrap();
581
582 assert_eq!(probs.len(), 2);
583 assert!(probs[0] > probs[1]); assert_relative_eq!(probs.sum(), 1.0, epsilon = 1e-10);
585 }
586
587 #[test]
588 fn test_episode_sampler() {
589 let sampler = EpisodeSampler::new(5, ShotType::OneShot, 15);
590
591 assert_eq!(sampler.support_size(), 5); assert_eq!(sampler.query_size(), 75); assert!(sampler.description().contains("5-way"));
594 }
595
596 #[test]
597 fn test_matching_network() {
598 let features =
599 Array2::from_shape_vec((4, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]).unwrap();
600 let labels = Array1::from_vec(vec![0, 0, 1, 1]);
601 let support = SupportSet::new(features, labels).unwrap();
602
603 let mut matcher = MatchingNetwork::new(DistanceMetric::Euclidean);
604 matcher.set_support(support);
605
606 let query = Array1::from_vec(vec![2.0, 3.0]);
607 let prediction = matcher.predict(&query.view()).unwrap();
608
609 assert_eq!(prediction, 0); }
611
612 #[test]
613 fn test_matching_network_attention() {
614 let features =
615 Array2::from_shape_vec((4, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]).unwrap();
616 let labels = Array1::from_vec(vec![0, 0, 1, 1]);
617 let support = SupportSet::new(features, labels).unwrap();
618
619 let mut matcher = MatchingNetwork::new(DistanceMetric::Euclidean);
620 matcher.set_support(support);
621
622 let query = Array1::from_vec(vec![2.0, 3.0]);
623 let attention = matcher.compute_attention(&query.view()).unwrap();
624
625 assert_eq!(attention.len(), 4);
626 assert_relative_eq!(attention.sum(), 1.0, epsilon = 1e-10);
627 }
628
629 #[test]
630 fn test_few_shot_accuracy() {
631 let mut acc = FewShotAccuracy::new();
632
633 acc.update(0, 0); acc.update(1, 1); acc.update(1, 0); assert_eq!(acc.accuracy(), 2.0 / 3.0);
638 assert_eq!(acc.counts(), (2, 3));
639
640 acc.reset();
641 assert_eq!(acc.accuracy(), 0.0);
642 }
643
644 #[test]
645 fn test_manhattan_distance() {
646 let a = Array1::from_vec(vec![1.0, 2.0, 3.0]);
647 let b = Array1::from_vec(vec![4.0, 5.0, 6.0]);
648
649 let dist = DistanceMetric::Manhattan.compute(&a.view(), &b.view());
650 assert_eq!(dist, 9.0);
651 }
652
653 #[test]
654 fn test_squared_euclidean_distance() {
655 let a = Array1::from_vec(vec![1.0, 2.0]);
656 let b = Array1::from_vec(vec![4.0, 6.0]);
657
658 let dist = DistanceMetric::SquaredEuclidean.compute(&a.view(), &b.view());
659 assert_eq!(dist, 25.0); }
661}