Skip to main content

scirs2_cluster/
prototype_enhanced.rs

1//! Enhanced prototype-based clustering algorithms
2//!
3//! This module provides advanced prototype-based methods beyond standard K-means,
4//! including competitive learning networks and learning vector quantization variants.
5//!
6//! # Algorithms
7//!
8//! - **Neural Gas**: Topology-preserving competitive learning with rank-ordered updates
9//! - **Growing Neural Gas (GNG)**: Adaptive topology without a fixed unit count
10//! - **LVQ** (Learning Vector Quantization): Supervised prototype adaptation
11//! - **GLVQ** (Generalized LVQ): Soft-margin prototype learning with adaptive metrics
12
13use scirs2_core::ndarray::{Array1, Array2, ArrayView2};
14use scirs2_core::numeric::{Float, FromPrimitive};
15use std::collections::HashMap;
16use std::fmt::Debug;
17
18use crate::error::{ClusteringError, Result};
19
20// ---------------------------------------------------------------------------
21// Shared distance helpers
22// ---------------------------------------------------------------------------
23
24/// Squared Euclidean distance between two slices.
25#[inline]
26fn sq_euclid(a: &[f64], b: &[f64]) -> f64 {
27    a.iter().zip(b.iter()).map(|(x, y)| (x - y) * (x - y)).sum()
28}
29
30/// Euclidean distance between two slices.
31#[inline]
32fn euclid(a: &[f64], b: &[f64]) -> f64 {
33    sq_euclid(a, b).sqrt()
34}
35
36/// LCG pseudo-random number generator state.
37fn lcg_next(state: &mut u64) -> f64 {
38    *state = state
39        .wrapping_mul(6_364_136_223_846_793_005)
40        .wrapping_add(1_442_695_040_888_963_407);
41    // Map to [0, 1).
42    (*state >> 11) as f64 / (1u64 << 53) as f64
43}
44
45/// Draw a usize in [0, n) using a LCG state.
46#[inline]
47fn lcg_usize(state: &mut u64, n: usize) -> usize {
48    lcg_next(state) as usize % n
49}
50
51// ---------------------------------------------------------------------------
52// Neural Gas
53// ---------------------------------------------------------------------------
54
55/// Result of Neural Gas clustering.
56#[derive(Debug, Clone)]
57pub struct NeuralGasResult {
58    /// Prototype / reference vectors, shape `(n_units, n_features)`.
59    pub prototypes: Array2<f64>,
60    /// Label of the nearest prototype for each training sample.
61    pub labels: Array1<usize>,
62    /// Number of prototype units.
63    pub n_units: usize,
64    /// Final quantization error (mean squared distance to nearest prototype).
65    pub quantization_error: f64,
66}
67
68/// Neural Gas unsupervised competitive learning network.
69///
70/// For each input, ranks all prototypes by distance and applies a
71/// neighbourhood function `h(k, λ)` that decreases with rank `k`.
72/// Over training, both the learning rate `ε` and neighbourhood
73/// parameter `λ` are annealed from their initial to their final values.
74///
75/// Reference: Martinetz & Schulten, 1991.
76pub struct NeuralGas {
77    /// Initial learning rate (default 0.5).
78    pub lr_i: f64,
79    /// Final learning rate (default 0.01).
80    pub lr_f: f64,
81    /// Initial neighbourhood parameter λ (default `n_units / 2`).
82    pub lambda_i: Option<f64>,
83    /// Final neighbourhood parameter λ (default 0.01).
84    pub lambda_f: f64,
85    /// RNG seed.
86    pub seed: u64,
87}
88
89impl Default for NeuralGas {
90    fn default() -> Self {
91        Self {
92            lr_i: 0.5,
93            lr_f: 0.01,
94            lambda_i: None,
95            lambda_f: 0.01,
96            seed: 42,
97        }
98    }
99}
100
101impl NeuralGas {
102    /// Fit Neural Gas.
103    ///
104    /// # Arguments
105    /// * `x` – Data matrix `(n_samples, n_features)`.
106    /// * `n_units` – Number of prototype units.
107    /// * `max_iter` – Number of training epochs (passes over the data).
108    pub fn fit(
109        &self,
110        x: ArrayView2<f64>,
111        n_units: usize,
112        max_iter: usize,
113    ) -> Result<NeuralGasResult> {
114        let (n_samples, n_features) = (x.shape()[0], x.shape()[1]);
115        if n_samples == 0 {
116            return Err(ClusteringError::InvalidInput("Empty input data".into()));
117        }
118        if n_units == 0 {
119            return Err(ClusteringError::InvalidInput("n_units must be > 0".into()));
120        }
121        if max_iter == 0 {
122            return Err(ClusteringError::InvalidInput("max_iter must be > 0".into()));
123        }
124
125        let mut rng = self.seed;
126
127        // Initialise prototypes by sampling data points.
128        let mut protos: Vec<Vec<f64>> = (0..n_units)
129            .map(|_| {
130                let idx = lcg_usize(&mut rng, n_samples);
131                x.row(idx).to_vec()
132            })
133            .collect();
134
135        let total_steps = max_iter * n_samples;
136        let lambda_i = self.lambda_i.unwrap_or((n_units as f64) / 2.0).max(0.5);
137
138        for epoch in 0..max_iter {
139            // Shuffle sample order each epoch.
140            let mut order: Vec<usize> = (0..n_samples).collect();
141            for i in (1..n_samples).rev() {
142                let j = lcg_usize(&mut rng, i + 1);
143                order.swap(i, j);
144            }
145
146            for &sample_idx in &order {
147                // Global step index for annealing schedule.
148                let step = epoch * n_samples + sample_idx;
149                let t = step as f64 / total_steps.max(1) as f64;
150
151                // Anneal learning rate and lambda.
152                let lr = self.lr_i * (self.lr_f / self.lr_i).powf(t);
153                let lam = lambda_i * (self.lambda_f / lambda_i).powf(t);
154
155                let input = x.row(sample_idx).to_vec();
156
157                // Rank all prototypes by distance to input.
158                let mut ranked: Vec<(f64, usize)> = protos
159                    .iter()
160                    .enumerate()
161                    .map(|(j, p)| (euclid(&input, p), j))
162                    .collect();
163                ranked.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal));
164
165                // Update each prototype with neighbourhood factor based on rank.
166                for (rank, (_, proto_idx)) in ranked.iter().enumerate() {
167                    let h = (-(rank as f64) / lam).exp();
168                    let p = &mut protos[*proto_idx];
169                    for k in 0..n_features {
170                        p[k] += lr * h * (input[k] - p[k]);
171                    }
172                }
173            }
174        }
175
176        // Assign labels and compute quantization error.
177        let mut labels = vec![0usize; n_samples];
178        let mut total_qe = 0.0f64;
179        for i in 0..n_samples {
180            let row = x.row(i).to_vec();
181            let (best, best_dist) = protos
182                .iter()
183                .enumerate()
184                .map(|(j, p)| (j, sq_euclid(&row, p)))
185                .min_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal))
186                .unwrap_or((0, 0.0));
187            labels[i] = best;
188            total_qe += best_dist;
189        }
190        let quantization_error = total_qe / n_samples as f64;
191
192        // Pack prototypes into Array2.
193        let mut proto_arr = Array2::<f64>::zeros((n_units, n_features));
194        for (j, p) in protos.iter().enumerate() {
195            for k in 0..n_features {
196                proto_arr[[j, k]] = p[k];
197            }
198        }
199
200        Ok(NeuralGasResult {
201            prototypes: proto_arr,
202            labels: Array1::from_vec(labels),
203            n_units,
204            quantization_error,
205        })
206    }
207}
208
209// ---------------------------------------------------------------------------
210// Growing Neural Gas
211// ---------------------------------------------------------------------------
212
213/// An edge in the GNG topology graph.
214#[derive(Debug, Clone)]
215struct GngEdge {
216    /// Age of the edge (incremented each time a winner is updated without
217    /// this edge being refreshed).
218    age: usize,
219}
220
221/// A node (unit) in the Growing Neural Gas network.
222#[derive(Debug, Clone)]
223struct GngNode {
224    /// Reference vector (prototype).
225    weights: Vec<f64>,
226    /// Accumulated local error.
227    error: f64,
228}
229
230/// Configuration for Growing Neural Gas.
231#[derive(Debug, Clone)]
232pub struct GngConfig {
233    /// Learning rate for the winner unit (default 0.1).
234    pub lr_winner: f64,
235    /// Learning rate for winner's neighbours (default 0.01).
236    pub lr_neighbor: f64,
237    /// Maximum edge age before removal (default 50).
238    pub max_age: usize,
239    /// How often (in steps) a new node is inserted (default 100).
240    pub insert_interval: usize,
241    /// Error reduction factor for all nodes after node insertion (default 0.5).
242    pub alpha: f64,
243    /// Global error decay per step (default 0.995).
244    pub beta: f64,
245    /// Maximum number of units (stops inserting when reached, default 200).
246    pub max_units: usize,
247    /// Total training steps.
248    pub max_steps: usize,
249    /// RNG seed.
250    pub seed: u64,
251}
252
253impl Default for GngConfig {
254    fn default() -> Self {
255        Self {
256            lr_winner: 0.1,
257            lr_neighbor: 0.01,
258            max_age: 50,
259            insert_interval: 100,
260            alpha: 0.5,
261            beta: 0.995,
262            max_units: 200,
263            max_steps: 5000,
264            seed: 42,
265        }
266    }
267}
268
269/// Result of Growing Neural Gas.
270#[derive(Debug, Clone)]
271pub struct GngResult {
272    /// Learned prototype weights `(n_units, n_features)`.
273    pub prototypes: Array2<f64>,
274    /// Edges as (node_a, node_b) pairs.
275    pub edges: Vec<(usize, usize)>,
276    /// Label of the nearest prototype for each training sample.
277    pub labels: Array1<usize>,
278    /// Final quantization error.
279    pub quantization_error: f64,
280}
281
282/// Growing Neural Gas — adaptive topology competitive learning.
283///
284/// Unlike Neural Gas, GNG starts with two units and grows by inserting new
285/// units between high-error units. Edges are added/removed dynamically.
286///
287/// Reference: Fritzke, 1995.
288pub struct GrowingNeuralGas {
289    /// Configuration.
290    pub config: GngConfig,
291}
292
293impl Default for GrowingNeuralGas {
294    fn default() -> Self {
295        Self {
296            config: GngConfig::default(),
297        }
298    }
299}
300
301impl GrowingNeuralGas {
302    /// Create a new GNG with the given config.
303    pub fn new(config: GngConfig) -> Self {
304        Self { config }
305    }
306
307    /// Fit the GNG model to data `x`.
308    pub fn fit(&self, x: ArrayView2<f64>) -> Result<GngResult> {
309        let (n_samples, n_features) = (x.shape()[0], x.shape()[1]);
310        if n_samples < 2 {
311            return Err(ClusteringError::InvalidInput(
312                "Need at least 2 samples for GNG".into(),
313            ));
314        }
315
316        let cfg = &self.config;
317        let mut rng = cfg.seed;
318
319        // Initialise with two nodes sampled from data.
320        let idx0 = lcg_usize(&mut rng, n_samples);
321        let idx1 = (idx0 + 1 + lcg_usize(&mut rng, n_samples - 1)) % n_samples;
322        let mut nodes: Vec<GngNode> = vec![
323            GngNode {
324                weights: x.row(idx0).to_vec(),
325                error: 0.0,
326            },
327            GngNode {
328                weights: x.row(idx1).to_vec(),
329                error: 0.0,
330            },
331        ];
332        // Adjacency: edges[i][j] = Option<GngEdge>
333        // Use a HashMap keyed by sorted (i, j) pairs.
334        let mut edge_map: HashMap<(usize, usize), GngEdge> = HashMap::new();
335        // Add initial edge.
336        edge_map.insert((0, 1), GngEdge { age: 0 });
337
338        let mut step = 0usize;
339        let data_vec: Vec<Vec<f64>> = (0..n_samples).map(|i| x.row(i).to_vec()).collect();
340
341        while step < cfg.max_steps {
342            // Pick random sample.
343            let sample = &data_vec[lcg_usize(&mut rng, n_samples)];
344
345            // Find winner (s1) and runner-up (s2).
346            let mut dists: Vec<(f64, usize)> = nodes
347                .iter()
348                .enumerate()
349                .map(|(j, n)| (sq_euclid(sample, &n.weights), j))
350                .collect();
351            dists.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal));
352
353            if dists.len() < 2 {
354                step += 1;
355                continue;
356            }
357
358            let s1 = dists[0].1;
359            let s2 = dists[1].1;
360            let dist_s1 = dists[0].0;
361
362            // Increment age of all edges incident to s1.
363            let edge_keys: Vec<(usize, usize)> = edge_map.keys().cloned().collect();
364            for key in &edge_keys {
365                if key.0 == s1 || key.1 == s1 {
366                    if let Some(e) = edge_map.get_mut(key) {
367                        e.age += 1;
368                    }
369                }
370            }
371
372            // Add/reset edge (s1, s2).
373            let edge_key = if s1 < s2 { (s1, s2) } else { (s2, s1) };
374            edge_map.insert(edge_key, GngEdge { age: 0 });
375
376            // Accumulate error for winner.
377            nodes[s1].error += dist_s1;
378
379            // Move winner and its topological neighbours toward sample.
380            let n_nodes = nodes.len();
381            let winner_w: Vec<f64> = nodes[s1].weights.clone();
382            for k in 0..n_features {
383                nodes[s1].weights[k] += cfg.lr_winner * (sample[k] - winner_w[k]);
384            }
385
386            let neighbor_ids: Vec<usize> = edge_map
387                .keys()
388                .filter_map(|&(a, b)| {
389                    if a == s1 {
390                        Some(b)
391                    } else if b == s1 {
392                        Some(a)
393                    } else {
394                        None
395                    }
396                })
397                .collect();
398
399            for nb in &neighbor_ids {
400                let nb_w: Vec<f64> = nodes[*nb].weights.clone();
401                for k in 0..n_features {
402                    nodes[*nb].weights[k] += cfg.lr_neighbor * (sample[k] - nb_w[k]);
403                }
404            }
405
406            // Remove edges older than max_age.
407            edge_map.retain(|_, e| e.age <= cfg.max_age);
408
409            // Remove isolated nodes (no edges).
410            // (Only do this after removing edges.)
411            let connected: std::collections::HashSet<usize> =
412                edge_map.keys().flat_map(|&(a, b)| [a, b]).collect();
413            // We'll skip removing nodes to keep index stability (just leave them).
414
415            // Apply global error decay.
416            for node in nodes.iter_mut() {
417                node.error *= cfg.beta;
418            }
419
420            // Insert new node periodically.
421            if step % cfg.insert_interval == 0 && nodes.len() < cfg.max_units && nodes.len() >= 2 {
422                // Find node with highest error.
423                let q = nodes
424                    .iter()
425                    .enumerate()
426                    .max_by(|a, b| {
427                        a.1.error
428                            .partial_cmp(&b.1.error)
429                            .unwrap_or(std::cmp::Ordering::Equal)
430                    })
431                    .map(|(i, _)| i)
432                    .unwrap_or(0);
433
434                // Find the neighbour of q with highest error.
435                let q_neighbors: Vec<usize> = edge_map
436                    .keys()
437                    .filter_map(|&(a, b)| {
438                        if a == q {
439                            Some(b)
440                        } else if b == q {
441                            Some(a)
442                        } else {
443                            None
444                        }
445                    })
446                    .collect();
447
448                if !q_neighbors.is_empty() {
449                    let f = q_neighbors
450                        .iter()
451                        .max_by(|&&a, &&b| {
452                            nodes[a]
453                                .error
454                                .partial_cmp(&nodes[b].error)
455                                .unwrap_or(std::cmp::Ordering::Equal)
456                        })
457                        .cloned()
458                        .unwrap_or(q_neighbors[0]);
459
460                    // Insert new node between q and f.
461                    let new_weights: Vec<f64> = nodes[q]
462                        .weights
463                        .iter()
464                        .zip(nodes[f].weights.iter())
465                        .map(|(a, b)| (a + b) / 2.0)
466                        .collect();
467                    let new_idx = nodes.len();
468                    nodes.push(GngNode {
469                        weights: new_weights,
470                        error: nodes[q].error * cfg.alpha,
471                    });
472
473                    // Adjust errors.
474                    nodes[q].error *= cfg.alpha;
475                    nodes[f].error *= cfg.alpha;
476
477                    // Remove q-f edge, add q-new and f-new edges.
478                    let qf_key = if q < f { (q, f) } else { (f, q) };
479                    edge_map.remove(&qf_key);
480                    let qn_key = if q < new_idx {
481                        (q, new_idx)
482                    } else {
483                        (new_idx, q)
484                    };
485                    let fn_key = if f < new_idx {
486                        (f, new_idx)
487                    } else {
488                        (new_idx, f)
489                    };
490                    edge_map.insert(qn_key, GngEdge { age: 0 });
491                    edge_map.insert(fn_key, GngEdge { age: 0 });
492                }
493            }
494
495            step += 1;
496        }
497
498        let n_units = nodes.len();
499        let mut proto_arr = Array2::<f64>::zeros((n_units, n_features));
500        for (j, node) in nodes.iter().enumerate() {
501            for k in 0..n_features {
502                proto_arr[[j, k]] = node.weights[k];
503            }
504        }
505
506        let edges: Vec<(usize, usize)> = edge_map.keys().cloned().collect();
507
508        // Assign labels.
509        let mut labels = vec![0usize; n_samples];
510        let mut total_qe = 0.0f64;
511        for i in 0..n_samples {
512            let row = x.row(i).to_vec();
513            let (best, best_dist) = nodes
514                .iter()
515                .enumerate()
516                .map(|(j, node)| (j, sq_euclid(&row, &node.weights)))
517                .min_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal))
518                .unwrap_or((0, 0.0));
519            labels[i] = best;
520            total_qe += best_dist;
521        }
522
523        Ok(GngResult {
524            prototypes: proto_arr,
525            edges,
526            labels: Array1::from_vec(labels),
527            quantization_error: total_qe / n_samples as f64,
528        })
529    }
530}
531
532// ---------------------------------------------------------------------------
533// LVQ — Learning Vector Quantization
534// ---------------------------------------------------------------------------
535
536/// Configuration for LVQ training.
537#[derive(Debug, Clone)]
538pub struct LvqConfig {
539    /// Number of prototypes per class (default 1).
540    pub prototypes_per_class: usize,
541    /// Initial learning rate (default 0.1).
542    pub lr_init: f64,
543    /// Final learning rate (default 0.001).
544    pub lr_final: f64,
545    /// Number of training epochs.
546    pub max_epochs: usize,
547    /// RNG seed.
548    pub seed: u64,
549}
550
551impl Default for LvqConfig {
552    fn default() -> Self {
553        Self {
554            prototypes_per_class: 1,
555            lr_init: 0.1,
556            lr_final: 0.001,
557            max_epochs: 50,
558            seed: 42,
559        }
560    }
561}
562
563/// Result of LVQ training.
564#[derive(Debug, Clone)]
565pub struct LvqResult {
566    /// Prototype weights, shape `(n_prototypes, n_features)`.
567    pub prototypes: Array2<f64>,
568    /// Class label for each prototype.
569    pub prototype_labels: Vec<usize>,
570    /// Training accuracy on the training set.
571    pub train_accuracy: f64,
572}
573
574impl LvqResult {
575    /// Predict the class of each row in `x`.
576    pub fn predict(&self, x: ArrayView2<f64>) -> Vec<usize> {
577        let n = x.shape()[0];
578        let n_proto = self.prototypes.shape()[0];
579        (0..n)
580            .map(|i| {
581                let row = x.row(i).to_vec();
582                let best = (0..n_proto)
583                    .map(|j| {
584                        let p: Vec<f64> = self.prototypes.row(j).to_vec();
585                        (j, sq_euclid(&row, &p))
586                    })
587                    .min_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal))
588                    .map(|(j, _)| j)
589                    .unwrap_or(0);
590                self.prototype_labels[best]
591            })
592            .collect()
593    }
594}
595
596/// LVQ-1: Learning Vector Quantization.
597///
598/// Supervised prototype learning: attracts the nearest correct-class prototype
599/// and repels the nearest wrong-class prototype toward/away from each input.
600pub struct LVQ {
601    /// Configuration.
602    pub config: LvqConfig,
603}
604
605impl Default for LVQ {
606    fn default() -> Self {
607        Self {
608            config: LvqConfig::default(),
609        }
610    }
611}
612
613impl LVQ {
614    /// Create a new LVQ with the given config.
615    pub fn new(config: LvqConfig) -> Self {
616        Self { config }
617    }
618
619    /// Fit LVQ to labelled data.
620    ///
621    /// # Arguments
622    /// * `x` – Feature matrix `(n_samples, n_features)`.
623    /// * `y` – Class labels, values in `0..n_classes`.
624    pub fn fit(&self, x: ArrayView2<f64>, y: &[usize]) -> Result<LvqResult> {
625        let (n_samples, n_features) = (x.shape()[0], x.shape()[1]);
626        if n_samples == 0 {
627            return Err(ClusteringError::InvalidInput("Empty input data".into()));
628        }
629        if y.len() != n_samples {
630            return Err(ClusteringError::InvalidInput(
631                "y must have the same length as x rows".into(),
632            ));
633        }
634
635        let n_classes = y.iter().cloned().max().map(|m| m + 1).unwrap_or(0);
636        if n_classes == 0 {
637            return Err(ClusteringError::InvalidInput("Empty class labels".into()));
638        }
639
640        let ppc = self.config.prototypes_per_class;
641        let mut rng = self.config.seed;
642
643        // Initialise prototypes by sampling from each class.
644        let mut class_samples: Vec<Vec<usize>> = vec![Vec::new(); n_classes];
645        for (i, &label) in y.iter().enumerate() {
646            if label < n_classes {
647                class_samples[label].push(i);
648            }
649        }
650
651        let mut proto_weights: Vec<Vec<f64>> = Vec::new();
652        let mut proto_labels: Vec<usize> = Vec::new();
653
654        for class in 0..n_classes {
655            let samples = &class_samples[class];
656            if samples.is_empty() {
657                continue;
658            }
659            for _ in 0..ppc {
660                let idx = samples[lcg_usize(&mut rng, samples.len())];
661                proto_weights.push(x.row(idx).to_vec());
662                proto_labels.push(class);
663            }
664        }
665
666        let n_proto = proto_weights.len();
667        if n_proto == 0 {
668            return Err(ClusteringError::ComputationError(
669                "No prototypes initialized".into(),
670            ));
671        }
672
673        let total_steps = self.config.max_epochs * n_samples;
674
675        // LVQ-1 training loop.
676        for epoch in 0..self.config.max_epochs {
677            // Shuffle.
678            let mut order: Vec<usize> = (0..n_samples).collect();
679            for i in (1..n_samples).rev() {
680                let j = lcg_usize(&mut rng, i + 1);
681                order.swap(i, j);
682            }
683
684            for (step_in_epoch, &sample_idx) in order.iter().enumerate() {
685                let step = epoch * n_samples + step_in_epoch;
686                let t = step as f64 / total_steps.max(1) as f64;
687                let lr = self.config.lr_init * (self.config.lr_final / self.config.lr_init).powf(t);
688
689                let input = x.row(sample_idx).to_vec();
690                let true_class = y[sample_idx];
691
692                // Find nearest prototype.
693                let nearest = (0..n_proto)
694                    .map(|j| (j, sq_euclid(&input, &proto_weights[j])))
695                    .min_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal))
696                    .map(|(j, _)| j)
697                    .unwrap_or(0);
698
699                // Attract if correct class, repel otherwise.
700                let sign = if proto_labels[nearest] == true_class {
701                    1.0f64
702                } else {
703                    -1.0f64
704                };
705
706                let w = &mut proto_weights[nearest];
707                for k in 0..n_features {
708                    w[k] += lr * sign * (input[k] - w[k]);
709                }
710            }
711        }
712
713        // Build result array.
714        let mut proto_arr = Array2::<f64>::zeros((n_proto, n_features));
715        for (j, w) in proto_weights.iter().enumerate() {
716            for k in 0..n_features {
717                proto_arr[[j, k]] = w[k];
718            }
719        }
720
721        // Compute training accuracy.
722        let predictions = {
723            let n = n_samples;
724            (0..n)
725                .map(|i| {
726                    let row = x.row(i).to_vec();
727                    let best = (0..n_proto)
728                        .map(|j| (j, sq_euclid(&row, &proto_weights[j])))
729                        .min_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal))
730                        .map(|(j, _)| j)
731                        .unwrap_or(0);
732                    proto_labels[best]
733                })
734                .collect::<Vec<usize>>()
735        };
736
737        let correct = predictions
738            .iter()
739            .zip(y.iter())
740            .filter(|(&p, &t)| p == t)
741            .count();
742        let train_accuracy = correct as f64 / n_samples as f64;
743
744        Ok(LvqResult {
745            prototypes: proto_arr,
746            prototype_labels: proto_labels,
747            train_accuracy,
748        })
749    }
750}
751
752// ---------------------------------------------------------------------------
753// GLVQ — Generalized Learning Vector Quantization
754// ---------------------------------------------------------------------------
755
756/// Result of GLVQ training.
757#[derive(Debug, Clone)]
758pub struct GlvqResult {
759    /// Prototype weights `(n_prototypes, n_features)`.
760    pub prototypes: Array2<f64>,
761    /// Class label for each prototype.
762    pub prototype_labels: Vec<usize>,
763    /// Training accuracy.
764    pub train_accuracy: f64,
765    /// Final GLVQ cost.
766    pub cost: f64,
767}
768
769impl GlvqResult {
770    /// Predict class labels for `x`.
771    pub fn predict(&self, x: ArrayView2<f64>) -> Vec<usize> {
772        let n = x.shape()[0];
773        let n_proto = self.prototypes.shape()[0];
774        (0..n)
775            .map(|i| {
776                let row = x.row(i).to_vec();
777                let best = (0..n_proto)
778                    .map(|j| {
779                        let p = self.prototypes.row(j).to_vec();
780                        (j, sq_euclid(&row, &p))
781                    })
782                    .min_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal))
783                    .map(|(j, _)| j)
784                    .unwrap_or(0);
785                self.prototype_labels[best]
786            })
787            .collect()
788    }
789}
790
791/// Configuration for GLVQ.
792#[derive(Debug, Clone)]
793pub struct GlvqConfig {
794    /// Prototypes per class (default 1).
795    pub prototypes_per_class: usize,
796    /// Learning rate (default 0.01).
797    pub lr: f64,
798    /// Sigmoid squashing steepness (default 1.0).
799    pub sigma: f64,
800    /// Number of training epochs.
801    pub max_epochs: usize,
802    /// RNG seed.
803    pub seed: u64,
804}
805
806impl Default for GlvqConfig {
807    fn default() -> Self {
808        Self {
809            prototypes_per_class: 1,
810            lr: 0.01,
811            sigma: 1.0,
812            max_epochs: 100,
813            seed: 42,
814        }
815    }
816}
817
818/// Generalized LVQ (GLVQ) — soft-margin prototype learning.
819///
820/// Minimises a differentiable cost function based on the relative distances
821/// to the nearest correct (d+) and nearest incorrect (d-) prototypes:
822///
823///   μ(x) = (d+ - d-) / (d+ + d-)
824///
825/// The sigmoid of μ is minimised.  Gradients are computed w.r.t. both d+
826/// and d- prototypes.
827///
828/// Reference: Sato & Yamada, 1996.
829pub struct GLVQ {
830    /// Configuration.
831    pub config: GlvqConfig,
832}
833
834impl Default for GLVQ {
835    fn default() -> Self {
836        Self {
837            config: GlvqConfig::default(),
838        }
839    }
840}
841
842impl GLVQ {
843    /// Create a new GLVQ with the given config.
844    pub fn new(config: GlvqConfig) -> Self {
845        Self { config }
846    }
847
848    /// Fit GLVQ to labelled data.
849    pub fn fit(&self, x: ArrayView2<f64>, y: &[usize]) -> Result<GlvqResult> {
850        let (n_samples, n_features) = (x.shape()[0], x.shape()[1]);
851        if n_samples == 0 {
852            return Err(ClusteringError::InvalidInput("Empty input data".into()));
853        }
854        if y.len() != n_samples {
855            return Err(ClusteringError::InvalidInput("y length mismatch".into()));
856        }
857
858        let n_classes = y.iter().cloned().max().map(|m| m + 1).unwrap_or(0);
859        if n_classes < 2 {
860            return Err(ClusteringError::InvalidInput(
861                "GLVQ requires at least 2 classes".into(),
862            ));
863        }
864
865        let ppc = self.config.prototypes_per_class;
866        let mut rng = self.config.seed;
867
868        // Initialise prototypes from class samples.
869        let mut class_samples: Vec<Vec<usize>> = vec![Vec::new(); n_classes];
870        for (i, &label) in y.iter().enumerate() {
871            if label < n_classes {
872                class_samples[label].push(i);
873            }
874        }
875
876        let mut proto_weights: Vec<Vec<f64>> = Vec::new();
877        let mut proto_labels: Vec<usize> = Vec::new();
878
879        for class in 0..n_classes {
880            let samples = &class_samples[class];
881            if samples.is_empty() {
882                continue;
883            }
884            for _ in 0..ppc {
885                let idx = samples[lcg_usize(&mut rng, samples.len())];
886                proto_weights.push(x.row(idx).to_vec());
887                proto_labels.push(class);
888            }
889        }
890
891        let n_proto = proto_weights.len();
892        let lr = self.config.lr;
893        let sigma = self.config.sigma;
894
895        let mut total_cost = 0.0f64;
896
897        // GLVQ training loop.
898        for _epoch in 0..self.config.max_epochs {
899            // Shuffle.
900            let mut order: Vec<usize> = (0..n_samples).collect();
901            for i in (1..n_samples).rev() {
902                let j = lcg_usize(&mut rng, i + 1);
903                order.swap(i, j);
904            }
905
906            total_cost = 0.0;
907            for &sample_idx in &order {
908                let input = x.row(sample_idx).to_vec();
909                let true_class = y[sample_idx];
910
911                // Find nearest same-class prototype (winner+) and nearest other-class prototype (winner-).
912                let mut d_plus = f64::INFINITY;
913                let mut d_minus = f64::INFINITY;
914                let mut winner_plus = 0usize;
915                let mut winner_minus = 0usize;
916
917                for j in 0..n_proto {
918                    let d = sq_euclid(&input, &proto_weights[j]);
919                    if proto_labels[j] == true_class {
920                        if d < d_plus {
921                            d_plus = d;
922                            winner_plus = j;
923                        }
924                    } else if d < d_minus {
925                        d_minus = d;
926                        winner_minus = j;
927                    }
928                }
929
930                if d_plus.is_infinite() || d_minus.is_infinite() {
931                    continue;
932                }
933
934                let denom = d_plus + d_minus;
935                if denom < 1e-12 {
936                    continue;
937                }
938
939                let mu = (d_plus - d_minus) / denom;
940                // Sigmoid activation: f(mu) = 1 / (1 + exp(-sigma * mu))
941                let f_mu = 1.0 / (1.0 + (-sigma * mu).exp());
942                // Derivative: f'(mu) = sigma * f(mu) * (1 - f(mu))
943                let f_prime = sigma * f_mu * (1.0 - f_mu);
944
945                total_cost += f_mu;
946
947                // Gradient w.r.t. d+:  f'(mu) * 2 * d- / denom^2
948                let grad_dp = f_prime * (2.0 * d_minus) / (denom * denom);
949                // Gradient w.r.t. d-: -f'(mu) * 2 * d+ / denom^2
950                let grad_dm = -f_prime * (2.0 * d_plus) / (denom * denom);
951
952                // Update winner+: gradient descent w.r.t. d+ = ||x - w+||^2
953                // dL/dw+ = 2 * grad_dp * (w+ - x)
954                let wp = &mut proto_weights[winner_plus];
955                for k in 0..n_features {
956                    wp[k] -= lr * 2.0 * grad_dp * (wp[k] - input[k]);
957                }
958
959                // Update winner-: gradient descent w.r.t. d- = ||x - w-||^2
960                // dL/dw- = 2 * grad_dm * (w- - x)
961                let wm = &mut proto_weights[winner_minus];
962                for k in 0..n_features {
963                    wm[k] -= lr * 2.0 * grad_dm * (wm[k] - input[k]);
964                }
965            }
966        }
967
968        // Build output.
969        let mut proto_arr = Array2::<f64>::zeros((n_proto, n_features));
970        for (j, w) in proto_weights.iter().enumerate() {
971            for k in 0..n_features {
972                proto_arr[[j, k]] = w[k];
973            }
974        }
975
976        // Compute training accuracy.
977        let mut correct = 0usize;
978        for i in 0..n_samples {
979            let row = x.row(i).to_vec();
980            let best = (0..n_proto)
981                .map(|j| (j, sq_euclid(&row, &proto_weights[j])))
982                .min_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal))
983                .map(|(j, _)| j)
984                .unwrap_or(0);
985            if proto_labels[best] == y[i] {
986                correct += 1;
987            }
988        }
989        let train_accuracy = correct as f64 / n_samples as f64;
990
991        Ok(GlvqResult {
992            prototypes: proto_arr,
993            prototype_labels: proto_labels,
994            train_accuracy,
995            cost: total_cost,
996        })
997    }
998}
999
1000// ---------------------------------------------------------------------------
1001// Tests
1002// ---------------------------------------------------------------------------
1003
1004#[cfg(test)]
1005mod tests {
1006    use super::*;
1007    use scirs2_core::ndarray::Array2;
1008
1009    fn two_cluster_data() -> (Array2<f64>, Vec<usize>) {
1010        let x = Array2::from_shape_vec(
1011            (12, 2),
1012            vec![
1013                0.0, 0.0, 0.1, 0.0, 0.0, 0.1, 0.2, 0.0, 0.1, 0.1, 0.0, 0.2, 5.0, 5.0, 5.1, 5.0,
1014                5.0, 5.1, 5.2, 5.0, 5.1, 5.1, 5.0, 5.2,
1015            ],
1016        )
1017        .expect("valid shape");
1018        let y = vec![0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1];
1019        (x, y)
1020    }
1021
1022    #[test]
1023    fn test_neural_gas_basic() {
1024        let (x, _) = two_cluster_data();
1025        let ng = NeuralGas::default();
1026        let result = ng.fit(x.view(), 2, 20).expect("neural gas fit");
1027        assert_eq!(result.n_units, 2);
1028        assert_eq!(result.labels.len(), 12);
1029        assert!(result.quantization_error >= 0.0);
1030    }
1031
1032    #[test]
1033    fn test_neural_gas_n_units_gt_samples() {
1034        let (x, _) = two_cluster_data();
1035        let ng = NeuralGas::default();
1036        // More units than well-separated samples still works.
1037        let result = ng.fit(x.view(), 5, 10).expect("ng many units");
1038        assert_eq!(result.n_units, 5);
1039    }
1040
1041    #[test]
1042    fn test_neural_gas_single_unit() {
1043        let (x, _) = two_cluster_data();
1044        let ng = NeuralGas::default();
1045        let result = ng.fit(x.view(), 1, 10).expect("ng 1 unit");
1046        assert_eq!(result.n_units, 1);
1047        assert!(result.labels.iter().all(|&l| l == 0));
1048    }
1049
1050    #[test]
1051    fn test_growing_neural_gas_basic() {
1052        let (x, _) = two_cluster_data();
1053        let config = GngConfig {
1054            max_steps: 200,
1055            insert_interval: 20,
1056            max_units: 10,
1057            seed: 7,
1058            ..GngConfig::default()
1059        };
1060        let gng = GrowingNeuralGas::new(config);
1061        let result = gng.fit(x.view()).expect("gng fit");
1062        assert!(result.prototypes.shape()[0] >= 2, "should have grown");
1063        assert_eq!(result.labels.len(), 12);
1064    }
1065
1066    #[test]
1067    fn test_lvq_two_classes() {
1068        let (x, y) = two_cluster_data();
1069        let config = LvqConfig {
1070            prototypes_per_class: 1,
1071            lr_init: 0.3,
1072            lr_final: 0.01,
1073            max_epochs: 100,
1074            seed: 42,
1075        };
1076        let lvq = LVQ::new(config);
1077        let result = lvq.fit(x.view(), &y).expect("lvq fit");
1078        assert_eq!(result.prototypes.shape()[0], 2); // 1 per class × 2 classes
1079                                                     // Well-separated data should give high accuracy.
1080        assert!(
1081            result.train_accuracy > 0.8,
1082            "expected > 80% accuracy, got {}",
1083            result.train_accuracy
1084        );
1085    }
1086
1087    #[test]
1088    fn test_lvq_predict() {
1089        let (x, y) = two_cluster_data();
1090        let lvq = LVQ::default();
1091        let result = lvq.fit(x.view(), &y).expect("lvq fit");
1092        let preds = result.predict(x.view());
1093        assert_eq!(preds.len(), 12);
1094    }
1095
1096    #[test]
1097    fn test_glvq_two_classes() {
1098        let (x, y) = two_cluster_data();
1099        let config = GlvqConfig {
1100            prototypes_per_class: 1,
1101            lr: 0.05,
1102            sigma: 1.0,
1103            max_epochs: 200,
1104            seed: 42,
1105        };
1106        let glvq = GLVQ::new(config);
1107        let result = glvq.fit(x.view(), &y).expect("glvq fit");
1108        assert_eq!(result.prototypes.shape()[0], 2);
1109        assert!(
1110            result.train_accuracy > 0.8,
1111            "expected > 80% accuracy, got {}",
1112            result.train_accuracy
1113        );
1114    }
1115
1116    #[test]
1117    fn test_glvq_predict() {
1118        let (x, y) = two_cluster_data();
1119        let glvq = GLVQ::default();
1120        let result = glvq.fit(x.view(), &y).expect("glvq fit");
1121        let preds = result.predict(x.view());
1122        assert_eq!(preds.len(), 12);
1123    }
1124
1125    #[test]
1126    fn test_lvq_invalid_input() {
1127        let (x, _y) = two_cluster_data();
1128        let lvq = LVQ::default();
1129        // Wrong y length.
1130        assert!(lvq.fit(x.view(), &[0, 1, 0]).is_err());
1131    }
1132
1133    #[test]
1134    fn test_glvq_single_class_error() {
1135        let x = Array2::from_shape_vec((4, 2), vec![0.0, 0.0, 0.1, 0.1, 0.2, 0.0, 0.3, 0.1])
1136            .expect("shape");
1137        let y = vec![0usize, 0, 0, 0];
1138        let glvq = GLVQ::default();
1139        assert!(glvq.fit(x.view(), &y).is_err(), "single class should error");
1140    }
1141}