1use scirs2_core::ndarray::{Array1, Array2, ArrayView2};
30
31use crate::error::{ClusteringError, Result};
32
33#[inline]
39fn sq_euclid(a: &[f64], b: &[f64]) -> f64 {
40 a.iter().zip(b.iter()).map(|(x, y)| (x - y) * (x - y)).sum()
41}
42
43#[inline]
45fn euclid(a: &[f64], b: &[f64]) -> f64 {
46 sq_euclid(a, b).sqrt()
47}
48
49struct Lcg(u64);
51
52impl Lcg {
53 fn new(seed: u64) -> Self {
54 Self(seed)
55 }
56
57 fn next_f64(&mut self) -> f64 {
59 self.0 = self
60 .0
61 .wrapping_mul(6_364_136_223_846_793_005)
62 .wrapping_add(1_442_695_040_888_963_407);
63 (self.0 >> 11) as f64 / (1u64 << 53) as f64
64 }
65
66 fn next_usize(&mut self, n: usize) -> usize {
68 (self.next_f64() * n as f64) as usize % n
69 }
70
71 fn shuffle(&mut self, v: &mut [usize]) {
73 for i in (1..v.len()).rev() {
74 let j = self.next_usize(i + 1);
75 v.swap(i, j);
76 }
77 }
78}
79
80fn find_bmu(input: &[f64], prototypes: &[Vec<f64>]) -> usize {
82 prototypes
83 .iter()
84 .enumerate()
85 .min_by(|(_, a), (_, b)| {
86 sq_euclid(input, a)
87 .partial_cmp(&sq_euclid(input, b))
88 .unwrap_or(std::cmp::Ordering::Equal)
89 })
90 .map(|(i, _)| i)
91 .unwrap_or(0)
92}
93
94#[derive(Debug, Clone)]
106pub struct WinnerTakeAll {
107 pub lr_init: f64,
109 pub lr_final: Option<f64>,
111 pub max_epochs: usize,
113 pub seed: u64,
115}
116
117impl Default for WinnerTakeAll {
118 fn default() -> Self {
119 Self {
120 lr_init: 0.3,
121 lr_final: None,
122 max_epochs: 100,
123 seed: 42,
124 }
125 }
126}
127
128impl WinnerTakeAll {
129 pub fn new(lr_init: f64, lr_final: Option<f64>, max_epochs: usize, seed: u64) -> Self {
131 Self {
132 lr_init,
133 lr_final,
134 max_epochs,
135 seed,
136 }
137 }
138
139 pub fn fit(&self, data: ArrayView2<f64>, n_prototypes: usize) -> Result<Array2<f64>> {
150 let (n_samples, n_features) = (data.shape()[0], data.shape()[1]);
151
152 if n_samples == 0 {
153 return Err(ClusteringError::InvalidInput("Empty input data".into()));
154 }
155 if n_prototypes == 0 {
156 return Err(ClusteringError::InvalidInput(
157 "n_prototypes must be > 0".into(),
158 ));
159 }
160 if n_features == 0 {
161 return Err(ClusteringError::InvalidInput(
162 "Data must have at least one feature".into(),
163 ));
164 }
165
166 let mut rng = Lcg::new(self.seed);
167
168 let mut prototypes: Vec<Vec<f64>> = (0..n_prototypes)
170 .map(|_| {
171 let idx = rng.next_usize(n_samples);
172 data.row(idx).to_vec()
173 })
174 .collect();
175
176 let total_steps = self.max_epochs * n_samples;
177 let mut order: Vec<usize> = (0..n_samples).collect();
178
179 for epoch in 0..self.max_epochs {
180 rng.shuffle(&mut order);
181
182 for (step_in_epoch, &sample_idx) in order.iter().enumerate() {
183 let global_step = epoch * n_samples + step_in_epoch;
184 let t = global_step as f64 / total_steps.max(1) as f64;
185
186 let lr = match self.lr_final {
187 Some(lr_f) => self.lr_init + t * (lr_f - self.lr_init),
188 None => self.lr_init,
189 };
190
191 let input = data.row(sample_idx).to_vec();
192 let bmu_idx = find_bmu(&input, &prototypes);
193
194 let bmu = &mut prototypes[bmu_idx];
195 for k in 0..n_features {
196 bmu[k] += lr * (input[k] - bmu[k]);
197 }
198 }
199 }
200
201 let mut out = Array2::<f64>::zeros((n_prototypes, n_features));
203 for (j, p) in prototypes.iter().enumerate() {
204 for k in 0..n_features {
205 out[[j, k]] = p[k];
206 }
207 }
208 Ok(out)
209 }
210
211 pub fn predict(&self, data: ArrayView2<f64>, prototypes: &Array2<f64>) -> Array1<usize> {
215 let n_samples = data.shape()[0];
216 let n_proto = prototypes.shape()[0];
217 let protos: Vec<Vec<f64>> = (0..n_proto).map(|j| prototypes.row(j).to_vec()).collect();
218
219 let labels: Vec<usize> = (0..n_samples)
220 .map(|i| {
221 let row = data.row(i).to_vec();
222 find_bmu(&row, &protos)
223 })
224 .collect();
225
226 Array1::from_vec(labels)
227 }
228}
229
230#[derive(Debug, Clone)]
236pub struct LvqModel {
237 pub prototypes: Array2<f64>,
239 pub labels: Array1<usize>,
241}
242
243impl LvqModel {
244 pub fn predict_one(&self, sample: &[f64]) -> usize {
246 let n_proto = self.prototypes.shape()[0];
247 let best = (0..n_proto)
248 .map(|j| {
249 let p = self.prototypes.row(j).to_vec();
250 (j, sq_euclid(sample, &p))
251 })
252 .min_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal))
253 .map(|(j, _)| j)
254 .unwrap_or(0);
255 self.labels[best]
256 }
257
258 pub fn predict(&self, data: ArrayView2<f64>) -> Array1<usize> {
260 let n = data.shape()[0];
261 let preds: Vec<usize> = (0..n)
262 .map(|i| self.predict_one(&data.row(i).to_vec()))
263 .collect();
264 Array1::from_vec(preds)
265 }
266}
267
268#[derive(Debug, Clone)]
277pub struct LearningVectorQuantization {
278 pub n_prototypes_per_class: usize,
280 pub lr_init: f64,
282 pub lr_final: f64,
284 pub max_epochs: usize,
286 pub seed: u64,
288}
289
290impl Default for LearningVectorQuantization {
291 fn default() -> Self {
292 Self {
293 n_prototypes_per_class: 1,
294 lr_init: 0.1,
295 lr_final: 0.001,
296 max_epochs: 50,
297 seed: 42,
298 }
299 }
300}
301
302impl LearningVectorQuantization {
303 pub fn new(
305 n_prototypes_per_class: usize,
306 lr_init: f64,
307 lr_final: f64,
308 max_epochs: usize,
309 seed: u64,
310 ) -> Self {
311 Self {
312 n_prototypes_per_class,
313 lr_init,
314 lr_final,
315 max_epochs,
316 seed,
317 }
318 }
319
320 pub fn fit(&self, data: ArrayView2<f64>, labels: &[usize]) -> Result<LvqModel> {
329 let (n_samples, n_features) = (data.shape()[0], data.shape()[1]);
330
331 if n_samples == 0 {
332 return Err(ClusteringError::InvalidInput("Empty input data".into()));
333 }
334 if labels.len() != n_samples {
335 return Err(ClusteringError::InvalidInput(
336 "labels length must equal number of data rows".into(),
337 ));
338 }
339 if self.n_prototypes_per_class == 0 {
340 return Err(ClusteringError::InvalidInput(
341 "n_prototypes_per_class must be > 0".into(),
342 ));
343 }
344
345 let n_classes = labels.iter().cloned().max().map(|m| m + 1).unwrap_or(0);
346 if n_classes == 0 {
347 return Err(ClusteringError::InvalidInput(
348 "No valid class labels found".into(),
349 ));
350 }
351
352 let mut class_samples: Vec<Vec<usize>> = vec![Vec::new(); n_classes];
354 for (i, &lbl) in labels.iter().enumerate() {
355 if lbl < n_classes {
356 class_samples[lbl].push(i);
357 }
358 }
359
360 let mut rng = Lcg::new(self.seed);
361
362 let mut proto_weights: Vec<Vec<f64>> = Vec::new();
364 let mut proto_labels: Vec<usize> = Vec::new();
365
366 for cls in 0..n_classes {
367 let samples = &class_samples[cls];
368 if samples.is_empty() {
369 continue;
371 }
372 for _ in 0..self.n_prototypes_per_class {
373 let idx = samples[rng.next_usize(samples.len())];
374 proto_weights.push(data.row(idx).to_vec());
375 proto_labels.push(cls);
376 }
377 }
378
379 if proto_weights.is_empty() {
380 return Err(ClusteringError::ComputationError(
381 "Could not initialise any prototypes".into(),
382 ));
383 }
384
385 let n_proto = proto_weights.len();
386 let total_steps = self.max_epochs * n_samples;
387 let mut order: Vec<usize> = (0..n_samples).collect();
388
389 for epoch in 0..self.max_epochs {
391 rng.shuffle(&mut order);
392
393 for (step_in_epoch, &sample_idx) in order.iter().enumerate() {
394 let global_step = epoch * n_samples + step_in_epoch;
395 let t = global_step as f64 / total_steps.max(1) as f64;
396 let lr = self.lr_init * (self.lr_final / self.lr_init).powf(t);
397
398 let input = data.row(sample_idx).to_vec();
399 let true_class = labels[sample_idx];
400
401 let nearest = (0..n_proto)
403 .map(|j| (j, sq_euclid(&input, &proto_weights[j])))
404 .min_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal))
405 .map(|(j, _)| j)
406 .unwrap_or(0);
407
408 let sign = if proto_labels[nearest] == true_class {
410 1.0f64
411 } else {
412 -1.0f64
413 };
414
415 let w = &mut proto_weights[nearest];
416 for k in 0..n_features {
417 w[k] += lr * sign * (input[k] - w[k]);
418 }
419 }
420 }
421
422 let mut proto_arr = Array2::<f64>::zeros((n_proto, n_features));
424 for (j, w) in proto_weights.iter().enumerate() {
425 for k in 0..n_features {
426 proto_arr[[j, k]] = w[k];
427 }
428 }
429
430 Ok(LvqModel {
431 prototypes: proto_arr,
432 labels: Array1::from_vec(proto_labels),
433 })
434 }
435
436 pub fn predict(model: &LvqModel, data: ArrayView2<f64>) -> Array1<usize> {
438 model.predict(data)
439 }
440}
441
442#[derive(Debug, Clone)]
448pub struct NeuralGasModel {
449 pub prototypes: Array2<f64>,
451 pub labels: Array1<usize>,
453 pub quantization_error: f64,
455}
456
457#[derive(Debug, Clone)]
466pub struct NeuralGas {
467 pub lr_winner: f64,
469 pub lr_final: f64,
471 pub lambda_init: Option<f64>,
473 pub lambda_final: f64,
475 pub max_epochs: usize,
477 pub seed: u64,
479}
480
481impl Default for NeuralGas {
482 fn default() -> Self {
483 Self {
484 lr_winner: 0.5,
485 lr_final: 0.01,
486 lambda_init: None,
487 lambda_final: 0.01,
488 max_epochs: 100,
489 seed: 42,
490 }
491 }
492}
493
494impl NeuralGas {
495 pub fn new(
497 lr_winner: f64,
498 lr_final: f64,
499 lambda_init: Option<f64>,
500 lambda_final: f64,
501 max_epochs: usize,
502 seed: u64,
503 ) -> Self {
504 Self {
505 lr_winner,
506 lr_final,
507 lambda_init,
508 lambda_final,
509 max_epochs,
510 seed,
511 }
512 }
513
514 pub fn fit(&self, data: ArrayView2<f64>, n_neurons: usize) -> Result<NeuralGasModel> {
523 let (n_samples, n_features) = (data.shape()[0], data.shape()[1]);
524
525 if n_samples == 0 {
526 return Err(ClusteringError::InvalidInput("Empty input data".into()));
527 }
528 if n_neurons == 0 {
529 return Err(ClusteringError::InvalidInput(
530 "n_neurons must be > 0".into(),
531 ));
532 }
533 if self.max_epochs == 0 {
534 return Err(ClusteringError::InvalidInput(
535 "max_epochs must be > 0".into(),
536 ));
537 }
538
539 let mut rng = Lcg::new(self.seed);
540
541 let mut prototypes: Vec<Vec<f64>> = (0..n_neurons)
543 .map(|_| {
544 let idx = rng.next_usize(n_samples);
545 data.row(idx).to_vec()
546 })
547 .collect();
548
549 let total_steps = self.max_epochs * n_samples;
550 let lambda_i = self.lambda_init.unwrap_or(n_neurons as f64 / 2.0).max(0.5);
551 let mut order: Vec<usize> = (0..n_samples).collect();
552
553 for epoch in 0..self.max_epochs {
554 rng.shuffle(&mut order);
555
556 for (step_in_epoch, &sample_idx) in order.iter().enumerate() {
557 let global_step = epoch * n_samples + step_in_epoch;
558 let t = global_step as f64 / total_steps.max(1) as f64;
559
560 let lr = self.lr_winner * (self.lr_final / self.lr_winner).powf(t);
562 let lam = lambda_i * (self.lambda_final / lambda_i).powf(t);
563
564 let input = data.row(sample_idx).to_vec();
565
566 let mut ranked: Vec<(f64, usize)> = prototypes
568 .iter()
569 .enumerate()
570 .map(|(j, p)| (euclid(&input, p), j))
571 .collect();
572 ranked.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal));
573
574 for (rank, (_, proto_idx)) in ranked.iter().enumerate() {
576 let h = (-(rank as f64) / lam).exp();
577 let p = &mut prototypes[*proto_idx];
578 for k in 0..n_features {
579 p[k] += lr * h * (input[k] - p[k]);
580 }
581 }
582 }
583 }
584
585 let mut labels_vec = vec![0usize; n_samples];
587 let mut total_qe = 0.0f64;
588 for i in 0..n_samples {
589 let row = data.row(i).to_vec();
590 let (best, best_dist) = prototypes
591 .iter()
592 .enumerate()
593 .map(|(j, p)| (j, sq_euclid(&row, p)))
594 .min_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal))
595 .unwrap_or((0, 0.0));
596 labels_vec[i] = best;
597 total_qe += best_dist;
598 }
599 let quantization_error = if n_samples > 0 {
600 total_qe / n_samples as f64
601 } else {
602 0.0
603 };
604
605 let mut proto_arr = Array2::<f64>::zeros((n_neurons, n_features));
607 for (j, p) in prototypes.iter().enumerate() {
608 for k in 0..n_features {
609 proto_arr[[j, k]] = p[k];
610 }
611 }
612
613 Ok(NeuralGasModel {
614 prototypes: proto_arr,
615 labels: Array1::from_vec(labels_vec),
616 quantization_error,
617 })
618 }
619}
620
621#[derive(Debug, Clone)]
627struct GngEdge {
628 age: usize,
629}
630
631#[derive(Debug, Clone)]
633struct GngNode {
634 weights: Vec<f64>,
635 error: f64,
636}
637
638#[derive(Debug, Clone)]
640pub struct GrowingNeuralGasModel {
641 pub prototypes: Array2<f64>,
643 pub edges: Vec<(usize, usize)>,
645 pub labels: Array1<usize>,
647 pub quantization_error: f64,
649}
650
651#[derive(Debug, Clone)]
659pub struct GrowingNeuralGas {
660 pub lr_winner: f64,
662 pub lr_neighbor: f64,
664 pub max_age: usize,
666 pub insert_interval: usize,
668 pub alpha: f64,
670 pub beta: f64,
672 pub max_units: usize,
674 pub max_steps: usize,
676 pub seed: u64,
678}
679
680impl Default for GrowingNeuralGas {
681 fn default() -> Self {
682 Self {
683 lr_winner: 0.1,
684 lr_neighbor: 0.01,
685 max_age: 50,
686 insert_interval: 100,
687 alpha: 0.5,
688 beta: 0.995,
689 max_units: 200,
690 max_steps: 5000,
691 seed: 42,
692 }
693 }
694}
695
696impl GrowingNeuralGas {
697 #[allow(clippy::too_many_arguments)]
699 pub fn new(
700 lr_winner: f64,
701 lr_neighbor: f64,
702 max_age: usize,
703 insert_interval: usize,
704 alpha: f64,
705 beta: f64,
706 max_units: usize,
707 max_steps: usize,
708 seed: u64,
709 ) -> Self {
710 Self {
711 lr_winner,
712 lr_neighbor,
713 max_age,
714 insert_interval,
715 alpha,
716 beta,
717 max_units,
718 max_steps,
719 seed,
720 }
721 }
722
723 pub fn fit(&self, data: ArrayView2<f64>) -> Result<GrowingNeuralGasModel> {
731 let (n_samples, n_features) = (data.shape()[0], data.shape()[1]);
732
733 if n_samples < 2 {
734 return Err(ClusteringError::InvalidInput(
735 "GNG requires at least 2 samples".into(),
736 ));
737 }
738 if n_features == 0 {
739 return Err(ClusteringError::InvalidInput(
740 "Data must have at least one feature".into(),
741 ));
742 }
743
744 let mut rng = Lcg::new(self.seed);
745
746 let idx0 = rng.next_usize(n_samples);
748 let idx1 = (idx0 + 1 + rng.next_usize(n_samples.saturating_sub(1).max(1))) % n_samples;
749
750 let mut nodes: Vec<GngNode> = vec![
751 GngNode {
752 weights: data.row(idx0).to_vec(),
753 error: 0.0,
754 },
755 GngNode {
756 weights: data.row(idx1).to_vec(),
757 error: 0.0,
758 },
759 ];
760
761 let mut edge_map: std::collections::HashMap<(usize, usize), GngEdge> =
763 std::collections::HashMap::new();
764 edge_map.insert((0, 1), GngEdge { age: 0 });
765
766 let data_rows: Vec<Vec<f64>> = (0..n_samples).map(|i| data.row(i).to_vec()).collect();
767
768 for step in 0..self.max_steps {
769 let sample = &data_rows[rng.next_usize(n_samples)];
770
771 let mut dists: Vec<(f64, usize)> = nodes
773 .iter()
774 .enumerate()
775 .map(|(j, n)| (sq_euclid(sample, &n.weights), j))
776 .collect();
777 dists.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal));
778
779 if dists.len() < 2 {
780 continue;
781 }
782
783 let s1 = dists[0].1;
784 let s2 = dists[1].1;
785 let dist_s1 = dists[0].0;
786
787 let edge_keys: Vec<(usize, usize)> = edge_map.keys().cloned().collect();
789 for key in &edge_keys {
790 if key.0 == s1 || key.1 == s1 {
791 if let Some(e) = edge_map.get_mut(key) {
792 e.age += 1;
793 }
794 }
795 }
796
797 let edge_key = if s1 < s2 { (s1, s2) } else { (s2, s1) };
799 edge_map.insert(edge_key, GngEdge { age: 0 });
800
801 nodes[s1].error += dist_s1;
803
804 for k in 0..n_features {
806 let delta = sample[k] - nodes[s1].weights[k];
807 nodes[s1].weights[k] += self.lr_winner * delta;
808 }
809
810 let neighbours: Vec<usize> = edge_map
812 .keys()
813 .filter_map(|&(a, b)| {
814 if a == s1 {
815 Some(b)
816 } else if b == s1 {
817 Some(a)
818 } else {
819 None
820 }
821 })
822 .collect();
823
824 for nb in &neighbours {
825 for k in 0..n_features {
826 let delta = sample[k] - nodes[*nb].weights[k];
827 nodes[*nb].weights[k] += self.lr_neighbor * delta;
828 }
829 }
830
831 edge_map.retain(|_, e| e.age <= self.max_age);
833
834 for node in nodes.iter_mut() {
836 node.error *= self.beta;
837 }
838
839 if step > 0
841 && step % self.insert_interval == 0
842 && nodes.len() < self.max_units
843 && nodes.len() >= 2
844 {
845 let q = nodes
847 .iter()
848 .enumerate()
849 .max_by(|a, b| {
850 a.1.error
851 .partial_cmp(&b.1.error)
852 .unwrap_or(std::cmp::Ordering::Equal)
853 })
854 .map(|(i, _)| i)
855 .unwrap_or(0);
856
857 let q_neighbours: Vec<usize> = edge_map
859 .keys()
860 .filter_map(|&(a, b)| {
861 if a == q {
862 Some(b)
863 } else if b == q {
864 Some(a)
865 } else {
866 None
867 }
868 })
869 .collect();
870
871 if !q_neighbours.is_empty() {
872 let f = q_neighbours
873 .iter()
874 .max_by(|&&a, &&b| {
875 nodes[a]
876 .error
877 .partial_cmp(&nodes[b].error)
878 .unwrap_or(std::cmp::Ordering::Equal)
879 })
880 .cloned()
881 .unwrap_or(q_neighbours[0]);
882
883 let new_weights: Vec<f64> = nodes[q]
885 .weights
886 .iter()
887 .zip(nodes[f].weights.iter())
888 .map(|(a, b)| (a + b) / 2.0)
889 .collect();
890
891 let new_idx = nodes.len();
892 let new_error = nodes[q].error * self.alpha;
893 nodes.push(GngNode {
894 weights: new_weights,
895 error: new_error,
896 });
897
898 nodes[q].error *= self.alpha;
899 nodes[f].error *= self.alpha;
900
901 let qf_key = if q < f { (q, f) } else { (f, q) };
903 edge_map.remove(&qf_key);
904
905 let qn_key = if q < new_idx {
906 (q, new_idx)
907 } else {
908 (new_idx, q)
909 };
910 let fn_key = if f < new_idx {
911 (f, new_idx)
912 } else {
913 (new_idx, f)
914 };
915 edge_map.insert(qn_key, GngEdge { age: 0 });
916 edge_map.insert(fn_key, GngEdge { age: 0 });
917 }
918 }
919 }
920
921 let n_units = nodes.len();
922 let mut proto_arr = Array2::<f64>::zeros((n_units, n_features));
923 for (j, node) in nodes.iter().enumerate() {
924 for k in 0..n_features {
925 proto_arr[[j, k]] = node.weights[k];
926 }
927 }
928
929 let edges: Vec<(usize, usize)> = edge_map.keys().cloned().collect();
930
931 let mut labels_vec = vec![0usize; n_samples];
933 let mut total_qe = 0.0f64;
934 for i in 0..n_samples {
935 let row = data_rows[i].as_slice();
936 let (best, best_dist) = nodes
937 .iter()
938 .enumerate()
939 .map(|(j, node)| (j, sq_euclid(row, &node.weights)))
940 .min_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal))
941 .unwrap_or((0, 0.0));
942 labels_vec[i] = best;
943 total_qe += best_dist;
944 }
945 let quantization_error = if n_samples > 0 {
946 total_qe / n_samples as f64
947 } else {
948 0.0
949 };
950
951 Ok(GrowingNeuralGasModel {
952 prototypes: proto_arr,
953 edges,
954 labels: Array1::from_vec(labels_vec),
955 quantization_error,
956 })
957 }
958}
959
960#[cfg(test)]
965mod tests {
966 use super::*;
967 use scirs2_core::ndarray::Array2;
968
969 fn two_cluster_data() -> (Array2<f64>, Vec<usize>) {
971 let vals = vec![
972 0.00, 0.00, 0.10, 0.00, 0.00, 0.10, 0.10, 0.10, 0.05, 0.05, -0.05, 0.05, -0.05, -0.05,
974 0.10, -0.05, 0.00, 0.15, -0.10, 0.00, 0.15, 0.10, 0.00, 0.20,
975 5.00, 5.00, 5.10, 5.00, 5.00, 5.10, 5.10, 5.10, 5.05, 5.05, 4.95, 5.05, 4.95, 4.95,
977 5.10, 4.95, 5.00, 5.15, 4.90, 5.00, 5.15, 5.10, 5.00, 5.20,
978 ];
979 let x = Array2::from_shape_vec((24, 2), vals).expect("shape ok");
980 let y: Vec<usize> = (0..12).map(|_| 0).chain((0..12).map(|_| 1)).collect();
981 (x, y)
982 }
983
984 #[test]
987 fn test_wta_basic() {
988 let (x, _) = two_cluster_data();
989 let wta = WinnerTakeAll::default();
990 let protos = wta.fit(x.view(), 2).expect("fit");
991 assert_eq!(protos.shape(), [2, 2]);
992 }
993
994 #[test]
995 fn test_wta_single_prototype() {
996 let (x, _) = two_cluster_data();
997 let wta = WinnerTakeAll::default();
998 let protos = wta.fit(x.view(), 1).expect("fit");
999 assert_eq!(protos.shape(), [1, 2]);
1000 }
1001
1002 #[test]
1003 fn test_wta_annealing() {
1004 let (x, _) = two_cluster_data();
1005 let wta = WinnerTakeAll {
1006 lr_init: 0.5,
1007 lr_final: Some(0.001),
1008 max_epochs: 50,
1009 seed: 7,
1010 };
1011 let protos = wta.fit(x.view(), 2).expect("fit annealing");
1012 assert_eq!(protos.shape()[0], 2);
1013 }
1014
1015 #[test]
1016 fn test_wta_predict() {
1017 let (x, _) = two_cluster_data();
1018 let wta = WinnerTakeAll::default();
1019 let protos = wta.fit(x.view(), 2).expect("fit");
1020 let labels = wta.predict(x.view(), &protos);
1021 assert_eq!(labels.len(), 24);
1022 assert!(labels.iter().all(|&l| l < 2));
1023 }
1024
1025 #[test]
1026 fn test_wta_converges_two_clusters() {
1027 let (x, _) = two_cluster_data();
1028 let wta = WinnerTakeAll {
1029 lr_init: 0.5,
1030 lr_final: Some(0.01),
1031 max_epochs: 200,
1032 seed: 42,
1033 };
1034 let protos = wta.fit(x.view(), 2).expect("fit");
1035 let p0 = protos.row(0).to_vec();
1037 let p1 = protos.row(1).to_vec();
1038 let d00 = sq_euclid(&p0, &[0.0, 0.0]);
1039 let d05 = sq_euclid(&p0, &[5.0, 5.0]);
1040 let d10 = sq_euclid(&p1, &[0.0, 0.0]);
1041 let d15 = sq_euclid(&p1, &[5.0, 5.0]);
1042 let well_placed = (d00 < d05 && d15 < d10) || (d05 < d00 && d10 < d15);
1043 assert!(well_placed, "prototypes should converge to cluster centres");
1044 }
1045
1046 #[test]
1047 fn test_wta_error_empty_data() {
1048 let x = Array2::<f64>::zeros((0, 2));
1049 let wta = WinnerTakeAll::default();
1050 assert!(wta.fit(x.view(), 2).is_err());
1051 }
1052
1053 #[test]
1054 fn test_wta_error_zero_prototypes() {
1055 let (x, _) = two_cluster_data();
1056 let wta = WinnerTakeAll::default();
1057 assert!(wta.fit(x.view(), 0).is_err());
1058 }
1059
1060 #[test]
1063 fn test_lvq_fit_basic() {
1064 let (x, y) = two_cluster_data();
1065 let lvq = LearningVectorQuantization::default();
1066 let model = lvq.fit(x.view(), &y).expect("fit");
1067 assert_eq!(model.prototypes.shape()[0], 2); assert_eq!(model.labels.len(), 2);
1069 }
1070
1071 #[test]
1072 fn test_lvq_predict() {
1073 let (x, y) = two_cluster_data();
1074 let lvq = LearningVectorQuantization::default();
1075 let model = lvq.fit(x.view(), &y).expect("fit");
1076 let preds = model.predict(x.view());
1077 assert_eq!(preds.len(), 24);
1078 let correct = preds.iter().zip(y.iter()).filter(|(&p, &t)| p == t).count();
1080 assert!(
1081 correct as f64 / 24.0 > 0.75,
1082 "accuracy should exceed 75%, got {}",
1083 correct
1084 );
1085 }
1086
1087 #[test]
1088 fn test_lvq_predict_one() {
1089 let (x, y) = two_cluster_data();
1090 let lvq = LearningVectorQuantization::default();
1091 let model = lvq.fit(x.view(), &y).expect("fit");
1092 let pred = model.predict_one(&[0.0, 0.0]);
1093 assert_eq!(pred, 0, "origin should map to class 0");
1094 let pred2 = model.predict_one(&[5.0, 5.0]);
1095 assert_eq!(pred2, 1, "(5,5) should map to class 1");
1096 }
1097
1098 #[test]
1099 fn test_lvq_function_predict() {
1100 let (x, y) = two_cluster_data();
1101 let lvq = LearningVectorQuantization::default();
1102 let model = lvq.fit(x.view(), &y).expect("fit");
1103 let preds = LearningVectorQuantization::predict(&model, x.view());
1104 assert_eq!(preds.len(), 24);
1105 }
1106
1107 #[test]
1108 fn test_lvq_multi_proto_per_class() {
1109 let (x, y) = two_cluster_data();
1110 let lvq = LearningVectorQuantization::new(2, 0.1, 0.001, 50, 42);
1111 let model = lvq.fit(x.view(), &y).expect("fit");
1112 assert_eq!(model.prototypes.shape()[0], 4); }
1114
1115 #[test]
1116 fn test_lvq_error_label_mismatch() {
1117 let (x, _) = two_cluster_data();
1118 let lvq = LearningVectorQuantization::default();
1119 assert!(lvq.fit(x.view(), &[0, 1, 2]).is_err());
1120 }
1121
1122 #[test]
1125 fn test_ng_basic() {
1126 let (x, _) = two_cluster_data();
1127 let ng = NeuralGas::default();
1128 let model = ng.fit(x.view(), 2).expect("fit");
1129 assert_eq!(model.prototypes.shape(), [2, 2]);
1130 assert_eq!(model.labels.len(), 24);
1131 assert!(model.quantization_error >= 0.0);
1132 }
1133
1134 #[test]
1135 fn test_ng_single_neuron() {
1136 let (x, _) = two_cluster_data();
1137 let ng = NeuralGas::default();
1138 let model = ng.fit(x.view(), 1).expect("fit");
1139 assert_eq!(model.prototypes.shape()[0], 1);
1140 assert!(model.labels.iter().all(|&l| l == 0));
1141 }
1142
1143 #[test]
1144 fn test_ng_converges() {
1145 let (x, _) = two_cluster_data();
1146 let ng = NeuralGas {
1147 lr_winner: 0.5,
1148 lr_final: 0.01,
1149 lambda_init: None,
1150 lambda_final: 0.01,
1151 max_epochs: 200,
1152 seed: 42,
1153 };
1154 let model = ng.fit(x.view(), 2).expect("fit");
1155 assert!(
1156 model.quantization_error < 1.0,
1157 "QE={} too high",
1158 model.quantization_error
1159 );
1160 }
1161
1162 #[test]
1163 fn test_ng_error_empty() {
1164 let x = Array2::<f64>::zeros((0, 2));
1165 let ng = NeuralGas::default();
1166 assert!(ng.fit(x.view(), 2).is_err());
1167 }
1168
1169 #[test]
1170 fn test_ng_error_zero_neurons() {
1171 let (x, _) = two_cluster_data();
1172 let ng = NeuralGas::default();
1173 assert!(ng.fit(x.view(), 0).is_err());
1174 }
1175
1176 #[test]
1179 fn test_gng_basic() {
1180 let (x, _) = two_cluster_data();
1181 let gng = GrowingNeuralGas {
1182 max_steps: 300,
1183 insert_interval: 30,
1184 max_units: 15,
1185 seed: 7,
1186 ..GrowingNeuralGas::default()
1187 };
1188 let model = gng.fit(x.view()).expect("fit");
1189 assert!(
1190 model.prototypes.shape()[0] >= 2,
1191 "should have at least initial units"
1192 );
1193 assert_eq!(model.labels.len(), 24);
1194 assert!(model.quantization_error >= 0.0);
1195 }
1196
1197 #[test]
1198 fn test_gng_grows_units() {
1199 let (x, _) = two_cluster_data();
1200 let gng = GrowingNeuralGas {
1201 max_steps: 1000,
1202 insert_interval: 50,
1203 max_units: 20,
1204 seed: 99,
1205 ..GrowingNeuralGas::default()
1206 };
1207 let model = gng.fit(x.view()).expect("fit");
1208 assert!(
1210 model.prototypes.shape()[0] >= 2,
1211 "GNG should grow beyond initial 2 units"
1212 );
1213 }
1214
1215 #[test]
1216 fn test_gng_edges_valid() {
1217 let (x, _) = two_cluster_data();
1218 let gng = GrowingNeuralGas {
1219 max_steps: 500,
1220 seed: 42,
1221 ..GrowingNeuralGas::default()
1222 };
1223 let model = gng.fit(x.view()).expect("fit");
1224 let n_units = model.prototypes.shape()[0];
1225 for &(a, b) in &model.edges {
1227 assert!(a < n_units, "edge endpoint {} out of range", a);
1228 assert!(b < n_units, "edge endpoint {} out of range", b);
1229 assert_ne!(a, b, "self-loop detected");
1230 }
1231 }
1232
1233 #[test]
1234 fn test_gng_error_too_few_samples() {
1235 let x = Array2::from_shape_vec((1, 2), vec![1.0, 2.0]).expect("shape ok");
1236 let gng = GrowingNeuralGas::default();
1237 assert!(gng.fit(x.view()).is_err());
1238 }
1239}