Skip to main content

scry_learn/cluster/
mini_batch_kmeans.rs

1// SPDX-License-Identifier: MIT OR Apache-2.0
2//! Mini-Batch K-Means clustering.
3//!
4//! Uses random mini-batches for centroid updates instead of full-data passes.
5//! Much faster on large datasets with slightly worse cluster quality.
6//!
7//! # Example
8//!
9//! ```
10//! use scry_learn::cluster::MiniBatchKMeans;
11//! use scry_learn::dataset::Dataset;
12//!
13//! let data = Dataset::new(
14//!     vec![vec![0.0, 0.0, 10.0, 10.0], vec![0.0, 0.0, 10.0, 10.0]],
15//!     vec![0.0; 4],
16//!     vec!["x".into(), "y".into()],
17//!     "label",
18//! );
19//!
20//! let mut mbk = MiniBatchKMeans::new(2).batch_size(2).seed(42);
21//! mbk.fit(&data).unwrap();
22//! assert_eq!(mbk.labels().len(), 4);
23//! ```
24
25use super::kmeans::kmeans_plus_plus;
26use crate::dataset::Dataset;
27use crate::distance::euclidean_sq;
28use crate::error::{Result, ScryLearnError};
29use crate::partial_fit::PartialFit;
30
31/// Mini-Batch K-Means clustering.
32///
33/// Approximates standard K-Means by updating centroids using random mini-batches
34/// of the data at each iteration, rather than the full dataset.
35/// This is significantly faster for large datasets while producing similar results.
36#[derive(Clone)]
37#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
38#[non_exhaustive]
39pub struct MiniBatchKMeans {
40    k: usize,
41    batch_size: usize,
42    max_iter: usize,
43    tolerance: f64,
44    seed: u64,
45    centroids: Vec<Vec<f64>>,
46    labels: Vec<usize>,
47    inertia: f64,
48    n_iter: usize,
49    fitted: bool,
50    // Per-centroid update counts for streaming average (used by partial_fit).
51    centroid_counts: Vec<u64>,
52    #[cfg_attr(feature = "serde", serde(default))]
53    _schema_version: u32,
54}
55
56impl MiniBatchKMeans {
57    /// Create a Mini-Batch K-Means model with k clusters.
58    pub fn new(k: usize) -> Self {
59        Self {
60            k,
61            batch_size: 1024,
62            max_iter: 100,
63            tolerance: 0.0,
64            seed: 42,
65            centroids: Vec::new(),
66            labels: Vec::new(),
67            inertia: f64::INFINITY,
68            n_iter: 0,
69            fitted: false,
70            centroid_counts: Vec::new(),
71            _schema_version: crate::version::SCHEMA_VERSION,
72        }
73    }
74
75    /// Set the mini-batch size (default 1024).
76    pub fn batch_size(mut self, n: usize) -> Self {
77        self.batch_size = n.max(1);
78        self
79    }
80
81    /// Set maximum iterations.
82    pub fn max_iter(mut self, n: usize) -> Self {
83        self.max_iter = n;
84        self
85    }
86
87    /// Set convergence tolerance.
88    pub fn tolerance(mut self, t: f64) -> Self {
89        self.tolerance = t;
90        self
91    }
92
93    /// Alias for [`tolerance`](Self::tolerance) (sklearn convention).
94    pub fn tol(self, t: f64) -> Self {
95        self.tolerance(t)
96    }
97
98    /// Set random seed.
99    pub fn seed(mut self, s: u64) -> Self {
100        self.seed = s;
101        self
102    }
103
104    /// Fit the model on a dataset (uses features only, ignores target).
105    pub fn fit(&mut self, data: &Dataset) -> Result<()> {
106        data.validate_finite()?;
107        let n = data.n_samples();
108        let m = data.n_features();
109        if n == 0 {
110            return Err(ScryLearnError::EmptyDataset);
111        }
112        if self.k == 0 || self.k > n {
113            return Err(ScryLearnError::InvalidParameter(format!(
114                "k must be between 1 and n_samples ({}), got {}",
115                n, self.k
116            )));
117        }
118
119        let rows = data.feature_matrix();
120        let mut rng = crate::rng::FastRng::new(self.seed);
121        let effective_batch = self.batch_size.min(n);
122
123        // K-means++ initialization.
124        let mut centroids = kmeans_plus_plus(&rows, self.k, self.seed);
125
126        // Per-centroid update counts for streaming average.
127        let mut centroid_counts = vec![0_u64; self.k];
128        let mut prev_inertia = f64::INFINITY;
129
130        for iter in 0..self.max_iter {
131            // Sample a mini-batch.
132            let batch_indices: Vec<usize> = (0..effective_batch).map(|_| rng.usize(0..n)).collect();
133
134            // Assign batch samples to nearest centroid.
135            let mut assignments = Vec::with_capacity(effective_batch);
136            for &idx in &batch_indices {
137                let mut best_c = 0;
138                let mut best_dist = f64::INFINITY;
139                for (c, centroid) in centroids.iter().enumerate() {
140                    let d = euclidean_sq(&rows[idx], centroid);
141                    if d < best_dist {
142                        best_dist = d;
143                        best_c = c;
144                    }
145                }
146                assignments.push(best_c);
147            }
148
149            // Update centroids with streaming average.
150            for (batch_i, &idx) in batch_indices.iter().enumerate() {
151                let c = assignments[batch_i];
152                centroid_counts[c] += 1;
153                let lr = 1.0 / centroid_counts[c] as f64;
154                for j in 0..m {
155                    centroids[c][j] += lr * (rows[idx][j] - centroids[c][j]);
156                }
157            }
158
159            // Compute full inertia periodically (every 10 iters or last iter).
160            if iter % 10 == 0 || iter == self.max_iter - 1 {
161                let mut inertia = 0.0;
162                for row in &rows {
163                    let mut best_dist = f64::INFINITY;
164                    for centroid in &centroids {
165                        let d = euclidean_sq(row, centroid);
166                        if d < best_dist {
167                            best_dist = d;
168                        }
169                    }
170                    inertia += best_dist;
171                }
172
173                self.n_iter = iter + 1;
174                self.inertia = inertia;
175
176                if self.tolerance > 0.0 && (prev_inertia - inertia).abs() < self.tolerance {
177                    break;
178                }
179                prev_inertia = inertia;
180            }
181        }
182
183        // Final assignment of all points.
184        self.labels = rows
185            .iter()
186            .map(|row| {
187                centroids
188                    .iter()
189                    .enumerate()
190                    .min_by(|(_, a), (_, b)| {
191                        euclidean_sq(row, a)
192                            .partial_cmp(&euclidean_sq(row, b))
193                            .unwrap_or(std::cmp::Ordering::Equal)
194                    })
195                    .map_or(0, |(idx, _)| idx)
196            })
197            .collect();
198
199        self.centroids = centroids;
200        self.fitted = true;
201        Ok(())
202    }
203
204    /// Predict cluster assignments for new data.
205    pub fn predict(&self, features: &[Vec<f64>]) -> Result<Vec<usize>> {
206        crate::version::check_schema_version(self._schema_version)?;
207        if !self.fitted {
208            return Err(ScryLearnError::NotFitted);
209        }
210        Ok(features
211            .iter()
212            .map(|row| {
213                self.centroids
214                    .iter()
215                    .enumerate()
216                    .min_by(|(_, a), (_, b)| {
217                        euclidean_sq(row, a)
218                            .partial_cmp(&euclidean_sq(row, b))
219                            .unwrap_or(std::cmp::Ordering::Equal)
220                    })
221                    .map_or(0, |(idx, _)| idx)
222            })
223            .collect())
224    }
225
226    /// Get the cluster centroids.
227    pub fn centroids(&self) -> &[Vec<f64>] {
228        &self.centroids
229    }
230
231    /// Get cluster labels for training data.
232    pub fn labels(&self) -> &[usize] {
233        &self.labels
234    }
235
236    /// Sum of squared distances to the nearest centroid.
237    pub fn inertia(&self) -> f64 {
238        self.inertia
239    }
240
241    /// Number of iterations run.
242    pub fn n_iter(&self) -> usize {
243        self.n_iter
244    }
245}
246
247impl PartialFit for MiniBatchKMeans {
248    /// Update centroids with a streaming average over the given batch.
249    ///
250    /// On the first call, initializes centroids via K-Means++ on the batch.
251    /// Subsequent calls assign each sample to the nearest centroid and
252    /// update it with a decaying learning rate (`1 / count`).
253    fn partial_fit(&mut self, data: &Dataset) -> Result<()> {
254        let n = data.n_samples();
255        if n == 0 {
256            // No-op on empty batch if already initialized; error if not.
257            if self.is_initialized() {
258                return Ok(());
259            }
260            return Err(ScryLearnError::EmptyDataset);
261        }
262
263        let rows = data.feature_matrix();
264        let m = data.n_features();
265
266        if !self.is_initialized() {
267            if self.k == 0 || self.k > n {
268                return Err(ScryLearnError::InvalidParameter(format!(
269                    "k must be between 1 and n_samples ({}), got {}",
270                    n, self.k
271                )));
272            }
273            self.centroids = kmeans_plus_plus(&rows, self.k, self.seed);
274            self.centroid_counts = vec![0_u64; self.k];
275        }
276
277        // Assign each sample to nearest centroid and update with streaming average.
278        for row in &rows {
279            let mut best_c = 0;
280            let mut best_dist = f64::INFINITY;
281            for (c, centroid) in self.centroids.iter().enumerate() {
282                let d = euclidean_sq(row, centroid);
283                if d < best_dist {
284                    best_dist = d;
285                    best_c = c;
286                }
287            }
288            self.centroid_counts[best_c] += 1;
289            let lr = 1.0 / self.centroid_counts[best_c] as f64;
290            for j in 0..m {
291                self.centroids[best_c][j] += lr * (row[j] - self.centroids[best_c][j]);
292            }
293        }
294
295        // Assign labels and compute inertia for this batch.
296        self.labels = rows
297            .iter()
298            .map(|row| {
299                self.centroids
300                    .iter()
301                    .enumerate()
302                    .min_by(|(_, a), (_, b)| {
303                        euclidean_sq(row, a)
304                            .partial_cmp(&euclidean_sq(row, b))
305                            .unwrap_or(std::cmp::Ordering::Equal)
306                    })
307                    .map_or(0, |(idx, _)| idx)
308            })
309            .collect();
310
311        self.inertia = rows
312            .iter()
313            .map(|row| {
314                self.centroids
315                    .iter()
316                    .map(|c| euclidean_sq(row, c))
317                    .fold(f64::INFINITY, f64::min)
318            })
319            .sum();
320
321        self.fitted = true;
322        Ok(())
323    }
324
325    fn is_initialized(&self) -> bool {
326        !self.centroids.is_empty()
327    }
328}
329
330#[cfg(test)]
331mod tests {
332    use super::*;
333
334    #[test]
335    fn test_mini_batch_kmeans_two_blobs() {
336        let mut f1 = Vec::new();
337        let mut f2 = Vec::new();
338        for i in 0..30 {
339            f1.push(i as f64 % 3.0);
340            f2.push(i as f64 % 3.0);
341        }
342        for i in 0..30 {
343            f1.push(100.0 + i as f64 % 3.0);
344            f2.push(100.0 + i as f64 % 3.0);
345        }
346
347        let data = Dataset::new(
348            vec![f1, f2],
349            vec![0.0; 60],
350            vec!["x".into(), "y".into()],
351            "label",
352        );
353
354        let mut mbk = MiniBatchKMeans::new(2).seed(42).batch_size(20);
355        mbk.fit(&data).unwrap();
356
357        let labels = mbk.labels();
358        let first_label = labels[0];
359        assert!(labels[..30].iter().all(|&l| l == first_label));
360        assert!(labels[30..].iter().all(|&l| l != first_label));
361    }
362
363    #[test]
364    fn test_mini_batch_kmeans_predict() {
365        let data = Dataset::new(
366            vec![vec![0.0, 0.0, 10.0, 10.0], vec![0.0, 0.0, 10.0, 10.0]],
367            vec![0.0; 4],
368            vec!["x".into(), "y".into()],
369            "label",
370        );
371
372        let mut mbk = MiniBatchKMeans::new(2).seed(42).batch_size(4);
373        mbk.fit(&data).unwrap();
374
375        let pred = mbk.predict(&[vec![1.0, 1.0], vec![9.0, 9.0]]).unwrap();
376        assert_ne!(
377            pred[0], pred[1],
378            "nearby and far points should be in different clusters"
379        );
380    }
381
382    #[test]
383    fn test_partial_fit_is_initialized() {
384        let mut mbk = MiniBatchKMeans::new(2);
385        assert!(!mbk.is_initialized());
386
387        let data = Dataset::new(
388            vec![vec![0.0, 0.0, 10.0, 10.0], vec![0.0, 0.0, 10.0, 10.0]],
389            vec![0.0; 4],
390            vec!["x".into(), "y".into()],
391            "label",
392        );
393        mbk.partial_fit(&data).unwrap();
394        assert!(mbk.is_initialized());
395    }
396
397    #[test]
398    fn test_partial_fit_convergence() {
399        // Two well-separated blobs, fed in batches.
400        let mut mbk = MiniBatchKMeans::new(2).seed(42);
401
402        // Batch 1: cluster A around (1, 1)
403        let b1 = Dataset::new(
404            vec![vec![0.5, 1.0, 1.5], vec![0.5, 1.0, 1.5]],
405            vec![0.0; 3],
406            vec!["x".into(), "y".into()],
407            "label",
408        );
409        // Batch 2: cluster B around (10, 10)
410        let b2 = Dataset::new(
411            vec![vec![9.5, 10.0, 10.5], vec![9.5, 10.0, 10.5]],
412            vec![0.0; 3],
413            vec!["x".into(), "y".into()],
414            "label",
415        );
416
417        mbk.partial_fit(&b1).unwrap();
418        mbk.partial_fit(&b2).unwrap();
419
420        // Centroids should be near the two cluster centers.
421        let c = mbk.centroids();
422        let c0_near_1 = c
423            .iter()
424            .any(|ci| (ci[0] - 1.0).abs() < 3.0 && (ci[1] - 1.0).abs() < 3.0);
425        let c1_near_10 = c
426            .iter()
427            .any(|ci| (ci[0] - 10.0).abs() < 3.0 && (ci[1] - 10.0).abs() < 3.0);
428        assert!(c0_near_1, "expected a centroid near (1,1), got {c:?}");
429        assert!(c1_near_10, "expected a centroid near (10,10), got {c:?}");
430
431        // Should predict correctly.
432        let pred = mbk.predict(&[vec![1.0, 1.0], vec![10.0, 10.0]]).unwrap();
433        assert_ne!(pred[0], pred[1], "different clusters expected");
434    }
435}