Skip to main content

scry_learn/cluster/
kmeans.rs

1// SPDX-License-Identifier: MIT OR Apache-2.0
2//! K-Means clustering with k-means++ initialization.
3//!
4//! # Example
5//!
6//! ```
7//! use scry_learn::cluster::KMeans;
8//! use scry_learn::dataset::Dataset;
9//!
10//! let data = Dataset::new(
11//!     vec![vec![0.0, 0.0, 10.0, 10.0], vec![0.0, 0.0, 10.0, 10.0]],
12//!     vec![0.0; 4],
13//!     vec!["x".into(), "y".into()],
14//!     "label",
15//! );
16//!
17//! let mut km = KMeans::new(2).n_init(10).seed(42);
18//! km.fit(&data).unwrap();
19//! assert_eq!(km.labels().len(), 4);
20//! ```
21
22use rayon::prelude::*;
23
24use crate::constants::KMEANS_PAR_THRESHOLD;
25use crate::dataset::Dataset;
26use crate::distance::euclidean_sq;
27use crate::error::{Result, ScryLearnError};
28
29/// K-Means clustering.
30///
31/// Uses k-means++ initialization for better convergence.
32/// When `n_init > 1` (default 10), the algorithm runs multiple times
33/// with different random seeds and keeps the result with the lowest inertia.
34#[derive(Clone)]
35#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
36#[non_exhaustive]
37pub struct KMeans {
38    k: usize,
39    max_iter: usize,
40    tolerance: f64,
41    seed: u64,
42    n_init: usize,
43    centroids: Vec<Vec<f64>>,
44    labels: Vec<usize>,
45    inertia: f64,
46    n_iter: usize,
47    fitted: bool,
48    #[cfg_attr(feature = "serde", serde(default))]
49    _schema_version: u32,
50}
51
52impl KMeans {
53    /// Create a K-Means model with k clusters.
54    pub fn new(k: usize) -> Self {
55        Self {
56            k,
57            max_iter: 300,
58            tolerance: 1e-4,
59            seed: 42,
60            n_init: 10,
61            centroids: Vec::new(),
62            labels: Vec::new(),
63            inertia: f64::INFINITY,
64            n_iter: 0,
65            fitted: false,
66            _schema_version: crate::version::SCHEMA_VERSION,
67        }
68    }
69
70    /// Set maximum iterations per run.
71    pub fn max_iter(mut self, n: usize) -> Self {
72        self.max_iter = n;
73        self
74    }
75
76    /// Set convergence tolerance.
77    pub fn tolerance(mut self, t: f64) -> Self {
78        self.tolerance = t;
79        self
80    }
81
82    /// Alias for [`tolerance`](Self::tolerance) (sklearn convention).
83    pub fn tol(self, t: f64) -> Self {
84        self.tolerance(t)
85    }
86
87    /// Set random seed.
88    pub fn seed(mut self, s: u64) -> Self {
89        self.seed = s;
90        self
91    }
92
93    /// Set the number of independent runs with different random seeds.
94    ///
95    /// The result with the lowest inertia is kept. Default is 10, matching sklearn.
96    /// Set to 1 for a single run (faster but less reliable).
97    pub fn n_init(mut self, n: usize) -> Self {
98        self.n_init = n.max(1);
99        self
100    }
101
102    /// Fit the model on a dataset (uses features only, ignores target).
103    ///
104    /// When `n_init > 1`, runs K-Means multiple times and keeps the best.
105    pub fn fit(&mut self, data: &Dataset) -> Result<()> {
106        data.validate_finite()?;
107        let n = data.n_samples();
108        if n == 0 {
109            return Err(ScryLearnError::EmptyDataset);
110        }
111        if self.k == 0 || self.k > n {
112            return Err(ScryLearnError::InvalidParameter(format!(
113                "k must be between 1 and n_samples ({}), got {}",
114                n, self.k
115            )));
116        }
117
118        let rows = data.feature_matrix();
119        let m = data.n_features();
120
121        let mut best_centroids: Option<Vec<Vec<f64>>> = None;
122        let mut best_labels: Option<Vec<usize>> = None;
123        let mut best_inertia = f64::INFINITY;
124        let mut best_n_iter = 0;
125
126        for run in 0..self.n_init {
127            let run_seed = self.seed.wrapping_add(run as u64);
128            let (centroids, labels, inertia, n_iter) = self.run_once(&rows, n, m, run_seed);
129
130            if inertia < best_inertia {
131                best_centroids = Some(centroids);
132                best_labels = Some(labels);
133                best_inertia = inertia;
134                best_n_iter = n_iter;
135            }
136        }
137
138        self.centroids = best_centroids.unwrap_or_default();
139        self.labels = best_labels.unwrap_or_default();
140        self.inertia = best_inertia;
141        self.n_iter = best_n_iter;
142        self.fitted = true;
143        Ok(())
144    }
145
146    /// Run a single K-Means pass with the given seed.
147    #[allow(clippy::type_complexity)]
148    fn run_once(
149        &self,
150        rows: &[Vec<f64>],
151        n: usize,
152        m: usize,
153        seed: u64,
154    ) -> (Vec<Vec<f64>>, Vec<usize>, f64, usize) {
155        let mut centroids = kmeans_plus_plus(rows, self.k, seed);
156        let mut labels = vec![0usize; n];
157        let mut prev_inertia = f64::INFINITY;
158        let mut final_inertia = f64::INFINITY;
159        let mut final_n_iter = 0;
160        let use_par = n * self.k >= KMEANS_PAR_THRESHOLD;
161
162        for iter in 0..self.max_iter {
163            // Assignment step.
164            let inertia;
165            if use_par {
166                let results: Vec<(usize, f64)> = rows
167                    .par_iter()
168                    .map(|row| {
169                        let mut best_dist = f64::INFINITY;
170                        let mut best_c = 0;
171                        for (c, centroid) in centroids.iter().enumerate() {
172                            let d = euclidean_sq(row, centroid);
173                            if d < best_dist {
174                                best_dist = d;
175                                best_c = c;
176                            }
177                        }
178                        (best_c, best_dist)
179                    })
180                    .collect();
181                inertia = results.iter().map(|(_, d)| d).sum();
182                for (i, (c, _)) in results.into_iter().enumerate() {
183                    labels[i] = c;
184                }
185            } else {
186                let mut seq_inertia = 0.0;
187                for (i, row) in rows.iter().enumerate() {
188                    let mut best_dist = f64::INFINITY;
189                    let mut best_c = 0;
190                    for (c, centroid) in centroids.iter().enumerate() {
191                        let d = euclidean_sq(row, centroid);
192                        if d < best_dist {
193                            best_dist = d;
194                            best_c = c;
195                        }
196                    }
197                    labels[i] = best_c;
198                    seq_inertia += best_dist;
199                }
200                inertia = seq_inertia;
201            }
202
203            // Update step.
204            let mut new_centroids = vec![vec![0.0; m]; self.k];
205            let mut counts = vec![0usize; self.k];
206
207            for (i, row) in rows.iter().enumerate() {
208                let c = labels[i];
209                counts[c] += 1;
210                for (j, &val) in row.iter().enumerate() {
211                    new_centroids[c][j] += val;
212                }
213            }
214
215            // Normalize non-empty centroids.
216            for c in 0..self.k {
217                if counts[c] > 0 {
218                    for val in &mut new_centroids[c] {
219                        *val /= counts[c] as f64;
220                    }
221                }
222            }
223
224            // Reinitialize empty centroids: pick the data point farthest
225            // from its nearest occupied centroid (sklearn's approach).
226            for c in 0..self.k {
227                if counts[c] == 0 {
228                    let mut max_dist = f64::NEG_INFINITY;
229                    let mut best_idx = 0;
230                    for (i, row) in rows.iter().enumerate() {
231                        let min_dist = new_centroids
232                            .iter()
233                            .enumerate()
234                            .filter(|&(ci, _)| ci != c && (counts[ci] > 0 || ci < c))
235                            .map(|(_, cen)| euclidean_sq(row, cen))
236                            .fold(f64::INFINITY, f64::min);
237                        if min_dist > max_dist {
238                            max_dist = min_dist;
239                            best_idx = i;
240                        }
241                    }
242                    new_centroids[c].clone_from(&rows[best_idx]);
243                }
244            }
245
246            // Check convergence.
247            let shift: f64 = centroids
248                .iter()
249                .zip(new_centroids.iter())
250                .map(|(old, new)| euclidean_sq(old, new))
251                .sum();
252
253            centroids = new_centroids;
254            final_n_iter = iter + 1;
255            final_inertia = inertia;
256
257            if (prev_inertia - inertia).abs() < self.tolerance || shift < self.tolerance {
258                break;
259            }
260            prev_inertia = inertia;
261        }
262
263        (centroids, labels, final_inertia, final_n_iter)
264    }
265
266    /// Predict cluster assignments for new data.
267    pub fn predict(&self, features: &[Vec<f64>]) -> Result<Vec<usize>> {
268        crate::version::check_schema_version(self._schema_version)?;
269        if !self.fitted {
270            return Err(ScryLearnError::NotFitted);
271        }
272        Ok(features
273            .iter()
274            .map(|row| {
275                self.centroids
276                    .iter()
277                    .enumerate()
278                    .min_by(|(_, a), (_, b)| {
279                        euclidean_sq(row, a)
280                            .partial_cmp(&euclidean_sq(row, b))
281                            .unwrap_or(std::cmp::Ordering::Equal)
282                    })
283                    .map_or(0, |(idx, _)| idx)
284            })
285            .collect())
286    }
287
288    /// Transform data into cluster-distance space.
289    ///
290    /// Returns a `n_samples × k` matrix where each value is the Euclidean
291    /// distance from the sample to each centroid.
292    ///
293    /// # Example
294    ///
295    /// ```
296    /// # use scry_learn::cluster::KMeans;
297    /// # use scry_learn::dataset::Dataset;
298    /// # let data = Dataset::new(
299    /// #     vec![vec![0.0, 10.0], vec![0.0, 10.0]],
300    /// #     vec![0.0; 2], vec!["x".into(), "y".into()], "l",
301    /// # );
302    /// # let mut km = KMeans::new(2).n_init(1).seed(42);
303    /// # km.fit(&data).unwrap();
304    /// let distances = km.transform(&[vec![5.0, 5.0]]).unwrap();
305    /// assert_eq!(distances[0].len(), 2); // one distance per centroid
306    /// ```
307    pub fn transform(&self, features: &[Vec<f64>]) -> Result<Vec<Vec<f64>>> {
308        if !self.fitted {
309            return Err(ScryLearnError::NotFitted);
310        }
311        Ok(features
312            .iter()
313            .map(|row| {
314                self.centroids
315                    .iter()
316                    .map(|c| euclidean_sq(row, c).sqrt())
317                    .collect()
318            })
319            .collect())
320    }
321
322    /// Get the cluster centroids.
323    pub fn centroids(&self) -> &[Vec<f64>] {
324        &self.centroids
325    }
326
327    /// Get cluster labels for training data.
328    pub fn labels(&self) -> &[usize] {
329        &self.labels
330    }
331
332    /// Sum of squared distances to the nearest centroid.
333    pub fn inertia(&self) -> f64 {
334        self.inertia
335    }
336
337    /// Number of iterations to converge.
338    pub fn n_iter(&self) -> usize {
339        self.n_iter
340    }
341}
342
343/// K-means++ initialization: select initial centroids to be spread apart.
344pub(crate) fn kmeans_plus_plus(rows: &[Vec<f64>], k: usize, seed: u64) -> Vec<Vec<f64>> {
345    let mut rng = crate::rng::FastRng::new(seed);
346    let n = rows.len();
347    let mut centroids = Vec::with_capacity(k);
348
349    // Pick first centroid randomly.
350    centroids.push(rows[rng.usize(0..n)].clone());
351
352    for _ in 1..k {
353        // Compute distances to nearest centroid.
354        let mut dists: Vec<f64> = rows
355            .iter()
356            .map(|row| {
357                centroids
358                    .iter()
359                    .map(|c| euclidean_sq(row, c))
360                    .fold(f64::INFINITY, f64::min)
361            })
362            .collect();
363
364        // Weighted random selection proportional to D².
365        let total: f64 = dists.iter().sum();
366        if total < 1e-12 {
367            centroids.push(rows[rng.usize(0..n)].clone());
368            continue;
369        }
370        for d in &mut dists {
371            *d /= total;
372        }
373
374        let r = rng.f64();
375        let mut cumsum = 0.0;
376        let mut selected = n - 1;
377        for (i, &d) in dists.iter().enumerate() {
378            cumsum += d;
379            if cumsum >= r {
380                selected = i;
381                break;
382            }
383        }
384        centroids.push(rows[selected].clone());
385    }
386
387    centroids
388}
389
390#[cfg(test)]
391mod tests {
392    use super::*;
393
394    #[test]
395    fn test_kmeans_two_blobs() {
396        // Two well-separated clusters.
397        let mut f1 = Vec::new();
398        let mut f2 = Vec::new();
399        let mut target = Vec::new();
400        for i in 0..30 {
401            f1.push(i as f64 % 3.0);
402            f2.push(i as f64 % 3.0);
403            target.push(0.0);
404        }
405        for i in 0..30 {
406            f1.push(100.0 + i as f64 % 3.0);
407            f2.push(100.0 + i as f64 % 3.0);
408            target.push(1.0);
409        }
410
411        let data = Dataset::new(vec![f1, f2], target, vec!["x".into(), "y".into()], "label");
412
413        let mut km = KMeans::new(2).seed(42).n_init(1);
414        km.fit(&data).unwrap();
415
416        // All points in the same blob should have the same label.
417        let labels = km.labels();
418        let first_label = labels[0];
419        assert!(labels[..30].iter().all(|&l| l == first_label));
420        assert!(labels[30..].iter().all(|&l| l != first_label));
421    }
422
423    #[test]
424    fn test_kmeans_predict() {
425        let data = Dataset::new(
426            vec![vec![0.0, 0.0, 10.0, 10.0], vec![0.0, 0.0, 10.0, 10.0]],
427            vec![0.0; 4],
428            vec!["x".into(), "y".into()],
429            "label",
430        );
431
432        let mut km = KMeans::new(2).seed(42).n_init(1);
433        km.fit(&data).unwrap();
434
435        let pred = km.predict(&[vec![1.0, 1.0], vec![9.0, 9.0]]).unwrap();
436        assert_ne!(
437            pred[0], pred[1],
438            "nearby and far points should be in different clusters"
439        );
440    }
441
442    #[test]
443    fn test_kmeans_n_init_improves_inertia() {
444        // n_init=10 should produce inertia ≤ n_init=1.
445        let mut rng = crate::rng::FastRng::new(7);
446        let n = 100;
447        let mut f1 = Vec::with_capacity(n);
448        let mut f2 = Vec::with_capacity(n);
449        for _ in 0..n / 2 {
450            f1.push(rng.f64() * 5.0);
451            f2.push(rng.f64() * 5.0);
452        }
453        for _ in 0..n / 2 {
454            f1.push(20.0 + rng.f64() * 5.0);
455            f2.push(20.0 + rng.f64() * 5.0);
456        }
457        let data = Dataset::new(
458            vec![f1, f2],
459            vec![0.0; n],
460            vec!["x".into(), "y".into()],
461            "label",
462        );
463
464        let mut km1 = KMeans::new(3).seed(7).n_init(1);
465        km1.fit(&data).unwrap();
466        let inertia1 = km1.inertia();
467
468        let mut km10 = KMeans::new(3).seed(7).n_init(10);
469        km10.fit(&data).unwrap();
470        let inertia10 = km10.inertia();
471
472        assert!(
473            inertia10 <= inertia1 + 1e-6,
474            "n_init=10 inertia ({inertia10:.4}) should be ≤ n_init=1 ({inertia1:.4})"
475        );
476    }
477
478    #[test]
479    fn test_kmeans_empty_cluster_reinit() {
480        // Pathological case: 3 clusters requested but data has 2 clear blobs.
481        // With bad initialization, one centroid can get zero assigned points.
482        // After fix, empty centroids should be reinitialized, not left at [0,0].
483        let mut f1 = Vec::new();
484        let mut f2 = Vec::new();
485        for _ in 0..50 {
486            f1.push(0.0);
487            f2.push(0.0);
488        }
489        for _ in 0..50 {
490            f1.push(100.0);
491            f2.push(100.0);
492        }
493        let data = Dataset::new(
494            vec![f1, f2],
495            vec![0.0; 100],
496            vec!["x".into(), "y".into()],
497            "l",
498        );
499
500        let mut km = KMeans::new(3).seed(42).n_init(1);
501        km.fit(&data).unwrap();
502
503        // No centroid should be at the origin [0,0] unless a cluster actually lives there.
504        // At least one centroid should be near (100,100) and at least one near (0,0).
505        let centroids = km.centroids();
506        assert_eq!(centroids.len(), 3);
507        let has_near_origin = centroids.iter().any(|c| c[0] < 50.0 && c[1] < 50.0);
508        let has_near_far = centroids.iter().any(|c| c[0] > 50.0 && c[1] > 50.0);
509        assert!(has_near_origin, "should have centroid near (0,0)");
510        assert!(has_near_far, "should have centroid near (100,100)");
511
512        // All 100 points should be assigned to some cluster.
513        assert_eq!(km.labels().len(), 100);
514    }
515
516    #[test]
517    fn test_kmeans_transform_shape() {
518        let data = Dataset::new(
519            vec![vec![0.0, 0.0, 10.0, 10.0], vec![0.0, 0.0, 10.0, 10.0]],
520            vec![0.0; 4],
521            vec!["x".into(), "y".into()],
522            "label",
523        );
524
525        let mut km = KMeans::new(2).seed(42).n_init(1);
526        km.fit(&data).unwrap();
527
528        let dists = km.transform(&[vec![5.0, 5.0], vec![0.0, 0.0]]).unwrap();
529        assert_eq!(dists.len(), 2, "should have 2 samples");
530        assert_eq!(
531            dists[0].len(),
532            2,
533            "should have distance to each of 2 centroids"
534        );
535        // All distances should be non-negative.
536        for row in &dists {
537            for &d in row {
538                assert!(d >= 0.0, "distance should be non-negative");
539            }
540        }
541    }
542}