scirs2_series/advanced_training_modules/
few_shot.rs

1//! Few-Shot Learning Algorithms
2//!
3//! This module implements advanced few-shot learning techniques including
4//! Prototypical Networks and REPTILE for rapid adaptation to new tasks
5//! with minimal training data.
6
7use scirs2_core::ndarray::{Array1, Array2};
8use scirs2_core::numeric::{Float, FromPrimitive};
9use std::fmt::Debug;
10
11use super::config::TaskData;
12use crate::error::Result;
13
14/// Prototypical Networks for Few-Shot Learning
15#[derive(Debug)]
16pub struct PrototypicalNetworks<F: Float + Debug + scirs2_core::ndarray::ScalarOperand> {
17    /// Feature extraction network parameters
18    feature_extractor: Array2<F>,
19    /// Input dimension
20    input_dim: usize,
21    /// Feature dimension
22    feature_dim: usize,
23    /// Hidden dimensions for feature extractor
24    hidden_dims: Vec<usize>,
25}
26
27impl<F: Float + Debug + Clone + FromPrimitive + scirs2_core::ndarray::ScalarOperand>
28    PrototypicalNetworks<F>
29{
30    /// Create new Prototypical Networks model
31    pub fn new(input_dim: usize, feature_dim: usize, hidden_dims: Vec<usize>) -> Self {
32        // Calculate total parameters for feature extractor
33        let mut total_params = 0;
34        let mut layer_sizes = vec![input_dim];
35        layer_sizes.extend(&hidden_dims);
36        layer_sizes.push(feature_dim);
37
38        for i in 0..layer_sizes.len() - 1 {
39            total_params += layer_sizes[i] * layer_sizes[i + 1] + layer_sizes[i + 1];
40            // weights + biases
41        }
42
43        // Initialize feature extractor parameters
44        let scale = F::from(2.0).unwrap() / F::from(input_dim + feature_dim).unwrap();
45        let std_dev = scale.sqrt();
46
47        let mut feature_extractor = Array2::zeros((1, total_params));
48        for i in 0..total_params {
49            let val = ((i * 43) % 1000) as f64 / 1000.0 - 0.5;
50            feature_extractor[[0, i]] = F::from(val).unwrap() * std_dev;
51        }
52
53        Self {
54            feature_extractor,
55            input_dim,
56            feature_dim,
57            hidden_dims,
58        }
59    }
60
61    /// Extract features from input data
62    pub fn extract_features(&self, input: &Array2<F>) -> Result<Array2<F>> {
63        let batch_size_ = input.nrows();
64        let mut current_input = input.clone();
65
66        // Extract layer weights and biases
67        let layer_params = self.extract_layer_parameters();
68
69        // Forward pass through feature extractor
70        for (weights, biases) in layer_params {
71            let mut layer_output = Array2::zeros((batch_size_, biases.len()));
72
73            // Apply linear transformation
74            for i in 0..batch_size_ {
75                for j in 0..biases.len() {
76                    let mut sum = biases[j];
77                    for k in 0..current_input.ncols() {
78                        if k < weights.ncols() {
79                            sum = sum + current_input[[i, k]] * weights[[j, k]];
80                        }
81                    }
82                    layer_output[[i, j]] = self.relu(sum);
83                }
84            }
85
86            current_input = layer_output;
87        }
88
89        Ok(current_input)
90    }
91
92    /// Compute prototypes for each class
93    pub fn compute_prototypes(
94        &self,
95        support_features: &Array2<F>,
96        support_labels: &Array1<usize>,
97    ) -> Result<Array2<F>> {
98        // Find unique classes
99        let mut unique_classes = Vec::new();
100        for &label in support_labels {
101            if !unique_classes.contains(&label) {
102                unique_classes.push(label);
103            }
104        }
105        unique_classes.sort();
106
107        let num_classes = unique_classes.len();
108        let mut prototypes = Array2::zeros((num_classes, self.feature_dim));
109
110        // Compute prototype for each class
111        for (class_idx, &class_label) in unique_classes.iter().enumerate() {
112            let mut class_features = Vec::new();
113            for (i, &label) in support_labels.iter().enumerate() {
114                if label == class_label {
115                    class_features.push(support_features.row(i).to_owned());
116                }
117            }
118
119            if !class_features.is_empty() {
120                // Compute mean of class _features
121                for j in 0..self.feature_dim {
122                    let mut sum = F::zero();
123                    for features in &class_features {
124                        sum = sum + features[j];
125                    }
126                    prototypes[[class_idx, j]] = sum / F::from(class_features.len()).unwrap();
127                }
128            }
129        }
130
131        Ok(prototypes)
132    }
133
134    /// Classify query samples using prototypical networks
135    pub fn classify_queries(
136        &self,
137        query_features: &Array2<F>,
138        prototypes: &Array2<F>,
139    ) -> Result<Array1<usize>> {
140        let num_queries = query_features.nrows();
141        let num_classes = prototypes.nrows();
142        let mut predictions = Array1::zeros(num_queries);
143
144        for i in 0..num_queries {
145            let mut min_distance = F::infinity();
146            let mut predicted_class = 0;
147
148            // Find closest prototype
149            for j in 0..num_classes {
150                let distance = self.euclidean_distance(
151                    &query_features.row(i).to_owned(),
152                    &prototypes.row(j).to_owned(),
153                )?;
154
155                if distance < min_distance {
156                    min_distance = distance;
157                    predicted_class = j;
158                }
159            }
160
161            predictions[i] = predicted_class;
162        }
163
164        Ok(predictions)
165    }
166
167    /// Few-shot learning episode
168    pub fn few_shot_episode(
169        &self,
170        support_x: &Array2<F>,
171        support_y: &Array1<usize>,
172        query_x: &Array2<F>,
173    ) -> Result<Array1<usize>> {
174        // Extract features
175        let support_features = self.extract_features(support_x)?;
176        let query_features = self.extract_features(query_x)?;
177
178        // Compute prototypes
179        let prototypes = self.compute_prototypes(&support_features, support_y)?;
180
181        // Classify queries
182        self.classify_queries(&query_features, &prototypes)
183    }
184
185    /// Train the feature extractor on a batch of few-shot tasks
186    pub fn meta_train(&mut self, episodes: &[FewShotEpisode<F>]) -> Result<F> {
187        let mut total_loss = F::zero();
188        let mut total_gradients = Array2::zeros(self.feature_extractor.dim());
189
190        for episode in episodes {
191            // Forward pass
192            let predictions =
193                self.few_shot_episode(&episode.support_x, &episode.support_y, &episode.query_x)?;
194
195            // Compute loss (cross-entropy approximation)
196            let mut episode_loss = F::zero();
197            for (i, &pred) in predictions.iter().enumerate() {
198                if i < episode.query_y.len() {
199                    let target = episode.query_y[i];
200                    if pred != target {
201                        episode_loss = episode_loss + F::one();
202                    }
203                }
204            }
205            episode_loss = episode_loss / F::from(predictions.len()).unwrap();
206
207            // Compute gradients (simplified numerical differentiation)
208            let gradients = self.compute_gradients(episode)?;
209            total_gradients = total_gradients + gradients;
210            total_loss = total_loss + episode_loss;
211        }
212
213        // Update parameters
214        let learning_rate = F::from(0.001).unwrap();
215        let num_episodes = F::from(episodes.len()).unwrap();
216        total_gradients = total_gradients / num_episodes;
217
218        self.feature_extractor = self.feature_extractor.clone() - total_gradients * learning_rate;
219
220        Ok(total_loss / num_episodes)
221    }
222
223    // Helper methods
224    fn extract_layer_parameters(&self) -> Vec<(Array2<F>, Array1<F>)> {
225        let param_vec = self.feature_extractor.row(0);
226        let mut layer_params = Vec::new();
227        let mut param_idx = 0;
228
229        let mut layer_sizes = vec![self.input_dim];
230        layer_sizes.extend(&self.hidden_dims);
231        layer_sizes.push(self.feature_dim);
232
233        for i in 0..layer_sizes.len() - 1 {
234            let input_size = layer_sizes[i];
235            let output_size = layer_sizes[i + 1];
236
237            // Extract weights
238            let mut weights = Array2::zeros((output_size, input_size));
239            for j in 0..output_size {
240                for k in 0..input_size {
241                    if param_idx < param_vec.len() {
242                        weights[[j, k]] = param_vec[param_idx];
243                        param_idx += 1;
244                    }
245                }
246            }
247
248            // Extract biases
249            let mut biases = Array1::zeros(output_size);
250            for j in 0..output_size {
251                if param_idx < param_vec.len() {
252                    biases[j] = param_vec[param_idx];
253                    param_idx += 1;
254                }
255            }
256
257            layer_params.push((weights, biases));
258        }
259
260        layer_params
261    }
262
263    fn euclidean_distance(&self, a: &Array1<F>, b: &Array1<F>) -> Result<F> {
264        let mut sum = F::zero();
265        for i in 0..a.len().min(b.len()) {
266            let diff = a[i] - b[i];
267            sum = sum + diff * diff;
268        }
269        Ok(sum.sqrt())
270    }
271
272    fn relu(&self, x: F) -> F {
273        x.max(F::zero())
274    }
275
276    fn compute_gradients(&self, episode: &FewShotEpisode<F>) -> Result<Array2<F>> {
277        // Simplified gradient computation
278        let epsilon = F::from(1e-5).unwrap();
279        let mut gradients = Array2::zeros(self.feature_extractor.dim());
280
281        let base_predictions =
282            self.few_shot_episode(&episode.support_x, &episode.support_y, &episode.query_x)?;
283        let mut base_loss = F::zero();
284        for (i, &pred) in base_predictions.iter().enumerate() {
285            if i < episode.query_y.len() && pred != episode.query_y[i] {
286                base_loss = base_loss + F::one();
287            }
288        }
289
290        // Numerical differentiation for each parameter
291        for i in 0..self.feature_extractor.ncols() {
292            let mut perturbed_extractor = self.feature_extractor.clone();
293            perturbed_extractor[[0, i]] = perturbed_extractor[[0, i]] + epsilon;
294
295            // Create temporary network with perturbed parameters
296            let mut temp_network = self.clone();
297            temp_network.feature_extractor = perturbed_extractor;
298
299            let perturbed_predictions = temp_network.few_shot_episode(
300                &episode.support_x,
301                &episode.support_y,
302                &episode.query_x,
303            )?;
304            let mut perturbed_loss = F::zero();
305            for (j, &pred) in perturbed_predictions.iter().enumerate() {
306                if j < episode.query_y.len() && pred != episode.query_y[j] {
307                    perturbed_loss = perturbed_loss + F::one();
308                }
309            }
310
311            gradients[[0, i]] = (perturbed_loss - base_loss) / epsilon;
312        }
313
314        Ok(gradients)
315    }
316}
317
318impl<F: Float + Debug + Clone + FromPrimitive + scirs2_core::ndarray::ScalarOperand> Clone
319    for PrototypicalNetworks<F>
320{
321    fn clone(&self) -> Self {
322        Self {
323            feature_extractor: self.feature_extractor.clone(),
324            input_dim: self.input_dim,
325            feature_dim: self.feature_dim,
326            hidden_dims: self.hidden_dims.clone(),
327        }
328    }
329}
330
331/// Few-shot learning episode data structure
332#[derive(Debug, Clone)]
333pub struct FewShotEpisode<F: Float + Debug> {
334    /// Support set inputs
335    pub support_x: Array2<F>,
336    /// Support set labels
337    pub support_y: Array1<usize>,
338    /// Query set inputs
339    pub query_x: Array2<F>,
340    /// Query set labels
341    pub query_y: Array1<usize>,
342}
343
344impl<F: Float + Debug> FewShotEpisode<F> {
345    /// Create a new few-shot episode
346    pub fn new(
347        support_x: Array2<F>,
348        support_y: Array1<usize>,
349        query_x: Array2<F>,
350        query_y: Array1<usize>,
351    ) -> Self {
352        Self {
353            support_x,
354            support_y,
355            query_x,
356            query_y,
357        }
358    }
359
360    /// Get the number of support samples
361    pub fn support_size(&self) -> usize {
362        self.support_x.nrows()
363    }
364
365    /// Get the number of query samples
366    pub fn query_size(&self) -> usize {
367        self.query_x.nrows()
368    }
369
370    /// Get unique classes in the episode
371    pub fn unique_classes(&self) -> Vec<usize> {
372        let mut classes = Vec::new();
373        for &label in &self.support_y {
374            if !classes.contains(&label) {
375                classes.push(label);
376            }
377        }
378        for &label in &self.query_y {
379            if !classes.contains(&label) {
380                classes.push(label);
381            }
382        }
383        classes.sort();
384        classes
385    }
386}
387
388/// REPTILE Algorithm for Meta-Learning
389#[derive(Debug)]
390pub struct REPTILE<F: Float + Debug + scirs2_core::ndarray::ScalarOperand> {
391    /// Base model parameters
392    parameters: Array2<F>,
393    /// Meta-learning rate
394    meta_lr: F,
395    /// Inner loop learning rate
396    inner_lr: F,
397    /// Number of inner gradient steps
398    inner_steps: usize,
399    /// Model dimensions
400    input_dim: usize,
401    hidden_dim: usize,
402    output_dim: usize,
403}
404
405impl<F: Float + Debug + Clone + FromPrimitive + scirs2_core::ndarray::ScalarOperand> REPTILE<F> {
406    /// Create new REPTILE instance
407    pub fn new(
408        input_dim: usize,
409        hidden_dim: usize,
410        output_dim: usize,
411        meta_lr: F,
412        inner_lr: F,
413        inner_steps: usize,
414    ) -> Self {
415        // Initialize parameters using Xavier initialization
416        let total_params =
417            input_dim * hidden_dim + hidden_dim + hidden_dim * output_dim + output_dim;
418        let scale = F::from(2.0).unwrap() / F::from(input_dim + output_dim).unwrap();
419        let std_dev = scale.sqrt();
420
421        let mut parameters = Array2::zeros((1, total_params));
422        for i in 0..total_params {
423            let val = ((i * 59) % 1000) as f64 / 1000.0 - 0.5;
424            parameters[[0, i]] = F::from(val).unwrap() * std_dev;
425        }
426
427        Self {
428            parameters,
429            meta_lr,
430            inner_lr,
431            inner_steps,
432            input_dim,
433            hidden_dim,
434            output_dim,
435        }
436    }
437
438    /// REPTILE meta-training step
439    pub fn meta_train(&mut self, tasks: &[TaskData<F>]) -> Result<F> {
440        let mut total_loss = F::zero();
441        let mut parameter_updates = Array2::zeros(self.parameters.dim());
442
443        for task in tasks {
444            // Store initial parameters
445            let initial_params = self.parameters.clone();
446
447            // Inner loop training on task
448            let mut task_params = initial_params.clone();
449            for _ in 0..self.inner_steps {
450                let gradients = self.compute_task_gradients(&task_params, task)?;
451                task_params = task_params - gradients * self.inner_lr;
452            }
453
454            // Compute task loss
455            let task_loss = self.forward(&task_params, &task.support_x, &task.support_y)?;
456            total_loss = total_loss + task_loss;
457
458            // REPTILE update: move towards task-adapted parameters
459            let update = task_params - initial_params;
460            parameter_updates = parameter_updates + update;
461        }
462
463        // Meta-update: average parameter updates across tasks
464        let num_tasks = F::from(tasks.len()).unwrap();
465        parameter_updates = parameter_updates / num_tasks;
466        total_loss = total_loss / num_tasks;
467
468        // Update meta-parameters
469        self.parameters = self.parameters.clone() + parameter_updates * self.meta_lr;
470
471        Ok(total_loss)
472    }
473
474    /// Fast adaptation for new task (few-shot learning)
475    pub fn fast_adapt(&self, support_x: &Array2<F>, support_y: &Array2<F>) -> Result<Array2<F>> {
476        let task = TaskData {
477            support_x: support_x.clone(),
478            support_y: support_y.clone(),
479            query_x: support_x.clone(),
480            query_y: support_y.clone(),
481        };
482
483        // Inner loop adaptation
484        let mut adapted_params = self.parameters.clone();
485        for _ in 0..self.inner_steps {
486            let gradients = self.compute_task_gradients(&adapted_params, &task)?;
487            adapted_params = adapted_params - gradients * self.inner_lr;
488        }
489
490        Ok(adapted_params)
491    }
492
493    /// Forward pass through neural network
494    fn forward(&self, params: &Array2<F>, inputs: &Array2<F>, targets: &Array2<F>) -> Result<F> {
495        let predictions = self.predict(params, inputs)?;
496
497        // Mean squared error loss
498        let mut loss = F::zero();
499        let (batch_size, _) = predictions.dim();
500
501        for i in 0..batch_size {
502            for j in 0..self.output_dim {
503                let diff = predictions[[i, j]] - targets[[i, j]];
504                loss = loss + diff * diff;
505            }
506        }
507
508        Ok(loss / F::from(batch_size).unwrap())
509    }
510
511    /// Make predictions using current parameters
512    pub fn predict(&self, params: &Array2<F>, inputs: &Array2<F>) -> Result<Array2<F>> {
513        let (batch_size, _) = inputs.dim();
514
515        // Extract weight matrices from flattened parameters
516        let (w1, b1, w2, b2) = self.extract_weights(params);
517
518        // Forward pass: input -> hidden -> output
519        let mut hidden = Array2::zeros((batch_size, self.hidden_dim));
520
521        // Input to hidden layer
522        for i in 0..batch_size {
523            for j in 0..self.hidden_dim {
524                let mut sum = b1[j];
525                for k in 0..self.input_dim {
526                    sum = sum + inputs[[i, k]] * w1[[j, k]];
527                }
528                hidden[[i, j]] = self.relu(sum); // ReLU activation
529            }
530        }
531
532        // Hidden to output layer
533        let mut output = Array2::zeros((batch_size, self.output_dim));
534        for i in 0..batch_size {
535            for j in 0..self.output_dim {
536                let mut sum = b2[j];
537                for k in 0..self.hidden_dim {
538                    sum = sum + hidden[[i, k]] * w2[[j, k]];
539                }
540                output[[i, j]] = sum; // Linear output
541            }
542        }
543
544        Ok(output)
545    }
546
547    /// Extract weight matrices from flattened parameter vector
548    fn extract_weights(&self, params: &Array2<F>) -> (Array2<F>, Array1<F>, Array2<F>, Array1<F>) {
549        let param_vec = params.row(0);
550        let mut idx = 0;
551
552        // W1: input_dim x hidden_dim
553        let mut w1 = Array2::zeros((self.hidden_dim, self.input_dim));
554        for i in 0..self.hidden_dim {
555            for j in 0..self.input_dim {
556                w1[[i, j]] = param_vec[idx];
557                idx += 1;
558            }
559        }
560
561        // b1: hidden_dim
562        let mut b1 = Array1::zeros(self.hidden_dim);
563        for i in 0..self.hidden_dim {
564            b1[i] = param_vec[idx];
565            idx += 1;
566        }
567
568        // W2: hidden_dim x output_dim
569        let mut w2 = Array2::zeros((self.output_dim, self.hidden_dim));
570        for i in 0..self.output_dim {
571            for j in 0..self.hidden_dim {
572                w2[[i, j]] = param_vec[idx];
573                idx += 1;
574            }
575        }
576
577        // b2: output_dim
578        let mut b2 = Array1::zeros(self.output_dim);
579        for i in 0..self.output_dim {
580            b2[i] = param_vec[idx];
581            idx += 1;
582        }
583
584        (w1, b1, w2, b2)
585    }
586
587    /// ReLU activation function
588    fn relu(&self, x: F) -> F {
589        x.max(F::zero())
590    }
591
592    /// Compute task-specific gradients
593    fn compute_task_gradients(&self, params: &Array2<F>, task: &TaskData<F>) -> Result<Array2<F>> {
594        let epsilon = F::from(1e-5).unwrap();
595        let mut gradients = Array2::zeros(params.dim());
596
597        let base_loss = self.forward(params, &task.support_x, &task.support_y)?;
598
599        for i in 0..params.ncols() {
600            let mut perturbed_params = params.clone();
601            perturbed_params[[0, i]] = perturbed_params[[0, i]] + epsilon;
602
603            let perturbed_loss =
604                self.forward(&perturbed_params, &task.support_x, &task.support_y)?;
605            gradients[[0, i]] = (perturbed_loss - base_loss) / epsilon;
606        }
607
608        Ok(gradients)
609    }
610
611    /// Get current parameters
612    pub fn parameters(&self) -> &Array2<F> {
613        &self.parameters
614    }
615
616    /// Set parameters
617    pub fn set_parameters(&mut self, parameters: Array2<F>) {
618        self.parameters = parameters;
619    }
620}
621
622#[cfg(test)]
623mod tests {
624    use super::*;
625    use approx::assert_abs_diff_eq;
626
627    #[test]
628    fn test_prototypical_networks_creation() {
629        let hidden_dims = vec![16, 32];
630        let model = PrototypicalNetworks::<f64>::new(10, 8, hidden_dims.clone());
631
632        assert_eq!(model.input_dim, 10);
633        assert_eq!(model.feature_dim, 8);
634        assert_eq!(model.hidden_dims, hidden_dims);
635    }
636
637    #[test]
638    fn test_few_shot_episode() {
639        let support_x =
640            Array2::from_shape_vec((4, 3), (0..12).map(|i| i as f64).collect()).unwrap();
641        let support_y = Array1::from_vec(vec![0, 0, 1, 1]);
642        let query_x = Array2::from_shape_vec((2, 3), (12..18).map(|i| i as f64).collect()).unwrap();
643        let query_y = Array1::from_vec(vec![0, 1]);
644
645        let episode = FewShotEpisode::new(support_x, support_y, query_x, query_y);
646
647        assert_eq!(episode.support_size(), 4);
648        assert_eq!(episode.query_size(), 2);
649
650        let classes = episode.unique_classes();
651        assert_eq!(classes, vec![0, 1]);
652    }
653
654    #[test]
655    fn test_prototypical_networks_features() {
656        let model = PrototypicalNetworks::<f64>::new(5, 4, vec![8]);
657        let input =
658            Array2::from_shape_vec((3, 5), (0..15).map(|i| i as f64 * 0.1).collect()).unwrap();
659
660        let features = model.extract_features(&input).unwrap();
661        assert_eq!(features.dim(), (3, 4));
662
663        // Check that features are finite
664        for &val in features.iter() {
665            assert!(val.is_finite());
666        }
667    }
668
669    #[test]
670    fn test_prototypical_networks_classification() {
671        let model = PrototypicalNetworks::<f64>::new(4, 6, vec![8]);
672
673        let support_x =
674            Array2::from_shape_vec((6, 4), (0..24).map(|i| i as f64 * 0.1).collect()).unwrap();
675        let support_y = Array1::from_vec(vec![0, 0, 0, 1, 1, 1]);
676        let query_x =
677            Array2::from_shape_vec((2, 4), (24..32).map(|i| i as f64 * 0.1).collect()).unwrap();
678
679        let predictions = model
680            .few_shot_episode(&support_x, &support_y, &query_x)
681            .unwrap();
682        assert_eq!(predictions.len(), 2);
683
684        // Predictions should be within valid class range
685        for &pred in predictions.iter() {
686            assert!(pred <= 1);
687        }
688    }
689
690    #[test]
691    fn test_reptile_creation() {
692        let reptile = REPTILE::<f64>::new(5, 10, 3, 0.01, 0.1, 5);
693
694        assert_eq!(reptile.input_dim, 5);
695        assert_eq!(reptile.hidden_dim, 10);
696        assert_eq!(reptile.output_dim, 3);
697    }
698
699    #[test]
700    fn test_reptile_prediction() {
701        let reptile = REPTILE::<f64>::new(4, 8, 2, 0.01, 0.1, 3);
702        let input =
703            Array2::from_shape_vec((3, 4), (0..12).map(|i| i as f64 * 0.1).collect()).unwrap();
704
705        let output = reptile.predict(&reptile.parameters, &input).unwrap();
706        assert_eq!(output.dim(), (3, 2));
707
708        // Check that output is finite
709        for &val in output.iter() {
710            assert!(val.is_finite());
711        }
712    }
713
714    #[test]
715    fn test_reptile_fast_adapt() {
716        let reptile = REPTILE::<f64>::new(3, 6, 2, 0.01, 0.1, 2);
717        let support_x =
718            Array2::from_shape_vec((4, 3), (0..12).map(|i| i as f64 * 0.2).collect()).unwrap();
719        let support_y =
720            Array2::from_shape_vec((4, 2), (0..8).map(|i| i as f64 * 0.1).collect()).unwrap();
721
722        let adapted_params = reptile.fast_adapt(&support_x, &support_y).unwrap();
723        assert_eq!(adapted_params.dim(), reptile.parameters.dim());
724
725        // Adapted parameters should be different from original
726        let params_changed = adapted_params
727            .iter()
728            .zip(reptile.parameters.iter())
729            .any(|(a, b)| (a - b).abs() > 1e-10);
730        assert!(params_changed);
731    }
732
733    #[test]
734    fn test_euclidean_distance() {
735        let model = PrototypicalNetworks::<f64>::new(3, 4, vec![]);
736        let a = Array1::from_vec(vec![1.0, 2.0, 3.0]);
737        let b = Array1::from_vec(vec![4.0, 5.0, 6.0]);
738
739        let distance = model.euclidean_distance(&a, &b).unwrap();
740        let expected = ((3.0_f64).powi(2) + (3.0_f64).powi(2) + (3.0_f64).powi(2)).sqrt();
741        assert_abs_diff_eq!(distance, expected, epsilon = 1e-10);
742    }
743
744    #[test]
745    fn test_compute_prototypes() {
746        let model = PrototypicalNetworks::<f64>::new(4, 3, vec![]);
747
748        // Create simple features where we know the expected prototypes
749        let features = Array2::from_shape_vec(
750            (6, 3),
751            vec![
752                1.0, 1.0, 1.0, // Class 0
753                2.0, 2.0, 2.0, // Class 0
754                3.0, 3.0, 3.0, // Class 1
755                4.0, 4.0, 4.0, // Class 1
756                5.0, 5.0, 5.0, // Class 1
757                6.0, 6.0, 6.0, // Class 2
758            ],
759        )
760        .unwrap();
761        let labels = Array1::from_vec(vec![0, 0, 1, 1, 1, 2]);
762
763        let prototypes = model.compute_prototypes(&features, &labels).unwrap();
764        assert_eq!(prototypes.dim(), (3, 3)); // 3 classes, 3 features
765
766        // Check class 0 prototype (mean of [1,1,1] and [2,2,2])
767        assert_abs_diff_eq!(prototypes[[0, 0]], 1.5, epsilon = 1e-10);
768        assert_abs_diff_eq!(prototypes[[0, 1]], 1.5, epsilon = 1e-10);
769        assert_abs_diff_eq!(prototypes[[0, 2]], 1.5, epsilon = 1e-10);
770
771        // Check class 1 prototype (mean of [3,3,3], [4,4,4], [5,5,5])
772        assert_abs_diff_eq!(prototypes[[1, 0]], 4.0, epsilon = 1e-10);
773        assert_abs_diff_eq!(prototypes[[1, 1]], 4.0, epsilon = 1e-10);
774        assert_abs_diff_eq!(prototypes[[1, 2]], 4.0, epsilon = 1e-10);
775
776        // Check class 2 prototype ([6,6,6])
777        assert_abs_diff_eq!(prototypes[[2, 0]], 6.0, epsilon = 1e-10);
778        assert_abs_diff_eq!(prototypes[[2, 1]], 6.0, epsilon = 1e-10);
779        assert_abs_diff_eq!(prototypes[[2, 2]], 6.0, epsilon = 1e-10);
780    }
781}