1use scirs2_core::ndarray::{s, Array1, Array2, ArrayView2};
4use scirs2_core::numeric::{Float, FromPrimitive};
5use scirs2_core::random::{rngs::StdRng, Rng, RngCore, SeedableRng};
6use scirs2_core::random::{Distribution, Normal};
7use std::fmt::Debug;
8use std::str::FromStr;
9
10use super::{euclidean_distance, vq};
11use crate::error::{ClusteringError, Result};
12use scirs2_core::validation::{clustering::*, parameters::*};
13
14#[derive(Debug, Clone, Copy, PartialEq, Eq)]
16pub enum MinitMethod {
17 Random,
19 Points,
21 PlusPlus,
23}
24
25impl MinitMethod {
26 pub fn parse_method(s: &str) -> Result<Self> {
40 match s.to_lowercase().as_str() {
41 "random" => Ok(MinitMethod::Random),
42 "points" => Ok(MinitMethod::Points),
43 "k-means++" | "kmeans++" | "plusplus" => Ok(MinitMethod::PlusPlus),
44 _ => Err(ClusteringError::InvalidInput(format!(
45 "Unknown initialization method: '{}'. Valid options are: 'random', 'points', 'k-means++'",
46 s
47 ))),
48 }
49 }
50}
51
52impl FromStr for MinitMethod {
53 type Err = ClusteringError;
54
55 fn from_str(s: &str) -> Result<Self> {
56 Self::parse_method(s)
57 }
58}
59
60#[derive(Debug, Clone, Copy, PartialEq, Eq)]
65pub enum MissingMethod {
66 Warn,
71 Raise,
76}
77
78#[allow(clippy::too_many_arguments)]
97#[allow(dead_code)]
98pub fn kmeans2<F>(
99 data: ArrayView2<F>,
100 k: usize,
101 iter: Option<usize>,
102 thresh: Option<F>,
103 minit: Option<MinitMethod>,
104 missing: Option<MissingMethod>,
105 check_finite: Option<bool>,
106 randomseed: Option<u64>,
107) -> Result<(Array2<F>, Array1<usize>)>
108where
109 F: Float + FromPrimitive + Debug + std::iter::Sum + std::fmt::Display,
110{
111 let n_samples = data.shape()[0];
112 let n_features = data.shape()[1];
113 let iterations = iter.unwrap_or(10);
114 let threshold = thresh.unwrap_or(F::from(1e-5).unwrap());
115 let missing_method = missing.unwrap_or(MissingMethod::Warn);
116 let check_finite_flag = check_finite.unwrap_or(true);
117
118 validate_clustering_data(&data, "K-means", check_finite_flag, Some(k))
120 .map_err(|e| ClusteringError::InvalidInput(format!("K-means: {}", e)))?;
121
122 check_n_clusters_bounds(&data, k, "K-means")
123 .map_err(|e| ClusteringError::InvalidInput(format!("{}", e)))?;
124
125 check_iteration_params(iterations, threshold, "K-means")
126 .map_err(|e| ClusteringError::InvalidInput(format!("{}", e)))?;
127
128 let init_method = minit.unwrap_or(MinitMethod::PlusPlus); let mut centroids = match init_method {
131 MinitMethod::Random => krandinit(data, k, randomseed)?,
132 MinitMethod::Points => kpoints(data, k, randomseed)?,
133 MinitMethod::PlusPlus => kmeans_plus_plus(data, k, randomseed)?,
134 };
135
136 let mut labels;
137
138 for _iteration in 0..iterations {
140 let prev_centroids = centroids.clone();
142
143 let (new_labels, _distances) = vq(data, centroids.view())?;
145 labels = new_labels;
146
147 let mut new_centroids = Array2::zeros((k, n_features));
149 let mut counts = Array1::zeros(k);
150
151 for i in 0..n_samples {
152 let cluster = labels[i];
153 let point = data.slice(s![i, ..]);
154
155 for j in 0..n_features {
156 new_centroids[[cluster, j]] = new_centroids[[cluster, j]] + point[j];
157 }
158
159 counts[cluster] += 1;
160 }
161
162 for i in 0..k {
164 if counts[i] == 0 {
165 match missing_method {
166 MissingMethod::Warn => {
167 eprintln!("One of the clusters is empty. Re-run kmeans with a different initialization.");
168 let mut max_dist = F::zero();
170 let mut far_idx = 0;
171
172 for j in 0..n_samples {
173 let cluster_j = labels[j];
174 let dist = euclidean_distance(
175 data.slice(s![j, ..]),
176 centroids.slice(s![cluster_j, ..]),
177 );
178 if dist > max_dist {
179 max_dist = dist;
180 far_idx = j;
181 }
182 }
183
184 for j in 0..n_features {
186 new_centroids[[i, j]] = data[[far_idx, j]];
187 }
188 counts[i] = 1;
189 }
190 MissingMethod::Raise => {
191 return Err(ClusteringError::EmptyCluster(
192 "One of the clusters is empty. Re-run kmeans with a different initialization.".to_string()
193 ));
194 }
195 }
196 } else {
197 for j in 0..n_features {
199 new_centroids[[i, j]] = new_centroids[[i, j]] / F::from(counts[i]).unwrap();
200 }
201 }
202 }
203
204 centroids = new_centroids;
205
206 let mut max_centroid_shift = F::zero();
208 for i in 0..k {
209 for j in 0..n_features {
210 let shift = (centroids[[i, j]] - prev_centroids[[i, j]]).abs();
211 if shift > max_centroid_shift {
212 max_centroid_shift = shift;
213 }
214 }
215 }
216
217 if max_centroid_shift < threshold {
219 break;
220 }
221 }
222
223 let (final_labels, _distances) = vq(data, centroids.view())?;
225
226 Ok((centroids, final_labels))
227}
228
229#[allow(clippy::too_many_arguments)]
268#[allow(dead_code)]
269pub fn kmeans2_str<F>(
270 data: ArrayView2<F>,
271 k: usize,
272 iter: Option<usize>,
273 thresh: Option<F>,
274 minit: Option<&str>,
275 missing: Option<&str>,
276 check_finite: Option<bool>,
277 randomseed: Option<u64>,
278) -> Result<(Array2<F>, Array1<usize>)>
279where
280 F: Float + FromPrimitive + Debug + std::iter::Sum + std::fmt::Display,
281{
282 let minit_method = if let Some(method_str) = minit {
284 Some(MinitMethod::from_str(method_str)?)
285 } else {
286 Some(MinitMethod::PlusPlus) };
288
289 let missing_method = if let Some(missing_str) = missing {
290 match missing_str.to_lowercase().as_str() {
291 "warn" => Some(MissingMethod::Warn),
292 "raise" => Some(MissingMethod::Raise),
293 _ => {
294 return Err(ClusteringError::InvalidInput(format!(
295 "Unknown missing method: '{}'. Valid options are: 'warn', 'raise'",
296 missing_str
297 )))
298 }
299 }
300 } else {
301 Some(MissingMethod::Warn) };
303
304 kmeans2(
306 data,
307 k,
308 iter,
309 thresh,
310 minit_method,
311 missing_method,
312 check_finite,
313 randomseed,
314 )
315}
316
317#[allow(dead_code)]
320fn krandinit<F>(data: ArrayView2<F>, k: usize, randomseed: Option<u64>) -> Result<Array2<F>>
321where
322 F: Float + FromPrimitive + Debug + std::iter::Sum,
323{
324 let n_samples = data.shape()[0];
325 let n_features = data.shape()[1];
326
327 let mut means = Array1::<F>::zeros(n_features);
329 let mut vars = Array1::<F>::zeros(n_features);
330
331 for j in 0..n_features {
332 let mut sum = F::zero();
333 for i in 0..n_samples {
334 sum = sum + data[[i, j]];
335 }
336 means[j] = sum / F::from(n_samples).unwrap();
337
338 let mut var_sum = F::zero();
339 for i in 0..n_samples {
340 let diff = data[[i, j]] - means[j];
341 var_sum = var_sum + diff * diff;
342 }
343 vars[j] = var_sum / F::from(n_samples).unwrap();
344 }
345
346 let mut centroids = Array2::<F>::zeros((k, n_features));
348
349 let mut rng: Box<dyn RngCore> = if let Some(_seed) = randomseed {
350 Box::new(StdRng::seed_from_u64(_seed))
351 } else {
352 Box::new(scirs2_core::random::rng())
353 };
354
355 for i in 0..k {
356 for j in 0..n_features {
357 let mean = means[j].to_f64().unwrap();
359 let std = vars[j].sqrt().to_f64().unwrap();
360
361 if std > 0.0 {
362 let normal = Normal::new(mean, std).unwrap();
363 let value = normal.sample(&mut rng);
364 centroids[[i, j]] = F::from(value).unwrap();
365 } else {
366 centroids[[i, j]] = means[j];
367 }
368 }
369 }
370
371 Ok(centroids)
372}
373
374#[allow(dead_code)]
376fn kpoints<F>(data: ArrayView2<F>, k: usize, randomseed: Option<u64>) -> Result<Array2<F>>
377where
378 F: Float + FromPrimitive + Debug,
379{
380 let n_samples = data.shape()[0];
381 let n_features = data.shape()[1];
382
383 let mut rng: Box<dyn RngCore> = if let Some(_seed) = randomseed {
384 Box::new(StdRng::seed_from_u64(_seed))
385 } else {
386 Box::new(scirs2_core::random::rng())
387 };
388
389 let mut indices: Vec<usize> = (0..n_samples).collect();
391
392 for i in 0..k {
394 let j = rng.random_range(i..n_samples);
395 indices.swap(i, j);
396 }
397
398 let mut centroids = Array2::zeros((k, n_features));
400 for i in 0..k {
401 let idx = indices[i];
402 for j in 0..n_features {
403 centroids[[i, j]] = data[[idx, j]];
404 }
405 }
406
407 Ok(centroids)
408}
409
410#[allow(dead_code)]
412fn kmeans_plus_plus<F>(data: ArrayView2<F>, k: usize, randomseed: Option<u64>) -> Result<Array2<F>>
413where
414 F: Float + FromPrimitive + Debug + std::iter::Sum,
415{
416 let n_samples = data.shape()[0];
417 let n_features = data.shape()[1];
418
419 let mut rng: Box<dyn RngCore> = if let Some(_seed) = randomseed {
420 Box::new(StdRng::seed_from_u64(_seed))
421 } else {
422 Box::new(scirs2_core::random::rng())
423 };
424
425 let mut centroids = Array2::zeros((k, n_features));
426
427 let first_idx = rng.random_range(0..n_samples);
429 for j in 0..n_features {
430 centroids[[0, j]] = data[[first_idx, j]];
431 }
432
433 for i in 1..k {
435 let mut distances = Array1::<F>::zeros(n_samples);
437
438 for j in 0..n_samples {
439 let mut min_dist = F::infinity();
440 for c in 0..i {
441 let dist = euclidean_distance(data.slice(s![j, ..]), centroids.slice(s![c, ..]));
442 if dist < min_dist {
443 min_dist = dist;
444 }
445 }
446 distances[j] = min_dist * min_dist;
447 }
448
449 let total = distances.iter().fold(F::zero(), |a, &b| a + b);
451 let mut probabilities = Array1::<F>::zeros(n_samples);
452 for j in 0..n_samples {
453 probabilities[j] = distances[j] / total;
454 }
455
456 let mut cumsum = F::zero();
458 let r = F::from(rng.random::<f64>()).unwrap();
459 let mut next_idx = n_samples - 1;
460
461 for j in 0..n_samples {
462 cumsum = cumsum + probabilities[j];
463 if cumsum > r {
464 next_idx = j;
465 break;
466 }
467 }
468
469 for j in 0..n_features {
471 centroids[[i, j]] = data[[next_idx, j]];
472 }
473 }
474
475 Ok(centroids)
476}
477
478#[cfg(test)]
479mod tests {
480 use super::*;
481 use approx::assert_abs_diff_eq;
482 use scirs2_core::ndarray::{array, Array2};
483
484 #[test]
485 fn test_kmeans2_basic_functionality() {
486 let data = array![
487 [1.0, 1.0],
488 [1.5, 1.5],
489 [0.8, 0.9],
490 [8.0, 8.0],
491 [8.2, 8.1],
492 [7.8, 7.9],
493 ];
494
495 let (centroids, labels) = kmeans2(
496 data.view(),
497 2,
498 Some(50),
499 Some(1e-6),
500 Some(MinitMethod::PlusPlus),
501 Some(MissingMethod::Warn),
502 Some(true),
503 Some(42),
504 )
505 .unwrap();
506
507 assert_eq!(centroids.shape(), [2, 2]);
509
510 assert_eq!(labels.len(), 6);
512
513 assert!(labels.iter().all(|&l| l == 0 || l == 1));
515
516 let unique_labels: std::collections::HashSet<_> = labels.iter().cloned().collect();
518 assert_eq!(unique_labels.len(), 2);
519 }
520
521 #[test]
522 fn test_kmeans2_parameter_validation() {
523 let data = array![[1.0, 1.0], [2.0, 2.0]];
524
525 let result = kmeans2(
527 data.view(),
528 0,
529 None,
530 None,
531 Some(MinitMethod::Random),
532 None,
533 None,
534 None,
535 );
536 assert!(result.is_err());
537
538 let result = kmeans2(
540 data.view(),
541 5,
542 None,
543 None,
544 Some(MinitMethod::Random),
545 None,
546 None,
547 None,
548 );
549 assert!(result.is_err());
550 }
551
552 #[test]
553 fn test_kmeans2_initialization_methods() {
554 let data = array![
555 [1.0, 1.0],
556 [1.5, 1.5],
557 [0.8, 0.9],
558 [8.0, 8.0],
559 [8.2, 8.1],
560 [7.8, 7.9],
561 ];
562
563 let methods = vec![
564 MinitMethod::Random,
565 MinitMethod::Points,
566 MinitMethod::PlusPlus,
567 ];
568
569 for method in methods {
570 let result = kmeans2(
571 data.view(),
572 2,
573 Some(10),
574 None,
575 Some(method),
576 Some(MissingMethod::Warn),
577 None,
578 Some(42),
579 );
580
581 assert!(result.is_ok(), "Failed with method: {:?}", method);
582 let (centroids, labels) = result.unwrap();
583 assert_eq!(centroids.shape(), [2, 2]);
584 assert_eq!(labels.len(), 6);
585 }
586 }
587
588 #[test]
589 fn test_kmeans2_reproducibility_with_seed() {
590 let data = array![
591 [1.0, 1.0],
592 [1.5, 1.5],
593 [0.8, 0.9],
594 [8.0, 8.0],
595 [8.2, 8.1],
596 [7.8, 7.9],
597 ];
598
599 let (centroids1, labels1) = kmeans2(
600 data.view(),
601 2,
602 Some(10),
603 None,
604 Some(MinitMethod::Random),
605 None,
606 None,
607 Some(42),
608 )
609 .unwrap();
610
611 let (centroids2, labels2) = kmeans2(
612 data.view(),
613 2,
614 Some(10),
615 None,
616 Some(MinitMethod::Random),
617 None,
618 None,
619 Some(42),
620 )
621 .unwrap();
622
623 assert_eq!(labels1, labels2);
625
626 for i in 0..centroids1.shape()[0] {
628 for j in 0..centroids1.shape()[1] {
629 assert_abs_diff_eq!(centroids1[[i, j]], centroids2[[i, j]], epsilon = 1e-10);
630 }
631 }
632 }
633
634 #[test]
635 fn test_kmeans2_single_cluster() {
636 let data = array![[1.0, 1.0], [1.1, 1.1], [0.9, 0.9],];
637
638 let (centroids, labels) = kmeans2(
639 data.view(),
640 1,
641 Some(10),
642 None,
643 Some(MinitMethod::Points),
644 None,
645 None,
646 Some(42),
647 )
648 .unwrap();
649
650 assert_eq!(centroids.shape(), [1, 2]);
652
653 assert!(labels.iter().all(|&l| l == 0));
655 }
656
657 #[test]
658 fn test_kmeans2_identical_points() {
659 let data = array![[1.0, 1.0], [1.0, 1.0], [1.0, 1.0], [1.0, 1.0],];
660
661 let (centroids, labels) = kmeans2(
662 data.view(),
663 2,
664 Some(10),
665 None,
666 Some(MinitMethod::Points),
667 Some(MissingMethod::Warn),
668 None,
669 Some(42),
670 )
671 .unwrap();
672
673 assert_eq!(centroids.shape(), [2, 2]);
675 assert_eq!(labels.len(), 4);
676
677 assert!(labels.iter().all(|&l| l == 0 || l == 1));
679 }
680
681 #[test]
682 fn test_kmeans2_missing_method_warn() {
683 let data = array![[0.0, 0.0], [0.1, 0.1], [10.0, 10.0],];
685
686 let result = kmeans2(
687 data.view(),
688 2,
689 Some(5),
690 None,
691 Some(MinitMethod::Random),
692 Some(MissingMethod::Warn),
693 None,
694 Some(123),
695 );
696
697 assert!(result.is_ok());
699 }
700
701 #[test]
702 fn test_kmeans2_convergence_behavior() {
703 let data = array![
705 [1.0, 1.0],
706 [1.1, 1.1],
707 [0.9, 0.9],
708 [10.0, 10.0],
709 [10.1, 10.1],
710 [9.9, 9.9],
711 ];
712
713 let (centroids_few_) = kmeans2(
715 data.view(),
716 2,
717 Some(1),
718 None,
719 Some(MinitMethod::PlusPlus),
720 None,
721 None,
722 Some(42),
723 )
724 .unwrap();
725
726 let (centroids_many_) = kmeans2(
727 data.view(),
728 2,
729 Some(100),
730 None,
731 Some(MinitMethod::PlusPlus),
732 None,
733 None,
734 Some(42),
735 )
736 .unwrap();
737
738 assert_eq!(centroids_few_.0.shape(), [2, 2]);
740 assert_eq!(centroids_many_.0.shape(), [2, 2]);
741 }
742
743 #[test]
744 fn test_kmeans2_high_k() {
745 let data = array![[1.0, 1.0], [2.0, 2.0], [3.0, 3.0], [4.0, 4.0], [5.0, 5.0],];
746
747 let (centroids, labels) = kmeans2(
749 data.view(),
750 5,
751 Some(10),
752 None,
753 Some(MinitMethod::Points),
754 None,
755 None,
756 Some(42),
757 )
758 .unwrap();
759
760 assert_eq!(centroids.shape(), [5, 2]);
761 assert_eq!(labels.len(), 5);
762
763 let unique_labels: std::collections::HashSet<_> = labels.iter().cloned().collect();
765 assert_eq!(unique_labels.len(), 5);
766 }
767
768 #[test]
769 fn test_kmeans2_different_thresholds() {
770 let data = array![[1.0, 1.0], [1.5, 1.5], [8.0, 8.0], [8.5, 8.5],];
771
772 let result1 = kmeans2(
774 data.view(),
775 2,
776 Some(100),
777 Some(1e-10), Some(MinitMethod::PlusPlus),
779 None,
780 None,
781 Some(42),
782 );
783
784 let result2 = kmeans2(
785 data.view(),
786 2,
787 Some(100),
788 Some(1e-1), Some(MinitMethod::PlusPlus),
790 None,
791 None,
792 Some(42),
793 );
794
795 assert!(result1.is_ok());
796 assert!(result2.is_ok());
797 }
798
799 #[test]
800 fn test_kmeans2_convergence_threshold() {
801 let data = array![
803 [1.0, 1.0],
804 [1.1, 1.1],
805 [0.9, 0.9],
806 [10.0, 10.0],
807 [10.1, 10.1],
808 [9.9, 9.9],
809 ];
810
811 let result1 = kmeans2(
813 data.view(),
814 2,
815 Some(100), Some(1e-10), Some(MinitMethod::PlusPlus),
818 None,
819 None,
820 Some(42),
821 );
822
823 assert!(result1.is_ok());
824 let (centroids1, labels1) = result1.unwrap();
825 assert_eq!(centroids1.shape(), [2, 2]);
826 assert_eq!(labels1.len(), 6);
827
828 let result2 = kmeans2(
830 data.view(),
831 2,
832 Some(100),
833 Some(1e-1), Some(MinitMethod::PlusPlus),
835 None,
836 None,
837 Some(42),
838 );
839
840 assert!(result2.is_ok());
841 let (centroids2, labels2) = result2.unwrap();
842 assert_eq!(centroids2.shape(), [2, 2]);
843 assert_eq!(labels2.len(), 6);
844 }
845
846 #[test]
847 fn test_kmeans2_check_finite() {
848 let data = array![[1.0, 2.0], [1.5, 1.5], [8.0, 8.0],];
850
851 let result = kmeans2(
852 data.view(),
853 2,
854 Some(10),
855 None,
856 Some(MinitMethod::Random),
857 None,
858 Some(true), Some(42),
860 );
861 assert!(result.is_ok());
862
863 let result = kmeans2(
865 data.view(),
866 2,
867 Some(10),
868 None,
869 Some(MinitMethod::Random),
870 None,
871 Some(false), Some(42),
873 );
874 assert!(result.is_ok());
875 }
876
877 #[test]
878 fn test_kmeans2_large_dataset() {
879 let mut data = Array2::zeros((100, 3));
881
882 for i in 0..100 {
884 let cluster = i % 3;
885 match cluster {
886 0 => {
887 data[[i, 0]] = 1.0 + (i as f64) * 0.01;
888 data[[i, 1]] = 1.0 + (i as f64) * 0.01;
889 data[[i, 2]] = 1.0 + (i as f64) * 0.01;
890 }
891 1 => {
892 data[[i, 0]] = 5.0 + (i as f64) * 0.01;
893 data[[i, 1]] = 5.0 + (i as f64) * 0.01;
894 data[[i, 2]] = 5.0 + (i as f64) * 0.01;
895 }
896 2 => {
897 data[[i, 0]] = 10.0 + (i as f64) * 0.01;
898 data[[i, 1]] = 10.0 + (i as f64) * 0.01;
899 data[[i, 2]] = 10.0 + (i as f64) * 0.01;
900 }
901 _ => unreachable!(),
902 }
903 }
904
905 let (centroids, labels) = kmeans2(
906 data.view(),
907 3,
908 Some(50),
909 None,
910 Some(MinitMethod::PlusPlus),
911 None,
912 None,
913 Some(42),
914 )
915 .unwrap();
916
917 assert_eq!(centroids.shape(), [3, 3]);
918 assert_eq!(labels.len(), 100);
919
920 let unique_labels: std::collections::HashSet<_> = labels.iter().cloned().collect();
922 assert_eq!(unique_labels.len(), 3);
923 }
924
925 use super::kmeans2_str;
927
928 #[test]
929 fn test_kmeans2_str_basic_functionality() {
930 let data = array![
931 [1.0, 1.0],
932 [1.5, 1.5],
933 [0.8, 0.9],
934 [8.0, 8.0],
935 [8.2, 8.1],
936 [7.8, 7.9],
937 ];
938
939 let (centroids, labels) = kmeans2_str(
940 data.view(),
941 2,
942 Some(50),
943 Some(1e-6),
944 Some("k-means++"),
945 Some("warn"),
946 Some(true),
947 Some(42),
948 )
949 .unwrap();
950
951 assert_eq!(centroids.shape(), [2, 2]);
952 assert_eq!(labels.len(), 6);
953 assert!(labels.iter().all(|&l| l == 0 || l == 1));
954
955 let unique_labels: std::collections::HashSet<_> = labels.iter().cloned().collect();
956 assert_eq!(unique_labels.len(), 2);
957 }
958
959 #[test]
960 fn test_kmeans2_str_all_init_methods() {
961 let data = array![
962 [1.0, 1.0],
963 [1.5, 1.5],
964 [0.8, 0.9],
965 [8.0, 8.0],
966 [8.2, 8.1],
967 [7.8, 7.9],
968 ];
969
970 let methods = vec!["random", "points", "k-means++", "kmeans++", "plusplus"];
971
972 for method in methods {
973 let result = kmeans2_str(
974 data.view(),
975 2,
976 Some(10),
977 None,
978 Some(method),
979 Some("warn"),
980 None,
981 Some(42),
982 );
983
984 assert!(result.is_ok(), "Failed with method: '{}'", method);
985 let (centroids, labels) = result.unwrap();
986 assert_eq!(centroids.shape(), [2, 2]);
987 assert_eq!(labels.len(), 6);
988 }
989 }
990
991 #[test]
992 fn test_kmeans2_str_case_insensitive() {
993 let data = array![[1.0, 1.0], [2.0, 2.0], [8.0, 8.0], [9.0, 9.0],];
994
995 let methods = vec![
997 "RANDOM",
998 "Random",
999 "random",
1000 "POINTS",
1001 "Points",
1002 "points",
1003 "K-MEANS++",
1004 "K-Means++",
1005 "k-means++",
1006 ];
1007
1008 for method in methods {
1009 let result = kmeans2_str(
1010 data.view(),
1011 2,
1012 Some(10),
1013 None,
1014 Some(method),
1015 Some("warn"),
1016 None,
1017 Some(42),
1018 );
1019
1020 assert!(result.is_ok(), "Failed with method: '{}'", method);
1021 }
1022 }
1023
1024 #[test]
1025 fn test_kmeans2_str_missing_methods() {
1026 let data = array![[1.0, 1.0], [2.0, 2.0], [8.0, 8.0],];
1027
1028 let missing_methods = vec!["warn", "raise", "WARN", "RAISE"];
1030
1031 for missing_method in missing_methods {
1032 let result = kmeans2_str(
1033 data.view(),
1034 2,
1035 Some(5),
1036 None,
1037 Some("points"),
1038 Some(missing_method),
1039 None,
1040 Some(42),
1041 );
1042
1043 assert!(
1044 result.is_ok(),
1045 "Failed with missing method: '{}'",
1046 missing_method
1047 );
1048 }
1049 }
1050
1051 #[test]
1052 fn test_kmeans2_str_invalid_method() {
1053 let data = array![[1.0, 1.0], [2.0, 2.0]];
1054
1055 let result = kmeans2_str(
1057 data.view(),
1058 2,
1059 Some(10),
1060 None,
1061 Some("invalid_method"),
1062 Some("warn"),
1063 None,
1064 None,
1065 );
1066
1067 assert!(result.is_err());
1068 assert!(result
1069 .unwrap_err()
1070 .to_string()
1071 .contains("Unknown initialization method"));
1072 }
1073
1074 #[test]
1075 fn test_kmeans2_str_invalid_missing_method() {
1076 let data = array![[1.0, 1.0], [2.0, 2.0]];
1077
1078 let result = kmeans2_str(
1080 data.view(),
1081 2,
1082 Some(10),
1083 None,
1084 Some("points"),
1085 Some("invalid_missing"),
1086 None,
1087 None,
1088 );
1089
1090 assert!(result.is_err());
1091 assert!(result
1092 .unwrap_err()
1093 .to_string()
1094 .contains("Unknown missing method"));
1095 }
1096
1097 #[test]
1098 fn test_kmeans2_str_defaults() {
1099 let data = array![[1.0, 1.0], [1.5, 1.5], [8.0, 8.0], [8.5, 8.5],];
1100
1101 let result = kmeans2_str(
1103 data.view(),
1104 2,
1105 Some(10),
1106 None,
1107 None, None, None,
1110 Some(42),
1111 );
1112
1113 assert!(result.is_ok());
1114 let (centroids, labels) = result.unwrap();
1115 assert_eq!(centroids.shape(), [2, 2]);
1116 assert_eq!(labels.len(), 4);
1117 }
1118
1119 #[test]
1120 fn test_kmeans2_str_equivalence_with_enum() {
1121 let data = array![
1122 [1.0, 1.0],
1123 [1.5, 1.5],
1124 [0.8, 0.9],
1125 [8.0, 8.0],
1126 [8.2, 8.1],
1127 [7.8, 7.9],
1128 ];
1129
1130 let (centroids_enum, labels_enum) = kmeans2(
1132 data.view(),
1133 2,
1134 Some(50),
1135 Some(1e-6),
1136 Some(MinitMethod::PlusPlus),
1137 Some(MissingMethod::Warn),
1138 Some(true),
1139 Some(42),
1140 )
1141 .unwrap();
1142
1143 let (centroids_str, labels_str) = kmeans2_str(
1144 data.view(),
1145 2,
1146 Some(50),
1147 Some(1e-6),
1148 Some("k-means++"),
1149 Some("warn"),
1150 Some(true),
1151 Some(42),
1152 )
1153 .unwrap();
1154
1155 assert_eq!(labels_enum, labels_str);
1157
1158 for i in 0..centroids_enum.shape()[0] {
1159 for j in 0..centroids_enum.shape()[1] {
1160 assert_abs_diff_eq!(
1161 centroids_enum[[i, j]],
1162 centroids_str[[i, j]],
1163 epsilon = 1e-10
1164 );
1165 }
1166 }
1167 }
1168
1169 #[test]
1170 fn test_minit_method_from_str() {
1171 assert_eq!(
1173 MinitMethod::from_str("random").unwrap(),
1174 MinitMethod::Random
1175 );
1176 assert_eq!(
1177 MinitMethod::from_str("RANDOM").unwrap(),
1178 MinitMethod::Random
1179 );
1180 assert_eq!(
1181 MinitMethod::from_str("points").unwrap(),
1182 MinitMethod::Points
1183 );
1184 assert_eq!(
1185 MinitMethod::from_str("POINTS").unwrap(),
1186 MinitMethod::Points
1187 );
1188 assert_eq!(
1189 MinitMethod::from_str("k-means++").unwrap(),
1190 MinitMethod::PlusPlus
1191 );
1192 assert_eq!(
1193 MinitMethod::from_str("kmeans++").unwrap(),
1194 MinitMethod::PlusPlus
1195 );
1196 assert_eq!(
1197 MinitMethod::from_str("plusplus").unwrap(),
1198 MinitMethod::PlusPlus
1199 );
1200 assert_eq!(
1201 MinitMethod::from_str("K-MEANS++").unwrap(),
1202 MinitMethod::PlusPlus
1203 );
1204
1205 assert!(MinitMethod::from_str("invalid").is_err());
1207 }
1208}