1use rayon::prelude::*;
23
24use crate::constants::KMEANS_PAR_THRESHOLD;
25use crate::dataset::Dataset;
26use crate::distance::euclidean_sq;
27use crate::error::{Result, ScryLearnError};
28
29#[derive(Clone)]
35#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
36#[non_exhaustive]
37pub struct KMeans {
38 k: usize,
39 max_iter: usize,
40 tolerance: f64,
41 seed: u64,
42 n_init: usize,
43 centroids: Vec<Vec<f64>>,
44 labels: Vec<usize>,
45 inertia: f64,
46 n_iter: usize,
47 fitted: bool,
48 #[cfg_attr(feature = "serde", serde(default))]
49 _schema_version: u32,
50}
51
52impl KMeans {
53 pub fn new(k: usize) -> Self {
55 Self {
56 k,
57 max_iter: 300,
58 tolerance: 1e-4,
59 seed: 42,
60 n_init: 10,
61 centroids: Vec::new(),
62 labels: Vec::new(),
63 inertia: f64::INFINITY,
64 n_iter: 0,
65 fitted: false,
66 _schema_version: crate::version::SCHEMA_VERSION,
67 }
68 }
69
70 pub fn max_iter(mut self, n: usize) -> Self {
72 self.max_iter = n;
73 self
74 }
75
76 pub fn tolerance(mut self, t: f64) -> Self {
78 self.tolerance = t;
79 self
80 }
81
82 pub fn tol(self, t: f64) -> Self {
84 self.tolerance(t)
85 }
86
87 pub fn seed(mut self, s: u64) -> Self {
89 self.seed = s;
90 self
91 }
92
93 pub fn n_init(mut self, n: usize) -> Self {
98 self.n_init = n.max(1);
99 self
100 }
101
102 pub fn fit(&mut self, data: &Dataset) -> Result<()> {
106 data.validate_finite()?;
107 let n = data.n_samples();
108 if n == 0 {
109 return Err(ScryLearnError::EmptyDataset);
110 }
111 if self.k == 0 || self.k > n {
112 return Err(ScryLearnError::InvalidParameter(format!(
113 "k must be between 1 and n_samples ({}), got {}",
114 n, self.k
115 )));
116 }
117
118 let rows = data.feature_matrix();
119 let m = data.n_features();
120
121 let mut best_centroids: Option<Vec<Vec<f64>>> = None;
122 let mut best_labels: Option<Vec<usize>> = None;
123 let mut best_inertia = f64::INFINITY;
124 let mut best_n_iter = 0;
125
126 for run in 0..self.n_init {
127 let run_seed = self.seed.wrapping_add(run as u64);
128 let (centroids, labels, inertia, n_iter) = self.run_once(&rows, n, m, run_seed);
129
130 if inertia < best_inertia {
131 best_centroids = Some(centroids);
132 best_labels = Some(labels);
133 best_inertia = inertia;
134 best_n_iter = n_iter;
135 }
136 }
137
138 self.centroids = best_centroids.unwrap_or_default();
139 self.labels = best_labels.unwrap_or_default();
140 self.inertia = best_inertia;
141 self.n_iter = best_n_iter;
142 self.fitted = true;
143 Ok(())
144 }
145
146 #[allow(clippy::type_complexity)]
148 fn run_once(
149 &self,
150 rows: &[Vec<f64>],
151 n: usize,
152 m: usize,
153 seed: u64,
154 ) -> (Vec<Vec<f64>>, Vec<usize>, f64, usize) {
155 let mut centroids = kmeans_plus_plus(rows, self.k, seed);
156 let mut labels = vec![0usize; n];
157 let mut prev_inertia = f64::INFINITY;
158 let mut final_inertia = f64::INFINITY;
159 let mut final_n_iter = 0;
160 let use_par = n * self.k >= KMEANS_PAR_THRESHOLD;
161
162 for iter in 0..self.max_iter {
163 let inertia;
165 if use_par {
166 let results: Vec<(usize, f64)> = rows
167 .par_iter()
168 .map(|row| {
169 let mut best_dist = f64::INFINITY;
170 let mut best_c = 0;
171 for (c, centroid) in centroids.iter().enumerate() {
172 let d = euclidean_sq(row, centroid);
173 if d < best_dist {
174 best_dist = d;
175 best_c = c;
176 }
177 }
178 (best_c, best_dist)
179 })
180 .collect();
181 inertia = results.iter().map(|(_, d)| d).sum();
182 for (i, (c, _)) in results.into_iter().enumerate() {
183 labels[i] = c;
184 }
185 } else {
186 let mut seq_inertia = 0.0;
187 for (i, row) in rows.iter().enumerate() {
188 let mut best_dist = f64::INFINITY;
189 let mut best_c = 0;
190 for (c, centroid) in centroids.iter().enumerate() {
191 let d = euclidean_sq(row, centroid);
192 if d < best_dist {
193 best_dist = d;
194 best_c = c;
195 }
196 }
197 labels[i] = best_c;
198 seq_inertia += best_dist;
199 }
200 inertia = seq_inertia;
201 }
202
203 let mut new_centroids = vec![vec![0.0; m]; self.k];
205 let mut counts = vec![0usize; self.k];
206
207 for (i, row) in rows.iter().enumerate() {
208 let c = labels[i];
209 counts[c] += 1;
210 for (j, &val) in row.iter().enumerate() {
211 new_centroids[c][j] += val;
212 }
213 }
214
215 for c in 0..self.k {
217 if counts[c] > 0 {
218 for val in &mut new_centroids[c] {
219 *val /= counts[c] as f64;
220 }
221 }
222 }
223
224 for c in 0..self.k {
227 if counts[c] == 0 {
228 let mut max_dist = f64::NEG_INFINITY;
229 let mut best_idx = 0;
230 for (i, row) in rows.iter().enumerate() {
231 let min_dist = new_centroids
232 .iter()
233 .enumerate()
234 .filter(|&(ci, _)| ci != c && (counts[ci] > 0 || ci < c))
235 .map(|(_, cen)| euclidean_sq(row, cen))
236 .fold(f64::INFINITY, f64::min);
237 if min_dist > max_dist {
238 max_dist = min_dist;
239 best_idx = i;
240 }
241 }
242 new_centroids[c].clone_from(&rows[best_idx]);
243 }
244 }
245
246 let shift: f64 = centroids
248 .iter()
249 .zip(new_centroids.iter())
250 .map(|(old, new)| euclidean_sq(old, new))
251 .sum();
252
253 centroids = new_centroids;
254 final_n_iter = iter + 1;
255 final_inertia = inertia;
256
257 if (prev_inertia - inertia).abs() < self.tolerance || shift < self.tolerance {
258 break;
259 }
260 prev_inertia = inertia;
261 }
262
263 (centroids, labels, final_inertia, final_n_iter)
264 }
265
266 pub fn predict(&self, features: &[Vec<f64>]) -> Result<Vec<usize>> {
268 crate::version::check_schema_version(self._schema_version)?;
269 if !self.fitted {
270 return Err(ScryLearnError::NotFitted);
271 }
272 Ok(features
273 .iter()
274 .map(|row| {
275 self.centroids
276 .iter()
277 .enumerate()
278 .min_by(|(_, a), (_, b)| {
279 euclidean_sq(row, a)
280 .partial_cmp(&euclidean_sq(row, b))
281 .unwrap_or(std::cmp::Ordering::Equal)
282 })
283 .map_or(0, |(idx, _)| idx)
284 })
285 .collect())
286 }
287
288 pub fn transform(&self, features: &[Vec<f64>]) -> Result<Vec<Vec<f64>>> {
308 if !self.fitted {
309 return Err(ScryLearnError::NotFitted);
310 }
311 Ok(features
312 .iter()
313 .map(|row| {
314 self.centroids
315 .iter()
316 .map(|c| euclidean_sq(row, c).sqrt())
317 .collect()
318 })
319 .collect())
320 }
321
322 pub fn centroids(&self) -> &[Vec<f64>] {
324 &self.centroids
325 }
326
327 pub fn labels(&self) -> &[usize] {
329 &self.labels
330 }
331
332 pub fn inertia(&self) -> f64 {
334 self.inertia
335 }
336
337 pub fn n_iter(&self) -> usize {
339 self.n_iter
340 }
341}
342
343pub(crate) fn kmeans_plus_plus(rows: &[Vec<f64>], k: usize, seed: u64) -> Vec<Vec<f64>> {
345 let mut rng = crate::rng::FastRng::new(seed);
346 let n = rows.len();
347 let mut centroids = Vec::with_capacity(k);
348
349 centroids.push(rows[rng.usize(0..n)].clone());
351
352 for _ in 1..k {
353 let mut dists: Vec<f64> = rows
355 .iter()
356 .map(|row| {
357 centroids
358 .iter()
359 .map(|c| euclidean_sq(row, c))
360 .fold(f64::INFINITY, f64::min)
361 })
362 .collect();
363
364 let total: f64 = dists.iter().sum();
366 if total < 1e-12 {
367 centroids.push(rows[rng.usize(0..n)].clone());
368 continue;
369 }
370 for d in &mut dists {
371 *d /= total;
372 }
373
374 let r = rng.f64();
375 let mut cumsum = 0.0;
376 let mut selected = n - 1;
377 for (i, &d) in dists.iter().enumerate() {
378 cumsum += d;
379 if cumsum >= r {
380 selected = i;
381 break;
382 }
383 }
384 centroids.push(rows[selected].clone());
385 }
386
387 centroids
388}
389
390#[cfg(test)]
391mod tests {
392 use super::*;
393
394 #[test]
395 fn test_kmeans_two_blobs() {
396 let mut f1 = Vec::new();
398 let mut f2 = Vec::new();
399 let mut target = Vec::new();
400 for i in 0..30 {
401 f1.push(i as f64 % 3.0);
402 f2.push(i as f64 % 3.0);
403 target.push(0.0);
404 }
405 for i in 0..30 {
406 f1.push(100.0 + i as f64 % 3.0);
407 f2.push(100.0 + i as f64 % 3.0);
408 target.push(1.0);
409 }
410
411 let data = Dataset::new(vec![f1, f2], target, vec!["x".into(), "y".into()], "label");
412
413 let mut km = KMeans::new(2).seed(42).n_init(1);
414 km.fit(&data).unwrap();
415
416 let labels = km.labels();
418 let first_label = labels[0];
419 assert!(labels[..30].iter().all(|&l| l == first_label));
420 assert!(labels[30..].iter().all(|&l| l != first_label));
421 }
422
423 #[test]
424 fn test_kmeans_predict() {
425 let data = Dataset::new(
426 vec![vec![0.0, 0.0, 10.0, 10.0], vec![0.0, 0.0, 10.0, 10.0]],
427 vec![0.0; 4],
428 vec!["x".into(), "y".into()],
429 "label",
430 );
431
432 let mut km = KMeans::new(2).seed(42).n_init(1);
433 km.fit(&data).unwrap();
434
435 let pred = km.predict(&[vec![1.0, 1.0], vec![9.0, 9.0]]).unwrap();
436 assert_ne!(
437 pred[0], pred[1],
438 "nearby and far points should be in different clusters"
439 );
440 }
441
442 #[test]
443 fn test_kmeans_n_init_improves_inertia() {
444 let mut rng = crate::rng::FastRng::new(7);
446 let n = 100;
447 let mut f1 = Vec::with_capacity(n);
448 let mut f2 = Vec::with_capacity(n);
449 for _ in 0..n / 2 {
450 f1.push(rng.f64() * 5.0);
451 f2.push(rng.f64() * 5.0);
452 }
453 for _ in 0..n / 2 {
454 f1.push(20.0 + rng.f64() * 5.0);
455 f2.push(20.0 + rng.f64() * 5.0);
456 }
457 let data = Dataset::new(
458 vec![f1, f2],
459 vec![0.0; n],
460 vec!["x".into(), "y".into()],
461 "label",
462 );
463
464 let mut km1 = KMeans::new(3).seed(7).n_init(1);
465 km1.fit(&data).unwrap();
466 let inertia1 = km1.inertia();
467
468 let mut km10 = KMeans::new(3).seed(7).n_init(10);
469 km10.fit(&data).unwrap();
470 let inertia10 = km10.inertia();
471
472 assert!(
473 inertia10 <= inertia1 + 1e-6,
474 "n_init=10 inertia ({inertia10:.4}) should be ≤ n_init=1 ({inertia1:.4})"
475 );
476 }
477
478 #[test]
479 fn test_kmeans_empty_cluster_reinit() {
480 let mut f1 = Vec::new();
484 let mut f2 = Vec::new();
485 for _ in 0..50 {
486 f1.push(0.0);
487 f2.push(0.0);
488 }
489 for _ in 0..50 {
490 f1.push(100.0);
491 f2.push(100.0);
492 }
493 let data = Dataset::new(
494 vec![f1, f2],
495 vec![0.0; 100],
496 vec!["x".into(), "y".into()],
497 "l",
498 );
499
500 let mut km = KMeans::new(3).seed(42).n_init(1);
501 km.fit(&data).unwrap();
502
503 let centroids = km.centroids();
506 assert_eq!(centroids.len(), 3);
507 let has_near_origin = centroids.iter().any(|c| c[0] < 50.0 && c[1] < 50.0);
508 let has_near_far = centroids.iter().any(|c| c[0] > 50.0 && c[1] > 50.0);
509 assert!(has_near_origin, "should have centroid near (0,0)");
510 assert!(has_near_far, "should have centroid near (100,100)");
511
512 assert_eq!(km.labels().len(), 100);
514 }
515
516 #[test]
517 fn test_kmeans_transform_shape() {
518 let data = Dataset::new(
519 vec![vec![0.0, 0.0, 10.0, 10.0], vec![0.0, 0.0, 10.0, 10.0]],
520 vec![0.0; 4],
521 vec!["x".into(), "y".into()],
522 "label",
523 );
524
525 let mut km = KMeans::new(2).seed(42).n_init(1);
526 km.fit(&data).unwrap();
527
528 let dists = km.transform(&[vec![5.0, 5.0], vec![0.0, 0.0]]).unwrap();
529 assert_eq!(dists.len(), 2, "should have 2 samples");
530 assert_eq!(
531 dists[0].len(),
532 2,
533 "should have distance to each of 2 centroids"
534 );
535 for row in &dists {
537 for &d in row {
538 assert!(d >= 0.0, "distance should be non-negative");
539 }
540 }
541 }
542}