1use scirs2_core::ndarray::{s, Array1, Array2, ArrayView2};
28use scirs2_core::numeric::{Float, FromPrimitive};
29use scirs2_core::random::{Rng, RngExt, SeedableRng};
30use std::fmt::Debug;
31
32use super::{euclidean_distance, kmeans_plus_plus};
33use crate::error::{ClusteringError, Result};
34
35#[derive(Debug, Clone, Copy, PartialEq, Eq)]
37pub enum MiniBatchInit {
38 KMeansPlusPlus,
40 Random,
42}
43
44impl Default for MiniBatchInit {
45 fn default() -> Self {
46 MiniBatchInit::KMeansPlusPlus
47 }
48}
49
50#[derive(Debug, Clone)]
52pub struct MiniBatchKMeansOptions<F: Float> {
53 pub max_iter: usize,
55 pub batch_size: usize,
57 pub tol: F,
59 pub random_seed: Option<u64>,
61 pub max_no_improvement: usize,
63 pub init_size: Option<usize>,
65 pub reassignment_ratio: F,
67 pub init: MiniBatchInit,
69 pub ewa_smoothing: F,
71}
72
73impl<F: Float + FromPrimitive> Default for MiniBatchKMeansOptions<F> {
74 fn default() -> Self {
75 Self {
76 max_iter: 100,
77 batch_size: 1024,
78 tol: F::from(1e-4).unwrap_or(F::epsilon()),
79 random_seed: None,
80 max_no_improvement: 10,
81 init_size: None,
82 reassignment_ratio: F::from(0.01).unwrap_or(F::epsilon()),
83 init: MiniBatchInit::KMeansPlusPlus,
84 ewa_smoothing: F::from(0.7).unwrap_or(F::one()),
85 }
86 }
87}
88
89#[derive(Debug, Clone)]
91pub struct MiniBatchKMeansResult<F: Float> {
92 pub centroids: Array2<F>,
94 pub labels: Array1<usize>,
96 pub n_iter: usize,
98 pub inertia: F,
100 pub converged: bool,
102 pub inertia_history: Vec<F>,
104 pub cluster_counts: Array1<usize>,
106 pub n_reassignments: usize,
108}
109
110pub fn minibatch_kmeans<F>(
141 data: ArrayView2<F>,
142 k: usize,
143 options: Option<MiniBatchKMeansOptions<F>>,
144) -> Result<(Array2<F>, Array1<usize>)>
145where
146 F: Float + FromPrimitive + Debug + std::iter::Sum,
147{
148 let result = minibatch_kmeans_full(data, k, options)?;
149 Ok((result.centroids, result.labels))
150}
151
152pub fn minibatch_kmeans_full<F>(
154 data: ArrayView2<F>,
155 k: usize,
156 options: Option<MiniBatchKMeansOptions<F>>,
157) -> Result<MiniBatchKMeansResult<F>>
158where
159 F: Float + FromPrimitive + Debug + std::iter::Sum,
160{
161 if k == 0 {
163 return Err(ClusteringError::InvalidInput(
164 "Number of clusters must be greater than 0".to_string(),
165 ));
166 }
167
168 let n_samples = data.shape()[0];
169 let n_features = data.shape()[1];
170
171 if n_samples == 0 {
172 return Err(ClusteringError::InvalidInput(
173 "Input data is empty".to_string(),
174 ));
175 }
176
177 if k > n_samples {
178 return Err(ClusteringError::InvalidInput(format!(
179 "Number of clusters ({}) cannot be greater than number of data points ({})",
180 k, n_samples
181 )));
182 }
183
184 let opts = options.unwrap_or_default();
185
186 let mut rng = match opts.random_seed {
188 Some(seed) => scirs2_core::random::rngs::StdRng::seed_from_u64(seed),
189 None => {
190 scirs2_core::random::rngs::StdRng::seed_from_u64(scirs2_core::random::rng().random())
191 }
192 };
193
194 let init_size = opts
196 .init_size
197 .unwrap_or_else(|| {
198 let default_size = 3 * opts.batch_size;
199 if default_size < 3 * k {
200 3 * k
201 } else {
202 default_size
203 }
204 })
205 .min(n_samples);
206
207 let centroids = match opts.init {
209 MiniBatchInit::KMeansPlusPlus => {
210 if init_size < n_samples {
211 let mut indices = Vec::with_capacity(init_size);
212 for _ in 0..init_size {
213 indices.push(rng.random_range(0..n_samples));
214 }
215 let init_data =
216 Array2::from_shape_fn((init_size, n_features), |(i, j)| data[[indices[i], j]]);
217 kmeans_plus_plus(init_data.view(), k, opts.random_seed)?
218 } else {
219 kmeans_plus_plus(data, k, opts.random_seed)?
220 }
221 }
222 MiniBatchInit::Random => {
223 let mut centers = Array2::zeros((k, n_features));
224 for i in 0..k {
225 let idx = rng.random_range(0..n_samples);
226 centers.row_mut(i).assign(&data.row(idx));
227 }
228 centers
229 }
230 };
231
232 let mut centroids = centroids;
234 let mut counts = Array1::<F>::from_elem(k, F::one());
235
236 let mut ewa_inertia: Option<F> = None;
238 let mut no_improvement_count = 0;
239 let mut best_inertia = F::infinity();
240 let mut prev_centers: Option<Array2<F>> = None;
241 let mut inertia_history = Vec::with_capacity(opts.max_iter);
242 let mut total_reassignments = 0;
243 let mut converged = false;
244 let mut n_iter = 0;
245
246 for iter in 0..opts.max_iter {
248 n_iter = iter + 1;
249
250 let batch_size = opts.batch_size.min(n_samples);
252 let mut batch_indices = Vec::with_capacity(batch_size);
253 for _ in 0..batch_size {
254 batch_indices.push(rng.random_range(0..n_samples));
255 }
256
257 let step_result =
259 mini_batch_step(&data, &batch_indices, &mut centroids, &mut counts, &opts)?;
260
261 total_reassignments += step_result.n_reassignments;
262
263 let current_ewa = match ewa_inertia {
265 Some(prev_ewa) => {
266 prev_ewa * opts.ewa_smoothing
267 + step_result.batch_inertia * (F::one() - opts.ewa_smoothing)
268 }
269 None => step_result.batch_inertia,
270 };
271 ewa_inertia = Some(current_ewa);
272 inertia_history.push(current_ewa);
273
274 if current_ewa < best_inertia {
276 best_inertia = current_ewa;
277 no_improvement_count = 0;
278 } else {
279 no_improvement_count += 1;
280 }
281
282 if let Some(ref prev) = prev_centers {
284 let mut center_shift = F::zero();
285 for i in 0..k {
286 let dist = euclidean_distance(centroids.slice(s![i, ..]), prev.slice(s![i, ..]));
287 center_shift = center_shift + dist;
288 }
289 let k_f = F::from(k).unwrap_or(F::one());
290 center_shift = center_shift / k_f;
291
292 if center_shift < opts.tol {
293 converged = true;
294 break;
295 }
296 }
297
298 prev_centers = Some(centroids.clone());
299
300 if no_improvement_count >= opts.max_no_improvement {
302 converged = true;
303 break;
304 }
305 }
306
307 let (final_labels, final_distances) = assign_labels(data, centroids.view())?;
309
310 let final_inertia = final_distances
312 .iter()
313 .fold(F::zero(), |acc, &d| acc + d * d);
314
315 let mut cluster_counts = Array1::<usize>::zeros(k);
317 for &label in final_labels.iter() {
318 if label < k {
319 cluster_counts[label] += 1;
320 }
321 }
322
323 Ok(MiniBatchKMeansResult {
324 centroids,
325 labels: final_labels,
326 n_iter,
327 inertia: final_inertia,
328 converged,
329 inertia_history,
330 cluster_counts,
331 n_reassignments: total_reassignments,
332 })
333}
334
335struct MiniBatchStepResult<F: Float> {
337 batch_inertia: F,
339 n_reassignments: usize,
341}
342
343fn mini_batch_step<F>(
345 data: &ArrayView2<F>,
346 batch_indices: &[usize],
347 centroids: &mut Array2<F>,
348 counts: &mut Array1<F>,
349 opts: &MiniBatchKMeansOptions<F>,
350) -> Result<MiniBatchStepResult<F>>
351where
352 F: Float + FromPrimitive + Debug,
353{
354 let k = centroids.shape()[0];
355 let n_features = centroids.shape()[1];
356 let batch_size = batch_indices.len();
357
358 let mut closest_distances = Array1::from_elem(batch_size, F::infinity());
359 let mut closest_centers = Array1::<usize>::zeros(batch_size);
360 let mut inertia = F::zero();
361
362 for (i, &sample_idx) in batch_indices.iter().enumerate() {
364 let sample = data.slice(s![sample_idx, ..]);
365
366 let mut min_dist = F::infinity();
367 let mut min_idx = 0;
368
369 for j in 0..k {
370 let dist = euclidean_distance(sample, centroids.slice(s![j, ..]));
371 if dist < min_dist {
372 min_dist = dist;
373 min_idx = j;
374 }
375 }
376
377 closest_centers[i] = min_idx;
378 closest_distances[i] = min_dist;
379 inertia = inertia + min_dist * min_dist;
380 }
381
382 for i in 0..batch_size {
384 let center_idx = closest_centers[i];
385 let sample_idx = batch_indices[i];
386 let sample = data.slice(s![sample_idx, ..]);
387
388 let count = counts[center_idx];
389 let learning_rate = F::one() / (count + F::one());
391
392 for j in 0..n_features {
393 centroids[[center_idx, j]] =
394 centroids[[center_idx, j]] * (F::one() - learning_rate) + sample[j] * learning_rate;
395 }
396
397 counts[center_idx] = count + F::one();
398 }
399
400 let mut n_reassignments = 0;
402 let max_count = counts.fold(F::zero(), |a, &b| a.max(b));
403 let reassign_threshold = max_count * opts.reassignment_ratio;
404
405 for c in 0..k {
406 if counts[c] < reassign_threshold {
407 let mut max_dist = F::zero();
409 let mut max_idx = 0;
410
411 for j in 0..batch_size {
412 if closest_distances[j] > max_dist {
413 max_dist = closest_distances[j];
414 max_idx = j;
415 }
416 }
417
418 if max_dist > F::zero() {
419 let sample_idx = batch_indices[max_idx];
420 let sample = data.slice(s![sample_idx, ..]);
421
422 for j in 0..n_features {
423 centroids[[c, j]] = sample[j];
424 }
425
426 counts[c] = counts[c].max(F::one());
427 closest_centers[max_idx] = c;
428 closest_distances[max_idx] = F::zero();
429 n_reassignments += 1;
430 }
431 }
432 }
433
434 let batch_f = F::from(batch_size).unwrap_or(F::one());
436 inertia = inertia / batch_f;
437
438 Ok(MiniBatchStepResult {
439 batch_inertia: inertia,
440 n_reassignments,
441 })
442}
443
444fn assign_labels<F>(
446 data: ArrayView2<F>,
447 centroids: ArrayView2<F>,
448) -> Result<(Array1<usize>, Array1<F>)>
449where
450 F: Float + FromPrimitive + Debug,
451{
452 let n_samples = data.shape()[0];
453 let n_clusters = centroids.shape()[0];
454
455 let mut labels = Array1::<usize>::zeros(n_samples);
456 let mut distances = Array1::<F>::zeros(n_samples);
457
458 for i in 0..n_samples {
459 let sample = data.slice(s![i, ..]);
460 let mut min_dist = F::infinity();
461 let mut min_idx = 0;
462
463 for j in 0..n_clusters {
464 let centroid = centroids.slice(s![j, ..]);
465 let dist = euclidean_distance(sample, centroid);
466
467 if dist < min_dist {
468 min_dist = dist;
469 min_idx = j;
470 }
471 }
472
473 labels[i] = min_idx;
474 distances[i] = min_dist;
475 }
476
477 Ok((labels, distances))
478}
479
480#[cfg(test)]
481mod tests {
482 use super::*;
483 use scirs2_core::ndarray::Array2;
484
485 fn make_two_cluster_data() -> Array2<f64> {
486 Array2::from_shape_vec(
487 (6, 2),
488 vec![1.0, 2.0, 1.2, 1.8, 0.8, 1.9, 4.0, 5.0, 4.2, 4.8, 3.9, 5.1],
489 )
490 .expect("Failed to create test data")
491 }
492
493 #[test]
494 fn test_minibatch_kmeans_simple() {
495 let data = make_two_cluster_data();
496
497 let options = MiniBatchKMeansOptions {
498 max_iter: 10,
499 batch_size: 3,
500 random_seed: Some(42),
501 ..Default::default()
502 };
503
504 let (centroids, labels) =
505 minibatch_kmeans(data.view(), 2, Some(options)).expect("Should succeed");
506
507 assert_eq!(centroids.shape(), &[2, 2]);
508 assert_eq!(labels.shape(), &[6]);
509
510 let unique_labels: std::collections::HashSet<_> = labels.iter().copied().collect();
511 assert_eq!(unique_labels.len(), 2);
512
513 let first_label = labels[0];
515 assert_eq!(labels[1], first_label);
516 assert_eq!(labels[2], first_label);
517
518 let second_label = labels[3];
519 assert_eq!(labels[4], second_label);
520 assert_eq!(labels[5], second_label);
521 }
522
523 #[test]
524 fn test_minibatch_kmeans_full_diagnostics() {
525 let data = make_two_cluster_data();
526
527 let options = MiniBatchKMeansOptions {
528 max_iter: 50,
529 batch_size: 4,
530 random_seed: Some(42),
531 ..Default::default()
532 };
533
534 let result = minibatch_kmeans_full(data.view(), 2, Some(options)).expect("Should succeed");
535
536 assert_eq!(result.centroids.shape(), &[2, 2]);
537 assert_eq!(result.labels.shape(), &[6]);
538 assert!(result.n_iter > 0);
539 assert!(result.inertia >= 0.0);
540 assert!(!result.inertia_history.is_empty());
541
542 for &count in result.cluster_counts.iter() {
544 assert!(count > 0, "Each cluster should have assigned points");
545 }
546 }
547
548 #[test]
549 fn test_minibatch_kmeans_convergence() {
550 let data = make_two_cluster_data();
551
552 let options = MiniBatchKMeansOptions {
553 max_iter: 1000,
554 batch_size: 6, random_seed: Some(42),
556 tol: 1e-6,
557 max_no_improvement: 20,
558 ..Default::default()
559 };
560
561 let result = minibatch_kmeans_full(data.view(), 2, Some(options)).expect("Should succeed");
562
563 assert!(
565 result.n_iter < 1000,
566 "Should converge before max_iter, took {} iters",
567 result.n_iter
568 );
569 }
570
571 #[test]
572 fn test_minibatch_kmeans_empty_clusters() {
573 let data = Array2::from_shape_vec(
574 (8, 2),
575 vec![
576 1.0, 1.0, 1.1, 1.1, 1.2, 1.0, 1.0, 1.2, 5.0, 5.0, 5.1, 5.1, 5.2, 5.0, 5.0, 5.2,
577 ],
578 )
579 .expect("Failed to create data");
580
581 let options = MiniBatchKMeansOptions {
582 max_iter: 20,
583 batch_size: 4,
584 random_seed: Some(42),
585 reassignment_ratio: 0.1,
586 ..Default::default()
587 };
588
589 let (centroids, labels) =
590 minibatch_kmeans(data.view(), 3, Some(options)).expect("Should succeed");
591
592 assert_eq!(centroids.shape(), &[3, 2]);
593 assert_eq!(labels.shape(), &[8]);
594
595 let unique_labels: std::collections::HashSet<_> = labels.iter().copied().collect();
596 assert!(unique_labels.len() <= 3);
597 }
598
599 #[test]
600 fn test_minibatch_kmeans_random_init() {
601 let data = make_two_cluster_data();
602
603 let options = MiniBatchKMeansOptions {
604 init: MiniBatchInit::Random,
605 random_seed: Some(42),
606 max_iter: 50,
607 batch_size: 4,
608 ..Default::default()
609 };
610
611 let (centroids, labels) =
612 minibatch_kmeans(data.view(), 2, Some(options)).expect("Should succeed");
613
614 assert_eq!(centroids.shape(), &[2, 2]);
615 assert_eq!(labels.shape(), &[6]);
616 }
617
618 #[test]
619 fn test_minibatch_kmeans_inertia_decreases() {
620 let data = make_two_cluster_data();
621
622 let options = MiniBatchKMeansOptions {
623 max_iter: 50,
624 batch_size: 6,
625 random_seed: Some(42),
626 ewa_smoothing: 0.5,
627 ..Default::default()
628 };
629
630 let result = minibatch_kmeans_full(data.view(), 2, Some(options)).expect("Should succeed");
631
632 if result.inertia_history.len() >= 3 {
634 let first_few: f64 = result.inertia_history[..3].iter().copied().sum::<f64>() / 3.0;
635 let last_few: f64 = result.inertia_history[result.inertia_history.len() - 3..]
636 .iter()
637 .copied()
638 .sum::<f64>()
639 / 3.0;
640
641 assert!(
642 last_few <= first_few + 1.0,
643 "Inertia should generally decrease: first_avg={}, last_avg={}",
644 first_few,
645 last_few
646 );
647 }
648 }
649
650 #[test]
651 fn test_minibatch_kmeans_invalid_inputs() {
652 let data = make_two_cluster_data();
653
654 let result = minibatch_kmeans(data.view(), 0, None);
656 assert!(result.is_err());
657
658 let result = minibatch_kmeans(data.view(), 100, None);
660 assert!(result.is_err());
661
662 let empty = Array2::<f64>::zeros((0, 2));
664 let result = minibatch_kmeans(empty.view(), 2, None);
665 assert!(result.is_err());
666 }
667
668 #[test]
669 fn test_minibatch_kmeans_k_equals_n() {
670 let data = make_two_cluster_data();
671
672 let options = MiniBatchKMeansOptions {
673 random_seed: Some(42),
674 max_iter: 10,
675 ..Default::default()
676 };
677
678 let (centroids, labels) =
679 minibatch_kmeans(data.view(), 6, Some(options)).expect("Should succeed");
680
681 assert_eq!(centroids.shape(), &[6, 2]);
682 assert_eq!(labels.shape(), &[6]);
683
684 let unique_labels: std::collections::HashSet<_> = labels.iter().copied().collect();
686 assert_eq!(unique_labels.len(), 6);
687 }
688}