1use scirs2_core::ndarray::{Array1, Array2};
8use scirs2_core::numeric::{Float, FromPrimitive};
9use std::fmt::Debug;
10
11use super::config::TaskData;
12use crate::error::Result;
13
14#[derive(Debug)]
16pub struct PrototypicalNetworks<F: Float + Debug + scirs2_core::ndarray::ScalarOperand> {
17 feature_extractor: Array2<F>,
19 input_dim: usize,
21 feature_dim: usize,
23 hidden_dims: Vec<usize>,
25}
26
27impl<F: Float + Debug + Clone + FromPrimitive + scirs2_core::ndarray::ScalarOperand>
28 PrototypicalNetworks<F>
29{
30 pub fn new(input_dim: usize, feature_dim: usize, hidden_dims: Vec<usize>) -> Self {
32 let mut total_params = 0;
34 let mut layer_sizes = vec![input_dim];
35 layer_sizes.extend(&hidden_dims);
36 layer_sizes.push(feature_dim);
37
38 for i in 0..layer_sizes.len() - 1 {
39 total_params += layer_sizes[i] * layer_sizes[i + 1] + layer_sizes[i + 1];
40 }
42
43 let scale = F::from(2.0).unwrap() / F::from(input_dim + feature_dim).unwrap();
45 let std_dev = scale.sqrt();
46
47 let mut feature_extractor = Array2::zeros((1, total_params));
48 for i in 0..total_params {
49 let val = ((i * 43) % 1000) as f64 / 1000.0 - 0.5;
50 feature_extractor[[0, i]] = F::from(val).unwrap() * std_dev;
51 }
52
53 Self {
54 feature_extractor,
55 input_dim,
56 feature_dim,
57 hidden_dims,
58 }
59 }
60
61 pub fn extract_features(&self, input: &Array2<F>) -> Result<Array2<F>> {
63 let batch_size_ = input.nrows();
64 let mut current_input = input.clone();
65
66 let layer_params = self.extract_layer_parameters();
68
69 for (weights, biases) in layer_params {
71 let mut layer_output = Array2::zeros((batch_size_, biases.len()));
72
73 for i in 0..batch_size_ {
75 for j in 0..biases.len() {
76 let mut sum = biases[j];
77 for k in 0..current_input.ncols() {
78 if k < weights.ncols() {
79 sum = sum + current_input[[i, k]] * weights[[j, k]];
80 }
81 }
82 layer_output[[i, j]] = self.relu(sum);
83 }
84 }
85
86 current_input = layer_output;
87 }
88
89 Ok(current_input)
90 }
91
92 pub fn compute_prototypes(
94 &self,
95 support_features: &Array2<F>,
96 support_labels: &Array1<usize>,
97 ) -> Result<Array2<F>> {
98 let mut unique_classes = Vec::new();
100 for &label in support_labels {
101 if !unique_classes.contains(&label) {
102 unique_classes.push(label);
103 }
104 }
105 unique_classes.sort();
106
107 let num_classes = unique_classes.len();
108 let mut prototypes = Array2::zeros((num_classes, self.feature_dim));
109
110 for (class_idx, &class_label) in unique_classes.iter().enumerate() {
112 let mut class_features = Vec::new();
113 for (i, &label) in support_labels.iter().enumerate() {
114 if label == class_label {
115 class_features.push(support_features.row(i).to_owned());
116 }
117 }
118
119 if !class_features.is_empty() {
120 for j in 0..self.feature_dim {
122 let mut sum = F::zero();
123 for features in &class_features {
124 sum = sum + features[j];
125 }
126 prototypes[[class_idx, j]] = sum / F::from(class_features.len()).unwrap();
127 }
128 }
129 }
130
131 Ok(prototypes)
132 }
133
134 pub fn classify_queries(
136 &self,
137 query_features: &Array2<F>,
138 prototypes: &Array2<F>,
139 ) -> Result<Array1<usize>> {
140 let num_queries = query_features.nrows();
141 let num_classes = prototypes.nrows();
142 let mut predictions = Array1::zeros(num_queries);
143
144 for i in 0..num_queries {
145 let mut min_distance = F::infinity();
146 let mut predicted_class = 0;
147
148 for j in 0..num_classes {
150 let distance = self.euclidean_distance(
151 &query_features.row(i).to_owned(),
152 &prototypes.row(j).to_owned(),
153 )?;
154
155 if distance < min_distance {
156 min_distance = distance;
157 predicted_class = j;
158 }
159 }
160
161 predictions[i] = predicted_class;
162 }
163
164 Ok(predictions)
165 }
166
167 pub fn few_shot_episode(
169 &self,
170 support_x: &Array2<F>,
171 support_y: &Array1<usize>,
172 query_x: &Array2<F>,
173 ) -> Result<Array1<usize>> {
174 let support_features = self.extract_features(support_x)?;
176 let query_features = self.extract_features(query_x)?;
177
178 let prototypes = self.compute_prototypes(&support_features, support_y)?;
180
181 self.classify_queries(&query_features, &prototypes)
183 }
184
185 pub fn meta_train(&mut self, episodes: &[FewShotEpisode<F>]) -> Result<F> {
187 let mut total_loss = F::zero();
188 let mut total_gradients = Array2::zeros(self.feature_extractor.dim());
189
190 for episode in episodes {
191 let predictions =
193 self.few_shot_episode(&episode.support_x, &episode.support_y, &episode.query_x)?;
194
195 let mut episode_loss = F::zero();
197 for (i, &pred) in predictions.iter().enumerate() {
198 if i < episode.query_y.len() {
199 let target = episode.query_y[i];
200 if pred != target {
201 episode_loss = episode_loss + F::one();
202 }
203 }
204 }
205 episode_loss = episode_loss / F::from(predictions.len()).unwrap();
206
207 let gradients = self.compute_gradients(episode)?;
209 total_gradients = total_gradients + gradients;
210 total_loss = total_loss + episode_loss;
211 }
212
213 let learning_rate = F::from(0.001).unwrap();
215 let num_episodes = F::from(episodes.len()).unwrap();
216 total_gradients = total_gradients / num_episodes;
217
218 self.feature_extractor = self.feature_extractor.clone() - total_gradients * learning_rate;
219
220 Ok(total_loss / num_episodes)
221 }
222
223 fn extract_layer_parameters(&self) -> Vec<(Array2<F>, Array1<F>)> {
225 let param_vec = self.feature_extractor.row(0);
226 let mut layer_params = Vec::new();
227 let mut param_idx = 0;
228
229 let mut layer_sizes = vec![self.input_dim];
230 layer_sizes.extend(&self.hidden_dims);
231 layer_sizes.push(self.feature_dim);
232
233 for i in 0..layer_sizes.len() - 1 {
234 let input_size = layer_sizes[i];
235 let output_size = layer_sizes[i + 1];
236
237 let mut weights = Array2::zeros((output_size, input_size));
239 for j in 0..output_size {
240 for k in 0..input_size {
241 if param_idx < param_vec.len() {
242 weights[[j, k]] = param_vec[param_idx];
243 param_idx += 1;
244 }
245 }
246 }
247
248 let mut biases = Array1::zeros(output_size);
250 for j in 0..output_size {
251 if param_idx < param_vec.len() {
252 biases[j] = param_vec[param_idx];
253 param_idx += 1;
254 }
255 }
256
257 layer_params.push((weights, biases));
258 }
259
260 layer_params
261 }
262
263 fn euclidean_distance(&self, a: &Array1<F>, b: &Array1<F>) -> Result<F> {
264 let mut sum = F::zero();
265 for i in 0..a.len().min(b.len()) {
266 let diff = a[i] - b[i];
267 sum = sum + diff * diff;
268 }
269 Ok(sum.sqrt())
270 }
271
272 fn relu(&self, x: F) -> F {
273 x.max(F::zero())
274 }
275
276 fn compute_gradients(&self, episode: &FewShotEpisode<F>) -> Result<Array2<F>> {
277 let epsilon = F::from(1e-5).unwrap();
279 let mut gradients = Array2::zeros(self.feature_extractor.dim());
280
281 let base_predictions =
282 self.few_shot_episode(&episode.support_x, &episode.support_y, &episode.query_x)?;
283 let mut base_loss = F::zero();
284 for (i, &pred) in base_predictions.iter().enumerate() {
285 if i < episode.query_y.len() && pred != episode.query_y[i] {
286 base_loss = base_loss + F::one();
287 }
288 }
289
290 for i in 0..self.feature_extractor.ncols() {
292 let mut perturbed_extractor = self.feature_extractor.clone();
293 perturbed_extractor[[0, i]] = perturbed_extractor[[0, i]] + epsilon;
294
295 let mut temp_network = self.clone();
297 temp_network.feature_extractor = perturbed_extractor;
298
299 let perturbed_predictions = temp_network.few_shot_episode(
300 &episode.support_x,
301 &episode.support_y,
302 &episode.query_x,
303 )?;
304 let mut perturbed_loss = F::zero();
305 for (j, &pred) in perturbed_predictions.iter().enumerate() {
306 if j < episode.query_y.len() && pred != episode.query_y[j] {
307 perturbed_loss = perturbed_loss + F::one();
308 }
309 }
310
311 gradients[[0, i]] = (perturbed_loss - base_loss) / epsilon;
312 }
313
314 Ok(gradients)
315 }
316}
317
318impl<F: Float + Debug + Clone + FromPrimitive + scirs2_core::ndarray::ScalarOperand> Clone
319 for PrototypicalNetworks<F>
320{
321 fn clone(&self) -> Self {
322 Self {
323 feature_extractor: self.feature_extractor.clone(),
324 input_dim: self.input_dim,
325 feature_dim: self.feature_dim,
326 hidden_dims: self.hidden_dims.clone(),
327 }
328 }
329}
330
331#[derive(Debug, Clone)]
333pub struct FewShotEpisode<F: Float + Debug> {
334 pub support_x: Array2<F>,
336 pub support_y: Array1<usize>,
338 pub query_x: Array2<F>,
340 pub query_y: Array1<usize>,
342}
343
344impl<F: Float + Debug> FewShotEpisode<F> {
345 pub fn new(
347 support_x: Array2<F>,
348 support_y: Array1<usize>,
349 query_x: Array2<F>,
350 query_y: Array1<usize>,
351 ) -> Self {
352 Self {
353 support_x,
354 support_y,
355 query_x,
356 query_y,
357 }
358 }
359
360 pub fn support_size(&self) -> usize {
362 self.support_x.nrows()
363 }
364
365 pub fn query_size(&self) -> usize {
367 self.query_x.nrows()
368 }
369
370 pub fn unique_classes(&self) -> Vec<usize> {
372 let mut classes = Vec::new();
373 for &label in &self.support_y {
374 if !classes.contains(&label) {
375 classes.push(label);
376 }
377 }
378 for &label in &self.query_y {
379 if !classes.contains(&label) {
380 classes.push(label);
381 }
382 }
383 classes.sort();
384 classes
385 }
386}
387
388#[derive(Debug)]
390pub struct REPTILE<F: Float + Debug + scirs2_core::ndarray::ScalarOperand> {
391 parameters: Array2<F>,
393 meta_lr: F,
395 inner_lr: F,
397 inner_steps: usize,
399 input_dim: usize,
401 hidden_dim: usize,
402 output_dim: usize,
403}
404
405impl<F: Float + Debug + Clone + FromPrimitive + scirs2_core::ndarray::ScalarOperand> REPTILE<F> {
406 pub fn new(
408 input_dim: usize,
409 hidden_dim: usize,
410 output_dim: usize,
411 meta_lr: F,
412 inner_lr: F,
413 inner_steps: usize,
414 ) -> Self {
415 let total_params =
417 input_dim * hidden_dim + hidden_dim + hidden_dim * output_dim + output_dim;
418 let scale = F::from(2.0).unwrap() / F::from(input_dim + output_dim).unwrap();
419 let std_dev = scale.sqrt();
420
421 let mut parameters = Array2::zeros((1, total_params));
422 for i in 0..total_params {
423 let val = ((i * 59) % 1000) as f64 / 1000.0 - 0.5;
424 parameters[[0, i]] = F::from(val).unwrap() * std_dev;
425 }
426
427 Self {
428 parameters,
429 meta_lr,
430 inner_lr,
431 inner_steps,
432 input_dim,
433 hidden_dim,
434 output_dim,
435 }
436 }
437
438 pub fn meta_train(&mut self, tasks: &[TaskData<F>]) -> Result<F> {
440 let mut total_loss = F::zero();
441 let mut parameter_updates = Array2::zeros(self.parameters.dim());
442
443 for task in tasks {
444 let initial_params = self.parameters.clone();
446
447 let mut task_params = initial_params.clone();
449 for _ in 0..self.inner_steps {
450 let gradients = self.compute_task_gradients(&task_params, task)?;
451 task_params = task_params - gradients * self.inner_lr;
452 }
453
454 let task_loss = self.forward(&task_params, &task.support_x, &task.support_y)?;
456 total_loss = total_loss + task_loss;
457
458 let update = task_params - initial_params;
460 parameter_updates = parameter_updates + update;
461 }
462
463 let num_tasks = F::from(tasks.len()).unwrap();
465 parameter_updates = parameter_updates / num_tasks;
466 total_loss = total_loss / num_tasks;
467
468 self.parameters = self.parameters.clone() + parameter_updates * self.meta_lr;
470
471 Ok(total_loss)
472 }
473
474 pub fn fast_adapt(&self, support_x: &Array2<F>, support_y: &Array2<F>) -> Result<Array2<F>> {
476 let task = TaskData {
477 support_x: support_x.clone(),
478 support_y: support_y.clone(),
479 query_x: support_x.clone(),
480 query_y: support_y.clone(),
481 };
482
483 let mut adapted_params = self.parameters.clone();
485 for _ in 0..self.inner_steps {
486 let gradients = self.compute_task_gradients(&adapted_params, &task)?;
487 adapted_params = adapted_params - gradients * self.inner_lr;
488 }
489
490 Ok(adapted_params)
491 }
492
493 fn forward(&self, params: &Array2<F>, inputs: &Array2<F>, targets: &Array2<F>) -> Result<F> {
495 let predictions = self.predict(params, inputs)?;
496
497 let mut loss = F::zero();
499 let (batch_size, _) = predictions.dim();
500
501 for i in 0..batch_size {
502 for j in 0..self.output_dim {
503 let diff = predictions[[i, j]] - targets[[i, j]];
504 loss = loss + diff * diff;
505 }
506 }
507
508 Ok(loss / F::from(batch_size).unwrap())
509 }
510
511 pub fn predict(&self, params: &Array2<F>, inputs: &Array2<F>) -> Result<Array2<F>> {
513 let (batch_size, _) = inputs.dim();
514
515 let (w1, b1, w2, b2) = self.extract_weights(params);
517
518 let mut hidden = Array2::zeros((batch_size, self.hidden_dim));
520
521 for i in 0..batch_size {
523 for j in 0..self.hidden_dim {
524 let mut sum = b1[j];
525 for k in 0..self.input_dim {
526 sum = sum + inputs[[i, k]] * w1[[j, k]];
527 }
528 hidden[[i, j]] = self.relu(sum); }
530 }
531
532 let mut output = Array2::zeros((batch_size, self.output_dim));
534 for i in 0..batch_size {
535 for j in 0..self.output_dim {
536 let mut sum = b2[j];
537 for k in 0..self.hidden_dim {
538 sum = sum + hidden[[i, k]] * w2[[j, k]];
539 }
540 output[[i, j]] = sum; }
542 }
543
544 Ok(output)
545 }
546
547 fn extract_weights(&self, params: &Array2<F>) -> (Array2<F>, Array1<F>, Array2<F>, Array1<F>) {
549 let param_vec = params.row(0);
550 let mut idx = 0;
551
552 let mut w1 = Array2::zeros((self.hidden_dim, self.input_dim));
554 for i in 0..self.hidden_dim {
555 for j in 0..self.input_dim {
556 w1[[i, j]] = param_vec[idx];
557 idx += 1;
558 }
559 }
560
561 let mut b1 = Array1::zeros(self.hidden_dim);
563 for i in 0..self.hidden_dim {
564 b1[i] = param_vec[idx];
565 idx += 1;
566 }
567
568 let mut w2 = Array2::zeros((self.output_dim, self.hidden_dim));
570 for i in 0..self.output_dim {
571 for j in 0..self.hidden_dim {
572 w2[[i, j]] = param_vec[idx];
573 idx += 1;
574 }
575 }
576
577 let mut b2 = Array1::zeros(self.output_dim);
579 for i in 0..self.output_dim {
580 b2[i] = param_vec[idx];
581 idx += 1;
582 }
583
584 (w1, b1, w2, b2)
585 }
586
587 fn relu(&self, x: F) -> F {
589 x.max(F::zero())
590 }
591
592 fn compute_task_gradients(&self, params: &Array2<F>, task: &TaskData<F>) -> Result<Array2<F>> {
594 let epsilon = F::from(1e-5).unwrap();
595 let mut gradients = Array2::zeros(params.dim());
596
597 let base_loss = self.forward(params, &task.support_x, &task.support_y)?;
598
599 for i in 0..params.ncols() {
600 let mut perturbed_params = params.clone();
601 perturbed_params[[0, i]] = perturbed_params[[0, i]] + epsilon;
602
603 let perturbed_loss =
604 self.forward(&perturbed_params, &task.support_x, &task.support_y)?;
605 gradients[[0, i]] = (perturbed_loss - base_loss) / epsilon;
606 }
607
608 Ok(gradients)
609 }
610
611 pub fn parameters(&self) -> &Array2<F> {
613 &self.parameters
614 }
615
616 pub fn set_parameters(&mut self, parameters: Array2<F>) {
618 self.parameters = parameters;
619 }
620}
621
622#[cfg(test)]
623mod tests {
624 use super::*;
625 use approx::assert_abs_diff_eq;
626
627 #[test]
628 fn test_prototypical_networks_creation() {
629 let hidden_dims = vec![16, 32];
630 let model = PrototypicalNetworks::<f64>::new(10, 8, hidden_dims.clone());
631
632 assert_eq!(model.input_dim, 10);
633 assert_eq!(model.feature_dim, 8);
634 assert_eq!(model.hidden_dims, hidden_dims);
635 }
636
637 #[test]
638 fn test_few_shot_episode() {
639 let support_x =
640 Array2::from_shape_vec((4, 3), (0..12).map(|i| i as f64).collect()).unwrap();
641 let support_y = Array1::from_vec(vec![0, 0, 1, 1]);
642 let query_x = Array2::from_shape_vec((2, 3), (12..18).map(|i| i as f64).collect()).unwrap();
643 let query_y = Array1::from_vec(vec![0, 1]);
644
645 let episode = FewShotEpisode::new(support_x, support_y, query_x, query_y);
646
647 assert_eq!(episode.support_size(), 4);
648 assert_eq!(episode.query_size(), 2);
649
650 let classes = episode.unique_classes();
651 assert_eq!(classes, vec![0, 1]);
652 }
653
654 #[test]
655 fn test_prototypical_networks_features() {
656 let model = PrototypicalNetworks::<f64>::new(5, 4, vec![8]);
657 let input =
658 Array2::from_shape_vec((3, 5), (0..15).map(|i| i as f64 * 0.1).collect()).unwrap();
659
660 let features = model.extract_features(&input).unwrap();
661 assert_eq!(features.dim(), (3, 4));
662
663 for &val in features.iter() {
665 assert!(val.is_finite());
666 }
667 }
668
669 #[test]
670 fn test_prototypical_networks_classification() {
671 let model = PrototypicalNetworks::<f64>::new(4, 6, vec![8]);
672
673 let support_x =
674 Array2::from_shape_vec((6, 4), (0..24).map(|i| i as f64 * 0.1).collect()).unwrap();
675 let support_y = Array1::from_vec(vec![0, 0, 0, 1, 1, 1]);
676 let query_x =
677 Array2::from_shape_vec((2, 4), (24..32).map(|i| i as f64 * 0.1).collect()).unwrap();
678
679 let predictions = model
680 .few_shot_episode(&support_x, &support_y, &query_x)
681 .unwrap();
682 assert_eq!(predictions.len(), 2);
683
684 for &pred in predictions.iter() {
686 assert!(pred <= 1);
687 }
688 }
689
690 #[test]
691 fn test_reptile_creation() {
692 let reptile = REPTILE::<f64>::new(5, 10, 3, 0.01, 0.1, 5);
693
694 assert_eq!(reptile.input_dim, 5);
695 assert_eq!(reptile.hidden_dim, 10);
696 assert_eq!(reptile.output_dim, 3);
697 }
698
699 #[test]
700 fn test_reptile_prediction() {
701 let reptile = REPTILE::<f64>::new(4, 8, 2, 0.01, 0.1, 3);
702 let input =
703 Array2::from_shape_vec((3, 4), (0..12).map(|i| i as f64 * 0.1).collect()).unwrap();
704
705 let output = reptile.predict(&reptile.parameters, &input).unwrap();
706 assert_eq!(output.dim(), (3, 2));
707
708 for &val in output.iter() {
710 assert!(val.is_finite());
711 }
712 }
713
714 #[test]
715 fn test_reptile_fast_adapt() {
716 let reptile = REPTILE::<f64>::new(3, 6, 2, 0.01, 0.1, 2);
717 let support_x =
718 Array2::from_shape_vec((4, 3), (0..12).map(|i| i as f64 * 0.2).collect()).unwrap();
719 let support_y =
720 Array2::from_shape_vec((4, 2), (0..8).map(|i| i as f64 * 0.1).collect()).unwrap();
721
722 let adapted_params = reptile.fast_adapt(&support_x, &support_y).unwrap();
723 assert_eq!(adapted_params.dim(), reptile.parameters.dim());
724
725 let params_changed = adapted_params
727 .iter()
728 .zip(reptile.parameters.iter())
729 .any(|(a, b)| (a - b).abs() > 1e-10);
730 assert!(params_changed);
731 }
732
733 #[test]
734 fn test_euclidean_distance() {
735 let model = PrototypicalNetworks::<f64>::new(3, 4, vec![]);
736 let a = Array1::from_vec(vec![1.0, 2.0, 3.0]);
737 let b = Array1::from_vec(vec![4.0, 5.0, 6.0]);
738
739 let distance = model.euclidean_distance(&a, &b).unwrap();
740 let expected = ((3.0_f64).powi(2) + (3.0_f64).powi(2) + (3.0_f64).powi(2)).sqrt();
741 assert_abs_diff_eq!(distance, expected, epsilon = 1e-10);
742 }
743
744 #[test]
745 fn test_compute_prototypes() {
746 let model = PrototypicalNetworks::<f64>::new(4, 3, vec![]);
747
748 let features = Array2::from_shape_vec(
750 (6, 3),
751 vec![
752 1.0, 1.0, 1.0, 2.0, 2.0, 2.0, 3.0, 3.0, 3.0, 4.0, 4.0, 4.0, 5.0, 5.0, 5.0, 6.0, 6.0, 6.0, ],
759 )
760 .unwrap();
761 let labels = Array1::from_vec(vec![0, 0, 1, 1, 1, 2]);
762
763 let prototypes = model.compute_prototypes(&features, &labels).unwrap();
764 assert_eq!(prototypes.dim(), (3, 3)); assert_abs_diff_eq!(prototypes[[0, 0]], 1.5, epsilon = 1e-10);
768 assert_abs_diff_eq!(prototypes[[0, 1]], 1.5, epsilon = 1e-10);
769 assert_abs_diff_eq!(prototypes[[0, 2]], 1.5, epsilon = 1e-10);
770
771 assert_abs_diff_eq!(prototypes[[1, 0]], 4.0, epsilon = 1e-10);
773 assert_abs_diff_eq!(prototypes[[1, 1]], 4.0, epsilon = 1e-10);
774 assert_abs_diff_eq!(prototypes[[1, 2]], 4.0, epsilon = 1e-10);
775
776 assert_abs_diff_eq!(prototypes[[2, 0]], 6.0, epsilon = 1e-10);
778 assert_abs_diff_eq!(prototypes[[2, 1]], 6.0, epsilon = 1e-10);
779 assert_abs_diff_eq!(prototypes[[2, 2]], 6.0, epsilon = 1e-10);
780 }
781}