Skip to main content

scirs2_cluster/vq/
minibatch_kmeans.rs

1//! Mini-Batch K-means clustering implementation
2//!
3//! This module provides an implementation of the Mini-Batch K-means algorithm
4//! (Sculley 2010), a variant of k-means that uses mini-batches to reduce
5//! computation time while still attempting to optimize the same objective function.
6//!
7//! # Advantages over standard K-means
8//!
9//! - **Much faster** for large datasets (sublinear per-iteration cost)
10//! - **Similar quality** to standard k-means in practice
11//! - **Streaming compatible**: can process data in chunks
12//!
13//! # Algorithm
14//!
15//! 1. Initialize centroids (k-means++ or random)
16//! 2. For each iteration:
17//!    a. Sample a mini-batch of size `batch_size` from the data
18//!    b. Assign each sample in the batch to its nearest centroid
19//!    c. Update centroids using per-center learning rate: eta = 1 / count(center)
20//! 3. Monitor convergence using exponentially weighted average (EWA) of inertia
21//! 4. Detect and reassign near-empty clusters
22//!
23//! # References
24//!
25//! Sculley, D. (2010). "Web-Scale K-Means Clustering." WWW, pp. 1177-1178.
26
27use scirs2_core::ndarray::{s, Array1, Array2, ArrayView2};
28use scirs2_core::numeric::{Float, FromPrimitive};
29use scirs2_core::random::{Rng, RngExt, SeedableRng};
30use std::fmt::Debug;
31
32use super::{euclidean_distance, kmeans_plus_plus};
33use crate::error::{ClusteringError, Result};
34
35/// Initialization method for Mini-Batch K-means
36#[derive(Debug, Clone, Copy, PartialEq, Eq)]
37pub enum MiniBatchInit {
38    /// K-means++ initialization (default, recommended)
39    KMeansPlusPlus,
40    /// Random sampling from data points
41    Random,
42}
43
44impl Default for MiniBatchInit {
45    fn default() -> Self {
46        MiniBatchInit::KMeansPlusPlus
47    }
48}
49
50/// Options for Mini-Batch K-means clustering
51#[derive(Debug, Clone)]
52pub struct MiniBatchKMeansOptions<F: Float> {
53    /// Maximum number of iterations
54    pub max_iter: usize,
55    /// Size of mini-batches
56    pub batch_size: usize,
57    /// Convergence threshold for centroid movement
58    pub tol: F,
59    /// Random seed for initialization and batch sampling
60    pub random_seed: Option<u64>,
61    /// Number of iterations without improvement before stopping
62    pub max_no_improvement: usize,
63    /// Number of samples to use for initialization
64    pub init_size: Option<usize>,
65    /// Ratio of samples that should be reassigned to prevent empty clusters
66    pub reassignment_ratio: F,
67    /// Initialization method
68    pub init: MiniBatchInit,
69    /// EWA smoothing factor for inertia tracking (0 to 1)
70    pub ewa_smoothing: F,
71}
72
73impl<F: Float + FromPrimitive> Default for MiniBatchKMeansOptions<F> {
74    fn default() -> Self {
75        Self {
76            max_iter: 100,
77            batch_size: 1024,
78            tol: F::from(1e-4).unwrap_or(F::epsilon()),
79            random_seed: None,
80            max_no_improvement: 10,
81            init_size: None,
82            reassignment_ratio: F::from(0.01).unwrap_or(F::epsilon()),
83            init: MiniBatchInit::KMeansPlusPlus,
84            ewa_smoothing: F::from(0.7).unwrap_or(F::one()),
85        }
86    }
87}
88
89/// Result of Mini-Batch K-means with convergence diagnostics
90#[derive(Debug, Clone)]
91pub struct MiniBatchKMeansResult<F: Float> {
92    /// Final cluster centroids (k x n_features)
93    pub centroids: Array2<F>,
94    /// Cluster assignments for each data point
95    pub labels: Array1<usize>,
96    /// Number of iterations performed
97    pub n_iter: usize,
98    /// Final inertia (sum of squared distances to nearest centroid)
99    pub inertia: F,
100    /// Whether the algorithm converged
101    pub converged: bool,
102    /// History of EWA inertia values per iteration
103    pub inertia_history: Vec<F>,
104    /// Per-cluster count of assigned samples
105    pub cluster_counts: Array1<usize>,
106    /// Number of reassignments performed during training
107    pub n_reassignments: usize,
108}
109
110/// Mini-Batch K-means clustering algorithm
111///
112/// # Arguments
113///
114/// * `data` - Input data (n_samples x n_features)
115/// * `k` - Number of clusters
116/// * `options` - Optional parameters
117///
118/// # Returns
119///
120/// * Tuple of (centroids, labels)
121///
122/// # Examples
123///
124/// ```
125/// use scirs2_core::ndarray::{Array2, ArrayView2};
126/// use scirs2_cluster::vq::minibatch_kmeans;
127///
128/// let data = Array2::from_shape_vec((6, 2), vec![
129///     1.0, 2.0,
130///     1.2, 1.8,
131///     0.8, 1.9,
132///     3.7, 4.2,
133///     3.9, 3.9,
134///     4.2, 4.1,
135/// ]).expect("Operation failed");
136///
137/// let (centroids, labels) = minibatch_kmeans(ArrayView2::from(&data), 2, None)
138///     .expect("Operation failed");
139/// ```
140pub fn minibatch_kmeans<F>(
141    data: ArrayView2<F>,
142    k: usize,
143    options: Option<MiniBatchKMeansOptions<F>>,
144) -> Result<(Array2<F>, Array1<usize>)>
145where
146    F: Float + FromPrimitive + Debug + std::iter::Sum,
147{
148    let result = minibatch_kmeans_full(data, k, options)?;
149    Ok((result.centroids, result.labels))
150}
151
152/// Mini-Batch K-means with full diagnostic output
153pub fn minibatch_kmeans_full<F>(
154    data: ArrayView2<F>,
155    k: usize,
156    options: Option<MiniBatchKMeansOptions<F>>,
157) -> Result<MiniBatchKMeansResult<F>>
158where
159    F: Float + FromPrimitive + Debug + std::iter::Sum,
160{
161    // Input validation
162    if k == 0 {
163        return Err(ClusteringError::InvalidInput(
164            "Number of clusters must be greater than 0".to_string(),
165        ));
166    }
167
168    let n_samples = data.shape()[0];
169    let n_features = data.shape()[1];
170
171    if n_samples == 0 {
172        return Err(ClusteringError::InvalidInput(
173            "Input data is empty".to_string(),
174        ));
175    }
176
177    if k > n_samples {
178        return Err(ClusteringError::InvalidInput(format!(
179            "Number of clusters ({}) cannot be greater than number of data points ({})",
180            k, n_samples
181        )));
182    }
183
184    let opts = options.unwrap_or_default();
185
186    // Setup RNG
187    let mut rng = match opts.random_seed {
188        Some(seed) => scirs2_core::random::rngs::StdRng::seed_from_u64(seed),
189        None => {
190            scirs2_core::random::rngs::StdRng::seed_from_u64(scirs2_core::random::rng().random())
191        }
192    };
193
194    // Determine initialization size
195    let init_size = opts
196        .init_size
197        .unwrap_or_else(|| {
198            let default_size = 3 * opts.batch_size;
199            if default_size < 3 * k {
200                3 * k
201            } else {
202                default_size
203            }
204        })
205        .min(n_samples);
206
207    // Initialize centroids
208    let centroids = match opts.init {
209        MiniBatchInit::KMeansPlusPlus => {
210            if init_size < n_samples {
211                let mut indices = Vec::with_capacity(init_size);
212                for _ in 0..init_size {
213                    indices.push(rng.random_range(0..n_samples));
214                }
215                let init_data =
216                    Array2::from_shape_fn((init_size, n_features), |(i, j)| data[[indices[i], j]]);
217                kmeans_plus_plus(init_data.view(), k, opts.random_seed)?
218            } else {
219                kmeans_plus_plus(data, k, opts.random_seed)?
220            }
221        }
222        MiniBatchInit::Random => {
223            let mut centers = Array2::zeros((k, n_features));
224            for i in 0..k {
225                let idx = rng.random_range(0..n_samples);
226                centers.row_mut(i).assign(&data.row(idx));
227            }
228            centers
229        }
230    };
231
232    // Initialize variables
233    let mut centroids = centroids;
234    let mut counts = Array1::<F>::from_elem(k, F::one());
235
236    // Convergence tracking
237    let mut ewa_inertia: Option<F> = None;
238    let mut no_improvement_count = 0;
239    let mut best_inertia = F::infinity();
240    let mut prev_centers: Option<Array2<F>> = None;
241    let mut inertia_history = Vec::with_capacity(opts.max_iter);
242    let mut total_reassignments = 0;
243    let mut converged = false;
244    let mut n_iter = 0;
245
246    // Mini-batch optimization loop
247    for iter in 0..opts.max_iter {
248        n_iter = iter + 1;
249
250        // Sample a mini-batch
251        let batch_size = opts.batch_size.min(n_samples);
252        let mut batch_indices = Vec::with_capacity(batch_size);
253        for _ in 0..batch_size {
254            batch_indices.push(rng.random_range(0..n_samples));
255        }
256
257        // Perform mini-batch step
258        let step_result =
259            mini_batch_step(&data, &batch_indices, &mut centroids, &mut counts, &opts)?;
260
261        total_reassignments += step_result.n_reassignments;
262
263        // Update EWA of inertia
264        let current_ewa = match ewa_inertia {
265            Some(prev_ewa) => {
266                prev_ewa * opts.ewa_smoothing
267                    + step_result.batch_inertia * (F::one() - opts.ewa_smoothing)
268            }
269            None => step_result.batch_inertia,
270        };
271        ewa_inertia = Some(current_ewa);
272        inertia_history.push(current_ewa);
273
274        // Check inertia improvement
275        if current_ewa < best_inertia {
276            best_inertia = current_ewa;
277            no_improvement_count = 0;
278        } else {
279            no_improvement_count += 1;
280        }
281
282        // Check centroid movement convergence
283        if let Some(ref prev) = prev_centers {
284            let mut center_shift = F::zero();
285            for i in 0..k {
286                let dist = euclidean_distance(centroids.slice(s![i, ..]), prev.slice(s![i, ..]));
287                center_shift = center_shift + dist;
288            }
289            let k_f = F::from(k).unwrap_or(F::one());
290            center_shift = center_shift / k_f;
291
292            if center_shift < opts.tol {
293                converged = true;
294                break;
295            }
296        }
297
298        prev_centers = Some(centroids.clone());
299
300        // Early stopping
301        if no_improvement_count >= opts.max_no_improvement {
302            converged = true;
303            break;
304        }
305    }
306
307    // Final label assignment
308    let (final_labels, final_distances) = assign_labels(data, centroids.view())?;
309
310    // Compute final inertia
311    let final_inertia = final_distances
312        .iter()
313        .fold(F::zero(), |acc, &d| acc + d * d);
314
315    // Compute per-cluster counts
316    let mut cluster_counts = Array1::<usize>::zeros(k);
317    for &label in final_labels.iter() {
318        if label < k {
319            cluster_counts[label] += 1;
320        }
321    }
322
323    Ok(MiniBatchKMeansResult {
324        centroids,
325        labels: final_labels,
326        n_iter,
327        inertia: final_inertia,
328        converged,
329        inertia_history,
330        cluster_counts,
331        n_reassignments: total_reassignments,
332    })
333}
334
335/// Result of a single mini-batch step
336struct MiniBatchStepResult<F: Float> {
337    /// Inertia of the mini-batch (average squared distance)
338    batch_inertia: F,
339    /// Number of reassignments performed
340    n_reassignments: usize,
341}
342
343/// Performs a single Mini-Batch K-means step
344fn mini_batch_step<F>(
345    data: &ArrayView2<F>,
346    batch_indices: &[usize],
347    centroids: &mut Array2<F>,
348    counts: &mut Array1<F>,
349    opts: &MiniBatchKMeansOptions<F>,
350) -> Result<MiniBatchStepResult<F>>
351where
352    F: Float + FromPrimitive + Debug,
353{
354    let k = centroids.shape()[0];
355    let n_features = centroids.shape()[1];
356    let batch_size = batch_indices.len();
357
358    let mut closest_distances = Array1::from_elem(batch_size, F::infinity());
359    let mut closest_centers = Array1::<usize>::zeros(batch_size);
360    let mut inertia = F::zero();
361
362    // Assignment: find nearest centroid for each batch sample
363    for (i, &sample_idx) in batch_indices.iter().enumerate() {
364        let sample = data.slice(s![sample_idx, ..]);
365
366        let mut min_dist = F::infinity();
367        let mut min_idx = 0;
368
369        for j in 0..k {
370            let dist = euclidean_distance(sample, centroids.slice(s![j, ..]));
371            if dist < min_dist {
372                min_dist = dist;
373                min_idx = j;
374            }
375        }
376
377        closest_centers[i] = min_idx;
378        closest_distances[i] = min_dist;
379        inertia = inertia + min_dist * min_dist;
380    }
381
382    // Update centroids using per-center learning rate
383    for i in 0..batch_size {
384        let center_idx = closest_centers[i];
385        let sample_idx = batch_indices[i];
386        let sample = data.slice(s![sample_idx, ..]);
387
388        let count = counts[center_idx];
389        // Learning rate decreases as count increases: eta = 1 / (count + 1)
390        let learning_rate = F::one() / (count + F::one());
391
392        for j in 0..n_features {
393            centroids[[center_idx, j]] =
394                centroids[[center_idx, j]] * (F::one() - learning_rate) + sample[j] * learning_rate;
395        }
396
397        counts[center_idx] = count + F::one();
398    }
399
400    // Handle near-empty clusters via reassignment
401    let mut n_reassignments = 0;
402    let max_count = counts.fold(F::zero(), |a, &b| a.max(b));
403    let reassign_threshold = max_count * opts.reassignment_ratio;
404
405    for c in 0..k {
406        if counts[c] < reassign_threshold {
407            // Find the batch point furthest from its assigned centroid
408            let mut max_dist = F::zero();
409            let mut max_idx = 0;
410
411            for j in 0..batch_size {
412                if closest_distances[j] > max_dist {
413                    max_dist = closest_distances[j];
414                    max_idx = j;
415                }
416            }
417
418            if max_dist > F::zero() {
419                let sample_idx = batch_indices[max_idx];
420                let sample = data.slice(s![sample_idx, ..]);
421
422                for j in 0..n_features {
423                    centroids[[c, j]] = sample[j];
424                }
425
426                counts[c] = counts[c].max(F::one());
427                closest_centers[max_idx] = c;
428                closest_distances[max_idx] = F::zero();
429                n_reassignments += 1;
430            }
431        }
432    }
433
434    // Normalize inertia by batch size
435    let batch_f = F::from(batch_size).unwrap_or(F::one());
436    inertia = inertia / batch_f;
437
438    Ok(MiniBatchStepResult {
439        batch_inertia: inertia,
440        n_reassignments,
441    })
442}
443
444/// Assigns each sample in the dataset to its closest centroid
445fn assign_labels<F>(
446    data: ArrayView2<F>,
447    centroids: ArrayView2<F>,
448) -> Result<(Array1<usize>, Array1<F>)>
449where
450    F: Float + FromPrimitive + Debug,
451{
452    let n_samples = data.shape()[0];
453    let n_clusters = centroids.shape()[0];
454
455    let mut labels = Array1::<usize>::zeros(n_samples);
456    let mut distances = Array1::<F>::zeros(n_samples);
457
458    for i in 0..n_samples {
459        let sample = data.slice(s![i, ..]);
460        let mut min_dist = F::infinity();
461        let mut min_idx = 0;
462
463        for j in 0..n_clusters {
464            let centroid = centroids.slice(s![j, ..]);
465            let dist = euclidean_distance(sample, centroid);
466
467            if dist < min_dist {
468                min_dist = dist;
469                min_idx = j;
470            }
471        }
472
473        labels[i] = min_idx;
474        distances[i] = min_dist;
475    }
476
477    Ok((labels, distances))
478}
479
480#[cfg(test)]
481mod tests {
482    use super::*;
483    use scirs2_core::ndarray::Array2;
484
485    fn make_two_cluster_data() -> Array2<f64> {
486        Array2::from_shape_vec(
487            (6, 2),
488            vec![1.0, 2.0, 1.2, 1.8, 0.8, 1.9, 4.0, 5.0, 4.2, 4.8, 3.9, 5.1],
489        )
490        .expect("Failed to create test data")
491    }
492
493    #[test]
494    fn test_minibatch_kmeans_simple() {
495        let data = make_two_cluster_data();
496
497        let options = MiniBatchKMeansOptions {
498            max_iter: 10,
499            batch_size: 3,
500            random_seed: Some(42),
501            ..Default::default()
502        };
503
504        let (centroids, labels) =
505            minibatch_kmeans(data.view(), 2, Some(options)).expect("Should succeed");
506
507        assert_eq!(centroids.shape(), &[2, 2]);
508        assert_eq!(labels.shape(), &[6]);
509
510        let unique_labels: std::collections::HashSet<_> = labels.iter().copied().collect();
511        assert_eq!(unique_labels.len(), 2);
512
513        // First 3 points should share a label, last 3 should share another
514        let first_label = labels[0];
515        assert_eq!(labels[1], first_label);
516        assert_eq!(labels[2], first_label);
517
518        let second_label = labels[3];
519        assert_eq!(labels[4], second_label);
520        assert_eq!(labels[5], second_label);
521    }
522
523    #[test]
524    fn test_minibatch_kmeans_full_diagnostics() {
525        let data = make_two_cluster_data();
526
527        let options = MiniBatchKMeansOptions {
528            max_iter: 50,
529            batch_size: 4,
530            random_seed: Some(42),
531            ..Default::default()
532        };
533
534        let result = minibatch_kmeans_full(data.view(), 2, Some(options)).expect("Should succeed");
535
536        assert_eq!(result.centroids.shape(), &[2, 2]);
537        assert_eq!(result.labels.shape(), &[6]);
538        assert!(result.n_iter > 0);
539        assert!(result.inertia >= 0.0);
540        assert!(!result.inertia_history.is_empty());
541
542        // Every cluster should have at least one point
543        for &count in result.cluster_counts.iter() {
544            assert!(count > 0, "Each cluster should have assigned points");
545        }
546    }
547
548    #[test]
549    fn test_minibatch_kmeans_convergence() {
550        let data = make_two_cluster_data();
551
552        let options = MiniBatchKMeansOptions {
553            max_iter: 1000,
554            batch_size: 6, // Full batch
555            random_seed: Some(42),
556            tol: 1e-6,
557            max_no_improvement: 20,
558            ..Default::default()
559        };
560
561        let result = minibatch_kmeans_full(data.view(), 2, Some(options)).expect("Should succeed");
562
563        // Should converge before max_iter
564        assert!(
565            result.n_iter < 1000,
566            "Should converge before max_iter, took {} iters",
567            result.n_iter
568        );
569    }
570
571    #[test]
572    fn test_minibatch_kmeans_empty_clusters() {
573        let data = Array2::from_shape_vec(
574            (8, 2),
575            vec![
576                1.0, 1.0, 1.1, 1.1, 1.2, 1.0, 1.0, 1.2, 5.0, 5.0, 5.1, 5.1, 5.2, 5.0, 5.0, 5.2,
577            ],
578        )
579        .expect("Failed to create data");
580
581        let options = MiniBatchKMeansOptions {
582            max_iter: 20,
583            batch_size: 4,
584            random_seed: Some(42),
585            reassignment_ratio: 0.1,
586            ..Default::default()
587        };
588
589        let (centroids, labels) =
590            minibatch_kmeans(data.view(), 3, Some(options)).expect("Should succeed");
591
592        assert_eq!(centroids.shape(), &[3, 2]);
593        assert_eq!(labels.shape(), &[8]);
594
595        let unique_labels: std::collections::HashSet<_> = labels.iter().copied().collect();
596        assert!(unique_labels.len() <= 3);
597    }
598
599    #[test]
600    fn test_minibatch_kmeans_random_init() {
601        let data = make_two_cluster_data();
602
603        let options = MiniBatchKMeansOptions {
604            init: MiniBatchInit::Random,
605            random_seed: Some(42),
606            max_iter: 50,
607            batch_size: 4,
608            ..Default::default()
609        };
610
611        let (centroids, labels) =
612            minibatch_kmeans(data.view(), 2, Some(options)).expect("Should succeed");
613
614        assert_eq!(centroids.shape(), &[2, 2]);
615        assert_eq!(labels.shape(), &[6]);
616    }
617
618    #[test]
619    fn test_minibatch_kmeans_inertia_decreases() {
620        let data = make_two_cluster_data();
621
622        let options = MiniBatchKMeansOptions {
623            max_iter: 50,
624            batch_size: 6,
625            random_seed: Some(42),
626            ewa_smoothing: 0.5,
627            ..Default::default()
628        };
629
630        let result = minibatch_kmeans_full(data.view(), 2, Some(options)).expect("Should succeed");
631
632        // Overall trend of inertia should be decreasing
633        if result.inertia_history.len() >= 3 {
634            let first_few: f64 = result.inertia_history[..3].iter().copied().sum::<f64>() / 3.0;
635            let last_few: f64 = result.inertia_history[result.inertia_history.len() - 3..]
636                .iter()
637                .copied()
638                .sum::<f64>()
639                / 3.0;
640
641            assert!(
642                last_few <= first_few + 1.0,
643                "Inertia should generally decrease: first_avg={}, last_avg={}",
644                first_few,
645                last_few
646            );
647        }
648    }
649
650    #[test]
651    fn test_minibatch_kmeans_invalid_inputs() {
652        let data = make_two_cluster_data();
653
654        // k = 0
655        let result = minibatch_kmeans(data.view(), 0, None);
656        assert!(result.is_err());
657
658        // k > n_samples
659        let result = minibatch_kmeans(data.view(), 100, None);
660        assert!(result.is_err());
661
662        // Empty data
663        let empty = Array2::<f64>::zeros((0, 2));
664        let result = minibatch_kmeans(empty.view(), 2, None);
665        assert!(result.is_err());
666    }
667
668    #[test]
669    fn test_minibatch_kmeans_k_equals_n() {
670        let data = make_two_cluster_data();
671
672        let options = MiniBatchKMeansOptions {
673            random_seed: Some(42),
674            max_iter: 10,
675            ..Default::default()
676        };
677
678        let (centroids, labels) =
679            minibatch_kmeans(data.view(), 6, Some(options)).expect("Should succeed");
680
681        assert_eq!(centroids.shape(), &[6, 2]);
682        assert_eq!(labels.shape(), &[6]);
683
684        // Each point should be in its own cluster
685        let unique_labels: std::collections::HashSet<_> = labels.iter().copied().collect();
686        assert_eq!(unique_labels.len(), 6);
687    }
688}