1use super::kmeans::kmeans_plus_plus;
26use crate::dataset::Dataset;
27use crate::distance::euclidean_sq;
28use crate::error::{Result, ScryLearnError};
29use crate::partial_fit::PartialFit;
30
31#[derive(Clone)]
37#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
38#[non_exhaustive]
39pub struct MiniBatchKMeans {
40 k: usize,
41 batch_size: usize,
42 max_iter: usize,
43 tolerance: f64,
44 seed: u64,
45 centroids: Vec<Vec<f64>>,
46 labels: Vec<usize>,
47 inertia: f64,
48 n_iter: usize,
49 fitted: bool,
50 centroid_counts: Vec<u64>,
52 #[cfg_attr(feature = "serde", serde(default))]
53 _schema_version: u32,
54}
55
56impl MiniBatchKMeans {
57 pub fn new(k: usize) -> Self {
59 Self {
60 k,
61 batch_size: 1024,
62 max_iter: 100,
63 tolerance: 0.0,
64 seed: 42,
65 centroids: Vec::new(),
66 labels: Vec::new(),
67 inertia: f64::INFINITY,
68 n_iter: 0,
69 fitted: false,
70 centroid_counts: Vec::new(),
71 _schema_version: crate::version::SCHEMA_VERSION,
72 }
73 }
74
75 pub fn batch_size(mut self, n: usize) -> Self {
77 self.batch_size = n.max(1);
78 self
79 }
80
81 pub fn max_iter(mut self, n: usize) -> Self {
83 self.max_iter = n;
84 self
85 }
86
87 pub fn tolerance(mut self, t: f64) -> Self {
89 self.tolerance = t;
90 self
91 }
92
93 pub fn tol(self, t: f64) -> Self {
95 self.tolerance(t)
96 }
97
98 pub fn seed(mut self, s: u64) -> Self {
100 self.seed = s;
101 self
102 }
103
104 pub fn fit(&mut self, data: &Dataset) -> Result<()> {
106 data.validate_finite()?;
107 let n = data.n_samples();
108 let m = data.n_features();
109 if n == 0 {
110 return Err(ScryLearnError::EmptyDataset);
111 }
112 if self.k == 0 || self.k > n {
113 return Err(ScryLearnError::InvalidParameter(format!(
114 "k must be between 1 and n_samples ({}), got {}",
115 n, self.k
116 )));
117 }
118
119 let rows = data.feature_matrix();
120 let mut rng = crate::rng::FastRng::new(self.seed);
121 let effective_batch = self.batch_size.min(n);
122
123 let mut centroids = kmeans_plus_plus(&rows, self.k, self.seed);
125
126 let mut centroid_counts = vec![0_u64; self.k];
128 let mut prev_inertia = f64::INFINITY;
129
130 for iter in 0..self.max_iter {
131 let batch_indices: Vec<usize> = (0..effective_batch).map(|_| rng.usize(0..n)).collect();
133
134 let mut assignments = Vec::with_capacity(effective_batch);
136 for &idx in &batch_indices {
137 let mut best_c = 0;
138 let mut best_dist = f64::INFINITY;
139 for (c, centroid) in centroids.iter().enumerate() {
140 let d = euclidean_sq(&rows[idx], centroid);
141 if d < best_dist {
142 best_dist = d;
143 best_c = c;
144 }
145 }
146 assignments.push(best_c);
147 }
148
149 for (batch_i, &idx) in batch_indices.iter().enumerate() {
151 let c = assignments[batch_i];
152 centroid_counts[c] += 1;
153 let lr = 1.0 / centroid_counts[c] as f64;
154 for j in 0..m {
155 centroids[c][j] += lr * (rows[idx][j] - centroids[c][j]);
156 }
157 }
158
159 if iter % 10 == 0 || iter == self.max_iter - 1 {
161 let mut inertia = 0.0;
162 for row in &rows {
163 let mut best_dist = f64::INFINITY;
164 for centroid in ¢roids {
165 let d = euclidean_sq(row, centroid);
166 if d < best_dist {
167 best_dist = d;
168 }
169 }
170 inertia += best_dist;
171 }
172
173 self.n_iter = iter + 1;
174 self.inertia = inertia;
175
176 if self.tolerance > 0.0 && (prev_inertia - inertia).abs() < self.tolerance {
177 break;
178 }
179 prev_inertia = inertia;
180 }
181 }
182
183 self.labels = rows
185 .iter()
186 .map(|row| {
187 centroids
188 .iter()
189 .enumerate()
190 .min_by(|(_, a), (_, b)| {
191 euclidean_sq(row, a)
192 .partial_cmp(&euclidean_sq(row, b))
193 .unwrap_or(std::cmp::Ordering::Equal)
194 })
195 .map_or(0, |(idx, _)| idx)
196 })
197 .collect();
198
199 self.centroids = centroids;
200 self.fitted = true;
201 Ok(())
202 }
203
204 pub fn predict(&self, features: &[Vec<f64>]) -> Result<Vec<usize>> {
206 crate::version::check_schema_version(self._schema_version)?;
207 if !self.fitted {
208 return Err(ScryLearnError::NotFitted);
209 }
210 Ok(features
211 .iter()
212 .map(|row| {
213 self.centroids
214 .iter()
215 .enumerate()
216 .min_by(|(_, a), (_, b)| {
217 euclidean_sq(row, a)
218 .partial_cmp(&euclidean_sq(row, b))
219 .unwrap_or(std::cmp::Ordering::Equal)
220 })
221 .map_or(0, |(idx, _)| idx)
222 })
223 .collect())
224 }
225
226 pub fn centroids(&self) -> &[Vec<f64>] {
228 &self.centroids
229 }
230
231 pub fn labels(&self) -> &[usize] {
233 &self.labels
234 }
235
236 pub fn inertia(&self) -> f64 {
238 self.inertia
239 }
240
241 pub fn n_iter(&self) -> usize {
243 self.n_iter
244 }
245}
246
247impl PartialFit for MiniBatchKMeans {
248 fn partial_fit(&mut self, data: &Dataset) -> Result<()> {
254 let n = data.n_samples();
255 if n == 0 {
256 if self.is_initialized() {
258 return Ok(());
259 }
260 return Err(ScryLearnError::EmptyDataset);
261 }
262
263 let rows = data.feature_matrix();
264 let m = data.n_features();
265
266 if !self.is_initialized() {
267 if self.k == 0 || self.k > n {
268 return Err(ScryLearnError::InvalidParameter(format!(
269 "k must be between 1 and n_samples ({}), got {}",
270 n, self.k
271 )));
272 }
273 self.centroids = kmeans_plus_plus(&rows, self.k, self.seed);
274 self.centroid_counts = vec![0_u64; self.k];
275 }
276
277 for row in &rows {
279 let mut best_c = 0;
280 let mut best_dist = f64::INFINITY;
281 for (c, centroid) in self.centroids.iter().enumerate() {
282 let d = euclidean_sq(row, centroid);
283 if d < best_dist {
284 best_dist = d;
285 best_c = c;
286 }
287 }
288 self.centroid_counts[best_c] += 1;
289 let lr = 1.0 / self.centroid_counts[best_c] as f64;
290 for j in 0..m {
291 self.centroids[best_c][j] += lr * (row[j] - self.centroids[best_c][j]);
292 }
293 }
294
295 self.labels = rows
297 .iter()
298 .map(|row| {
299 self.centroids
300 .iter()
301 .enumerate()
302 .min_by(|(_, a), (_, b)| {
303 euclidean_sq(row, a)
304 .partial_cmp(&euclidean_sq(row, b))
305 .unwrap_or(std::cmp::Ordering::Equal)
306 })
307 .map_or(0, |(idx, _)| idx)
308 })
309 .collect();
310
311 self.inertia = rows
312 .iter()
313 .map(|row| {
314 self.centroids
315 .iter()
316 .map(|c| euclidean_sq(row, c))
317 .fold(f64::INFINITY, f64::min)
318 })
319 .sum();
320
321 self.fitted = true;
322 Ok(())
323 }
324
325 fn is_initialized(&self) -> bool {
326 !self.centroids.is_empty()
327 }
328}
329
330#[cfg(test)]
331mod tests {
332 use super::*;
333
334 #[test]
335 fn test_mini_batch_kmeans_two_blobs() {
336 let mut f1 = Vec::new();
337 let mut f2 = Vec::new();
338 for i in 0..30 {
339 f1.push(i as f64 % 3.0);
340 f2.push(i as f64 % 3.0);
341 }
342 for i in 0..30 {
343 f1.push(100.0 + i as f64 % 3.0);
344 f2.push(100.0 + i as f64 % 3.0);
345 }
346
347 let data = Dataset::new(
348 vec![f1, f2],
349 vec![0.0; 60],
350 vec!["x".into(), "y".into()],
351 "label",
352 );
353
354 let mut mbk = MiniBatchKMeans::new(2).seed(42).batch_size(20);
355 mbk.fit(&data).unwrap();
356
357 let labels = mbk.labels();
358 let first_label = labels[0];
359 assert!(labels[..30].iter().all(|&l| l == first_label));
360 assert!(labels[30..].iter().all(|&l| l != first_label));
361 }
362
363 #[test]
364 fn test_mini_batch_kmeans_predict() {
365 let data = Dataset::new(
366 vec![vec![0.0, 0.0, 10.0, 10.0], vec![0.0, 0.0, 10.0, 10.0]],
367 vec![0.0; 4],
368 vec!["x".into(), "y".into()],
369 "label",
370 );
371
372 let mut mbk = MiniBatchKMeans::new(2).seed(42).batch_size(4);
373 mbk.fit(&data).unwrap();
374
375 let pred = mbk.predict(&[vec![1.0, 1.0], vec![9.0, 9.0]]).unwrap();
376 assert_ne!(
377 pred[0], pred[1],
378 "nearby and far points should be in different clusters"
379 );
380 }
381
382 #[test]
383 fn test_partial_fit_is_initialized() {
384 let mut mbk = MiniBatchKMeans::new(2);
385 assert!(!mbk.is_initialized());
386
387 let data = Dataset::new(
388 vec![vec![0.0, 0.0, 10.0, 10.0], vec![0.0, 0.0, 10.0, 10.0]],
389 vec![0.0; 4],
390 vec!["x".into(), "y".into()],
391 "label",
392 );
393 mbk.partial_fit(&data).unwrap();
394 assert!(mbk.is_initialized());
395 }
396
397 #[test]
398 fn test_partial_fit_convergence() {
399 let mut mbk = MiniBatchKMeans::new(2).seed(42);
401
402 let b1 = Dataset::new(
404 vec![vec![0.5, 1.0, 1.5], vec![0.5, 1.0, 1.5]],
405 vec![0.0; 3],
406 vec!["x".into(), "y".into()],
407 "label",
408 );
409 let b2 = Dataset::new(
411 vec![vec![9.5, 10.0, 10.5], vec![9.5, 10.0, 10.5]],
412 vec![0.0; 3],
413 vec!["x".into(), "y".into()],
414 "label",
415 );
416
417 mbk.partial_fit(&b1).unwrap();
418 mbk.partial_fit(&b2).unwrap();
419
420 let c = mbk.centroids();
422 let c0_near_1 = c
423 .iter()
424 .any(|ci| (ci[0] - 1.0).abs() < 3.0 && (ci[1] - 1.0).abs() < 3.0);
425 let c1_near_10 = c
426 .iter()
427 .any(|ci| (ci[0] - 10.0).abs() < 3.0 && (ci[1] - 10.0).abs() < 3.0);
428 assert!(c0_near_1, "expected a centroid near (1,1), got {c:?}");
429 assert!(c1_near_10, "expected a centroid near (10,10), got {c:?}");
430
431 let pred = mbk.predict(&[vec![1.0, 1.0], vec![10.0, 10.0]]).unwrap();
433 assert_ne!(pred[0], pred[1], "different clusters expected");
434 }
435}