1use scirs2_core::ndarray::{Array1, Array2, ArrayView2};
14use scirs2_core::numeric::{Float, FromPrimitive};
15use std::collections::HashMap;
16use std::fmt::Debug;
17
18use crate::error::{ClusteringError, Result};
19
20#[inline]
26fn sq_euclid(a: &[f64], b: &[f64]) -> f64 {
27 a.iter().zip(b.iter()).map(|(x, y)| (x - y) * (x - y)).sum()
28}
29
30#[inline]
32fn euclid(a: &[f64], b: &[f64]) -> f64 {
33 sq_euclid(a, b).sqrt()
34}
35
36fn lcg_next(state: &mut u64) -> f64 {
38 *state = state
39 .wrapping_mul(6_364_136_223_846_793_005)
40 .wrapping_add(1_442_695_040_888_963_407);
41 (*state >> 11) as f64 / (1u64 << 53) as f64
43}
44
45#[inline]
47fn lcg_usize(state: &mut u64, n: usize) -> usize {
48 lcg_next(state) as usize % n
49}
50
51#[derive(Debug, Clone)]
57pub struct NeuralGasResult {
58 pub prototypes: Array2<f64>,
60 pub labels: Array1<usize>,
62 pub n_units: usize,
64 pub quantization_error: f64,
66}
67
68pub struct NeuralGas {
77 pub lr_i: f64,
79 pub lr_f: f64,
81 pub lambda_i: Option<f64>,
83 pub lambda_f: f64,
85 pub seed: u64,
87}
88
89impl Default for NeuralGas {
90 fn default() -> Self {
91 Self {
92 lr_i: 0.5,
93 lr_f: 0.01,
94 lambda_i: None,
95 lambda_f: 0.01,
96 seed: 42,
97 }
98 }
99}
100
101impl NeuralGas {
102 pub fn fit(
109 &self,
110 x: ArrayView2<f64>,
111 n_units: usize,
112 max_iter: usize,
113 ) -> Result<NeuralGasResult> {
114 let (n_samples, n_features) = (x.shape()[0], x.shape()[1]);
115 if n_samples == 0 {
116 return Err(ClusteringError::InvalidInput("Empty input data".into()));
117 }
118 if n_units == 0 {
119 return Err(ClusteringError::InvalidInput("n_units must be > 0".into()));
120 }
121 if max_iter == 0 {
122 return Err(ClusteringError::InvalidInput("max_iter must be > 0".into()));
123 }
124
125 let mut rng = self.seed;
126
127 let mut protos: Vec<Vec<f64>> = (0..n_units)
129 .map(|_| {
130 let idx = lcg_usize(&mut rng, n_samples);
131 x.row(idx).to_vec()
132 })
133 .collect();
134
135 let total_steps = max_iter * n_samples;
136 let lambda_i = self.lambda_i.unwrap_or((n_units as f64) / 2.0).max(0.5);
137
138 for epoch in 0..max_iter {
139 let mut order: Vec<usize> = (0..n_samples).collect();
141 for i in (1..n_samples).rev() {
142 let j = lcg_usize(&mut rng, i + 1);
143 order.swap(i, j);
144 }
145
146 for &sample_idx in &order {
147 let step = epoch * n_samples + sample_idx;
149 let t = step as f64 / total_steps.max(1) as f64;
150
151 let lr = self.lr_i * (self.lr_f / self.lr_i).powf(t);
153 let lam = lambda_i * (self.lambda_f / lambda_i).powf(t);
154
155 let input = x.row(sample_idx).to_vec();
156
157 let mut ranked: Vec<(f64, usize)> = protos
159 .iter()
160 .enumerate()
161 .map(|(j, p)| (euclid(&input, p), j))
162 .collect();
163 ranked.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal));
164
165 for (rank, (_, proto_idx)) in ranked.iter().enumerate() {
167 let h = (-(rank as f64) / lam).exp();
168 let p = &mut protos[*proto_idx];
169 for k in 0..n_features {
170 p[k] += lr * h * (input[k] - p[k]);
171 }
172 }
173 }
174 }
175
176 let mut labels = vec![0usize; n_samples];
178 let mut total_qe = 0.0f64;
179 for i in 0..n_samples {
180 let row = x.row(i).to_vec();
181 let (best, best_dist) = protos
182 .iter()
183 .enumerate()
184 .map(|(j, p)| (j, sq_euclid(&row, p)))
185 .min_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal))
186 .unwrap_or((0, 0.0));
187 labels[i] = best;
188 total_qe += best_dist;
189 }
190 let quantization_error = total_qe / n_samples as f64;
191
192 let mut proto_arr = Array2::<f64>::zeros((n_units, n_features));
194 for (j, p) in protos.iter().enumerate() {
195 for k in 0..n_features {
196 proto_arr[[j, k]] = p[k];
197 }
198 }
199
200 Ok(NeuralGasResult {
201 prototypes: proto_arr,
202 labels: Array1::from_vec(labels),
203 n_units,
204 quantization_error,
205 })
206 }
207}
208
209#[derive(Debug, Clone)]
215struct GngEdge {
216 age: usize,
219}
220
221#[derive(Debug, Clone)]
223struct GngNode {
224 weights: Vec<f64>,
226 error: f64,
228}
229
230#[derive(Debug, Clone)]
232pub struct GngConfig {
233 pub lr_winner: f64,
235 pub lr_neighbor: f64,
237 pub max_age: usize,
239 pub insert_interval: usize,
241 pub alpha: f64,
243 pub beta: f64,
245 pub max_units: usize,
247 pub max_steps: usize,
249 pub seed: u64,
251}
252
253impl Default for GngConfig {
254 fn default() -> Self {
255 Self {
256 lr_winner: 0.1,
257 lr_neighbor: 0.01,
258 max_age: 50,
259 insert_interval: 100,
260 alpha: 0.5,
261 beta: 0.995,
262 max_units: 200,
263 max_steps: 5000,
264 seed: 42,
265 }
266 }
267}
268
269#[derive(Debug, Clone)]
271pub struct GngResult {
272 pub prototypes: Array2<f64>,
274 pub edges: Vec<(usize, usize)>,
276 pub labels: Array1<usize>,
278 pub quantization_error: f64,
280}
281
282pub struct GrowingNeuralGas {
289 pub config: GngConfig,
291}
292
293impl Default for GrowingNeuralGas {
294 fn default() -> Self {
295 Self {
296 config: GngConfig::default(),
297 }
298 }
299}
300
301impl GrowingNeuralGas {
302 pub fn new(config: GngConfig) -> Self {
304 Self { config }
305 }
306
307 pub fn fit(&self, x: ArrayView2<f64>) -> Result<GngResult> {
309 let (n_samples, n_features) = (x.shape()[0], x.shape()[1]);
310 if n_samples < 2 {
311 return Err(ClusteringError::InvalidInput(
312 "Need at least 2 samples for GNG".into(),
313 ));
314 }
315
316 let cfg = &self.config;
317 let mut rng = cfg.seed;
318
319 let idx0 = lcg_usize(&mut rng, n_samples);
321 let idx1 = (idx0 + 1 + lcg_usize(&mut rng, n_samples - 1)) % n_samples;
322 let mut nodes: Vec<GngNode> = vec![
323 GngNode {
324 weights: x.row(idx0).to_vec(),
325 error: 0.0,
326 },
327 GngNode {
328 weights: x.row(idx1).to_vec(),
329 error: 0.0,
330 },
331 ];
332 let mut edge_map: HashMap<(usize, usize), GngEdge> = HashMap::new();
335 edge_map.insert((0, 1), GngEdge { age: 0 });
337
338 let mut step = 0usize;
339 let data_vec: Vec<Vec<f64>> = (0..n_samples).map(|i| x.row(i).to_vec()).collect();
340
341 while step < cfg.max_steps {
342 let sample = &data_vec[lcg_usize(&mut rng, n_samples)];
344
345 let mut dists: Vec<(f64, usize)> = nodes
347 .iter()
348 .enumerate()
349 .map(|(j, n)| (sq_euclid(sample, &n.weights), j))
350 .collect();
351 dists.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal));
352
353 if dists.len() < 2 {
354 step += 1;
355 continue;
356 }
357
358 let s1 = dists[0].1;
359 let s2 = dists[1].1;
360 let dist_s1 = dists[0].0;
361
362 let edge_keys: Vec<(usize, usize)> = edge_map.keys().cloned().collect();
364 for key in &edge_keys {
365 if key.0 == s1 || key.1 == s1 {
366 if let Some(e) = edge_map.get_mut(key) {
367 e.age += 1;
368 }
369 }
370 }
371
372 let edge_key = if s1 < s2 { (s1, s2) } else { (s2, s1) };
374 edge_map.insert(edge_key, GngEdge { age: 0 });
375
376 nodes[s1].error += dist_s1;
378
379 let n_nodes = nodes.len();
381 let winner_w: Vec<f64> = nodes[s1].weights.clone();
382 for k in 0..n_features {
383 nodes[s1].weights[k] += cfg.lr_winner * (sample[k] - winner_w[k]);
384 }
385
386 let neighbor_ids: Vec<usize> = edge_map
387 .keys()
388 .filter_map(|&(a, b)| {
389 if a == s1 {
390 Some(b)
391 } else if b == s1 {
392 Some(a)
393 } else {
394 None
395 }
396 })
397 .collect();
398
399 for nb in &neighbor_ids {
400 let nb_w: Vec<f64> = nodes[*nb].weights.clone();
401 for k in 0..n_features {
402 nodes[*nb].weights[k] += cfg.lr_neighbor * (sample[k] - nb_w[k]);
403 }
404 }
405
406 edge_map.retain(|_, e| e.age <= cfg.max_age);
408
409 let connected: std::collections::HashSet<usize> =
412 edge_map.keys().flat_map(|&(a, b)| [a, b]).collect();
413 for node in nodes.iter_mut() {
417 node.error *= cfg.beta;
418 }
419
420 if step % cfg.insert_interval == 0 && nodes.len() < cfg.max_units && nodes.len() >= 2 {
422 let q = nodes
424 .iter()
425 .enumerate()
426 .max_by(|a, b| {
427 a.1.error
428 .partial_cmp(&b.1.error)
429 .unwrap_or(std::cmp::Ordering::Equal)
430 })
431 .map(|(i, _)| i)
432 .unwrap_or(0);
433
434 let q_neighbors: Vec<usize> = edge_map
436 .keys()
437 .filter_map(|&(a, b)| {
438 if a == q {
439 Some(b)
440 } else if b == q {
441 Some(a)
442 } else {
443 None
444 }
445 })
446 .collect();
447
448 if !q_neighbors.is_empty() {
449 let f = q_neighbors
450 .iter()
451 .max_by(|&&a, &&b| {
452 nodes[a]
453 .error
454 .partial_cmp(&nodes[b].error)
455 .unwrap_or(std::cmp::Ordering::Equal)
456 })
457 .cloned()
458 .unwrap_or(q_neighbors[0]);
459
460 let new_weights: Vec<f64> = nodes[q]
462 .weights
463 .iter()
464 .zip(nodes[f].weights.iter())
465 .map(|(a, b)| (a + b) / 2.0)
466 .collect();
467 let new_idx = nodes.len();
468 nodes.push(GngNode {
469 weights: new_weights,
470 error: nodes[q].error * cfg.alpha,
471 });
472
473 nodes[q].error *= cfg.alpha;
475 nodes[f].error *= cfg.alpha;
476
477 let qf_key = if q < f { (q, f) } else { (f, q) };
479 edge_map.remove(&qf_key);
480 let qn_key = if q < new_idx {
481 (q, new_idx)
482 } else {
483 (new_idx, q)
484 };
485 let fn_key = if f < new_idx {
486 (f, new_idx)
487 } else {
488 (new_idx, f)
489 };
490 edge_map.insert(qn_key, GngEdge { age: 0 });
491 edge_map.insert(fn_key, GngEdge { age: 0 });
492 }
493 }
494
495 step += 1;
496 }
497
498 let n_units = nodes.len();
499 let mut proto_arr = Array2::<f64>::zeros((n_units, n_features));
500 for (j, node) in nodes.iter().enumerate() {
501 for k in 0..n_features {
502 proto_arr[[j, k]] = node.weights[k];
503 }
504 }
505
506 let edges: Vec<(usize, usize)> = edge_map.keys().cloned().collect();
507
508 let mut labels = vec![0usize; n_samples];
510 let mut total_qe = 0.0f64;
511 for i in 0..n_samples {
512 let row = x.row(i).to_vec();
513 let (best, best_dist) = nodes
514 .iter()
515 .enumerate()
516 .map(|(j, node)| (j, sq_euclid(&row, &node.weights)))
517 .min_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal))
518 .unwrap_or((0, 0.0));
519 labels[i] = best;
520 total_qe += best_dist;
521 }
522
523 Ok(GngResult {
524 prototypes: proto_arr,
525 edges,
526 labels: Array1::from_vec(labels),
527 quantization_error: total_qe / n_samples as f64,
528 })
529 }
530}
531
532#[derive(Debug, Clone)]
538pub struct LvqConfig {
539 pub prototypes_per_class: usize,
541 pub lr_init: f64,
543 pub lr_final: f64,
545 pub max_epochs: usize,
547 pub seed: u64,
549}
550
551impl Default for LvqConfig {
552 fn default() -> Self {
553 Self {
554 prototypes_per_class: 1,
555 lr_init: 0.1,
556 lr_final: 0.001,
557 max_epochs: 50,
558 seed: 42,
559 }
560 }
561}
562
563#[derive(Debug, Clone)]
565pub struct LvqResult {
566 pub prototypes: Array2<f64>,
568 pub prototype_labels: Vec<usize>,
570 pub train_accuracy: f64,
572}
573
574impl LvqResult {
575 pub fn predict(&self, x: ArrayView2<f64>) -> Vec<usize> {
577 let n = x.shape()[0];
578 let n_proto = self.prototypes.shape()[0];
579 (0..n)
580 .map(|i| {
581 let row = x.row(i).to_vec();
582 let best = (0..n_proto)
583 .map(|j| {
584 let p: Vec<f64> = self.prototypes.row(j).to_vec();
585 (j, sq_euclid(&row, &p))
586 })
587 .min_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal))
588 .map(|(j, _)| j)
589 .unwrap_or(0);
590 self.prototype_labels[best]
591 })
592 .collect()
593 }
594}
595
596pub struct LVQ {
601 pub config: LvqConfig,
603}
604
605impl Default for LVQ {
606 fn default() -> Self {
607 Self {
608 config: LvqConfig::default(),
609 }
610 }
611}
612
613impl LVQ {
614 pub fn new(config: LvqConfig) -> Self {
616 Self { config }
617 }
618
619 pub fn fit(&self, x: ArrayView2<f64>, y: &[usize]) -> Result<LvqResult> {
625 let (n_samples, n_features) = (x.shape()[0], x.shape()[1]);
626 if n_samples == 0 {
627 return Err(ClusteringError::InvalidInput("Empty input data".into()));
628 }
629 if y.len() != n_samples {
630 return Err(ClusteringError::InvalidInput(
631 "y must have the same length as x rows".into(),
632 ));
633 }
634
635 let n_classes = y.iter().cloned().max().map(|m| m + 1).unwrap_or(0);
636 if n_classes == 0 {
637 return Err(ClusteringError::InvalidInput("Empty class labels".into()));
638 }
639
640 let ppc = self.config.prototypes_per_class;
641 let mut rng = self.config.seed;
642
643 let mut class_samples: Vec<Vec<usize>> = vec![Vec::new(); n_classes];
645 for (i, &label) in y.iter().enumerate() {
646 if label < n_classes {
647 class_samples[label].push(i);
648 }
649 }
650
651 let mut proto_weights: Vec<Vec<f64>> = Vec::new();
652 let mut proto_labels: Vec<usize> = Vec::new();
653
654 for class in 0..n_classes {
655 let samples = &class_samples[class];
656 if samples.is_empty() {
657 continue;
658 }
659 for _ in 0..ppc {
660 let idx = samples[lcg_usize(&mut rng, samples.len())];
661 proto_weights.push(x.row(idx).to_vec());
662 proto_labels.push(class);
663 }
664 }
665
666 let n_proto = proto_weights.len();
667 if n_proto == 0 {
668 return Err(ClusteringError::ComputationError(
669 "No prototypes initialized".into(),
670 ));
671 }
672
673 let total_steps = self.config.max_epochs * n_samples;
674
675 for epoch in 0..self.config.max_epochs {
677 let mut order: Vec<usize> = (0..n_samples).collect();
679 for i in (1..n_samples).rev() {
680 let j = lcg_usize(&mut rng, i + 1);
681 order.swap(i, j);
682 }
683
684 for (step_in_epoch, &sample_idx) in order.iter().enumerate() {
685 let step = epoch * n_samples + step_in_epoch;
686 let t = step as f64 / total_steps.max(1) as f64;
687 let lr = self.config.lr_init * (self.config.lr_final / self.config.lr_init).powf(t);
688
689 let input = x.row(sample_idx).to_vec();
690 let true_class = y[sample_idx];
691
692 let nearest = (0..n_proto)
694 .map(|j| (j, sq_euclid(&input, &proto_weights[j])))
695 .min_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal))
696 .map(|(j, _)| j)
697 .unwrap_or(0);
698
699 let sign = if proto_labels[nearest] == true_class {
701 1.0f64
702 } else {
703 -1.0f64
704 };
705
706 let w = &mut proto_weights[nearest];
707 for k in 0..n_features {
708 w[k] += lr * sign * (input[k] - w[k]);
709 }
710 }
711 }
712
713 let mut proto_arr = Array2::<f64>::zeros((n_proto, n_features));
715 for (j, w) in proto_weights.iter().enumerate() {
716 for k in 0..n_features {
717 proto_arr[[j, k]] = w[k];
718 }
719 }
720
721 let predictions = {
723 let n = n_samples;
724 (0..n)
725 .map(|i| {
726 let row = x.row(i).to_vec();
727 let best = (0..n_proto)
728 .map(|j| (j, sq_euclid(&row, &proto_weights[j])))
729 .min_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal))
730 .map(|(j, _)| j)
731 .unwrap_or(0);
732 proto_labels[best]
733 })
734 .collect::<Vec<usize>>()
735 };
736
737 let correct = predictions
738 .iter()
739 .zip(y.iter())
740 .filter(|(&p, &t)| p == t)
741 .count();
742 let train_accuracy = correct as f64 / n_samples as f64;
743
744 Ok(LvqResult {
745 prototypes: proto_arr,
746 prototype_labels: proto_labels,
747 train_accuracy,
748 })
749 }
750}
751
752#[derive(Debug, Clone)]
758pub struct GlvqResult {
759 pub prototypes: Array2<f64>,
761 pub prototype_labels: Vec<usize>,
763 pub train_accuracy: f64,
765 pub cost: f64,
767}
768
769impl GlvqResult {
770 pub fn predict(&self, x: ArrayView2<f64>) -> Vec<usize> {
772 let n = x.shape()[0];
773 let n_proto = self.prototypes.shape()[0];
774 (0..n)
775 .map(|i| {
776 let row = x.row(i).to_vec();
777 let best = (0..n_proto)
778 .map(|j| {
779 let p = self.prototypes.row(j).to_vec();
780 (j, sq_euclid(&row, &p))
781 })
782 .min_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal))
783 .map(|(j, _)| j)
784 .unwrap_or(0);
785 self.prototype_labels[best]
786 })
787 .collect()
788 }
789}
790
791#[derive(Debug, Clone)]
793pub struct GlvqConfig {
794 pub prototypes_per_class: usize,
796 pub lr: f64,
798 pub sigma: f64,
800 pub max_epochs: usize,
802 pub seed: u64,
804}
805
806impl Default for GlvqConfig {
807 fn default() -> Self {
808 Self {
809 prototypes_per_class: 1,
810 lr: 0.01,
811 sigma: 1.0,
812 max_epochs: 100,
813 seed: 42,
814 }
815 }
816}
817
818pub struct GLVQ {
830 pub config: GlvqConfig,
832}
833
834impl Default for GLVQ {
835 fn default() -> Self {
836 Self {
837 config: GlvqConfig::default(),
838 }
839 }
840}
841
842impl GLVQ {
843 pub fn new(config: GlvqConfig) -> Self {
845 Self { config }
846 }
847
848 pub fn fit(&self, x: ArrayView2<f64>, y: &[usize]) -> Result<GlvqResult> {
850 let (n_samples, n_features) = (x.shape()[0], x.shape()[1]);
851 if n_samples == 0 {
852 return Err(ClusteringError::InvalidInput("Empty input data".into()));
853 }
854 if y.len() != n_samples {
855 return Err(ClusteringError::InvalidInput("y length mismatch".into()));
856 }
857
858 let n_classes = y.iter().cloned().max().map(|m| m + 1).unwrap_or(0);
859 if n_classes < 2 {
860 return Err(ClusteringError::InvalidInput(
861 "GLVQ requires at least 2 classes".into(),
862 ));
863 }
864
865 let ppc = self.config.prototypes_per_class;
866 let mut rng = self.config.seed;
867
868 let mut class_samples: Vec<Vec<usize>> = vec![Vec::new(); n_classes];
870 for (i, &label) in y.iter().enumerate() {
871 if label < n_classes {
872 class_samples[label].push(i);
873 }
874 }
875
876 let mut proto_weights: Vec<Vec<f64>> = Vec::new();
877 let mut proto_labels: Vec<usize> = Vec::new();
878
879 for class in 0..n_classes {
880 let samples = &class_samples[class];
881 if samples.is_empty() {
882 continue;
883 }
884 for _ in 0..ppc {
885 let idx = samples[lcg_usize(&mut rng, samples.len())];
886 proto_weights.push(x.row(idx).to_vec());
887 proto_labels.push(class);
888 }
889 }
890
891 let n_proto = proto_weights.len();
892 let lr = self.config.lr;
893 let sigma = self.config.sigma;
894
895 let mut total_cost = 0.0f64;
896
897 for _epoch in 0..self.config.max_epochs {
899 let mut order: Vec<usize> = (0..n_samples).collect();
901 for i in (1..n_samples).rev() {
902 let j = lcg_usize(&mut rng, i + 1);
903 order.swap(i, j);
904 }
905
906 total_cost = 0.0;
907 for &sample_idx in &order {
908 let input = x.row(sample_idx).to_vec();
909 let true_class = y[sample_idx];
910
911 let mut d_plus = f64::INFINITY;
913 let mut d_minus = f64::INFINITY;
914 let mut winner_plus = 0usize;
915 let mut winner_minus = 0usize;
916
917 for j in 0..n_proto {
918 let d = sq_euclid(&input, &proto_weights[j]);
919 if proto_labels[j] == true_class {
920 if d < d_plus {
921 d_plus = d;
922 winner_plus = j;
923 }
924 } else if d < d_minus {
925 d_minus = d;
926 winner_minus = j;
927 }
928 }
929
930 if d_plus.is_infinite() || d_minus.is_infinite() {
931 continue;
932 }
933
934 let denom = d_plus + d_minus;
935 if denom < 1e-12 {
936 continue;
937 }
938
939 let mu = (d_plus - d_minus) / denom;
940 let f_mu = 1.0 / (1.0 + (-sigma * mu).exp());
942 let f_prime = sigma * f_mu * (1.0 - f_mu);
944
945 total_cost += f_mu;
946
947 let grad_dp = f_prime * (2.0 * d_minus) / (denom * denom);
949 let grad_dm = -f_prime * (2.0 * d_plus) / (denom * denom);
951
952 let wp = &mut proto_weights[winner_plus];
955 for k in 0..n_features {
956 wp[k] -= lr * 2.0 * grad_dp * (wp[k] - input[k]);
957 }
958
959 let wm = &mut proto_weights[winner_minus];
962 for k in 0..n_features {
963 wm[k] -= lr * 2.0 * grad_dm * (wm[k] - input[k]);
964 }
965 }
966 }
967
968 let mut proto_arr = Array2::<f64>::zeros((n_proto, n_features));
970 for (j, w) in proto_weights.iter().enumerate() {
971 for k in 0..n_features {
972 proto_arr[[j, k]] = w[k];
973 }
974 }
975
976 let mut correct = 0usize;
978 for i in 0..n_samples {
979 let row = x.row(i).to_vec();
980 let best = (0..n_proto)
981 .map(|j| (j, sq_euclid(&row, &proto_weights[j])))
982 .min_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal))
983 .map(|(j, _)| j)
984 .unwrap_or(0);
985 if proto_labels[best] == y[i] {
986 correct += 1;
987 }
988 }
989 let train_accuracy = correct as f64 / n_samples as f64;
990
991 Ok(GlvqResult {
992 prototypes: proto_arr,
993 prototype_labels: proto_labels,
994 train_accuracy,
995 cost: total_cost,
996 })
997 }
998}
999
1000#[cfg(test)]
1005mod tests {
1006 use super::*;
1007 use scirs2_core::ndarray::Array2;
1008
1009 fn two_cluster_data() -> (Array2<f64>, Vec<usize>) {
1010 let x = Array2::from_shape_vec(
1011 (12, 2),
1012 vec![
1013 0.0, 0.0, 0.1, 0.0, 0.0, 0.1, 0.2, 0.0, 0.1, 0.1, 0.0, 0.2, 5.0, 5.0, 5.1, 5.0,
1014 5.0, 5.1, 5.2, 5.0, 5.1, 5.1, 5.0, 5.2,
1015 ],
1016 )
1017 .expect("valid shape");
1018 let y = vec![0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1];
1019 (x, y)
1020 }
1021
1022 #[test]
1023 fn test_neural_gas_basic() {
1024 let (x, _) = two_cluster_data();
1025 let ng = NeuralGas::default();
1026 let result = ng.fit(x.view(), 2, 20).expect("neural gas fit");
1027 assert_eq!(result.n_units, 2);
1028 assert_eq!(result.labels.len(), 12);
1029 assert!(result.quantization_error >= 0.0);
1030 }
1031
1032 #[test]
1033 fn test_neural_gas_n_units_gt_samples() {
1034 let (x, _) = two_cluster_data();
1035 let ng = NeuralGas::default();
1036 let result = ng.fit(x.view(), 5, 10).expect("ng many units");
1038 assert_eq!(result.n_units, 5);
1039 }
1040
1041 #[test]
1042 fn test_neural_gas_single_unit() {
1043 let (x, _) = two_cluster_data();
1044 let ng = NeuralGas::default();
1045 let result = ng.fit(x.view(), 1, 10).expect("ng 1 unit");
1046 assert_eq!(result.n_units, 1);
1047 assert!(result.labels.iter().all(|&l| l == 0));
1048 }
1049
1050 #[test]
1051 fn test_growing_neural_gas_basic() {
1052 let (x, _) = two_cluster_data();
1053 let config = GngConfig {
1054 max_steps: 200,
1055 insert_interval: 20,
1056 max_units: 10,
1057 seed: 7,
1058 ..GngConfig::default()
1059 };
1060 let gng = GrowingNeuralGas::new(config);
1061 let result = gng.fit(x.view()).expect("gng fit");
1062 assert!(result.prototypes.shape()[0] >= 2, "should have grown");
1063 assert_eq!(result.labels.len(), 12);
1064 }
1065
1066 #[test]
1067 fn test_lvq_two_classes() {
1068 let (x, y) = two_cluster_data();
1069 let config = LvqConfig {
1070 prototypes_per_class: 1,
1071 lr_init: 0.3,
1072 lr_final: 0.01,
1073 max_epochs: 100,
1074 seed: 42,
1075 };
1076 let lvq = LVQ::new(config);
1077 let result = lvq.fit(x.view(), &y).expect("lvq fit");
1078 assert_eq!(result.prototypes.shape()[0], 2); assert!(
1081 result.train_accuracy > 0.8,
1082 "expected > 80% accuracy, got {}",
1083 result.train_accuracy
1084 );
1085 }
1086
1087 #[test]
1088 fn test_lvq_predict() {
1089 let (x, y) = two_cluster_data();
1090 let lvq = LVQ::default();
1091 let result = lvq.fit(x.view(), &y).expect("lvq fit");
1092 let preds = result.predict(x.view());
1093 assert_eq!(preds.len(), 12);
1094 }
1095
1096 #[test]
1097 fn test_glvq_two_classes() {
1098 let (x, y) = two_cluster_data();
1099 let config = GlvqConfig {
1100 prototypes_per_class: 1,
1101 lr: 0.05,
1102 sigma: 1.0,
1103 max_epochs: 200,
1104 seed: 42,
1105 };
1106 let glvq = GLVQ::new(config);
1107 let result = glvq.fit(x.view(), &y).expect("glvq fit");
1108 assert_eq!(result.prototypes.shape()[0], 2);
1109 assert!(
1110 result.train_accuracy > 0.8,
1111 "expected > 80% accuracy, got {}",
1112 result.train_accuracy
1113 );
1114 }
1115
1116 #[test]
1117 fn test_glvq_predict() {
1118 let (x, y) = two_cluster_data();
1119 let glvq = GLVQ::default();
1120 let result = glvq.fit(x.view(), &y).expect("glvq fit");
1121 let preds = result.predict(x.view());
1122 assert_eq!(preds.len(), 12);
1123 }
1124
1125 #[test]
1126 fn test_lvq_invalid_input() {
1127 let (x, _y) = two_cluster_data();
1128 let lvq = LVQ::default();
1129 assert!(lvq.fit(x.view(), &[0, 1, 0]).is_err());
1131 }
1132
1133 #[test]
1134 fn test_glvq_single_class_error() {
1135 let x = Array2::from_shape_vec((4, 2), vec![0.0, 0.0, 0.1, 0.1, 0.2, 0.0, 0.3, 0.1])
1136 .expect("shape");
1137 let y = vec![0usize, 0, 0, 0];
1138 let glvq = GLVQ::default();
1139 assert!(glvq.fit(x.view(), &y).is_err(), "single class should error");
1140 }
1141}