Skip to main content

scirs2_cluster/
deep_cluster.rs

1//! Deep Clustering and Advanced Kernel-Based Clustering Algorithms
2//!
3//! This module provides:
4//! - Kernel K-Means: K-means in reproducing kernel Hilbert space
5//! - Trimmed K-Means: Robust clustering that ignores outliers
6//! - Dirichlet Process Mixture (CRP Gibbs sampler): Non-parametric Bayesian clustering
7//! - Fuzzy C-Means: Soft clustering with fuzzy membership
8//!
9//! # References
10//! - Scholkopf et al. (1998) "Kernel PCA and De-Noising in Feature Spaces"
11//! - Cuesta-Albertos et al. (1997) "Trimmed k-means: an attempt to robustify quantizers"
12//! - Ferguson (1973) "A Bayesian Analysis of Some Nonparametric Problems"
13//! - Bezdek (1981) "Pattern Recognition with Fuzzy Objective Function Algorithms"
14
15use std::f64::consts::TAU;
16
17use crate::error::{ClusteringError, Result};
18
19// ─────────────────────────────────────────────────────────────────────────────
20// Simple deterministic LCG / Park-Miller RNG (no external rand dependency)
21// ─────────────────────────────────────────────────────────────────────────────
22
23/// Park-Miller LCG random number generator.
24struct Lcg {
25    state: u64,
26}
27
28impl Lcg {
29    fn new(seed: u64) -> Self {
30        let state = if seed == 0 { 6364136223846793005 } else { seed };
31        Self { state }
32    }
33
34    /// Returns a value in [0, 1)
35    fn next_f64(&mut self) -> f64 {
36        // Knuth multiplicative LCG
37        self.state = self
38            .state
39            .wrapping_mul(6364136223846793005)
40            .wrapping_add(1442695040888963407);
41        let bits = (self.state >> 11) as f64;
42        bits / (1u64 << 53) as f64
43    }
44
45    /// Returns a value in [low, high)
46    fn next_range_usize(&mut self, low: usize, high: usize) -> usize {
47        if low >= high {
48            return low;
49        }
50        let span = (high - low) as f64;
51        low + (self.next_f64() * span) as usize
52    }
53
54    /// Standard normal via Box-Muller
55    fn next_normal(&mut self) -> f64 {
56        let u1 = self.next_f64().max(1e-15);
57        let u2 = self.next_f64();
58        (-2.0 * u1.ln()).sqrt() * (TAU * u2).cos()
59    }
60}
61
62// ─────────────────────────────────────────────────────────────────────────────
63// Kernel Types
64// ─────────────────────────────────────────────────────────────────────────────
65
66/// Kernel function types for Kernel K-Means
67#[derive(Clone, Debug)]
68pub enum KernelType {
69    /// Linear kernel: k(x,y) = x·y
70    Linear,
71    /// Polynomial kernel: k(x,y) = (gamma·x·y + coef0)^degree
72    Polynomial { degree: u32, coef0: f64, gamma: f64 },
73    /// RBF / Gaussian kernel: k(x,y) = exp(-gamma * ||x-y||^2)
74    Rbf { gamma: f64 },
75    /// Sigmoid kernel: k(x,y) = tanh(gamma·x·y + coef0)
76    Sigmoid { coef0: f64, gamma: f64 },
77}
78
79impl KernelType {
80    /// Evaluate kernel between two vectors.
81    pub fn compute(&self, x: &[f64], y: &[f64]) -> f64 {
82        debug_assert_eq!(x.len(), y.len(), "kernel vectors must have same dimension");
83        match self {
84            KernelType::Linear => dot(x, y),
85            KernelType::Polynomial {
86                degree,
87                coef0,
88                gamma,
89            } => (gamma * dot(x, y) + coef0).powi(*degree as i32),
90            KernelType::Rbf { gamma } => {
91                let sq = sq_dist(x, y);
92                (-gamma * sq).exp()
93            }
94            KernelType::Sigmoid { coef0, gamma } => (gamma * dot(x, y) + coef0).tanh(),
95        }
96    }
97}
98
99fn dot(a: &[f64], b: &[f64]) -> f64 {
100    a.iter().zip(b.iter()).map(|(ai, bi)| ai * bi).sum()
101}
102
103fn sq_dist(a: &[f64], b: &[f64]) -> f64 {
104    a.iter()
105        .zip(b.iter())
106        .map(|(ai, bi)| (ai - bi).powi(2))
107        .sum()
108}
109
110// ─────────────────────────────────────────────────────────────────────────────
111// Kernel K-Means
112// ─────────────────────────────────────────────────────────────────────────────
113
114/// Build the full kernel matrix K[i,j] = kernel(x_i, x_j).
115fn build_kernel_matrix(data: &[Vec<f64>], kernel: &KernelType) -> Vec<Vec<f64>> {
116    let n = data.len();
117    let mut k = vec![vec![0.0f64; n]; n];
118    for i in 0..n {
119        k[i][i] = kernel.compute(&data[i], &data[i]);
120        for j in (i + 1)..n {
121            let v = kernel.compute(&data[i], &data[j]);
122            k[i][j] = v;
123            k[j][i] = v;
124        }
125    }
126    k
127}
128
129/// Compute the kernel k-means objective given assignments.
130///
131/// The distance in RKHS between x_i and cluster centre c_l is:
132///   d²(φ(x_i), μ_l) = K(i,i) - 2/|Cl| Σ_{j∈Cl} K(i,j) + 1/|Cl|² Σ_{j,k∈Cl} K(j,k)
133fn kernel_kmeans_objective(k_mat: &[Vec<f64>], labels: &[usize], n_clusters: usize) -> f64 {
134    let _n = labels.len();
135    // cluster member lists
136    let mut members: Vec<Vec<usize>> = vec![Vec::new(); n_clusters];
137    for (i, &l) in labels.iter().enumerate() {
138        members[l].push(i);
139    }
140    let mut total = 0.0f64;
141    for (i, &l) in labels.iter().enumerate() {
142        let cl = &members[l];
143        let sz = cl.len() as f64;
144        if sz == 0.0 {
145            continue;
146        }
147        // K(i,i)
148        let kii = k_mat[i][i];
149        // 2/sz * sum_j K(i,j) for j in cluster
150        let cross: f64 = cl.iter().map(|&j| k_mat[i][j]).sum::<f64>();
151        // 1/sz^2 * sum_{j,k in cluster} K(j,k)
152        let inner: f64 = cl
153            .iter()
154            .flat_map(|&j| cl.iter().map(move |&kk| k_mat[j][kk]))
155            .sum::<f64>();
156        total += kii - 2.0 * cross / sz + inner / (sz * sz);
157    }
158    total
159}
160
161/// Assign each point to the cluster whose RKHS centroid it is closest to.
162fn kernel_kmeans_assign(k_mat: &[Vec<f64>], labels: &[usize], n_clusters: usize) -> Vec<usize> {
163    let n = labels.len();
164    // precompute cluster sums needed for distance
165    let mut members: Vec<Vec<usize>> = vec![Vec::new(); n_clusters];
166    for (i, &l) in labels.iter().enumerate() {
167        members[l].push(i);
168    }
169    // For each cluster l: inner[l] = (1/|Cl|^2) * sum_{j,k in Cl} K(j,k)
170    let inner: Vec<f64> = (0..n_clusters)
171        .map(|l| {
172            let cl = &members[l];
173            let sz = cl.len() as f64;
174            if sz == 0.0 {
175                return f64::INFINITY;
176            }
177            let s: f64 = cl
178                .iter()
179                .flat_map(|&j| cl.iter().map(move |&kk| k_mat[j][kk]))
180                .sum();
181            s / (sz * sz)
182        })
183        .collect();
184
185    // For each cluster l: cross_sum[i][l] = (1/|Cl|) * sum_{j in Cl} K(i,j)
186    let mut new_labels = vec![0usize; n];
187    for i in 0..n {
188        let mut best_l = 0;
189        let mut best_dist = f64::INFINITY;
190        for l in 0..n_clusters {
191            let cl = &members[l];
192            let sz = cl.len() as f64;
193            if sz == 0.0 {
194                continue;
195            }
196            let cross: f64 = cl.iter().map(|&j| k_mat[i][j]).sum::<f64>();
197            let dist = k_mat[i][i] - 2.0 * cross / sz + inner[l];
198            if dist < best_dist {
199                best_dist = dist;
200                best_l = l;
201            }
202        }
203        new_labels[i] = best_l;
204    }
205    new_labels
206}
207
208/// Kernel K-Means clustering.
209///
210/// Performs K-means in a reproducing kernel Hilbert space, allowing non-linear
211/// cluster boundaries in the original feature space.
212///
213/// # Arguments
214/// - `data`: slice of n feature vectors, each of length d
215/// - `n_clusters`: number of clusters K
216/// - `kernel`: kernel function defining the feature space
217/// - `max_iter`: maximum EM iterations
218/// - `n_init`: number of random restarts (best is kept)
219/// - `seed`: RNG seed for reproducibility
220///
221/// # Returns
222/// `(labels, inertia)` — cluster assignments (0..K-1) and kernel objective value
223///
224/// # Errors
225/// Returns an error when input is empty or `n_clusters` > n_samples.
226pub fn kernel_kmeans(
227    data: &[Vec<f64>],
228    n_clusters: usize,
229    kernel: KernelType,
230    max_iter: usize,
231    n_init: usize,
232    seed: u64,
233) -> Result<(Vec<usize>, f64)> {
234    if data.is_empty() {
235        return Err(ClusteringError::InvalidInput(
236            "data must not be empty".into(),
237        ));
238    }
239    if n_clusters == 0 {
240        return Err(ClusteringError::InvalidInput(
241            "n_clusters must be >= 1".into(),
242        ));
243    }
244    if n_clusters > data.len() {
245        return Err(ClusteringError::InvalidInput(format!(
246            "n_clusters ({}) > n_samples ({})",
247            n_clusters,
248            data.len()
249        )));
250    }
251    let n = data.len();
252    let k_mat = build_kernel_matrix(data, &kernel);
253    let _rng = Lcg::new(seed); // seed base, actual work uses run_rng
254
255    let mut best_labels = vec![0usize; n];
256    let mut best_obj = f64::INFINITY;
257
258    for run in 0..n_init.max(1) {
259        // Random initialisation: assign each point uniformly at random, then
260        // ensure every cluster has at least one member via forced seeding.
261        let run_seed = seed.wrapping_add(run as u64).wrapping_add(1);
262        let mut run_rng = Lcg::new(run_seed);
263
264        // Force first k distinct points as cluster centres
265        let mut labels = vec![0usize; n];
266        // shuffle indices to pick k seeds
267        let mut idx: Vec<usize> = (0..n).collect();
268        // Fisher-Yates for k steps
269        for i in 0..n_clusters {
270            let j = run_rng.next_range_usize(i, n);
271            idx.swap(i, j);
272            labels[idx[i]] = i;
273        }
274        // assign remaining randomly
275        for i in n_clusters..n {
276            labels[i] = run_rng.next_range_usize(0, n_clusters);
277        }
278
279        // Iteratively re-assign
280        for _ in 0..max_iter {
281            let new_labels = kernel_kmeans_assign(&k_mat, &labels, n_clusters);
282            // Check if any cluster became empty and re-seed if so
283            let mut counts = vec![0usize; n_clusters];
284            for &l in &new_labels {
285                counts[l] += 1;
286            }
287            let empty = counts.iter().any(|&c| c == 0);
288            if empty {
289                // Keep old labels for empty clusters — restart inner loop
290                break;
291            }
292            if new_labels == labels {
293                break;
294            }
295            labels = new_labels;
296        }
297
298        let obj = kernel_kmeans_objective(&k_mat, &labels, n_clusters);
299        if obj < best_obj {
300            best_obj = obj;
301            best_labels = labels;
302        }
303        // run_rng used per-iteration
304    }
305
306    Ok((best_labels, best_obj))
307}
308
309// ─────────────────────────────────────────────────────────────────────────────
310// Trimmed K-Means
311// ─────────────────────────────────────────────────────────────────────────────
312
313/// Trimmed K-Means (Cuesta-Albertos et al., 1997).
314///
315/// At each iteration a fraction `trim_ratio` of the points with the largest
316/// distance to their assigned centroid are treated as outliers (label = None)
317/// and excluded from centroid updates.
318///
319/// # Arguments
320/// - `data`: n feature vectors of length d
321/// - `n_clusters`: number of clusters K
322/// - `trim_ratio`: fraction in `[0, 0.5)` of samples to trim as outliers
323/// - `max_iter`: maximum EM iterations
324/// - `seed`: RNG seed
325///
326/// # Returns
327/// `(labels, centroids)` where `labels[i] = None` for trimmed points.
328///
329/// # Errors
330/// Returns an error if data is empty, n_clusters is 0, or trim_ratio is out of range.
331pub fn trimmed_kmeans(
332    data: &[Vec<f64>],
333    n_clusters: usize,
334    trim_ratio: f64,
335    max_iter: usize,
336    seed: u64,
337) -> Result<(Vec<Option<usize>>, Vec<Vec<f64>>)> {
338    if data.is_empty() {
339        return Err(ClusteringError::InvalidInput(
340            "data must not be empty".into(),
341        ));
342    }
343    if n_clusters == 0 {
344        return Err(ClusteringError::InvalidInput(
345            "n_clusters must be >= 1".into(),
346        ));
347    }
348    if !(0.0..0.5).contains(&trim_ratio) {
349        return Err(ClusteringError::InvalidInput(
350            "trim_ratio must be in [0, 0.5)".into(),
351        ));
352    }
353    let n = data.len();
354    let d = data[0].len();
355    if n_clusters > n {
356        return Err(ClusteringError::InvalidInput(format!(
357            "n_clusters ({}) > n_samples ({})",
358            n_clusters, n
359        )));
360    }
361
362    let n_trim = (n as f64 * trim_ratio).floor() as usize;
363    let n_active = n - n_trim;
364    if n_active < n_clusters {
365        return Err(ClusteringError::InvalidInput(
366            "After trimming, too few points remain for the requested n_clusters".into(),
367        ));
368    }
369
370    let mut rng = Lcg::new(seed);
371
372    // Initialise centroids via k-means++ style
373    let mut centroids = kmeans_plus_plus_init(data, n_clusters, &mut rng);
374
375    let mut labels = vec![None::<usize>; n];
376
377    for _iter in 0..max_iter {
378        // Step 1: assign each point to nearest centroid
379        let mut dists: Vec<(usize, f64)> = (0..n)
380            .map(|i| {
381                let (cl, dist) = nearest_centroid(&data[i], &centroids);
382                (cl, dist)
383            })
384            .collect();
385
386        // Step 2: trim n_trim points with largest assignment distance
387        let mut order: Vec<usize> = (0..n).collect();
388        order.sort_by(|&a, &b| {
389            dists[b]
390                .1
391                .partial_cmp(&dists[a].1)
392                .unwrap_or(std::cmp::Ordering::Equal)
393        });
394        let trimmed_set: std::collections::HashSet<usize> =
395            order[..n_trim].iter().cloned().collect();
396
397        // Step 3: update labels
398        for i in 0..n {
399            if trimmed_set.contains(&i) {
400                labels[i] = None;
401            } else {
402                labels[i] = Some(dists[i].0);
403            }
404        }
405
406        // Step 4: update centroids using only active (non-trimmed) points
407        let mut new_centroids = vec![vec![0.0f64; d]; n_clusters];
408        let mut counts = vec![0usize; n_clusters];
409        for (i, lbl) in labels.iter().enumerate() {
410            if let Some(l) = lbl {
411                for (feat, &v) in new_centroids[*l].iter_mut().zip(data[i].iter()) {
412                    *feat += v;
413                }
414                counts[*l] += 1;
415            }
416        }
417        let mut changed = false;
418        for l in 0..n_clusters {
419            if counts[l] > 0 {
420                let old = &centroids[l];
421                let new_c: Vec<f64> = new_centroids[l]
422                    .iter()
423                    .map(|&s| s / counts[l] as f64)
424                    .collect();
425                let diff: f64 = old
426                    .iter()
427                    .zip(new_c.iter())
428                    .map(|(a, b)| (a - b).powi(2))
429                    .sum::<f64>()
430                    .sqrt();
431                if diff > 1e-10 {
432                    changed = true;
433                }
434                centroids[l] = new_c;
435            }
436        }
437        if !changed {
438            break;
439        }
440    }
441
442    Ok((labels, centroids))
443}
444
445fn nearest_centroid(point: &[f64], centroids: &[Vec<f64>]) -> (usize, f64) {
446    let mut best_c = 0;
447    let mut best_d = f64::INFINITY;
448    for (i, c) in centroids.iter().enumerate() {
449        let d: f64 = point
450            .iter()
451            .zip(c.iter())
452            .map(|(a, b)| (a - b).powi(2))
453            .sum();
454        if d < best_d {
455            best_d = d;
456            best_c = i;
457        }
458    }
459    (best_c, best_d)
460}
461
462fn kmeans_plus_plus_init(data: &[Vec<f64>], k: usize, rng: &mut Lcg) -> Vec<Vec<f64>> {
463    let n = data.len();
464    let first = rng.next_range_usize(0, n);
465    let mut centroids = vec![data[first].clone()];
466    for _ in 1..k {
467        let dists: Vec<f64> = data
468            .iter()
469            .map(|x| {
470                centroids
471                    .iter()
472                    .map(|c| sq_dist(x, c))
473                    .fold(f64::INFINITY, f64::min)
474            })
475            .collect();
476        let total: f64 = dists.iter().sum();
477        let target = rng.next_f64() * total;
478        let mut cumsum = 0.0;
479        let mut chosen = n - 1;
480        for (i, &d) in dists.iter().enumerate() {
481            cumsum += d;
482            if cumsum >= target {
483                chosen = i;
484                break;
485            }
486        }
487        centroids.push(data[chosen].clone());
488    }
489    centroids
490}
491
492// ─────────────────────────────────────────────────────────────────────────────
493// Dirichlet Process Mixture (Chinese Restaurant Process Gibbs sampler)
494// ─────────────────────────────────────────────────────────────────────────────
495
496/// Dirichlet Process Mixture model estimated via collapsed Gibbs sampling.
497///
498/// Uses a conjugate Normal-Wishart prior under the hood with simplified
499/// spherical (isotropic) Gaussians per component, allowing automatic
500/// inference of the number of clusters.
501///
502/// # References
503/// - Neal (2000) "Markov Chain Sampling Methods for Dirichlet Process Mixture Models"
504#[derive(Debug, Clone)]
505pub struct DpMixture {
506    /// Concentration parameter α (higher → more clusters)
507    pub alpha: f64,
508    /// Number of active components found after fitting
509    pub n_components: usize,
510    /// Component weights (mixing proportions)
511    pub weights: Vec<f64>,
512    /// Component means (one `Vec<f64>` per component)
513    pub means: Vec<Vec<f64>>,
514    /// Precision parameters per component (scalar per component → isotropic)
515    pub concentrations: Vec<f64>,
516}
517
518impl DpMixture {
519    /// Create a new DpMixture with concentration parameter `alpha`.
520    pub fn new(alpha: f64) -> Self {
521        Self {
522            alpha,
523            n_components: 0,
524            weights: Vec::new(),
525            means: Vec::new(),
526            concentrations: Vec::new(),
527        }
528    }
529
530    /// Fit via collapsed Gibbs sampling (Algorithm 3 from Neal 2000 simplified).
531    ///
532    /// # Returns
533    /// Cluster label assignments for each sample.
534    pub fn fit(&mut self, data: &[Vec<f64>], max_iter: usize, seed: u64) -> Vec<usize> {
535        if data.is_empty() {
536            self.n_components = 0;
537            return Vec::new();
538        }
539        let n = data.len();
540        let d = data[0].len();
541        let mut rng = Lcg::new(seed);
542
543        // Prior hyperparameters (conjugate Normal with known spherical precision)
544        let prior_mean = vec![0.0f64; d];
545        let prior_kappa = 1.0f64; // prior pseudo-count
546        let lambda = 1.0f64; // prior precision (isotropic)
547
548        // Initialise: each point gets its own cluster (CRP start)
549        let mut assignments: Vec<usize> = (0..n).collect();
550        // component_members[k] = list of sample indices in component k
551        let mut component_members: Vec<Vec<usize>> = (0..n).map(|i| vec![i]).collect();
552
553        for _iter in 0..max_iter {
554            for i in 0..n {
555                let xi = &data[i];
556                let current_k = assignments[i];
557
558                // Remove i from its current component
559                component_members[current_k].retain(|&m| m != i);
560
561                // Remove empty components
562                let alive: Vec<usize> = (0..component_members.len())
563                    .filter(|&k| !component_members[k].is_empty())
564                    .collect();
565                // Re-index
566                let new_members: Vec<Vec<usize>> = alive
567                    .iter()
568                    .map(|&k| component_members[k].clone())
569                    .collect();
570                // Update assignments to new indices
571                for j in 0..n {
572                    if j == i {
573                        continue;
574                    }
575                    let old_k = assignments[j];
576                    if let Some(pos) = alive.iter().position(|&k| k == old_k) {
577                        assignments[j] = pos;
578                    }
579                }
580                component_members = new_members;
581
582                let k_live = component_members.len();
583
584                // Compute posterior predictive probabilities for each existing cluster
585                let mut log_probs: Vec<f64> = Vec::with_capacity(k_live + 1);
586                for k in 0..k_live {
587                    let members = &component_members[k];
588                    let n_k = members.len() as f64;
589                    // Posterior Normal-Normal predictive
590                    let kappa_n = prior_kappa + n_k;
591                    let mu_n: Vec<f64> = {
592                        let mut s = prior_mean.clone();
593                        for &m in members.iter() {
594                            for (f, &v) in s.iter_mut().zip(data[m].iter()) {
595                                *f += v;
596                            }
597                        }
598                        s.iter().map(|&v| v / (prior_kappa + n_k)).collect()
599                    };
600                    let pred_var = (kappa_n + 1.0) / (kappa_n * lambda);
601                    // log p(xi | cluster k data) — Gaussian predictive
602                    let log_lik: f64 = xi
603                        .iter()
604                        .zip(mu_n.iter())
605                        .map(|(&xf, &mf)| {
606                            let z = (xf - mf).powi(2);
607                            -0.5 * (z / pred_var + (TAU * pred_var).ln())
608                        })
609                        .sum();
610                    let log_prior = (n_k / (n as f64 - 1.0 + self.alpha)).ln();
611                    log_probs.push(log_prior + log_lik);
612                }
613
614                // New cluster probability
615                let pred_var_new = (prior_kappa + 1.0) / (prior_kappa * lambda);
616                let log_lik_new: f64 = xi
617                    .iter()
618                    .zip(prior_mean.iter())
619                    .map(|(&xf, &mf)| {
620                        let z = (xf - mf).powi(2);
621                        -0.5 * (z / pred_var_new + (TAU * pred_var_new).ln())
622                    })
623                    .sum();
624                let log_prior_new = (self.alpha / (n as f64 - 1.0 + self.alpha)).ln();
625                log_probs.push(log_prior_new + log_lik_new);
626
627                // Numerically stable softmax sampling
628                let max_lp = log_probs.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
629                let probs: Vec<f64> = log_probs.iter().map(|&lp| (lp - max_lp).exp()).collect();
630                let total: f64 = probs.iter().sum();
631                let u = rng.next_f64() * total;
632                let mut cumsum = 0.0;
633                let mut chosen_k = probs.len() - 1;
634                for (idx, &p) in probs.iter().enumerate() {
635                    cumsum += p;
636                    if cumsum >= u {
637                        chosen_k = idx;
638                        break;
639                    }
640                }
641
642                if chosen_k == k_live {
643                    // New component
644                    component_members.push(vec![i]);
645                    assignments[i] = k_live;
646                } else {
647                    component_members[chosen_k].push(i);
648                    assignments[i] = chosen_k;
649                }
650            }
651        }
652
653        // Populate model parameters from final state
654        let k_final = component_members.len();
655        self.n_components = k_final;
656        self.weights = component_members
657            .iter()
658            .map(|m| m.len() as f64 / n as f64)
659            .collect();
660        self.means = component_members
661            .iter()
662            .map(|members| {
663                let mut mu = vec![0.0f64; d];
664                for &m in members.iter() {
665                    for (f, &v) in mu.iter_mut().zip(data[m].iter()) {
666                        *f += v;
667                    }
668                }
669                mu.iter().map(|&v| v / members.len() as f64).collect()
670            })
671            .collect();
672        self.concentrations = vec![lambda; k_final];
673
674        assignments
675    }
676
677    /// Number of active (non-empty) clusters found after fitting.
678    pub fn n_clusters(&self) -> usize {
679        self.n_components
680    }
681}
682
683// ─────────────────────────────────────────────────────────────────────────────
684// Fuzzy C-Means
685// ─────────────────────────────────────────────────────────────────────────────
686
687/// Fuzzy C-Means clustering (Bezdek 1981).
688///
689/// Each point has a soft membership degree `u[i][c]` to each cluster c, with
690/// degrees summing to 1 over clusters. The fuzziness exponent `m > 1` controls
691/// how crisp (→1) or diffuse (→∞) the memberships are.
692///
693/// # Arguments
694/// - `data`: n feature vectors of length d
695/// - `n_clusters`: number of clusters K
696/// - `fuzziness`: exponent m (> 1, typically 2.0)
697/// - `max_iter`: maximum EM iterations
698/// - `tol`: convergence threshold on centroid change
699/// - `seed`: RNG seed
700///
701/// # Returns
702/// `(centroids, membership_matrix)` where `membership_matrix` is n × K.
703///
704/// # Errors
705/// Returns an error for invalid parameters.
706pub fn fuzzy_cmeans(
707    data: &[Vec<f64>],
708    n_clusters: usize,
709    fuzziness: f64,
710    max_iter: usize,
711    tol: f64,
712    seed: u64,
713) -> Result<(Vec<Vec<f64>>, Vec<Vec<f64>>)> {
714    if data.is_empty() {
715        return Err(ClusteringError::InvalidInput(
716            "data must not be empty".into(),
717        ));
718    }
719    if n_clusters == 0 {
720        return Err(ClusteringError::InvalidInput(
721            "n_clusters must be >= 1".into(),
722        ));
723    }
724    if fuzziness <= 1.0 {
725        return Err(ClusteringError::InvalidInput(
726            "fuzziness (m) must be > 1.0".into(),
727        ));
728    }
729    if n_clusters > data.len() {
730        return Err(ClusteringError::InvalidInput(format!(
731            "n_clusters ({}) > n_samples ({})",
732            n_clusters,
733            data.len()
734        )));
735    }
736
737    let n = data.len();
738    let d = data[0].len();
739    let m = fuzziness;
740    let exp = 2.0 / (m - 1.0);
741
742    // Initialise memberships randomly
743    let mut rng = Lcg::new(seed);
744    let mut u: Vec<Vec<f64>> = (0..n)
745        .map(|_| {
746            let raw: Vec<f64> = (0..n_clusters).map(|_| rng.next_f64() + 1e-12).collect();
747            let s: f64 = raw.iter().sum();
748            raw.iter().map(|&v| v / s).collect()
749        })
750        .collect();
751
752    let mut centroids = compute_fuzzy_centroids(data, &u, n_clusters, m, d);
753
754    for _iter in 0..max_iter {
755        // Update membership matrix
756        let mut new_u = vec![vec![0.0f64; n_clusters]; n];
757        for i in 0..n {
758            let dists: Vec<f64> = (0..n_clusters)
759                .map(|c| sq_dist(&data[i], &centroids[c]).max(1e-30))
760                .collect();
761            // Check if point is exactly on a centroid
762            let exact: Vec<usize> = dists
763                .iter()
764                .enumerate()
765                .filter(|(_, &d)| d < 1e-30)
766                .map(|(c, _)| c)
767                .collect();
768            if !exact.is_empty() {
769                let share = 1.0 / exact.len() as f64;
770                for &c in &exact {
771                    new_u[i][c] = share;
772                }
773            } else {
774                for c in 0..n_clusters {
775                    let ratio_sum: f64 = (0..n_clusters)
776                        .map(|j| (dists[c] / dists[j]).powf(exp))
777                        .sum();
778                    new_u[i][c] = 1.0 / ratio_sum;
779                }
780            }
781        }
782
783        // Update centroids
784        let new_centroids = compute_fuzzy_centroids(data, &new_u, n_clusters, m, d);
785
786        // Check convergence
787        let max_change: f64 = centroids
788            .iter()
789            .zip(new_centroids.iter())
790            .map(|(c_old, c_new)| {
791                c_old
792                    .iter()
793                    .zip(c_new.iter())
794                    .map(|(a, b)| (a - b).abs())
795                    .fold(0.0f64, f64::max)
796            })
797            .fold(0.0f64, f64::max);
798
799        u = new_u;
800        centroids = new_centroids;
801
802        if max_change < tol {
803            break;
804        }
805    }
806
807    Ok((centroids, u))
808}
809
810fn compute_fuzzy_centroids(
811    data: &[Vec<f64>],
812    u: &[Vec<f64>],
813    n_clusters: usize,
814    m: f64,
815    d: usize,
816) -> Vec<Vec<f64>> {
817    (0..n_clusters)
818        .map(|c| {
819            let mut num = vec![0.0f64; d];
820            let mut denom = 0.0f64;
821            for (i, xi) in data.iter().enumerate() {
822                let uic_m = u[i][c].powf(m);
823                denom += uic_m;
824                for (f, &v) in num.iter_mut().zip(xi.iter()) {
825                    *f += uic_m * v;
826                }
827            }
828            if denom.abs() < 1e-30 {
829                vec![0.0f64; d]
830            } else {
831                num.iter().map(|&v| v / denom).collect()
832            }
833        })
834        .collect()
835}
836
837// ─────────────────────────────────────────────────────────────────────────────
838// Unit tests
839// ─────────────────────────────────────────────────────────────────────────────
840
841#[cfg(test)]
842mod tests {
843    use super::*;
844
845    fn two_cluster_data() -> Vec<Vec<f64>> {
846        // Two well-separated Gaussian blobs
847        let mut v = Vec::new();
848        for i in 0..20 {
849            v.push(vec![i as f64 * 0.1, i as f64 * 0.1]);
850        }
851        for i in 0..20 {
852            v.push(vec![10.0 + i as f64 * 0.1, 10.0 + i as f64 * 0.1]);
853        }
854        v
855    }
856
857    // ── Kernel K-Means ──────────────────────────────────────────────────────
858
859    #[test]
860    fn test_kernel_kmeans_rbf_two_clusters() {
861        let data = two_cluster_data();
862        let (labels, inertia) = kernel_kmeans(&data, 2, KernelType::Rbf { gamma: 0.5 }, 50, 3, 42)
863            .expect("kernel_kmeans should succeed");
864        assert_eq!(labels.len(), 40);
865        assert!(inertia.is_finite());
866        // Check the two blobs are separated: first 20 and last 20 should have same label
867        let l0 = labels[0];
868        let l20 = labels[20];
869        assert_ne!(l0, l20, "blobs should be in different clusters");
870        assert!(labels[..20].iter().all(|&l| l == l0));
871        assert!(labels[20..].iter().all(|&l| l == l20));
872    }
873
874    #[test]
875    fn test_kernel_kmeans_linear() {
876        let data = two_cluster_data();
877        let (labels, _) = kernel_kmeans(&data, 2, KernelType::Linear, 20, 2, 7)
878            .expect("kernel_kmeans linear should succeed");
879        assert_eq!(labels.len(), 40);
880    }
881
882    #[test]
883    fn test_kernel_kmeans_polynomial() {
884        let data = two_cluster_data();
885        let (labels, _) = kernel_kmeans(
886            &data,
887            2,
888            KernelType::Polynomial {
889                degree: 2,
890                coef0: 1.0,
891                gamma: 0.1,
892            },
893            20,
894            2,
895            99,
896        )
897        .expect("kernel_kmeans poly should succeed");
898        assert_eq!(labels.len(), 40);
899    }
900
901    #[test]
902    fn test_kernel_kmeans_invalid_inputs() {
903        let data = two_cluster_data();
904        assert!(kernel_kmeans(&[], 2, KernelType::Linear, 10, 1, 0).is_err());
905        assert!(kernel_kmeans(&data, 0, KernelType::Linear, 10, 1, 0).is_err());
906        assert!(kernel_kmeans(&data, 100, KernelType::Linear, 10, 1, 0).is_err());
907    }
908
909    // ── Trimmed K-Means ─────────────────────────────────────────────────────
910
911    #[test]
912    fn test_trimmed_kmeans_basic() {
913        let mut data = two_cluster_data();
914        // Add some outliers
915        data.push(vec![100.0, 100.0]);
916        data.push(vec![-100.0, -100.0]);
917
918        let (labels, centroids) =
919            trimmed_kmeans(&data, 2, 0.05, 100, 42).expect("trimmed_kmeans should succeed");
920        assert_eq!(labels.len(), data.len());
921        assert_eq!(centroids.len(), 2);
922        // The two extreme outliers should be trimmed (None)
923        let trimmed_count = labels.iter().filter(|l| l.is_none()).count();
924        assert!(trimmed_count >= 1, "at least one outlier should be trimmed");
925    }
926
927    #[test]
928    fn test_trimmed_kmeans_no_trim() {
929        let data = two_cluster_data();
930        let (labels, centroids) = trimmed_kmeans(&data, 2, 0.0, 50, 0)
931            .expect("trimmed_kmeans with trim=0 should succeed");
932        assert_eq!(labels.len(), 40);
933        assert_eq!(centroids.len(), 2);
934        // No points trimmed
935        assert!(labels.iter().all(|l| l.is_some()));
936    }
937
938    #[test]
939    fn test_trimmed_kmeans_invalid() {
940        let data = two_cluster_data();
941        assert!(trimmed_kmeans(&[], 2, 0.1, 10, 0).is_err());
942        assert!(trimmed_kmeans(&data, 0, 0.1, 10, 0).is_err());
943        assert!(trimmed_kmeans(&data, 2, 0.6, 10, 0).is_err()); // trim_ratio >= 0.5
944    }
945
946    // ── DpMixture ───────────────────────────────────────────────────────────
947
948    #[test]
949    fn test_dp_mixture_finds_clusters() {
950        let data = two_cluster_data();
951        let mut dpm = DpMixture::new(1.0);
952        let labels = dpm.fit(&data, 30, 42);
953        assert_eq!(labels.len(), 40);
954        // With two well-separated blobs the DP should find at least 2 components
955        assert!(dpm.n_clusters() >= 1);
956        assert!(!dpm.means.is_empty());
957        assert!(!dpm.weights.is_empty());
958        let weight_sum: f64 = dpm.weights.iter().sum();
959        assert!((weight_sum - 1.0).abs() < 1e-10);
960    }
961
962    #[test]
963    fn test_dp_mixture_empty() {
964        let mut dpm = DpMixture::new(1.0);
965        let labels = dpm.fit(&[], 10, 0);
966        assert!(labels.is_empty());
967        assert_eq!(dpm.n_clusters(), 0);
968    }
969
970    // ── Fuzzy C-Means ───────────────────────────────────────────────────────
971
972    #[test]
973    fn test_fuzzy_cmeans_basic() {
974        let data = two_cluster_data();
975        let (centroids, membership) =
976            fuzzy_cmeans(&data, 2, 2.0, 100, 1e-6, 42).expect("fuzzy_cmeans should succeed");
977        assert_eq!(centroids.len(), 2);
978        assert_eq!(membership.len(), 40);
979        assert_eq!(membership[0].len(), 2);
980        // Membership rows must sum to 1
981        for row in &membership {
982            let s: f64 = row.iter().sum();
983            assert!(
984                (s - 1.0).abs() < 1e-8,
985                "membership row must sum to 1, got {}",
986                s
987            );
988        }
989    }
990
991    #[test]
992    fn test_fuzzy_cmeans_high_fuzziness() {
993        let data = two_cluster_data();
994        let (_, membership) = fuzzy_cmeans(&data, 3, 3.5, 50, 1e-5, 99)
995            .expect("fuzzy_cmeans high fuzz should succeed");
996        for row in &membership {
997            let s: f64 = row.iter().sum();
998            assert!((s - 1.0).abs() < 1e-7);
999        }
1000    }
1001
1002    #[test]
1003    fn test_fuzzy_cmeans_invalid() {
1004        let data = two_cluster_data();
1005        assert!(fuzzy_cmeans(&[], 2, 2.0, 10, 1e-6, 0).is_err());
1006        assert!(fuzzy_cmeans(&data, 0, 2.0, 10, 1e-6, 0).is_err());
1007        assert!(fuzzy_cmeans(&data, 2, 1.0, 10, 1e-6, 0).is_err()); // fuzziness <= 1
1008        assert!(fuzzy_cmeans(&data, 2, 0.5, 10, 1e-6, 0).is_err()); // fuzziness < 1
1009        assert!(fuzzy_cmeans(&data, 100, 2.0, 10, 1e-6, 0).is_err()); // k > n
1010    }
1011}