1use crate::{TrainError, TrainResult};
44use scirs2_core::ndarray::{Array1, Array2, ArrayView1, Axis};
45
46#[derive(Debug, Clone, Copy, PartialEq, Eq)]
48pub enum ShotType {
49 OneShot,
51 FewShot(usize),
53 Custom(usize),
55}
56
57impl ShotType {
58 pub fn k(&self) -> usize {
60 match self {
61 ShotType::OneShot => 1,
62 ShotType::FewShot(k) => *k,
63 ShotType::Custom(k) => *k,
64 }
65 }
66}
67
68#[derive(Debug, Clone, Copy, PartialEq, Eq)]
70pub enum DistanceMetric {
71 Euclidean,
73 Cosine,
75 Manhattan,
77 SquaredEuclidean,
79}
80
81impl DistanceMetric {
82 pub fn compute(&self, a: &ArrayView1<f64>, b: &ArrayView1<f64>) -> f64 {
84 match self {
85 DistanceMetric::Euclidean => {
86 let diff = a.to_owned() - b.to_owned();
87 diff.dot(&diff).sqrt()
88 }
89 DistanceMetric::Cosine => {
90 let dot = a.dot(b);
91 let norm_a = a.dot(a).sqrt();
92 let norm_b = b.dot(b).sqrt();
93 if norm_a == 0.0 || norm_b == 0.0 {
94 0.0
95 } else {
96 1.0 - (dot / (norm_a * norm_b))
97 }
98 }
99 DistanceMetric::Manhattan => {
100 let diff = a.to_owned() - b.to_owned();
101 diff.iter().map(|x| x.abs()).sum()
102 }
103 DistanceMetric::SquaredEuclidean => {
104 let diff = a.to_owned() - b.to_owned();
105 diff.dot(&diff)
106 }
107 }
108 }
109}
110
111#[derive(Debug, Clone)]
115pub struct SupportSet {
116 pub features: Array2<f64>,
118 pub labels: Array1<usize>,
120 pub num_classes: usize,
122}
123
124impl SupportSet {
125 pub fn new(features: Array2<f64>, labels: Array1<usize>) -> TrainResult<Self> {
134 if features.nrows() != labels.len() {
135 return Err(TrainError::InvalidParameter(format!(
136 "Feature rows ({}) must match label count ({})",
137 features.nrows(),
138 labels.len()
139 )));
140 }
141
142 let num_classes = labels.iter().max().copied().unwrap_or(0) + 1;
143
144 Ok(Self {
145 features,
146 labels,
147 num_classes,
148 })
149 }
150
151 pub fn get_class_examples(&self, class_id: usize) -> Array2<f64> {
153 let indices: Vec<usize> = self
154 .labels
155 .iter()
156 .enumerate()
157 .filter(|(_, &label)| label == class_id)
158 .map(|(idx, _)| idx)
159 .collect();
160
161 if indices.is_empty() {
162 return Array2::zeros((0, self.features.ncols()));
163 }
164
165 let mut result = Array2::zeros((indices.len(), self.features.ncols()));
166 for (i, &idx) in indices.iter().enumerate() {
167 result.row_mut(i).assign(&self.features.row(idx));
168 }
169 result
170 }
171
172 pub fn size(&self) -> usize {
174 self.features.nrows()
175 }
176}
177
178#[derive(Debug, Clone)]
183pub struct PrototypicalDistance {
184 metric: DistanceMetric,
186 prototypes: Option<Array2<f64>>,
188}
189
190impl PrototypicalDistance {
191 pub fn euclidean() -> Self {
193 Self {
194 metric: DistanceMetric::Euclidean,
195 prototypes: None,
196 }
197 }
198
199 pub fn cosine() -> Self {
201 Self {
202 metric: DistanceMetric::Cosine,
203 prototypes: None,
204 }
205 }
206
207 pub fn new(metric: DistanceMetric) -> Self {
209 Self {
210 metric,
211 prototypes: None,
212 }
213 }
214
215 pub fn compute_prototypes(&mut self, support: &SupportSet) {
220 let mut prototypes = Array2::zeros((support.num_classes, support.features.ncols()));
221
222 for class_id in 0..support.num_classes {
223 let class_examples = support.get_class_examples(class_id);
224 if class_examples.nrows() > 0 {
225 let prototype = class_examples
226 .mean_axis(Axis(0))
227 .expect("mean_axis on non-empty class examples");
228 prototypes.row_mut(class_id).assign(&prototype);
229 }
230 }
231
232 self.prototypes = Some(prototypes);
233 }
234
235 pub fn compute_distances(&self, query: &ArrayView1<f64>) -> TrainResult<Array1<f64>> {
243 let prototypes = self
244 .prototypes
245 .as_ref()
246 .ok_or_else(|| TrainError::Other("Prototypes not computed".to_string()))?;
247
248 let mut distances = Array1::zeros(prototypes.nrows());
249 for (i, prototype) in prototypes.axis_iter(Axis(0)).enumerate() {
250 distances[i] = self.metric.compute(query, &prototype);
251 }
252
253 Ok(distances)
254 }
255
256 pub fn predict(&self, query: &ArrayView1<f64>) -> TrainResult<usize> {
264 let distances = self.compute_distances(query)?;
265
266 let mut min_idx = 0;
268 let mut min_dist = distances[0];
269 for (i, &dist) in distances.iter().enumerate() {
270 if dist < min_dist {
271 min_dist = dist;
272 min_idx = i;
273 }
274 }
275
276 Ok(min_idx)
277 }
278
279 pub fn predict_proba(
288 &self,
289 query: &ArrayView1<f64>,
290 temperature: f64,
291 ) -> TrainResult<Array1<f64>> {
292 let distances = self.compute_distances(query)?;
293
294 let logits = distances.mapv(|d| -d / temperature);
296
297 let max_logit = logits.iter().copied().fold(f64::NEG_INFINITY, f64::max);
299 let exp_logits = logits.mapv(|x| (x - max_logit).exp());
300 let sum_exp = exp_logits.sum();
301 let probs = exp_logits.mapv(|x| x / sum_exp);
302
303 Ok(probs)
304 }
305}
306
307#[derive(Debug, Clone)]
311pub struct EpisodeSampler {
312 n_way: usize,
314 shot_type: ShotType,
316 n_query: usize,
318}
319
320impl EpisodeSampler {
321 pub fn new(n_way: usize, shot_type: ShotType, n_query: usize) -> Self {
328 Self {
329 n_way,
330 shot_type,
331 n_query,
332 }
333 }
334
335 pub fn support_size(&self) -> usize {
337 self.n_way * self.shot_type.k()
338 }
339
340 pub fn query_size(&self) -> usize {
342 self.n_way * self.n_query
343 }
344
345 pub fn description(&self) -> String {
347 format!(
348 "{}-way {}-shot (query: {} per class)",
349 self.n_way,
350 self.shot_type.k(),
351 self.n_query
352 )
353 }
354}
355
356#[derive(Debug, Clone)]
360pub struct MatchingNetwork {
361 metric: DistanceMetric,
363 support: Option<SupportSet>,
365}
366
367impl MatchingNetwork {
368 pub fn new(metric: DistanceMetric) -> Self {
370 Self {
371 metric,
372 support: None,
373 }
374 }
375
376 pub fn set_support(&mut self, support: SupportSet) {
378 self.support = Some(support);
379 }
380
381 pub fn compute_attention(&self, query: &ArrayView1<f64>) -> TrainResult<Array1<f64>> {
389 let support = self
390 .support
391 .as_ref()
392 .ok_or_else(|| TrainError::Other("Support set not set".to_string()))?;
393
394 let n_support = support.size();
395 let mut similarities = Array1::zeros(n_support);
396
397 for i in 0..n_support {
399 let support_example = support.features.row(i);
400 similarities[i] = -self.metric.compute(query, &support_example);
401 }
402
403 let max_sim = similarities
405 .iter()
406 .copied()
407 .fold(f64::NEG_INFINITY, f64::max);
408 let exp_sims = similarities.mapv(|x| (x - max_sim).exp());
409 let sum_exp = exp_sims.sum();
410 let weights = exp_sims.mapv(|x| x / sum_exp);
411
412 Ok(weights)
413 }
414
415 pub fn predict_proba(&self, query: &ArrayView1<f64>) -> TrainResult<Array1<f64>> {
423 let support = self
424 .support
425 .as_ref()
426 .ok_or_else(|| TrainError::Other("Support set not set".to_string()))?;
427
428 let attention = self.compute_attention(query)?;
429 let mut class_probs = Array1::zeros(support.num_classes);
430
431 for (i, &weight) in attention.iter().enumerate() {
433 let label = support.labels[i];
434 class_probs[label] += weight;
435 }
436
437 Ok(class_probs)
438 }
439
440 pub fn predict(&self, query: &ArrayView1<f64>) -> TrainResult<usize> {
442 let probs = self.predict_proba(query)?;
443 let mut max_idx = 0;
444 let mut max_prob = probs[0];
445 for (i, &prob) in probs.iter().enumerate() {
446 if prob > max_prob {
447 max_prob = prob;
448 max_idx = i;
449 }
450 }
451 Ok(max_idx)
452 }
453}
454
455#[derive(Debug, Clone, Default)]
457pub struct FewShotAccuracy {
458 correct: usize,
459 total: usize,
460}
461
462impl FewShotAccuracy {
463 pub fn new() -> Self {
465 Self {
466 correct: 0,
467 total: 0,
468 }
469 }
470
471 pub fn update(&mut self, predicted: usize, actual: usize) {
473 self.total += 1;
474 if predicted == actual {
475 self.correct += 1;
476 }
477 }
478
479 pub fn accuracy(&self) -> f64 {
481 if self.total == 0 {
482 0.0
483 } else {
484 self.correct as f64 / self.total as f64
485 }
486 }
487
488 pub fn reset(&mut self) {
490 self.correct = 0;
491 self.total = 0;
492 }
493
494 pub fn counts(&self) -> (usize, usize) {
496 (self.correct, self.total)
497 }
498}
499
500#[cfg(test)]
501mod tests {
502 use super::*;
503 use approx::assert_relative_eq;
504
505 #[test]
506 fn test_shot_type() {
507 assert_eq!(ShotType::OneShot.k(), 1);
508 assert_eq!(ShotType::FewShot(5).k(), 5);
509 assert_eq!(ShotType::Custom(10).k(), 10);
510 }
511
512 #[test]
513 fn test_euclidean_distance() {
514 let a = Array1::from_vec(vec![1.0, 2.0, 3.0]);
515 let b = Array1::from_vec(vec![4.0, 5.0, 6.0]);
516
517 let dist = DistanceMetric::Euclidean.compute(&a.view(), &b.view());
518 assert_relative_eq!(dist, 5.196152, epsilon = 1e-5);
519 }
520
521 #[test]
522 fn test_cosine_distance() {
523 let a = Array1::from_vec(vec![1.0, 0.0]);
524 let b = Array1::from_vec(vec![0.0, 1.0]);
525
526 let dist = DistanceMetric::Cosine.compute(&a.view(), &b.view());
527 assert_relative_eq!(dist, 1.0, epsilon = 1e-5);
528 }
529
530 #[test]
531 fn test_support_set_creation() {
532 let features = Array2::from_shape_vec((4, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0])
533 .expect("unwrap");
534 let labels = Array1::from_vec(vec![0, 0, 1, 1]);
535
536 let support = SupportSet::new(features, labels).expect("unwrap");
537 assert_eq!(support.size(), 4);
538 assert_eq!(support.num_classes, 2);
539 }
540
541 #[test]
542 fn test_support_set_get_class_examples() {
543 let features = Array2::from_shape_vec((4, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0])
544 .expect("unwrap");
545 let labels = Array1::from_vec(vec![0, 0, 1, 1]);
546
547 let support = SupportSet::new(features, labels).expect("unwrap");
548 let class_0 = support.get_class_examples(0);
549
550 assert_eq!(class_0.nrows(), 2);
551 assert_eq!(class_0[[0, 0]], 1.0);
552 assert_eq!(class_0[[1, 0]], 3.0);
553 }
554
555 #[test]
556 fn test_prototypical_distance() {
557 let features = Array2::from_shape_vec((4, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0])
558 .expect("unwrap");
559 let labels = Array1::from_vec(vec![0, 0, 1, 1]);
560 let support = SupportSet::new(features, labels).expect("unwrap");
561
562 let mut proto = PrototypicalDistance::euclidean();
563 proto.compute_prototypes(&support);
564
565 let query = Array1::from_vec(vec![2.0, 3.0]);
566 let prediction = proto.predict(&query.view()).expect("unwrap");
567
568 assert_eq!(prediction, 0); }
570
571 #[test]
572 fn test_prototypical_predict_proba() {
573 let features = Array2::from_shape_vec((4, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0])
574 .expect("unwrap");
575 let labels = Array1::from_vec(vec![0, 0, 1, 1]);
576 let support = SupportSet::new(features, labels).expect("unwrap");
577
578 let mut proto = PrototypicalDistance::euclidean();
579 proto.compute_prototypes(&support);
580
581 let query = Array1::from_vec(vec![2.0, 3.0]);
582 let probs = proto.predict_proba(&query.view(), 1.0).expect("unwrap");
583
584 assert_eq!(probs.len(), 2);
585 assert!(probs[0] > probs[1]); assert_relative_eq!(probs.sum(), 1.0, epsilon = 1e-10);
587 }
588
589 #[test]
590 fn test_episode_sampler() {
591 let sampler = EpisodeSampler::new(5, ShotType::OneShot, 15);
592
593 assert_eq!(sampler.support_size(), 5); assert_eq!(sampler.query_size(), 75); assert!(sampler.description().contains("5-way"));
596 }
597
598 #[test]
599 fn test_matching_network() {
600 let features = Array2::from_shape_vec((4, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0])
601 .expect("unwrap");
602 let labels = Array1::from_vec(vec![0, 0, 1, 1]);
603 let support = SupportSet::new(features, labels).expect("unwrap");
604
605 let mut matcher = MatchingNetwork::new(DistanceMetric::Euclidean);
606 matcher.set_support(support);
607
608 let query = Array1::from_vec(vec![2.0, 3.0]);
609 let prediction = matcher.predict(&query.view()).expect("unwrap");
610
611 assert_eq!(prediction, 0); }
613
614 #[test]
615 fn test_matching_network_attention() {
616 let features = Array2::from_shape_vec((4, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0])
617 .expect("unwrap");
618 let labels = Array1::from_vec(vec![0, 0, 1, 1]);
619 let support = SupportSet::new(features, labels).expect("unwrap");
620
621 let mut matcher = MatchingNetwork::new(DistanceMetric::Euclidean);
622 matcher.set_support(support);
623
624 let query = Array1::from_vec(vec![2.0, 3.0]);
625 let attention = matcher.compute_attention(&query.view()).expect("unwrap");
626
627 assert_eq!(attention.len(), 4);
628 assert_relative_eq!(attention.sum(), 1.0, epsilon = 1e-10);
629 }
630
631 #[test]
632 fn test_few_shot_accuracy() {
633 let mut acc = FewShotAccuracy::new();
634
635 acc.update(0, 0); acc.update(1, 1); acc.update(1, 0); assert_eq!(acc.accuracy(), 2.0 / 3.0);
640 assert_eq!(acc.counts(), (2, 3));
641
642 acc.reset();
643 assert_eq!(acc.accuracy(), 0.0);
644 }
645
646 #[test]
647 fn test_manhattan_distance() {
648 let a = Array1::from_vec(vec![1.0, 2.0, 3.0]);
649 let b = Array1::from_vec(vec![4.0, 5.0, 6.0]);
650
651 let dist = DistanceMetric::Manhattan.compute(&a.view(), &b.view());
652 assert_eq!(dist, 9.0);
653 }
654
655 #[test]
656 fn test_squared_euclidean_distance() {
657 let a = Array1::from_vec(vec![1.0, 2.0]);
658 let b = Array1::from_vec(vec![4.0, 6.0]);
659
660 let dist = DistanceMetric::SquaredEuclidean.compute(&a.view(), &b.view());
661 assert_eq!(dist, 25.0); }
663}