1use scirs2_core::ndarray::{s, Array1, Array2, ArrayView2};
11use scirs2_core::numeric::{Float, FromPrimitive};
12use scirs2_core::random::{Rng, SeedableRng};
13use std::fmt::Debug;
14
15use super::{euclidean_distance, kmeans_plus_plus};
16use crate::error::{ClusteringError, Result};
17
18#[derive(Debug, Clone)]
20pub struct MiniBatchKMeansOptions<F: Float> {
21 pub max_iter: usize,
23 pub batch_size: usize,
25 pub tol: F,
27 pub random_seed: Option<u64>,
29 pub max_no_improvement: usize,
31 pub init_size: Option<usize>,
33 pub reassignment_ratio: F,
35}
36
37impl<F: Float + FromPrimitive> Default for MiniBatchKMeansOptions<F> {
38 fn default() -> Self {
39 Self {
40 max_iter: 100,
41 batch_size: 1024,
42 tol: F::from(1e-4).unwrap(),
43 random_seed: None,
44 max_no_improvement: 10,
45 init_size: None,
46 reassignment_ratio: F::from(0.01).unwrap(),
47 }
48 }
49}
50
51#[allow(dead_code)]
83pub fn minibatch_kmeans<F>(
84 data: ArrayView2<F>,
85 k: usize,
86 options: Option<MiniBatchKMeansOptions<F>>,
87) -> Result<(Array2<F>, Array1<usize>)>
88where
89 F: Float + FromPrimitive + Debug + std::iter::Sum,
90{
91 if k == 0 {
93 return Err(ClusteringError::InvalidInput(
94 "Number of clusters must be greater than 0".to_string(),
95 ));
96 }
97
98 let n_samples = data.shape()[0];
99 let n_features = data.shape()[1];
100
101 if n_samples == 0 {
102 return Err(ClusteringError::InvalidInput(
103 "Input data is empty".to_string(),
104 ));
105 }
106
107 if k > n_samples {
108 return Err(ClusteringError::InvalidInput(format!(
109 "Number of clusters ({}) cannot be greater than number of data points ({})",
110 k, n_samples
111 )));
112 }
113
114 let opts = options.unwrap_or_default();
115
116 let mut rng = match opts.random_seed {
118 Some(seed) => scirs2_core::random::rngs::StdRng::seed_from_u64(seed),
119 None => {
120 scirs2_core::random::rngs::StdRng::seed_from_u64(scirs2_core::random::rng().random())
121 }
122 };
123
124 let init_size = opts.init_size.unwrap_or_else(|| {
126 let default_size = 3 * opts.batch_size;
127 if default_size < 3 * k {
128 default_size
129 } else {
130 3 * k
131 }
132 });
133
134 let init_size = init_size.min(n_samples);
135
136 let centroids = if init_size < n_samples {
138 let mut indices = Vec::with_capacity(init_size);
140 for _ in 0..init_size {
141 indices.push(rng.random_range(0..n_samples));
142 }
143
144 let init_data =
145 Array2::from_shape_fn((init_size, n_features), |(i, j)| data[[indices[i], j]]);
146 kmeans_plus_plus(init_data.view(), k, opts.random_seed)?
147 } else {
148 kmeans_plus_plus(data, k, opts.random_seed)?
150 };
151
152 let mut centroids = centroids;
154 let mut counts = Array1::ones(k); let mut ewa_inertia = None; let mut no_improvement_count = 0;
159 let mut best_inertia = F::infinity();
160 let mut prev_centers: Option<Array2<F>> = None;
161
162 for iter in 0..opts.max_iter {
164 let batch_size = opts.batch_size.min(n_samples);
166 let mut batch_indices = Vec::with_capacity(batch_size);
167 for _ in 0..batch_size {
168 batch_indices.push(rng.random_range(0..n_samples));
169 }
170
171 let (batch_inertia, has_converged) =
173 mini_batch_step(&data, &batch_indices, &mut centroids, &mut counts, &opts)?;
174
175 if iter == opts.max_iter - 1 {
178 let (_new_labels_) = assign_labels(data, centroids.view())?;
180 }
182
183 let ewa_factor = F::from(0.7).unwrap(); let current_ewa = match ewa_inertia {
186 Some(prev_ewa) => prev_ewa * ewa_factor + batch_inertia * (F::one() - ewa_factor),
187 None => batch_inertia,
188 };
189 ewa_inertia = Some(current_ewa);
190
191 if current_ewa < best_inertia {
193 best_inertia = current_ewa;
194 no_improvement_count = 0;
195 } else {
196 no_improvement_count += 1;
197 }
198
199 if let Some(prev) = prev_centers {
201 let mut center_shift = F::zero();
202 for i in 0..k {
203 let dist = euclidean_distance(centroids.slice(s![i, ..]), prev.slice(s![i, ..]));
204 center_shift = center_shift + dist;
205 }
206
207 center_shift = center_shift / F::from(k).unwrap();
209
210 if center_shift < opts.tol {
211 break;
213 }
214 }
215
216 prev_centers = Some(centroids.clone());
218
219 if no_improvement_count >= opts.max_no_improvement {
221 break;
222 }
223
224 if has_converged {
226 break;
227 }
228 }
229
230 let (final_labels, _) = assign_labels(data, centroids.view())?;
232
233 Ok((centroids, final_labels))
234}
235
236#[allow(dead_code)]
250fn mini_batch_step<F>(
251 data: &ArrayView2<F>,
252 batch_indices: &[usize],
253 centroids: &mut Array2<F>,
254 counts: &mut Array1<F>,
255 opts: &MiniBatchKMeansOptions<F>,
256) -> Result<(F, bool)>
257where
258 F: Float + FromPrimitive + Debug,
259{
260 let k = centroids.shape()[0];
261 let n_features = centroids.shape()[1];
262 let batch_size = batch_indices.len();
263
264 let mut closest_distances = Array1::from_elem(batch_size, F::infinity());
266 let mut closest_centers = Array1::zeros(batch_size);
267 let mut inertia = F::zero();
268
269 for (i, &sample_idx) in batch_indices.iter().enumerate() {
271 let sample = data.slice(s![sample_idx, ..]);
272
273 let mut min_dist = F::infinity();
275 let mut min_idx = 0;
276
277 for j in 0..k {
278 let dist = euclidean_distance(sample, centroids.slice(s![j, ..]));
279 if dist < min_dist {
280 min_dist = dist;
281 min_idx = j;
282 }
283 }
284
285 closest_centers[i] = min_idx;
286 closest_distances[i] = min_dist;
287 inertia = inertia + min_dist * min_dist;
288 }
289
290 for i in 0..batch_size {
292 let center_idx = closest_centers[i];
293 let sample_idx = batch_indices[i];
294 let sample = data.slice(s![sample_idx, ..]);
295
296 let count = counts[center_idx];
298 let learning_rate = F::one() / (count + F::one()); for j in 0..n_features {
301 centroids[[center_idx, j]] =
302 centroids[[center_idx, j]] * (F::one() - learning_rate) + sample[j] * learning_rate;
303 }
304
305 counts[center_idx] = count + F::one();
306 }
307
308 let mut has_empty = false;
310 let max_count = counts.fold(F::zero(), |a, &b| a.max(b));
311 let reassign_threshold = max_count * opts.reassignment_ratio;
312
313 for i in 0..k {
314 if counts[i] < reassign_threshold {
315 has_empty = true;
316
317 let mut max_dist = F::zero();
319 let mut max_idx = 0;
320
321 for j in 0..batch_size {
322 if closest_distances[j] > max_dist {
323 max_dist = closest_distances[j];
324 max_idx = j;
325 }
326 }
327
328 if max_dist > F::zero() {
330 let sample_idx = batch_indices[max_idx];
331 let sample = data.slice(s![sample_idx, ..]);
332
333 for j in 0..n_features {
334 centroids[[i, j]] = sample[j];
335 }
336
337 counts[i] = counts[i].max(F::from(1.0).unwrap());
339
340 closest_centers[max_idx] = i;
342 closest_distances[max_idx] = F::zero();
343 }
344 }
345 }
346
347 inertia = inertia / F::from(batch_size).unwrap();
349
350 let has_converged = !has_empty && inertia < opts.tol;
352
353 Ok((inertia, has_converged))
354}
355
356#[allow(dead_code)]
367fn assign_labels<F>(
368 data: ArrayView2<F>,
369 centroids: ArrayView2<F>,
370) -> Result<(Array1<usize>, Array1<F>)>
371where
372 F: Float + FromPrimitive + Debug,
373{
374 let n_samples = data.shape()[0];
375 let n_clusters = centroids.shape()[0];
376
377 let mut labels = Array1::zeros(n_samples);
378 let mut distances = Array1::zeros(n_samples);
379
380 for i in 0..n_samples {
381 let sample = data.slice(s![i, ..]);
382 let mut min_dist = F::infinity();
383 let mut min_idx = 0;
384
385 for j in 0..n_clusters {
386 let centroid = centroids.slice(s![j, ..]);
387 let dist = euclidean_distance(sample, centroid);
388
389 if dist < min_dist {
390 min_dist = dist;
391 min_idx = j;
392 }
393 }
394
395 labels[i] = min_idx;
396 distances[i] = min_dist;
397 }
398
399 Ok((labels, distances))
400}
401
402#[cfg(test)]
403mod tests {
404 use super::*;
405 use scirs2_core::ndarray::Array2;
406
407 #[test]
408 fn test_minibatch_kmeans_simple() {
409 let data = Array2::from_shape_vec(
411 (6, 2),
412 vec![1.0, 2.0, 1.2, 1.8, 0.8, 1.9, 4.0, 5.0, 4.2, 4.8, 3.9, 5.1],
413 )
414 .unwrap();
415
416 let options = MiniBatchKMeansOptions {
418 max_iter: 10,
419 batch_size: 3,
420 random_seed: Some(42), ..Default::default()
422 };
423
424 let (centroids, labels) = minibatch_kmeans(data.view(), 2, Some(options)).unwrap();
425
426 assert_eq!(centroids.shape(), &[2, 2]);
428 assert_eq!(labels.shape(), &[6]);
429
430 let unique_labels: Vec<_> = labels
432 .iter()
433 .copied()
434 .collect::<std::collections::HashSet<_>>()
435 .into_iter()
436 .collect();
437 assert_eq!(unique_labels.len(), 2);
438
439 let first_label = labels[0];
441 assert_eq!(labels[1], first_label);
442 assert_eq!(labels[2], first_label);
443
444 let second_label = labels[3];
445 assert_eq!(labels[4], second_label);
446 assert_eq!(labels[5], second_label);
447
448 let cluster1_idx = if first_label == 0 { 0 } else { 1 };
450 assert!((centroids[[cluster1_idx, 0]] - 1.0).abs() < 0.5);
451 assert!((centroids[[cluster1_idx, 1]] - 2.0).abs() < 0.5);
452
453 let cluster2_idx = if first_label == 0 { 1 } else { 0 };
455 assert!((centroids[[cluster2_idx, 0]] - 4.0).abs() < 0.5);
456 assert!((centroids[[cluster2_idx, 1]] - 5.0).abs() < 0.5);
457 }
458
459 #[test]
460 fn test_minibatch_kmeans_empty_clusters() {
461 let data = Array2::from_shape_vec(
463 (8, 2),
464 vec![
465 1.0, 1.0, 1.1, 1.1, 1.2, 1.0, 1.0, 1.2, 5.0, 5.0, 5.1, 5.1, 5.2, 5.0, 5.0, 5.2,
466 ],
467 )
468 .unwrap();
469
470 let options = MiniBatchKMeansOptions {
472 max_iter: 20,
473 batch_size: 4,
474 random_seed: Some(42), reassignment_ratio: 0.1, ..Default::default()
477 };
478
479 let (centroids, labels) = minibatch_kmeans(data.view(), 3, Some(options)).unwrap();
480
481 assert_eq!(centroids.shape(), &[3, 2]);
483 assert_eq!(labels.shape(), &[8]);
484
485 let unique_labels: Vec<_> = labels
487 .iter()
488 .copied()
489 .collect::<std::collections::HashSet<_>>()
490 .into_iter()
491 .collect();
492 assert!(unique_labels.len() <= 3);
493
494 let mut centroid_counts = [0; 3];
496 for &label in labels.iter() {
497 centroid_counts[label] += 1;
498 }
499
500 for &count in centroid_counts.iter() {
503 assert!(count > 0);
504 }
505 }
506}