1use scirs2_core::ndarray::{s, Array1, Array2, ArrayView1, ArrayView2};
4use scirs2_core::numeric::{Float, FromPrimitive};
5use scirs2_core::random::Rng;
6use std::fmt::Debug;
7
8use super::{euclidean_distance, vq};
9use crate::error::{ClusteringError, Result};
10#[derive(Debug, Clone)]
16pub struct KMeansOptions<F: Float> {
17 pub max_iter: usize,
19 pub tol: F,
21 pub random_seed: Option<u64>,
23 pub n_init: usize,
25 pub init_method: KMeansInit,
27}
28
29impl<F: Float + FromPrimitive> Default for KMeansOptions<F> {
30 fn default() -> Self {
31 Self {
32 max_iter: 300,
33 tol: F::from(1e-4).unwrap(),
34 random_seed: None,
35 n_init: 10,
36 init_method: KMeansInit::KMeansPlusPlus,
37 }
38 }
39}
40
41#[allow(clippy::too_many_arguments)]
76#[allow(dead_code)]
77pub fn kmeans<F>(
78 obs: ArrayView2<F>,
79 k_or_guess: usize,
80 iter: Option<usize>,
81 thresh: Option<F>,
82 check_finite: Option<bool>,
83 seed: Option<u64>,
84) -> Result<(Array2<F>, F)>
85where
86 F: Float + FromPrimitive + Debug + std::iter::Sum + std::fmt::Display,
87{
88 let k = k_or_guess; let max_iter = iter.unwrap_or(20);
90 let tol = thresh.unwrap_or(F::from(1e-5).unwrap());
91 let _check_finite_flag = check_finite.unwrap_or(true);
92
93 if obs.is_empty() {
95 return Err(ClusteringError::InvalidInput(
96 "Input data is empty".to_string(),
97 ));
98 }
99 if k == 0 {
100 return Err(ClusteringError::InvalidInput(
101 "Number of clusters must be greater than 0".to_string(),
102 ));
103 }
104 if k > obs.nrows() {
105 return Err(ClusteringError::InvalidInput(format!(
106 "Number of clusters ({}) cannot be greater than number of data points ({})",
107 k,
108 obs.nrows()
109 )));
110 }
111
112 let options = KMeansOptions {
114 max_iter,
115 tol,
116 random_seed: seed,
117 n_init: 1, init_method: KMeansInit::KMeansPlusPlus,
119 };
120
121 let (centroids, labels) = kmeans_with_options(obs, k, Some(options))?;
123
124 let distortion = calculate_distortion(obs, centroids.view(), &labels);
126
127 Ok((centroids, distortion))
128}
129
130#[allow(dead_code)]
165pub fn kmeans_with_options<F>(
166 data: ArrayView2<F>,
167 k: usize,
168 options: Option<KMeansOptions<F>>,
169) -> Result<(Array2<F>, Array1<usize>)>
170where
171 F: Float + FromPrimitive + Debug + std::iter::Sum,
172{
173 if k == 0 {
174 return Err(ClusteringError::InvalidInput(
175 "Number of clusters must be greater than 0".to_string(),
176 ));
177 }
178
179 let n_samples = data.shape()[0];
180 if n_samples == 0 {
181 return Err(ClusteringError::InvalidInput(
182 "Input data is empty".to_string(),
183 ));
184 }
185
186 if k > n_samples {
187 return Err(ClusteringError::InvalidInput(format!(
188 "Number of clusters ({}) cannot be greater than number of data points ({})",
189 k, n_samples
190 )));
191 }
192
193 let opts = options.unwrap_or_default();
194 let mut bestcentroids = None;
197 let mut best_labels = None;
198 let mut best_inertia = F::infinity();
199
200 let n_init = if opts.init_method == KMeansInit::KMeansParallel {
202 1
203 } else {
204 opts.n_init
205 };
206
207 for _ in 0..n_init {
208 let centroids = kmeans_init(data, k, Some(opts.init_method), opts.random_seed)?;
210
211 let (centroids, labels, inertia) = _kmeans_single(data, centroids.view(), &opts)?;
213
214 if inertia < best_inertia {
215 bestcentroids = Some(centroids);
216 best_labels = Some(labels);
217 best_inertia = inertia;
218 }
219 }
220
221 Ok((bestcentroids.unwrap(), best_labels.unwrap()))
222}
223
224#[allow(dead_code)]
226fn calculate_distortion<F>(
227 data: ArrayView2<F>,
228 centroids: ArrayView2<F>,
229 labels: &Array1<usize>,
230) -> F
231where
232 F: Float + FromPrimitive + Debug + std::iter::Sum,
233{
234 let n_samples = data.shape()[0];
235 let mut total_distortion = F::zero();
236
237 for i in 0..n_samples {
238 let cluster = labels[i];
239 let point = data.slice(s![i, ..]);
240 let centroid = centroids.slice(s![cluster, ..]);
241
242 let squared_distance = euclidean_distance(point, centroid).powi(2);
243 total_distortion = total_distortion + squared_distance;
244 }
245
246 total_distortion
247}
248
249#[allow(dead_code)]
251fn _kmeans_single<F>(
252 data: ArrayView2<F>,
253 initcentroids: ArrayView2<F>,
254 opts: &KMeansOptions<F>,
255) -> Result<(Array2<F>, Array1<usize>, F)>
256where
257 F: Float + FromPrimitive + Debug + std::iter::Sum,
258{
259 let n_samples = data.shape()[0];
260 let n_features = data.shape()[1];
261 let k = initcentroids.shape()[0];
262
263 let mut centroids = initcentroids.to_owned();
264 let mut labels = Array1::zeros(n_samples);
265 let mut prev_centroid_diff = F::infinity();
266
267 for _iter in 0..opts.max_iter {
268 let (new_labels, distances) = vq(data, centroids.view())?;
270 labels = new_labels;
271
272 let mut newcentroids = Array2::zeros((k, n_features));
274 let mut counts = Array1::zeros(k);
275
276 for i in 0..n_samples {
277 let cluster = labels[i];
278 let point = data.slice(s![i, ..]);
279
280 for j in 0..n_features {
281 newcentroids[[cluster, j]] = newcentroids[[cluster, j]] + point[j];
282 }
283
284 counts[cluster] += 1;
285 }
286
287 for i in 0..k {
289 if counts[i] == 0 {
290 let mut max_dist = F::zero();
292 let mut far_idx = 0;
293
294 for j in 0..n_samples {
295 let dist = distances[j];
296 if dist > max_dist {
297 max_dist = dist;
298 far_idx = j;
299 }
300 }
301
302 for j in 0..n_features {
304 newcentroids[[i, j]] = data[[far_idx, j]];
305 }
306
307 counts[i] = 1;
308 } else {
309 for j in 0..n_features {
311 newcentroids[[i, j]] = newcentroids[[i, j]] / F::from(counts[i]).unwrap();
312 }
313 }
314 }
315
316 let mut centroid_diff = F::zero();
318 for i in 0..k {
319 let dist =
320 euclidean_distance(centroids.slice(s![i, ..]), newcentroids.slice(s![i, ..]));
321 centroid_diff = centroid_diff + dist;
322 }
323
324 centroids = newcentroids;
325
326 if centroid_diff <= opts.tol || centroid_diff >= prev_centroid_diff {
327 break;
328 }
329
330 prev_centroid_diff = centroid_diff;
331 }
332
333 let mut inertia = F::zero();
335 for i in 0..n_samples {
336 let cluster = labels[i];
337 let dist = euclidean_distance(data.slice(s![i, ..]), centroids.slice(s![cluster, ..]));
338 inertia = inertia + dist * dist;
339 }
340
341 Ok((centroids, labels, inertia))
342}
343
344#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
346pub enum KMeansInit {
347 Random,
349 #[default]
351 KMeansPlusPlus,
352 KMeansParallel,
354}
355
356#[allow(dead_code)]
369pub fn kmeans_init<F>(
370 data: ArrayView2<F>,
371 k: usize,
372 init_method: Option<KMeansInit>,
373 random_seed: Option<u64>,
374) -> Result<Array2<F>>
375where
376 F: Float + FromPrimitive + Debug + std::iter::Sum,
377{
378 match init_method.unwrap_or_default() {
379 KMeansInit::Random => random_init(data, k, random_seed),
380 KMeansInit::KMeansPlusPlus => kmeans_plus_plus(data, k, random_seed),
381 KMeansInit::KMeansParallel => kmeans_parallel(data, k, random_seed),
382 }
383}
384
385#[allow(dead_code)]
397pub fn random_init<F>(data: ArrayView2<F>, k: usize, random_seed: Option<u64>) -> Result<Array2<F>>
398where
399 F: Float + FromPrimitive + Debug + std::iter::Sum,
400{
401 let n_samples = data.shape()[0];
402 let n_features = data.shape()[1];
403
404 if k == 0 || k > n_samples {
405 return Err(ClusteringError::InvalidInput(format!(
406 "Number of clusters ({}) must be between 1 and number of samples ({})",
407 k, n_samples
408 )));
409 }
410
411 let mut rng = scirs2_core::random::rng();
412 let mut centroids = Array2::zeros((k, n_features));
413 let mut selected_indices = Vec::with_capacity(k);
414
415 while selected_indices.len() < k {
417 let idx = rng.random_range(0..n_samples);
418 if !selected_indices.contains(&idx) {
419 selected_indices.push(idx);
420 }
421 }
422
423 for (i, &idx) in selected_indices.iter().enumerate() {
425 for j in 0..n_features {
426 centroids[[i, j]] = data[[idx, j]];
427 }
428 }
429
430 Ok(centroids)
431}
432
433#[allow(dead_code)]
445pub fn kmeans_plus_plus<F>(
446 data: ArrayView2<F>,
447 k: usize,
448 random_seed: Option<u64>,
449) -> Result<Array2<F>>
450where
451 F: Float + FromPrimitive + Debug + std::iter::Sum,
452{
453 let n_samples = data.shape()[0];
454 let n_features = data.shape()[1];
455
456 if k == 0 || k > n_samples {
457 return Err(ClusteringError::InvalidInput(format!(
458 "Number of clusters ({}) must be between 1 and number of samples ({})",
459 k, n_samples
460 )));
461 }
462
463 let mut rng = scirs2_core::random::rng();
464
465 let mut centroids = Array2::zeros((k, n_features));
466
467 let first_idx = rng.random_range(0..n_samples);
469 for j in 0..n_features {
470 centroids[[0, j]] = data[[first_idx, j]];
471 }
472
473 if k == 1 {
474 return Ok(centroids);
475 }
476
477 for i in 1..k {
479 let mut min_distances = Array1::from_elem(n_samples, F::infinity());
481
482 for sample_idx in 0..n_samples {
483 let sample = data.slice(s![sample_idx, ..]);
484
485 for centroid_idx in 0..i {
486 let centroid = centroids.slice(s![centroid_idx, ..]);
487 let dist = euclidean_distance(sample, centroid);
488
489 if dist < min_distances[sample_idx] {
490 min_distances[sample_idx] = dist;
491 }
492 }
493 }
494
495 let mut weights = min_distances.mapv(|d| d * d);
497
498 let sum_weights = weights.sum();
500 if sum_weights > F::zero() {
501 weights.mapv_inplace(|w| w / sum_weights);
502 } else {
503 weights.fill(F::from(1.0 / n_samples as f64).unwrap());
505 }
506
507 let mut cum_weights = weights.clone();
509 for j in 1..n_samples {
510 cum_weights[j] = cum_weights[j] + cum_weights[j - 1];
511 }
512
513 let rand_val = F::from(rng.random_range(0.0..1.0)).unwrap();
515 let mut next_idx = 0;
516
517 for j in 0..n_samples {
518 if rand_val <= cum_weights[j] {
519 next_idx = j;
520 break;
521 }
522 }
523
524 for j in 0..n_features {
526 centroids[[i, j]] = data[[next_idx, j]];
527 }
528 }
529
530 Ok(centroids)
531}
532
533#[allow(dead_code)]
552pub fn kmeans_parallel<F>(
553 data: ArrayView2<F>,
554 k: usize,
555 random_seed: Option<u64>,
556) -> Result<Array2<F>>
557where
558 F: Float + FromPrimitive + Debug + std::iter::Sum,
559{
560 let n_samples = data.shape()[0];
561 let n_features = data.shape()[1];
562
563 if k == 0 || k > n_samples {
564 return Err(ClusteringError::InvalidInput(format!(
565 "Number of clusters ({}) must be between 1 and number of samples ({})",
566 k, n_samples
567 )));
568 }
569
570 let mut rng = scirs2_core::random::rng();
571
572 let l = F::from(5.0).unwrap(); let n_rounds = 8; let mut centers = Vec::new();
578 let mut weights = Vec::new();
579
580 let first_idx = rng.random_range(0..n_samples);
582 let mut first_center = Vec::with_capacity(n_features);
583 for j in 0..n_features {
584 first_center.push(data[[first_idx, j]]);
585 }
586 centers.push(first_center);
587 weights.push(F::one()); for _ in 0..n_rounds {
591 let mut min_distances = Array1::from_elem(n_samples, F::infinity());
593
594 for sample_idx in 0..n_samples {
595 let sample = data.slice(s![sample_idx, ..]);
596
597 for center in centers.iter() {
598 let mut dist_sq = F::zero();
599 for j in 0..n_features {
600 let diff = sample[j] - center[j];
601 dist_sq = dist_sq + diff * diff;
602 }
603 let dist = dist_sq.sqrt();
604
605 if dist < min_distances[sample_idx] {
606 min_distances[sample_idx] = dist;
607 }
608 }
609 }
610
611 let potential: F = min_distances.iter().map(|&d| d * d).sum();
613 if potential <= F::epsilon() {
614 break; }
616
617 let expected_new_centers = l * F::from(k).unwrap();
619 let oversampling = F::min(expected_new_centers / potential, F::one());
620
621 for sample_idx in 0..n_samples {
622 let probability = min_distances[sample_idx] * min_distances[sample_idx] * oversampling;
623
624 if F::from(rng.random_range(0.0..1.0)).unwrap() < probability {
626 let mut new_center = Vec::with_capacity(n_features);
627 for j in 0..n_features {
628 new_center.push(data[[sample_idx, j]]);
629 }
630 centers.push(new_center);
631 weights.push(F::one()); }
633 }
634 }
635
636 match centers.len().cmp(&k) {
638 std::cmp::Ordering::Greater => {
639 let n_centers = centers.len();
641 let mut centers_array = Array2::zeros((n_centers, n_features));
642 let mut weights_array = Array1::zeros(n_centers);
643
644 for i in 0..n_centers {
645 for j in 0..n_features {
646 centers_array[[i, j]] = centers[i][j];
647 }
648 weights_array[i] = weights[i];
649 }
650
651 let options = KMeansOptions {
653 max_iter: 100,
654 tol: F::from(1e-4).unwrap(),
655 random_seed,
656 n_init: 1,
657 init_method: KMeansInit::KMeansPlusPlus,
658 };
659
660 let init_indices: Vec<usize> = (0..n_centers)
662 .filter(|_| rng.random_range(0.0..1.0) < 0.5) .take(k) .collect();
665
666 let actual_indices = if init_indices.len() < k {
668 (0..k.min(n_centers)).collect::<Vec<usize>>()
669 } else {
670 init_indices
671 };
672
673 let mut initcentroids = Array2::zeros((actual_indices.len(), n_features));
674 for (i, &idx) in actual_indices.iter().enumerate() {
675 for j in 0..n_features {
676 initcentroids[[i, j]] = centers_array[[idx, j]];
677 }
678 }
679
680 let (finalcentroids_, _) = _weighted_kmeans_single(
682 centers_array.view(),
683 weights_array.view(),
684 initcentroids.view(),
685 &options,
686 )?;
687
688 Ok(finalcentroids_)
689 }
690 std::cmp::Ordering::Less => {
691 let mut centroids = Array2::zeros((k, n_features));
693
694 for i in 0..centers.len() {
696 for j in 0..n_features {
697 centroids[[i, j]] = centers[i][j];
698 }
699 }
700
701 let mut selected_indices = Vec::with_capacity(k - centers.len());
703 while selected_indices.len() < k - centers.len() {
704 let idx = rng.random_range(0..n_samples);
705 if !selected_indices.contains(&idx) {
706 selected_indices.push(idx);
707 }
708 }
709
710 for (i, &idx) in selected_indices.iter().enumerate() {
711 for j in 0..n_features {
712 centroids[[centers.len() + i, j]] = data[[idx, j]];
713 }
714 }
715
716 Ok(centroids)
717 }
718 std::cmp::Ordering::Equal => {
719 let mut centroids = Array2::zeros((k, n_features));
721 for i in 0..k {
722 for j in 0..n_features {
723 centroids[[i, j]] = centers[i][j];
724 }
725 }
726 Ok(centroids)
727 }
728 }
729}
730
731#[allow(dead_code)]
733fn _weighted_kmeans_single<F>(
734 data: ArrayView2<F>,
735 weights: ArrayView1<F>,
736 initcentroids: ArrayView2<F>,
737 opts: &KMeansOptions<F>,
738) -> Result<(Array2<F>, Array1<usize>)>
739where
740 F: Float + FromPrimitive + Debug + std::iter::Sum,
741{
742 let n_samples = data.shape()[0];
743 let n_features = data.shape()[1];
744 let k = initcentroids.shape()[0];
745
746 let mut centroids = initcentroids.to_owned();
747 let mut labels = Array1::zeros(n_samples);
748 let mut prev_centroid_diff = F::infinity();
749
750 for _iter in 0..opts.max_iter {
751 let (new_labels_, _) = vq(data, centroids.view())?;
753 labels = new_labels_;
754
755 let mut newcentroids = Array2::zeros((k, n_features));
757 let mut total_weights = Array1::zeros(k);
758
759 for i in 0..n_samples {
760 let cluster = labels[i];
761 let point = data.slice(s![i, ..]);
762 let weight = weights[i];
763
764 for j in 0..n_features {
765 newcentroids[[cluster, j]] = newcentroids[[cluster, j]] + point[j] * weight;
766 }
767
768 total_weights[cluster] = total_weights[cluster] + weight;
769 }
770
771 for i in 0..k {
773 if total_weights[i] <= F::epsilon() {
774 let mut max_dist = F::zero();
776 let mut far_idx = 0;
777
778 for j in 0..n_samples {
779 let dist = euclidean_distance(
780 data.slice(s![j, ..]),
781 centroids.slice(s![labels[j], ..]),
782 );
783 if dist > max_dist {
784 max_dist = dist;
785 far_idx = j;
786 }
787 }
788
789 for j in 0..n_features {
791 newcentroids[[i, j]] = data[[far_idx, j]];
792 }
793
794 total_weights[i] = weights[far_idx];
795 } else {
796 for j in 0..n_features {
798 newcentroids[[i, j]] = newcentroids[[i, j]] / total_weights[i];
799 }
800 }
801 }
802
803 let mut centroid_diff = F::zero();
805 for i in 0..k {
806 let dist =
807 euclidean_distance(centroids.slice(s![i, ..]), newcentroids.slice(s![i, ..]));
808 centroid_diff = centroid_diff + dist;
809 }
810
811 centroids = newcentroids;
812
813 if centroid_diff <= opts.tol || centroid_diff >= prev_centroid_diff {
814 break;
815 }
816
817 prev_centroid_diff = centroid_diff;
818 }
819
820 Ok((centroids, labels))
821}
822
823#[allow(dead_code)]
860pub fn kmeans_with_metric<F>(
861 data: ArrayView2<F>,
862 k: usize,
863 metric: Box<dyn crate::vq::VQDistanceMetric<F>>,
864 options: Option<KMeansOptions<F>>,
865) -> Result<(Array2<F>, Array1<usize>)>
866where
867 F: Float + FromPrimitive + Debug + std::iter::Sum + Send + Sync + 'static,
868{
869 if k == 0 {
870 return Err(ClusteringError::InvalidInput(
871 "Number of clusters must be greater than 0".to_string(),
872 ));
873 }
874
875 let n_samples = data.shape()[0];
876 if n_samples == 0 {
877 return Err(ClusteringError::InvalidInput(
878 "Input data is empty".to_string(),
879 ));
880 }
881
882 if k > n_samples {
883 return Err(ClusteringError::InvalidInput(format!(
884 "Number of clusters ({}) cannot be greater than number of data points ({})",
885 k, n_samples
886 )));
887 }
888
889 let opts = options.unwrap_or_default();
890
891 let mut bestcentroids = None;
892 let mut best_labels = None;
893 let mut best_inertia = F::infinity();
894
895 let n_init = if opts.init_method == KMeansInit::KMeansParallel {
897 1
898 } else {
899 opts.n_init
900 };
901
902 for _ in 0..n_init {
903 let centroids = kmeans_init(data, k, Some(opts.init_method), opts.random_seed)?;
905
906 let (centroids, labels, inertia) =
908 _kmeans_single_with_metric(data, centroids.view(), metric.as_ref(), &opts)?;
909
910 if inertia < best_inertia {
911 bestcentroids = Some(centroids);
912 best_labels = Some(labels);
913 best_inertia = inertia;
914 }
915 }
916
917 Ok((bestcentroids.unwrap(), best_labels.unwrap()))
918}
919
920#[allow(dead_code)]
922fn _kmeans_single_with_metric<F>(
923 data: ArrayView2<F>,
924 initcentroids: ArrayView2<F>,
925 metric: &dyn crate::vq::VQDistanceMetric<F>,
926 opts: &KMeansOptions<F>,
927) -> Result<(Array2<F>, Array1<usize>, F)>
928where
929 F: Float + FromPrimitive + Debug + std::iter::Sum + Send + Sync,
930{
931 let n_samples = data.shape()[0];
932 let n_features = data.shape()[1];
933 let k = initcentroids.shape()[0];
934
935 let mut centroids = initcentroids.to_owned();
936 let mut labels = Array1::zeros(n_samples);
937 let mut prev_centroid_diff = F::infinity();
938
939 for _iter in 0..opts.max_iter {
940 let (new_labels, distances) = _vq_with_metric(data, centroids.view(), metric)?;
942 labels = new_labels;
943
944 let mut newcentroids = Array2::zeros((k, n_features));
946 let mut counts = Array1::zeros(k);
947
948 for i in 0..n_samples {
949 let cluster = labels[i];
950 let point = data.slice(s![i, ..]);
951
952 for j in 0..n_features {
953 newcentroids[[cluster, j]] = newcentroids[[cluster, j]] + point[j];
954 }
955
956 counts[cluster] += 1;
957 }
958
959 for i in 0..k {
961 if counts[i] == 0 {
962 let mut max_dist = F::zero();
964 let mut far_idx = 0;
965
966 for j in 0..n_samples {
967 let dist = distances[j];
968 if dist > max_dist {
969 max_dist = dist;
970 far_idx = j;
971 }
972 }
973
974 for j in 0..n_features {
976 newcentroids[[i, j]] = data[[far_idx, j]];
977 }
978
979 counts[i] = 1;
980 } else {
981 for j in 0..n_features {
983 newcentroids[[i, j]] = newcentroids[[i, j]] / F::from(counts[i]).unwrap();
984 }
985 }
986 }
987
988 let mut centroid_diff = F::zero();
990 for i in 0..k {
991 let dist = metric.distance(centroids.slice(s![i, ..]), newcentroids.slice(s![i, ..]));
992 centroid_diff = centroid_diff + dist;
993 }
994
995 centroids = newcentroids;
996
997 if centroid_diff <= opts.tol || centroid_diff >= prev_centroid_diff {
998 break;
999 }
1000
1001 prev_centroid_diff = centroid_diff;
1002 }
1003
1004 let mut inertia = F::zero();
1006 for i in 0..n_samples {
1007 let cluster = labels[i];
1008 let dist = metric.distance(data.slice(s![i, ..]), centroids.slice(s![cluster, ..]));
1009 inertia = inertia + dist * dist;
1010 }
1011
1012 Ok((centroids, labels, inertia))
1013}
1014
1015#[allow(dead_code)]
1017fn _vq_with_metric<F>(
1018 data: ArrayView2<F>,
1019 centroids: ArrayView2<F>,
1020 metric: &dyn crate::vq::VQDistanceMetric<F>,
1021) -> Result<(Array1<usize>, Array1<F>)>
1022where
1023 F: Float + FromPrimitive + Debug + Send + Sync,
1024{
1025 let n_samples = data.shape()[0];
1026 let ncentroids = centroids.shape()[0];
1027
1028 let mut labels = Array1::zeros(n_samples);
1029 let mut distances = Array1::zeros(n_samples);
1030
1031 for i in 0..n_samples {
1032 let point = data.slice(s![i, ..]);
1033 let mut min_dist = F::infinity();
1034 let mut closest_centroid = 0;
1035
1036 for j in 0..ncentroids {
1037 let centroid = centroids.slice(s![j, ..]);
1038 let dist = metric.distance(point, centroid);
1039
1040 if dist < min_dist {
1041 min_dist = dist;
1042 closest_centroid = j;
1043 }
1044 }
1045
1046 labels[i] = closest_centroid;
1047 distances[i] = min_dist;
1048 }
1049
1050 Ok((labels, distances))
1051}
1052
1053#[cfg(test)]
1054mod tests {
1055 use super::*;
1056 use scirs2_core::ndarray::Array2;
1057
1058 #[test]
1059 fn test_kmeans_random_init() {
1060 let data = Array2::from_shape_vec(
1062 (6, 2),
1063 vec![1.0, 2.0, 1.2, 1.8, 0.8, 1.9, 4.0, 5.0, 4.2, 4.8, 3.9, 5.1],
1064 )
1065 .unwrap();
1066
1067 let options = KMeansOptions {
1069 init_method: KMeansInit::Random,
1070 ..Default::default()
1071 };
1072
1073 let result = kmeans_with_options(data.view(), 2, Some(options));
1074 assert!(result.is_ok());
1075
1076 let (centroids, labels) = result.unwrap();
1077
1078 assert_eq!(centroids.shape(), &[2, 2]);
1080 assert_eq!(labels.len(), 6);
1081
1082 let unique_labels: std::collections::HashSet<_> = labels.iter().cloned().collect();
1084 assert_eq!(unique_labels.len(), 2);
1085 }
1086
1087 #[test]
1088 fn test_kmeans_plusplus_init() {
1089 let data = Array2::from_shape_vec(
1091 (6, 2),
1092 vec![1.0, 2.0, 1.2, 1.8, 0.8, 1.9, 4.0, 5.0, 4.2, 4.8, 3.9, 5.1],
1093 )
1094 .unwrap();
1095
1096 let options = KMeansOptions {
1098 init_method: KMeansInit::KMeansPlusPlus,
1099 ..Default::default()
1100 };
1101
1102 let result = kmeans_with_options(data.view(), 2, Some(options));
1103 assert!(result.is_ok());
1104
1105 let (centroids, labels) = result.unwrap();
1106
1107 assert_eq!(centroids.shape(), &[2, 2]);
1109 assert_eq!(labels.len(), 6);
1110
1111 let unique_labels: std::collections::HashSet<_> = labels.iter().cloned().collect();
1113 assert_eq!(unique_labels.len(), 2);
1114 }
1115
1116 #[test]
1117 fn test_kmeans_parallel_init() {
1118 let data = Array2::from_shape_vec(
1120 (20, 2),
1121 vec![
1122 1.0, 2.0, 1.2, 1.8, 0.8, 1.9, 1.1, 2.2, 0.9, 1.7, 1.3, 2.1, 1.0, 1.9, 0.7, 2.0,
1123 1.2, 2.3, 1.5, 1.8, 5.0, 6.0, 5.2, 5.8, 4.8, 6.2, 5.1, 5.9, 5.3, 6.1, 4.9, 5.7,
1124 5.0, 6.3, 5.4, 5.6, 4.7, 5.9, 5.2, 6.2,
1125 ],
1126 )
1127 .unwrap();
1128
1129 let options = KMeansOptions {
1131 init_method: KMeansInit::KMeansParallel,
1132 ..Default::default()
1133 };
1134
1135 let result = kmeans_with_options(data.view(), 2, Some(options));
1136 assert!(result.is_ok());
1137
1138 let (centroids, labels) = result.unwrap();
1139
1140 assert_eq!(centroids.shape(), &[2, 2]);
1142 assert_eq!(labels.len(), 20);
1143
1144 let unique_labels: std::collections::HashSet<_> = labels.iter().cloned().collect();
1146 assert_eq!(unique_labels.len(), 2);
1147
1148 let first_cluster = labels[0];
1150 for i in 0..10 {
1151 assert_eq!(labels[i], first_cluster);
1152 }
1153
1154 let second_cluster = labels[10];
1155 assert_ne!(first_cluster, second_cluster);
1156 for i in 10..20 {
1157 assert_eq!(labels[i], second_cluster);
1158 }
1159 }
1160
1161 #[test]
1162 fn test_scipy_compatible_kmeans() {
1163 let data = Array2::from_shape_vec(
1165 (6, 2),
1166 vec![1.0, 2.0, 1.2, 1.8, 0.8, 1.9, 4.0, 5.0, 4.2, 4.8, 3.9, 5.1],
1167 )
1168 .unwrap();
1169
1170 let result = kmeans(
1172 data.view(),
1173 2, Some(20), Some(1e-5), Some(true), Some(42), );
1179 assert!(result.is_ok());
1180
1181 let (centroids, distortion) = result.unwrap();
1182
1183 assert_eq!(centroids.shape(), &[2, 2]);
1185
1186 assert!(distortion > 0.0);
1188
1189 let result = kmeans(
1191 data.view(),
1192 2, None, None, None, None, );
1198 assert!(result.is_ok());
1199
1200 let (centroids2, distortion2) = result.unwrap();
1201 assert_eq!(centroids2.shape(), &[2, 2]);
1202 assert!(distortion2 > 0.0);
1203 }
1204
1205 #[test]
1206 fn test_scipy_kmeans_check_finite() {
1207 let data =
1208 Array2::from_shape_vec((4, 2), vec![1.0, 2.0, 1.5, 1.5, 8.0, 8.0, 8.5, 8.5]).unwrap();
1209
1210 let result = kmeans(
1212 data.view(),
1213 2,
1214 Some(10),
1215 Some(1e-5),
1216 Some(true), Some(42),
1218 );
1219 assert!(result.is_ok());
1220
1221 let result = kmeans(
1223 data.view(),
1224 2,
1225 Some(10),
1226 Some(1e-5),
1227 Some(false), Some(42),
1229 );
1230 assert!(result.is_ok());
1231 }
1232}