scirs2_cluster/vq/
minibatch_kmeans.rs

1//! Mini-Batch K-means clustering implementation
2//!
3//! This module provides an implementation of the Mini-Batch K-means algorithm,
4//! a variant of k-means that uses mini-batches to reduce computation time while
5//! still attempting to optimize the same objective function.
6//!
7//! Mini-Batch K-means is much faster than standard K-means for large datasets
8//! and provides results that are generally close to those of the standard algorithm.
9
10use scirs2_core::ndarray::{s, Array1, Array2, ArrayView2};
11use scirs2_core::numeric::{Float, FromPrimitive};
12use scirs2_core::random::{Rng, SeedableRng};
13use std::fmt::Debug;
14
15use super::{euclidean_distance, kmeans_plus_plus};
16use crate::error::{ClusteringError, Result};
17
18/// Options for Mini-Batch K-means clustering
19#[derive(Debug, Clone)]
20pub struct MiniBatchKMeansOptions<F: Float> {
21    /// Maximum number of iterations
22    pub max_iter: usize,
23    /// Size of mini-batches
24    pub batch_size: usize,
25    /// Convergence threshold for centroid movement
26    pub tol: F,
27    /// Random seed for initialization and batch sampling
28    pub random_seed: Option<u64>,
29    /// Number of iterations without improvement before stopping
30    pub max_no_improvement: usize,
31    /// Number of samples to use for initialization
32    pub init_size: Option<usize>,
33    /// Ratio of samples that should be reassigned to prevent empty clusters
34    pub reassignment_ratio: F,
35}
36
37impl<F: Float + FromPrimitive> Default for MiniBatchKMeansOptions<F> {
38    fn default() -> Self {
39        Self {
40            max_iter: 100,
41            batch_size: 1024,
42            tol: F::from(1e-4).unwrap(),
43            random_seed: None,
44            max_no_improvement: 10,
45            init_size: None,
46            reassignment_ratio: F::from(0.01).unwrap(),
47        }
48    }
49}
50
51/// Mini-Batch K-means clustering algorithm
52///
53/// # Arguments
54///
55/// * `data` - Input data (n_samples × n_features)
56/// * `k` - Number of clusters
57/// * `options` - Optional parameters
58///
59/// # Returns
60///
61/// * Tuple of (centroids, labels) where:
62///   - centroids: Array of shape (k × n_features)
63///   - labels: Array of shape (n_samples,) with cluster assignments
64///
65/// # Examples
66///
67/// ```
68/// use scirs2_core::ndarray::{Array2, ArrayView2};
69/// use scirs2_cluster::vq::minibatch_kmeans;
70///
71/// let data = Array2::from_shape_vec((6, 2), vec![
72///     1.0, 2.0,
73///     1.2, 1.8,
74///     0.8, 1.9,
75///     3.7, 4.2,
76///     3.9, 3.9,
77///     4.2, 4.1,
78/// ]).unwrap();
79///
80/// let (centroids, labels) = minibatch_kmeans(ArrayView2::from(&data), 2, None).unwrap();
81/// ```
82#[allow(dead_code)]
83pub fn minibatch_kmeans<F>(
84    data: ArrayView2<F>,
85    k: usize,
86    options: Option<MiniBatchKMeansOptions<F>>,
87) -> Result<(Array2<F>, Array1<usize>)>
88where
89    F: Float + FromPrimitive + Debug + std::iter::Sum,
90{
91    // Input validation
92    if k == 0 {
93        return Err(ClusteringError::InvalidInput(
94            "Number of clusters must be greater than 0".to_string(),
95        ));
96    }
97
98    let n_samples = data.shape()[0];
99    let n_features = data.shape()[1];
100
101    if n_samples == 0 {
102        return Err(ClusteringError::InvalidInput(
103            "Input data is empty".to_string(),
104        ));
105    }
106
107    if k > n_samples {
108        return Err(ClusteringError::InvalidInput(format!(
109            "Number of clusters ({}) cannot be greater than number of data points ({})",
110            k, n_samples
111        )));
112    }
113
114    let opts = options.unwrap_or_default();
115
116    // Setup RNG
117    let mut rng = match opts.random_seed {
118        Some(seed) => scirs2_core::random::rngs::StdRng::seed_from_u64(seed),
119        None => {
120            scirs2_core::random::rngs::StdRng::seed_from_u64(scirs2_core::random::rng().random())
121        }
122    };
123
124    // Determine initialization size
125    let init_size = opts.init_size.unwrap_or_else(|| {
126        let default_size = 3 * opts.batch_size;
127        if default_size < 3 * k {
128            default_size
129        } else {
130            3 * k
131        }
132    });
133
134    let init_size = init_size.min(n_samples);
135
136    // Initialize centroids using kmeans++
137    let centroids = if init_size < n_samples {
138        // Sample init_size data points for initialization (simpler method for this example)
139        let mut indices = Vec::with_capacity(init_size);
140        for _ in 0..init_size {
141            indices.push(rng.random_range(0..n_samples));
142        }
143
144        let init_data =
145            Array2::from_shape_fn((init_size, n_features), |(i, j)| data[[indices[i], j]]);
146        kmeans_plus_plus(init_data.view(), k, opts.random_seed)?
147    } else {
148        // Use all data points for initialization
149        kmeans_plus_plus(data, k, opts.random_seed)?
150    };
151
152    // Initialize variables for optimization
153    let mut centroids = centroids;
154    let mut counts = Array1::ones(k); // Initialize counts to avoid division by zero
155
156    // Variables for convergence detection
157    let mut ewa_inertia = None; // Exponentially weighted average of inertia
158    let mut no_improvement_count = 0;
159    let mut best_inertia = F::infinity();
160    let mut prev_centers: Option<Array2<F>> = None;
161
162    // Mini-batch optimization
163    for iter in 0..opts.max_iter {
164        // Sample a mini-batch
165        let batch_size = opts.batch_size.min(n_samples);
166        let mut batch_indices = Vec::with_capacity(batch_size);
167        for _ in 0..batch_size {
168            batch_indices.push(rng.random_range(0..n_samples));
169        }
170
171        // Perform mini-batch step
172        let (batch_inertia, has_converged) =
173            mini_batch_step(&data, &batch_indices, &mut centroids, &mut counts, &opts)?;
174
175        // If this is the last iteration, assign all points to clusters for final labeling
176        // We don't need to do this on every iteration, just for the final result
177        if iter == opts.max_iter - 1 {
178            // This will be used only for the final return value
179            let (_new_labels_) = assign_labels(data, centroids.view())?;
180            // We don't store this since we'll recompute it at the end anyway
181        }
182
183        // Update exponentially weighted average of inertia
184        let ewa_factor = F::from(0.7).unwrap(); // Smoothing factor for EWA
185        let current_ewa = match ewa_inertia {
186            Some(prev_ewa) => prev_ewa * ewa_factor + batch_inertia * (F::one() - ewa_factor),
187            None => batch_inertia,
188        };
189        ewa_inertia = Some(current_ewa);
190
191        // Check for convergence based on inertia
192        if current_ewa < best_inertia {
193            best_inertia = current_ewa;
194            no_improvement_count = 0;
195        } else {
196            no_improvement_count += 1;
197        }
198
199        // Check for convergence based on centroid movement
200        if let Some(prev) = prev_centers {
201            let mut center_shift = F::zero();
202            for i in 0..k {
203                let dist = euclidean_distance(centroids.slice(s![i, ..]), prev.slice(s![i, ..]));
204                center_shift = center_shift + dist;
205            }
206
207            // Normalize by number of centroids and features
208            center_shift = center_shift / F::from(k).unwrap();
209
210            if center_shift < opts.tol {
211                // Converged based on centroid movement
212                break;
213            }
214        }
215
216        // Store current centroids for next iteration
217        prev_centers = Some(centroids.clone());
218
219        // Check for early stopping
220        if no_improvement_count >= opts.max_no_improvement {
221            break;
222        }
223
224        // If convergence detected in mini-batch step
225        if has_converged {
226            break;
227        }
228    }
229
230    // Final label assignment
231    let (final_labels, _) = assign_labels(data, centroids.view())?;
232
233    Ok((centroids, final_labels))
234}
235
236/// Performs a single Mini-Batch K-means step
237///
238/// # Arguments
239///
240/// * `data` - Input data
241/// * `batch_indices` - Indices of samples in the current mini-batch
242/// * `centroids` - Current centroids (modified in-place)
243/// * `counts` - Counts of samples assigned to each centroid (modified in-place)
244/// * `opts` - Algorithm options
245///
246/// # Returns
247///
248/// * Tuple of (batch_inertia, has_converged)
249#[allow(dead_code)]
250fn mini_batch_step<F>(
251    data: &ArrayView2<F>,
252    batch_indices: &[usize],
253    centroids: &mut Array2<F>,
254    counts: &mut Array1<F>,
255    opts: &MiniBatchKMeansOptions<F>,
256) -> Result<(F, bool)>
257where
258    F: Float + FromPrimitive + Debug,
259{
260    let k = centroids.shape()[0];
261    let n_features = centroids.shape()[1];
262    let batch_size = batch_indices.len();
263
264    // Initialize mini-batch specific variables
265    let mut closest_distances = Array1::from_elem(batch_size, F::infinity());
266    let mut closest_centers = Array1::zeros(batch_size);
267    let mut inertia = F::zero();
268
269    // Assign samples to closest centroids
270    for (i, &sample_idx) in batch_indices.iter().enumerate() {
271        let sample = data.slice(s![sample_idx, ..]);
272
273        // Find closest centroid
274        let mut min_dist = F::infinity();
275        let mut min_idx = 0;
276
277        for j in 0..k {
278            let dist = euclidean_distance(sample, centroids.slice(s![j, ..]));
279            if dist < min_dist {
280                min_dist = dist;
281                min_idx = j;
282            }
283        }
284
285        closest_centers[i] = min_idx;
286        closest_distances[i] = min_dist;
287        inertia = inertia + min_dist * min_dist;
288    }
289
290    // Update centroids based on mini-batch assignments
291    for i in 0..batch_size {
292        let center_idx = closest_centers[i];
293        let sample_idx = batch_indices[i];
294        let sample = data.slice(s![sample_idx, ..]);
295
296        // Incremental update of centroid
297        let count = counts[center_idx];
298        let learning_rate = F::one() / (count + F::one()); // Decrease learning rate as count increases
299
300        for j in 0..n_features {
301            centroids[[center_idx, j]] =
302                centroids[[center_idx, j]] * (F::one() - learning_rate) + sample[j] * learning_rate;
303        }
304
305        counts[center_idx] = count + F::one();
306    }
307
308    // Handle reassignment of small or empty clusters
309    let mut has_empty = false;
310    let max_count = counts.fold(F::zero(), |a, &b| a.max(b));
311    let reassign_threshold = max_count * opts.reassignment_ratio;
312
313    for i in 0..k {
314        if counts[i] < reassign_threshold {
315            has_empty = true;
316
317            // Find the point furthest from its centroid in this batch
318            let mut max_dist = F::zero();
319            let mut max_idx = 0;
320
321            for j in 0..batch_size {
322                if closest_distances[j] > max_dist {
323                    max_dist = closest_distances[j];
324                    max_idx = j;
325                }
326            }
327
328            // Reassign this centroid to the furthest point
329            if max_dist > F::zero() {
330                let sample_idx = batch_indices[max_idx];
331                let sample = data.slice(s![sample_idx, ..]);
332
333                for j in 0..n_features {
334                    centroids[[i, j]] = sample[j];
335                }
336
337                // Reset count to a small value to prevent immediate reassignment
338                counts[i] = counts[i].max(F::from(1.0).unwrap());
339
340                // Update closest center and distance for this point
341                closest_centers[max_idx] = i;
342                closest_distances[max_idx] = F::zero();
343            }
344        }
345    }
346
347    // Normalize inertia by batch size
348    inertia = inertia / F::from(batch_size).unwrap();
349
350    // Check if we have converged
351    let has_converged = !has_empty && inertia < opts.tol;
352
353    Ok((inertia, has_converged))
354}
355
356/// Assigns each sample in the dataset to its closest centroid
357///
358/// # Arguments
359///
360/// * `data` - Input data
361/// * `centroids` - Current centroids
362///
363/// # Returns
364///
365/// * Tuple of (labels, distances)
366#[allow(dead_code)]
367fn assign_labels<F>(
368    data: ArrayView2<F>,
369    centroids: ArrayView2<F>,
370) -> Result<(Array1<usize>, Array1<F>)>
371where
372    F: Float + FromPrimitive + Debug,
373{
374    let n_samples = data.shape()[0];
375    let n_clusters = centroids.shape()[0];
376
377    let mut labels = Array1::zeros(n_samples);
378    let mut distances = Array1::zeros(n_samples);
379
380    for i in 0..n_samples {
381        let sample = data.slice(s![i, ..]);
382        let mut min_dist = F::infinity();
383        let mut min_idx = 0;
384
385        for j in 0..n_clusters {
386            let centroid = centroids.slice(s![j, ..]);
387            let dist = euclidean_distance(sample, centroid);
388
389            if dist < min_dist {
390                min_dist = dist;
391                min_idx = j;
392            }
393        }
394
395        labels[i] = min_idx;
396        distances[i] = min_dist;
397    }
398
399    Ok((labels, distances))
400}
401
402#[cfg(test)]
403mod tests {
404    use super::*;
405    use scirs2_core::ndarray::Array2;
406
407    #[test]
408    fn test_minibatch_kmeans_simple() {
409        // Create a simple dataset with clear clusters
410        let data = Array2::from_shape_vec(
411            (6, 2),
412            vec![1.0, 2.0, 1.2, 1.8, 0.8, 1.9, 4.0, 5.0, 4.2, 4.8, 3.9, 5.1],
413        )
414        .unwrap();
415
416        // Run mini-batch k-means with k=2
417        let options = MiniBatchKMeansOptions {
418            max_iter: 10,
419            batch_size: 3,
420            random_seed: Some(42), // For reproducibility
421            ..Default::default()
422        };
423
424        let (centroids, labels) = minibatch_kmeans(data.view(), 2, Some(options)).unwrap();
425
426        // Check dimensions
427        assert_eq!(centroids.shape(), &[2, 2]);
428        assert_eq!(labels.shape(), &[6]);
429
430        // Check that we have 2 unique labels
431        let unique_labels: Vec<_> = labels
432            .iter()
433            .copied()
434            .collect::<std::collections::HashSet<_>>()
435            .into_iter()
436            .collect();
437        assert_eq!(unique_labels.len(), 2);
438
439        // Check that the first 3 points are in one cluster and the last 3 in another
440        let first_label = labels[0];
441        assert_eq!(labels[1], first_label);
442        assert_eq!(labels[2], first_label);
443
444        let second_label = labels[3];
445        assert_eq!(labels[4], second_label);
446        assert_eq!(labels[5], second_label);
447
448        // First cluster should be around (1, 2)
449        let cluster1_idx = if first_label == 0 { 0 } else { 1 };
450        assert!((centroids[[cluster1_idx, 0]] - 1.0).abs() < 0.5);
451        assert!((centroids[[cluster1_idx, 1]] - 2.0).abs() < 0.5);
452
453        // Second cluster should be around (4, 5)
454        let cluster2_idx = if first_label == 0 { 1 } else { 0 };
455        assert!((centroids[[cluster2_idx, 0]] - 4.0).abs() < 0.5);
456        assert!((centroids[[cluster2_idx, 1]] - 5.0).abs() < 0.5);
457    }
458
459    #[test]
460    fn test_minibatch_kmeans_empty_clusters() {
461        // Create a dataset where empty clusters could occur
462        let data = Array2::from_shape_vec(
463            (8, 2),
464            vec![
465                1.0, 1.0, 1.1, 1.1, 1.2, 1.0, 1.0, 1.2, 5.0, 5.0, 5.1, 5.1, 5.2, 5.0, 5.0, 5.2,
466            ],
467        )
468        .unwrap();
469
470        // Run mini-batch k-means with k=3 (which would likely lead to an empty cluster)
471        let options = MiniBatchKMeansOptions {
472            max_iter: 20,
473            batch_size: 4,
474            random_seed: Some(42),   // For reproducibility
475            reassignment_ratio: 0.1, // Higher reassignment to test this feature
476            ..Default::default()
477        };
478
479        let (centroids, labels) = minibatch_kmeans(data.view(), 3, Some(options)).unwrap();
480
481        // Check dimensions
482        assert_eq!(centroids.shape(), &[3, 2]);
483        assert_eq!(labels.shape(), &[8]);
484
485        // We should have at most 3 clusters
486        let unique_labels: Vec<_> = labels
487            .iter()
488            .copied()
489            .collect::<std::collections::HashSet<_>>()
490            .into_iter()
491            .collect();
492        assert!(unique_labels.len() <= 3);
493
494        // Every centroid should have at least one point assigned to it
495        let mut centroid_counts = [0; 3];
496        for &label in labels.iter() {
497            centroid_counts[label] += 1;
498        }
499
500        // We might not have all 3 clusters used due to reassignment
501        // but there should be no empty clusters in the output
502        for &count in centroid_counts.iter() {
503            assert!(count > 0);
504        }
505    }
506}