Skip to main content

scirs2_cluster/
competitive_learning.rs

1//! Competitive learning algorithms for prototype-based clustering.
2//!
3//! This module provides function-oriented and struct-based competitive learning algorithms
4//! that complement the lower-level implementations in [`crate::prototype_enhanced`].
5//!
6//! # Algorithms
7//!
8//! - [`WinnerTakeAll`] – Basic competitive learning via BMU (Best Matching Unit) updates.
9//! - [`LearningVectorQuantization`] – Supervised LVQ-1 prototype learning.
10//! - [`NeuralGas`] – Topology-preserving competitive learning with rank-based neighbourhood.
11//! - [`GrowingNeuralGas`] – Adaptive-topology neural gas that grows its unit graph dynamically.
12//!
13//! # Examples
14//!
15//! ```
16//! use scirs2_core::ndarray::Array2;
17//! use scirs2_cluster::competitive_learning::WinnerTakeAll;
18//!
19//! let data = Array2::from_shape_vec((8, 2), vec![
20//!     0.0, 0.0,  0.1, 0.0,  0.0, 0.1,  0.1, 0.1,
21//!     5.0, 5.0,  5.1, 5.0,  5.0, 5.1,  5.1, 5.1,
22//! ]).expect("shape ok");
23//!
24//! let wta = WinnerTakeAll::default();
25//! let prototypes = wta.fit(data.view(), 2).expect("fit ok");
26//! assert_eq!(prototypes.shape(), [2, 2]);
27//! ```
28
29use scirs2_core::ndarray::{Array1, Array2, ArrayView2};
30
31use crate::error::{ClusteringError, Result};
32
33// ---------------------------------------------------------------------------
34// Internal helpers
35// ---------------------------------------------------------------------------
36
37/// Squared Euclidean distance between two slices.
38#[inline]
39fn sq_euclid(a: &[f64], b: &[f64]) -> f64 {
40    a.iter().zip(b.iter()).map(|(x, y)| (x - y) * (x - y)).sum()
41}
42
43/// Euclidean distance between two slices.
44#[inline]
45fn euclid(a: &[f64], b: &[f64]) -> f64 {
46    sq_euclid(a, b).sqrt()
47}
48
49/// Minimal Linear Congruential Generator for deterministic reproducibility.
50struct Lcg(u64);
51
52impl Lcg {
53    fn new(seed: u64) -> Self {
54        Self(seed)
55    }
56
57    /// Advance state and return a value in `[0, 1)`.
58    fn next_f64(&mut self) -> f64 {
59        self.0 = self
60            .0
61            .wrapping_mul(6_364_136_223_846_793_005)
62            .wrapping_add(1_442_695_040_888_963_407);
63        (self.0 >> 11) as f64 / (1u64 << 53) as f64
64    }
65
66    /// Return a random `usize` in `[0, n)`.
67    fn next_usize(&mut self, n: usize) -> usize {
68        (self.next_f64() * n as f64) as usize % n
69    }
70
71    /// Fisher-Yates shuffle of a slice.
72    fn shuffle(&mut self, v: &mut [usize]) {
73        for i in (1..v.len()).rev() {
74            let j = self.next_usize(i + 1);
75            v.swap(i, j);
76        }
77    }
78}
79
80/// Return the index of the Best Matching Unit (BMU) — the prototype nearest to `input`.
81fn find_bmu(input: &[f64], prototypes: &[Vec<f64>]) -> usize {
82    prototypes
83        .iter()
84        .enumerate()
85        .min_by(|(_, a), (_, b)| {
86            sq_euclid(input, a)
87                .partial_cmp(&sq_euclid(input, b))
88                .unwrap_or(std::cmp::Ordering::Equal)
89        })
90        .map(|(i, _)| i)
91        .unwrap_or(0)
92}
93
94// ---------------------------------------------------------------------------
95// WinnerTakeAll
96// ---------------------------------------------------------------------------
97
98/// Basic competitive learning — Winner-Take-All (WTA).
99///
100/// Each training step selects the prototype (BMU) closest to the input sample
101/// and moves it toward that sample by `learning_rate`.  The learning rate is
102/// optionally annealed from `lr_init` to `lr_final` over the training run.
103///
104/// After training, `fit` returns the learned prototype matrix.
105#[derive(Debug, Clone)]
106pub struct WinnerTakeAll {
107    /// Initial (or constant) learning rate.  Default 0.3.
108    pub lr_init: f64,
109    /// Final learning rate for linear annealing.  `None` means no annealing.  Default `None`.
110    pub lr_final: Option<f64>,
111    /// Number of training epochs (full passes over the dataset).  Default 100.
112    pub max_epochs: usize,
113    /// RNG seed for reproducibility.  Default 42.
114    pub seed: u64,
115}
116
117impl Default for WinnerTakeAll {
118    fn default() -> Self {
119        Self {
120            lr_init: 0.3,
121            lr_final: None,
122            max_epochs: 100,
123            seed: 42,
124        }
125    }
126}
127
128impl WinnerTakeAll {
129    /// Create a `WinnerTakeAll` with all options specified.
130    pub fn new(lr_init: f64, lr_final: Option<f64>, max_epochs: usize, seed: u64) -> Self {
131        Self {
132            lr_init,
133            lr_final,
134            max_epochs,
135            seed,
136        }
137    }
138
139    /// Fit the WTA network on `data`, producing `n_prototypes` learned prototypes.
140    ///
141    /// Prototypes are initialised by sampling distinct data points uniformly at random.
142    ///
143    /// # Arguments
144    /// * `data` – Data matrix `(n_samples, n_features)`.
145    /// * `n_prototypes` – Number of prototype vectors to learn.
146    ///
147    /// # Returns
148    /// `Array2<f64>` of shape `(n_prototypes, n_features)`.
149    pub fn fit(&self, data: ArrayView2<f64>, n_prototypes: usize) -> Result<Array2<f64>> {
150        let (n_samples, n_features) = (data.shape()[0], data.shape()[1]);
151
152        if n_samples == 0 {
153            return Err(ClusteringError::InvalidInput("Empty input data".into()));
154        }
155        if n_prototypes == 0 {
156            return Err(ClusteringError::InvalidInput(
157                "n_prototypes must be > 0".into(),
158            ));
159        }
160        if n_features == 0 {
161            return Err(ClusteringError::InvalidInput(
162                "Data must have at least one feature".into(),
163            ));
164        }
165
166        let mut rng = Lcg::new(self.seed);
167
168        // Initialise prototypes by sampling (with replacement if necessary) from data.
169        let mut prototypes: Vec<Vec<f64>> = (0..n_prototypes)
170            .map(|_| {
171                let idx = rng.next_usize(n_samples);
172                data.row(idx).to_vec()
173            })
174            .collect();
175
176        let total_steps = self.max_epochs * n_samples;
177        let mut order: Vec<usize> = (0..n_samples).collect();
178
179        for epoch in 0..self.max_epochs {
180            rng.shuffle(&mut order);
181
182            for (step_in_epoch, &sample_idx) in order.iter().enumerate() {
183                let global_step = epoch * n_samples + step_in_epoch;
184                let t = global_step as f64 / total_steps.max(1) as f64;
185
186                let lr = match self.lr_final {
187                    Some(lr_f) => self.lr_init + t * (lr_f - self.lr_init),
188                    None => self.lr_init,
189                };
190
191                let input = data.row(sample_idx).to_vec();
192                let bmu_idx = find_bmu(&input, &prototypes);
193
194                let bmu = &mut prototypes[bmu_idx];
195                for k in 0..n_features {
196                    bmu[k] += lr * (input[k] - bmu[k]);
197                }
198            }
199        }
200
201        // Pack into Array2.
202        let mut out = Array2::<f64>::zeros((n_prototypes, n_features));
203        for (j, p) in prototypes.iter().enumerate() {
204            for k in 0..n_features {
205                out[[j, k]] = p[k];
206            }
207        }
208        Ok(out)
209    }
210
211    /// Assign each sample in `data` to its nearest prototype.
212    ///
213    /// Returns `Array1<usize>` of length `n_samples`.
214    pub fn predict(&self, data: ArrayView2<f64>, prototypes: &Array2<f64>) -> Array1<usize> {
215        let n_samples = data.shape()[0];
216        let n_proto = prototypes.shape()[0];
217        let protos: Vec<Vec<f64>> = (0..n_proto).map(|j| prototypes.row(j).to_vec()).collect();
218
219        let labels: Vec<usize> = (0..n_samples)
220            .map(|i| {
221                let row = data.row(i).to_vec();
222                find_bmu(&row, &protos)
223            })
224            .collect();
225
226        Array1::from_vec(labels)
227    }
228}
229
230// ---------------------------------------------------------------------------
231// LearningVectorQuantization (supervised)
232// ---------------------------------------------------------------------------
233
234/// Learned LVQ-1 model, produced by [`LearningVectorQuantization::fit`].
235#[derive(Debug, Clone)]
236pub struct LvqModel {
237    /// Prototype weight vectors, shape `(n_prototypes, n_features)`.
238    pub prototypes: Array2<f64>,
239    /// Class label for each prototype row.
240    pub labels: Array1<usize>,
241}
242
243impl LvqModel {
244    /// Predict the class of a single 1-D sample slice.
245    pub fn predict_one(&self, sample: &[f64]) -> usize {
246        let n_proto = self.prototypes.shape()[0];
247        let best = (0..n_proto)
248            .map(|j| {
249                let p = self.prototypes.row(j).to_vec();
250                (j, sq_euclid(sample, &p))
251            })
252            .min_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal))
253            .map(|(j, _)| j)
254            .unwrap_or(0);
255        self.labels[best]
256    }
257
258    /// Predict class labels for every row in `data`.
259    pub fn predict(&self, data: ArrayView2<f64>) -> Array1<usize> {
260        let n = data.shape()[0];
261        let preds: Vec<usize> = (0..n)
262            .map(|i| self.predict_one(&data.row(i).to_vec()))
263            .collect();
264        Array1::from_vec(preds)
265    }
266}
267
268/// Supervised prototype learning via LVQ-1.
269///
270/// Each prototype is associated with a class label.  For every training
271/// sample the nearest prototype is found:
272/// - If it carries the **correct** class label it is moved *toward* the sample.
273/// - Otherwise it is moved *away* from the sample.
274///
275/// Prototypes are initialised by sampling one or more points per class.
276#[derive(Debug, Clone)]
277pub struct LearningVectorQuantization {
278    /// Number of prototypes per class.  Default 1.
279    pub n_prototypes_per_class: usize,
280    /// Initial learning rate.  Default 0.1.
281    pub lr_init: f64,
282    /// Final learning rate (annealed).  Default 0.001.
283    pub lr_final: f64,
284    /// Number of training epochs.  Default 50.
285    pub max_epochs: usize,
286    /// RNG seed.  Default 42.
287    pub seed: u64,
288}
289
290impl Default for LearningVectorQuantization {
291    fn default() -> Self {
292        Self {
293            n_prototypes_per_class: 1,
294            lr_init: 0.1,
295            lr_final: 0.001,
296            max_epochs: 50,
297            seed: 42,
298        }
299    }
300}
301
302impl LearningVectorQuantization {
303    /// Create a new LVQ instance with all hyperparameters.
304    pub fn new(
305        n_prototypes_per_class: usize,
306        lr_init: f64,
307        lr_final: f64,
308        max_epochs: usize,
309        seed: u64,
310    ) -> Self {
311        Self {
312            n_prototypes_per_class,
313            lr_init,
314            lr_final,
315            max_epochs,
316            seed,
317        }
318    }
319
320    /// Train LVQ-1 on labelled data.
321    ///
322    /// # Arguments
323    /// * `data`   – Feature matrix `(n_samples, n_features)`.
324    /// * `labels` – Integer class labels; values must be in `0..n_classes`.
325    ///
326    /// # Returns
327    /// A trained [`LvqModel`].
328    pub fn fit(&self, data: ArrayView2<f64>, labels: &[usize]) -> Result<LvqModel> {
329        let (n_samples, n_features) = (data.shape()[0], data.shape()[1]);
330
331        if n_samples == 0 {
332            return Err(ClusteringError::InvalidInput("Empty input data".into()));
333        }
334        if labels.len() != n_samples {
335            return Err(ClusteringError::InvalidInput(
336                "labels length must equal number of data rows".into(),
337            ));
338        }
339        if self.n_prototypes_per_class == 0 {
340            return Err(ClusteringError::InvalidInput(
341                "n_prototypes_per_class must be > 0".into(),
342            ));
343        }
344
345        let n_classes = labels.iter().cloned().max().map(|m| m + 1).unwrap_or(0);
346        if n_classes == 0 {
347            return Err(ClusteringError::InvalidInput(
348                "No valid class labels found".into(),
349            ));
350        }
351
352        // Collect sample indices per class.
353        let mut class_samples: Vec<Vec<usize>> = vec![Vec::new(); n_classes];
354        for (i, &lbl) in labels.iter().enumerate() {
355            if lbl < n_classes {
356                class_samples[lbl].push(i);
357            }
358        }
359
360        let mut rng = Lcg::new(self.seed);
361
362        // Initialise prototypes.
363        let mut proto_weights: Vec<Vec<f64>> = Vec::new();
364        let mut proto_labels: Vec<usize> = Vec::new();
365
366        for cls in 0..n_classes {
367            let samples = &class_samples[cls];
368            if samples.is_empty() {
369                // Class has no samples; cannot initialise — skip it.
370                continue;
371            }
372            for _ in 0..self.n_prototypes_per_class {
373                let idx = samples[rng.next_usize(samples.len())];
374                proto_weights.push(data.row(idx).to_vec());
375                proto_labels.push(cls);
376            }
377        }
378
379        if proto_weights.is_empty() {
380            return Err(ClusteringError::ComputationError(
381                "Could not initialise any prototypes".into(),
382            ));
383        }
384
385        let n_proto = proto_weights.len();
386        let total_steps = self.max_epochs * n_samples;
387        let mut order: Vec<usize> = (0..n_samples).collect();
388
389        // LVQ-1 training loop.
390        for epoch in 0..self.max_epochs {
391            rng.shuffle(&mut order);
392
393            for (step_in_epoch, &sample_idx) in order.iter().enumerate() {
394                let global_step = epoch * n_samples + step_in_epoch;
395                let t = global_step as f64 / total_steps.max(1) as f64;
396                let lr = self.lr_init * (self.lr_final / self.lr_init).powf(t);
397
398                let input = data.row(sample_idx).to_vec();
399                let true_class = labels[sample_idx];
400
401                // Find the nearest prototype.
402                let nearest = (0..n_proto)
403                    .map(|j| (j, sq_euclid(&input, &proto_weights[j])))
404                    .min_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal))
405                    .map(|(j, _)| j)
406                    .unwrap_or(0);
407
408                // Attract if correct class, repel otherwise.
409                let sign = if proto_labels[nearest] == true_class {
410                    1.0f64
411                } else {
412                    -1.0f64
413                };
414
415                let w = &mut proto_weights[nearest];
416                for k in 0..n_features {
417                    w[k] += lr * sign * (input[k] - w[k]);
418                }
419            }
420        }
421
422        // Assemble result.
423        let mut proto_arr = Array2::<f64>::zeros((n_proto, n_features));
424        for (j, w) in proto_weights.iter().enumerate() {
425            for k in 0..n_features {
426                proto_arr[[j, k]] = w[k];
427            }
428        }
429
430        Ok(LvqModel {
431            prototypes: proto_arr,
432            labels: Array1::from_vec(proto_labels),
433        })
434    }
435
436    /// Convenience: predict class labels for all rows in `data` using a trained `model`.
437    pub fn predict(model: &LvqModel, data: ArrayView2<f64>) -> Array1<usize> {
438        model.predict(data)
439    }
440}
441
442// ---------------------------------------------------------------------------
443// NeuralGas (function-oriented API)
444// ---------------------------------------------------------------------------
445
446/// Trained Neural Gas model, produced by [`NeuralGas::fit`].
447#[derive(Debug, Clone)]
448pub struct NeuralGasModel {
449    /// Prototype / reference vectors, shape `(n_neurons, n_features)`.
450    pub prototypes: Array2<f64>,
451    /// Label (index of nearest prototype) for each training sample.
452    pub labels: Array1<usize>,
453    /// Mean quantization error over the training set.
454    pub quantization_error: f64,
455}
456
457/// Neural Gas — topology-preserving competitive learning.
458///
459/// For each input the prototypes are ranked by distance.  The winner
460/// (rank 0) receives the full learning rate; neighbours receive
461/// `lr * exp(-rank / lambda)`.  Both `lr` and `lambda` are exponentially
462/// annealed.
463///
464/// Reference: Martinetz & Schulten (1991).
465#[derive(Debug, Clone)]
466pub struct NeuralGas {
467    /// Initial learning rate for the winner.  Default 0.5.
468    pub lr_winner: f64,
469    /// Final learning rate (annealing target).  Default 0.01.
470    pub lr_final: f64,
471    /// Initial neighbourhood width λ.  `None` → `n_neurons / 2`.
472    pub lambda_init: Option<f64>,
473    /// Final λ (annealing target).  Default 0.01.
474    pub lambda_final: f64,
475    /// Number of training epochs.  Default 100.
476    pub max_epochs: usize,
477    /// RNG seed.  Default 42.
478    pub seed: u64,
479}
480
481impl Default for NeuralGas {
482    fn default() -> Self {
483        Self {
484            lr_winner: 0.5,
485            lr_final: 0.01,
486            lambda_init: None,
487            lambda_final: 0.01,
488            max_epochs: 100,
489            seed: 42,
490        }
491    }
492}
493
494impl NeuralGas {
495    /// Create a `NeuralGas` with explicit hyperparameters.
496    pub fn new(
497        lr_winner: f64,
498        lr_final: f64,
499        lambda_init: Option<f64>,
500        lambda_final: f64,
501        max_epochs: usize,
502        seed: u64,
503    ) -> Self {
504        Self {
505            lr_winner,
506            lr_final,
507            lambda_init,
508            lambda_final,
509            max_epochs,
510            seed,
511        }
512    }
513
514    /// Fit the Neural Gas network.
515    ///
516    /// # Arguments
517    /// * `data`      – Data matrix `(n_samples, n_features)`.
518    /// * `n_neurons` – Number of prototype units.
519    ///
520    /// # Returns
521    /// A [`NeuralGasModel`] with the learned prototypes and training assignments.
522    pub fn fit(&self, data: ArrayView2<f64>, n_neurons: usize) -> Result<NeuralGasModel> {
523        let (n_samples, n_features) = (data.shape()[0], data.shape()[1]);
524
525        if n_samples == 0 {
526            return Err(ClusteringError::InvalidInput("Empty input data".into()));
527        }
528        if n_neurons == 0 {
529            return Err(ClusteringError::InvalidInput(
530                "n_neurons must be > 0".into(),
531            ));
532        }
533        if self.max_epochs == 0 {
534            return Err(ClusteringError::InvalidInput(
535                "max_epochs must be > 0".into(),
536            ));
537        }
538
539        let mut rng = Lcg::new(self.seed);
540
541        // Initialise prototypes by sampling from data.
542        let mut prototypes: Vec<Vec<f64>> = (0..n_neurons)
543            .map(|_| {
544                let idx = rng.next_usize(n_samples);
545                data.row(idx).to_vec()
546            })
547            .collect();
548
549        let total_steps = self.max_epochs * n_samples;
550        let lambda_i = self.lambda_init.unwrap_or(n_neurons as f64 / 2.0).max(0.5);
551        let mut order: Vec<usize> = (0..n_samples).collect();
552
553        for epoch in 0..self.max_epochs {
554            rng.shuffle(&mut order);
555
556            for (step_in_epoch, &sample_idx) in order.iter().enumerate() {
557                let global_step = epoch * n_samples + step_in_epoch;
558                let t = global_step as f64 / total_steps.max(1) as f64;
559
560                // Exponential annealing for lr and lambda.
561                let lr = self.lr_winner * (self.lr_final / self.lr_winner).powf(t);
562                let lam = lambda_i * (self.lambda_final / lambda_i).powf(t);
563
564                let input = data.row(sample_idx).to_vec();
565
566                // Rank all prototypes by Euclidean distance.
567                let mut ranked: Vec<(f64, usize)> = prototypes
568                    .iter()
569                    .enumerate()
570                    .map(|(j, p)| (euclid(&input, p), j))
571                    .collect();
572                ranked.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal));
573
574                // Update each prototype with neighbourhood factor based on rank.
575                for (rank, (_, proto_idx)) in ranked.iter().enumerate() {
576                    let h = (-(rank as f64) / lam).exp();
577                    let p = &mut prototypes[*proto_idx];
578                    for k in 0..n_features {
579                        p[k] += lr * h * (input[k] - p[k]);
580                    }
581                }
582            }
583        }
584
585        // Assign labels and compute quantization error.
586        let mut labels_vec = vec![0usize; n_samples];
587        let mut total_qe = 0.0f64;
588        for i in 0..n_samples {
589            let row = data.row(i).to_vec();
590            let (best, best_dist) = prototypes
591                .iter()
592                .enumerate()
593                .map(|(j, p)| (j, sq_euclid(&row, p)))
594                .min_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal))
595                .unwrap_or((0, 0.0));
596            labels_vec[i] = best;
597            total_qe += best_dist;
598        }
599        let quantization_error = if n_samples > 0 {
600            total_qe / n_samples as f64
601        } else {
602            0.0
603        };
604
605        // Pack into Array2.
606        let mut proto_arr = Array2::<f64>::zeros((n_neurons, n_features));
607        for (j, p) in prototypes.iter().enumerate() {
608            for k in 0..n_features {
609                proto_arr[[j, k]] = p[k];
610            }
611        }
612
613        Ok(NeuralGasModel {
614            prototypes: proto_arr,
615            labels: Array1::from_vec(labels_vec),
616            quantization_error,
617        })
618    }
619}
620
621// ---------------------------------------------------------------------------
622// GrowingNeuralGas (function-oriented API)
623// ---------------------------------------------------------------------------
624
625/// A directed edge in the GNG graph, tracking its age.
626#[derive(Debug, Clone)]
627struct GngEdge {
628    age: usize,
629}
630
631/// A node (unit) in the GNG network.
632#[derive(Debug, Clone)]
633struct GngNode {
634    weights: Vec<f64>,
635    error: f64,
636}
637
638/// Trained Growing Neural Gas model, produced by [`GrowingNeuralGas::fit`].
639#[derive(Debug, Clone)]
640pub struct GrowingNeuralGasModel {
641    /// Learned prototype vectors, shape `(n_units, n_features)`.
642    pub prototypes: Array2<f64>,
643    /// Topology edges as `(node_a, node_b)` pairs (sorted, a < b).
644    pub edges: Vec<(usize, usize)>,
645    /// Label (nearest prototype) for each training sample.
646    pub labels: Array1<usize>,
647    /// Final mean quantization error.
648    pub quantization_error: f64,
649}
650
651/// Growing Neural Gas — adaptive topology competitive learning.
652///
653/// The GNG starts with two prototype units and grows by inserting new nodes
654/// between the highest-error nodes.  Edges age and are removed when their age
655/// exceeds `max_age`.  No a priori unit count is needed.
656///
657/// Reference: Fritzke (1995).
658#[derive(Debug, Clone)]
659pub struct GrowingNeuralGas {
660    /// Learning rate for the winner unit.  Default 0.1.
661    pub lr_winner: f64,
662    /// Learning rate for topological neighbours.  Default 0.01.
663    pub lr_neighbor: f64,
664    /// Maximum edge age before removal.  Default 50.
665    pub max_age: usize,
666    /// Steps between node insertions.  Default 100.
667    pub insert_interval: usize,
668    /// Error reduction factor applied to nodes after insertion.  Default 0.5.
669    pub alpha: f64,
670    /// Global error decay per step.  Default 0.995.
671    pub beta: f64,
672    /// Upper bound on the number of units.  Default 200.
673    pub max_units: usize,
674    /// Total training steps.  Default 5000.
675    pub max_steps: usize,
676    /// RNG seed.  Default 42.
677    pub seed: u64,
678}
679
680impl Default for GrowingNeuralGas {
681    fn default() -> Self {
682        Self {
683            lr_winner: 0.1,
684            lr_neighbor: 0.01,
685            max_age: 50,
686            insert_interval: 100,
687            alpha: 0.5,
688            beta: 0.995,
689            max_units: 200,
690            max_steps: 5000,
691            seed: 42,
692        }
693    }
694}
695
696impl GrowingNeuralGas {
697    /// Create a new GNG with all hyperparameters specified.
698    #[allow(clippy::too_many_arguments)]
699    pub fn new(
700        lr_winner: f64,
701        lr_neighbor: f64,
702        max_age: usize,
703        insert_interval: usize,
704        alpha: f64,
705        beta: f64,
706        max_units: usize,
707        max_steps: usize,
708        seed: u64,
709    ) -> Self {
710        Self {
711            lr_winner,
712            lr_neighbor,
713            max_age,
714            insert_interval,
715            alpha,
716            beta,
717            max_units,
718            max_steps,
719            seed,
720        }
721    }
722
723    /// Fit the GNG model to `data`.
724    ///
725    /// # Arguments
726    /// * `data` – Data matrix `(n_samples, n_features)`.
727    ///
728    /// # Returns
729    /// A [`GrowingNeuralGasModel`] with the learned topology and assignments.
730    pub fn fit(&self, data: ArrayView2<f64>) -> Result<GrowingNeuralGasModel> {
731        let (n_samples, n_features) = (data.shape()[0], data.shape()[1]);
732
733        if n_samples < 2 {
734            return Err(ClusteringError::InvalidInput(
735                "GNG requires at least 2 samples".into(),
736            ));
737        }
738        if n_features == 0 {
739            return Err(ClusteringError::InvalidInput(
740                "Data must have at least one feature".into(),
741            ));
742        }
743
744        let mut rng = Lcg::new(self.seed);
745
746        // Initialise with two nodes sampled from data.
747        let idx0 = rng.next_usize(n_samples);
748        let idx1 = (idx0 + 1 + rng.next_usize(n_samples.saturating_sub(1).max(1))) % n_samples;
749
750        let mut nodes: Vec<GngNode> = vec![
751            GngNode {
752                weights: data.row(idx0).to_vec(),
753                error: 0.0,
754            },
755            GngNode {
756                weights: data.row(idx1).to_vec(),
757                error: 0.0,
758            },
759        ];
760
761        // Edge map: key = sorted (a, b) where a < b.
762        let mut edge_map: std::collections::HashMap<(usize, usize), GngEdge> =
763            std::collections::HashMap::new();
764        edge_map.insert((0, 1), GngEdge { age: 0 });
765
766        let data_rows: Vec<Vec<f64>> = (0..n_samples).map(|i| data.row(i).to_vec()).collect();
767
768        for step in 0..self.max_steps {
769            let sample = &data_rows[rng.next_usize(n_samples)];
770
771            // Find winner (s1) and runner-up (s2) by squared Euclidean distance.
772            let mut dists: Vec<(f64, usize)> = nodes
773                .iter()
774                .enumerate()
775                .map(|(j, n)| (sq_euclid(sample, &n.weights), j))
776                .collect();
777            dists.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal));
778
779            if dists.len() < 2 {
780                continue;
781            }
782
783            let s1 = dists[0].1;
784            let s2 = dists[1].1;
785            let dist_s1 = dists[0].0;
786
787            // Age all edges incident to s1.
788            let edge_keys: Vec<(usize, usize)> = edge_map.keys().cloned().collect();
789            for key in &edge_keys {
790                if key.0 == s1 || key.1 == s1 {
791                    if let Some(e) = edge_map.get_mut(key) {
792                        e.age += 1;
793                    }
794                }
795            }
796
797            // Set/reset edge (s1, s2).
798            let edge_key = if s1 < s2 { (s1, s2) } else { (s2, s1) };
799            edge_map.insert(edge_key, GngEdge { age: 0 });
800
801            // Accumulate error for winner.
802            nodes[s1].error += dist_s1;
803
804            // Move winner toward sample.
805            for k in 0..n_features {
806                let delta = sample[k] - nodes[s1].weights[k];
807                nodes[s1].weights[k] += self.lr_winner * delta;
808            }
809
810            // Move topological neighbours of s1 toward sample.
811            let neighbours: Vec<usize> = edge_map
812                .keys()
813                .filter_map(|&(a, b)| {
814                    if a == s1 {
815                        Some(b)
816                    } else if b == s1 {
817                        Some(a)
818                    } else {
819                        None
820                    }
821                })
822                .collect();
823
824            for nb in &neighbours {
825                for k in 0..n_features {
826                    let delta = sample[k] - nodes[*nb].weights[k];
827                    nodes[*nb].weights[k] += self.lr_neighbor * delta;
828                }
829            }
830
831            // Remove edges older than max_age.
832            edge_map.retain(|_, e| e.age <= self.max_age);
833
834            // Global error decay.
835            for node in nodes.iter_mut() {
836                node.error *= self.beta;
837            }
838
839            // Insert new node periodically.
840            if step > 0
841                && step % self.insert_interval == 0
842                && nodes.len() < self.max_units
843                && nodes.len() >= 2
844            {
845                // Node with highest accumulated error.
846                let q = nodes
847                    .iter()
848                    .enumerate()
849                    .max_by(|a, b| {
850                        a.1.error
851                            .partial_cmp(&b.1.error)
852                            .unwrap_or(std::cmp::Ordering::Equal)
853                    })
854                    .map(|(i, _)| i)
855                    .unwrap_or(0);
856
857                // Neighbour of q with highest error.
858                let q_neighbours: Vec<usize> = edge_map
859                    .keys()
860                    .filter_map(|&(a, b)| {
861                        if a == q {
862                            Some(b)
863                        } else if b == q {
864                            Some(a)
865                        } else {
866                            None
867                        }
868                    })
869                    .collect();
870
871                if !q_neighbours.is_empty() {
872                    let f = q_neighbours
873                        .iter()
874                        .max_by(|&&a, &&b| {
875                            nodes[a]
876                                .error
877                                .partial_cmp(&nodes[b].error)
878                                .unwrap_or(std::cmp::Ordering::Equal)
879                        })
880                        .cloned()
881                        .unwrap_or(q_neighbours[0]);
882
883                    // New node between q and f.
884                    let new_weights: Vec<f64> = nodes[q]
885                        .weights
886                        .iter()
887                        .zip(nodes[f].weights.iter())
888                        .map(|(a, b)| (a + b) / 2.0)
889                        .collect();
890
891                    let new_idx = nodes.len();
892                    let new_error = nodes[q].error * self.alpha;
893                    nodes.push(GngNode {
894                        weights: new_weights,
895                        error: new_error,
896                    });
897
898                    nodes[q].error *= self.alpha;
899                    nodes[f].error *= self.alpha;
900
901                    // Remove q-f edge; add q-new and f-new.
902                    let qf_key = if q < f { (q, f) } else { (f, q) };
903                    edge_map.remove(&qf_key);
904
905                    let qn_key = if q < new_idx {
906                        (q, new_idx)
907                    } else {
908                        (new_idx, q)
909                    };
910                    let fn_key = if f < new_idx {
911                        (f, new_idx)
912                    } else {
913                        (new_idx, f)
914                    };
915                    edge_map.insert(qn_key, GngEdge { age: 0 });
916                    edge_map.insert(fn_key, GngEdge { age: 0 });
917                }
918            }
919        }
920
921        let n_units = nodes.len();
922        let mut proto_arr = Array2::<f64>::zeros((n_units, n_features));
923        for (j, node) in nodes.iter().enumerate() {
924            for k in 0..n_features {
925                proto_arr[[j, k]] = node.weights[k];
926            }
927        }
928
929        let edges: Vec<(usize, usize)> = edge_map.keys().cloned().collect();
930
931        // Assign labels and quantization error.
932        let mut labels_vec = vec![0usize; n_samples];
933        let mut total_qe = 0.0f64;
934        for i in 0..n_samples {
935            let row = data_rows[i].as_slice();
936            let (best, best_dist) = nodes
937                .iter()
938                .enumerate()
939                .map(|(j, node)| (j, sq_euclid(row, &node.weights)))
940                .min_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal))
941                .unwrap_or((0, 0.0));
942            labels_vec[i] = best;
943            total_qe += best_dist;
944        }
945        let quantization_error = if n_samples > 0 {
946            total_qe / n_samples as f64
947        } else {
948            0.0
949        };
950
951        Ok(GrowingNeuralGasModel {
952            prototypes: proto_arr,
953            edges,
954            labels: Array1::from_vec(labels_vec),
955            quantization_error,
956        })
957    }
958}
959
960// ---------------------------------------------------------------------------
961// Tests
962// ---------------------------------------------------------------------------
963
964#[cfg(test)]
965mod tests {
966    use super::*;
967    use scirs2_core::ndarray::Array2;
968
969    /// Generate two well-separated Gaussian-like clusters (12 points each, 2 features).
970    fn two_cluster_data() -> (Array2<f64>, Vec<usize>) {
971        let vals = vec![
972            // Cluster 0  (~origin)
973            0.00, 0.00, 0.10, 0.00, 0.00, 0.10, 0.10, 0.10, 0.05, 0.05, -0.05, 0.05, -0.05, -0.05,
974            0.10, -0.05, 0.00, 0.15, -0.10, 0.00, 0.15, 0.10, 0.00, 0.20,
975            // Cluster 1  (~(5, 5))
976            5.00, 5.00, 5.10, 5.00, 5.00, 5.10, 5.10, 5.10, 5.05, 5.05, 4.95, 5.05, 4.95, 4.95,
977            5.10, 4.95, 5.00, 5.15, 4.90, 5.00, 5.15, 5.10, 5.00, 5.20,
978        ];
979        let x = Array2::from_shape_vec((24, 2), vals).expect("shape ok");
980        let y: Vec<usize> = (0..12).map(|_| 0).chain((0..12).map(|_| 1)).collect();
981        (x, y)
982    }
983
984    // --- WinnerTakeAll ---
985
986    #[test]
987    fn test_wta_basic() {
988        let (x, _) = two_cluster_data();
989        let wta = WinnerTakeAll::default();
990        let protos = wta.fit(x.view(), 2).expect("fit");
991        assert_eq!(protos.shape(), [2, 2]);
992    }
993
994    #[test]
995    fn test_wta_single_prototype() {
996        let (x, _) = two_cluster_data();
997        let wta = WinnerTakeAll::default();
998        let protos = wta.fit(x.view(), 1).expect("fit");
999        assert_eq!(protos.shape(), [1, 2]);
1000    }
1001
1002    #[test]
1003    fn test_wta_annealing() {
1004        let (x, _) = two_cluster_data();
1005        let wta = WinnerTakeAll {
1006            lr_init: 0.5,
1007            lr_final: Some(0.001),
1008            max_epochs: 50,
1009            seed: 7,
1010        };
1011        let protos = wta.fit(x.view(), 2).expect("fit annealing");
1012        assert_eq!(protos.shape()[0], 2);
1013    }
1014
1015    #[test]
1016    fn test_wta_predict() {
1017        let (x, _) = two_cluster_data();
1018        let wta = WinnerTakeAll::default();
1019        let protos = wta.fit(x.view(), 2).expect("fit");
1020        let labels = wta.predict(x.view(), &protos);
1021        assert_eq!(labels.len(), 24);
1022        assert!(labels.iter().all(|&l| l < 2));
1023    }
1024
1025    #[test]
1026    fn test_wta_converges_two_clusters() {
1027        let (x, _) = two_cluster_data();
1028        let wta = WinnerTakeAll {
1029            lr_init: 0.5,
1030            lr_final: Some(0.01),
1031            max_epochs: 200,
1032            seed: 42,
1033        };
1034        let protos = wta.fit(x.view(), 2).expect("fit");
1035        // One prototype should be near origin and the other near (5, 5).
1036        let p0 = protos.row(0).to_vec();
1037        let p1 = protos.row(1).to_vec();
1038        let d00 = sq_euclid(&p0, &[0.0, 0.0]);
1039        let d05 = sq_euclid(&p0, &[5.0, 5.0]);
1040        let d10 = sq_euclid(&p1, &[0.0, 0.0]);
1041        let d15 = sq_euclid(&p1, &[5.0, 5.0]);
1042        let well_placed = (d00 < d05 && d15 < d10) || (d05 < d00 && d10 < d15);
1043        assert!(well_placed, "prototypes should converge to cluster centres");
1044    }
1045
1046    #[test]
1047    fn test_wta_error_empty_data() {
1048        let x = Array2::<f64>::zeros((0, 2));
1049        let wta = WinnerTakeAll::default();
1050        assert!(wta.fit(x.view(), 2).is_err());
1051    }
1052
1053    #[test]
1054    fn test_wta_error_zero_prototypes() {
1055        let (x, _) = two_cluster_data();
1056        let wta = WinnerTakeAll::default();
1057        assert!(wta.fit(x.view(), 0).is_err());
1058    }
1059
1060    // --- LearningVectorQuantization ---
1061
1062    #[test]
1063    fn test_lvq_fit_basic() {
1064        let (x, y) = two_cluster_data();
1065        let lvq = LearningVectorQuantization::default();
1066        let model = lvq.fit(x.view(), &y).expect("fit");
1067        assert_eq!(model.prototypes.shape()[0], 2); // 1 per class × 2 classes
1068        assert_eq!(model.labels.len(), 2);
1069    }
1070
1071    #[test]
1072    fn test_lvq_predict() {
1073        let (x, y) = two_cluster_data();
1074        let lvq = LearningVectorQuantization::default();
1075        let model = lvq.fit(x.view(), &y).expect("fit");
1076        let preds = model.predict(x.view());
1077        assert_eq!(preds.len(), 24);
1078        // Well-separated data should be classified correctly.
1079        let correct = preds.iter().zip(y.iter()).filter(|(&p, &t)| p == t).count();
1080        assert!(
1081            correct as f64 / 24.0 > 0.75,
1082            "accuracy should exceed 75%, got {}",
1083            correct
1084        );
1085    }
1086
1087    #[test]
1088    fn test_lvq_predict_one() {
1089        let (x, y) = two_cluster_data();
1090        let lvq = LearningVectorQuantization::default();
1091        let model = lvq.fit(x.view(), &y).expect("fit");
1092        let pred = model.predict_one(&[0.0, 0.0]);
1093        assert_eq!(pred, 0, "origin should map to class 0");
1094        let pred2 = model.predict_one(&[5.0, 5.0]);
1095        assert_eq!(pred2, 1, "(5,5) should map to class 1");
1096    }
1097
1098    #[test]
1099    fn test_lvq_function_predict() {
1100        let (x, y) = two_cluster_data();
1101        let lvq = LearningVectorQuantization::default();
1102        let model = lvq.fit(x.view(), &y).expect("fit");
1103        let preds = LearningVectorQuantization::predict(&model, x.view());
1104        assert_eq!(preds.len(), 24);
1105    }
1106
1107    #[test]
1108    fn test_lvq_multi_proto_per_class() {
1109        let (x, y) = two_cluster_data();
1110        let lvq = LearningVectorQuantization::new(2, 0.1, 0.001, 50, 42);
1111        let model = lvq.fit(x.view(), &y).expect("fit");
1112        assert_eq!(model.prototypes.shape()[0], 4); // 2 per class × 2 classes
1113    }
1114
1115    #[test]
1116    fn test_lvq_error_label_mismatch() {
1117        let (x, _) = two_cluster_data();
1118        let lvq = LearningVectorQuantization::default();
1119        assert!(lvq.fit(x.view(), &[0, 1, 2]).is_err());
1120    }
1121
1122    // --- NeuralGas ---
1123
1124    #[test]
1125    fn test_ng_basic() {
1126        let (x, _) = two_cluster_data();
1127        let ng = NeuralGas::default();
1128        let model = ng.fit(x.view(), 2).expect("fit");
1129        assert_eq!(model.prototypes.shape(), [2, 2]);
1130        assert_eq!(model.labels.len(), 24);
1131        assert!(model.quantization_error >= 0.0);
1132    }
1133
1134    #[test]
1135    fn test_ng_single_neuron() {
1136        let (x, _) = two_cluster_data();
1137        let ng = NeuralGas::default();
1138        let model = ng.fit(x.view(), 1).expect("fit");
1139        assert_eq!(model.prototypes.shape()[0], 1);
1140        assert!(model.labels.iter().all(|&l| l == 0));
1141    }
1142
1143    #[test]
1144    fn test_ng_converges() {
1145        let (x, _) = two_cluster_data();
1146        let ng = NeuralGas {
1147            lr_winner: 0.5,
1148            lr_final: 0.01,
1149            lambda_init: None,
1150            lambda_final: 0.01,
1151            max_epochs: 200,
1152            seed: 42,
1153        };
1154        let model = ng.fit(x.view(), 2).expect("fit");
1155        assert!(
1156            model.quantization_error < 1.0,
1157            "QE={} too high",
1158            model.quantization_error
1159        );
1160    }
1161
1162    #[test]
1163    fn test_ng_error_empty() {
1164        let x = Array2::<f64>::zeros((0, 2));
1165        let ng = NeuralGas::default();
1166        assert!(ng.fit(x.view(), 2).is_err());
1167    }
1168
1169    #[test]
1170    fn test_ng_error_zero_neurons() {
1171        let (x, _) = two_cluster_data();
1172        let ng = NeuralGas::default();
1173        assert!(ng.fit(x.view(), 0).is_err());
1174    }
1175
1176    // --- GrowingNeuralGas ---
1177
1178    #[test]
1179    fn test_gng_basic() {
1180        let (x, _) = two_cluster_data();
1181        let gng = GrowingNeuralGas {
1182            max_steps: 300,
1183            insert_interval: 30,
1184            max_units: 15,
1185            seed: 7,
1186            ..GrowingNeuralGas::default()
1187        };
1188        let model = gng.fit(x.view()).expect("fit");
1189        assert!(
1190            model.prototypes.shape()[0] >= 2,
1191            "should have at least initial units"
1192        );
1193        assert_eq!(model.labels.len(), 24);
1194        assert!(model.quantization_error >= 0.0);
1195    }
1196
1197    #[test]
1198    fn test_gng_grows_units() {
1199        let (x, _) = two_cluster_data();
1200        let gng = GrowingNeuralGas {
1201            max_steps: 1000,
1202            insert_interval: 50,
1203            max_units: 20,
1204            seed: 99,
1205            ..GrowingNeuralGas::default()
1206        };
1207        let model = gng.fit(x.view()).expect("fit");
1208        // With 1000 steps and interval 50, there should be more than 2 units.
1209        assert!(
1210            model.prototypes.shape()[0] >= 2,
1211            "GNG should grow beyond initial 2 units"
1212        );
1213    }
1214
1215    #[test]
1216    fn test_gng_edges_valid() {
1217        let (x, _) = two_cluster_data();
1218        let gng = GrowingNeuralGas {
1219            max_steps: 500,
1220            seed: 42,
1221            ..GrowingNeuralGas::default()
1222        };
1223        let model = gng.fit(x.view()).expect("fit");
1224        let n_units = model.prototypes.shape()[0];
1225        // All edge endpoints must be valid unit indices.
1226        for &(a, b) in &model.edges {
1227            assert!(a < n_units, "edge endpoint {} out of range", a);
1228            assert!(b < n_units, "edge endpoint {} out of range", b);
1229            assert_ne!(a, b, "self-loop detected");
1230        }
1231    }
1232
1233    #[test]
1234    fn test_gng_error_too_few_samples() {
1235        let x = Array2::from_shape_vec((1, 2), vec![1.0, 2.0]).expect("shape ok");
1236        let gng = GrowingNeuralGas::default();
1237        assert!(gng.fit(x.view()).is_err());
1238    }
1239}