scirs2_cluster/vq/
kmeans2.rs

1//! Enhanced K-means clustering implementation with multiple initialization methods
2
3use scirs2_core::ndarray::{s, Array1, Array2, ArrayView2};
4use scirs2_core::numeric::{Float, FromPrimitive};
5use scirs2_core::random::{rngs::StdRng, Rng, RngCore, SeedableRng};
6use scirs2_core::random::{Distribution, Normal};
7use std::fmt::Debug;
8use std::str::FromStr;
9
10use super::{euclidean_distance, vq};
11use crate::error::{ClusteringError, Result};
12use scirs2_core::validation::{clustering::*, parameters::*};
13
14/// Initialization methods for kmeans2
15#[derive(Debug, Clone, Copy, PartialEq, Eq)]
16pub enum MinitMethod {
17    /// Generate k centroids from a Gaussian with mean and variance estimated from the data
18    Random,
19    /// Choose k observations (rows) at random from data for the initial centroids
20    Points,
21    /// K-means++ initialization (careful seeding)
22    PlusPlus,
23}
24
25impl MinitMethod {
26    /// Parse a string into a MinitMethod (SciPy-compatible)
27    ///
28    /// # Arguments
29    ///
30    /// * `s` - String representation of the initialization method
31    ///
32    /// # Returns
33    ///
34    /// * The corresponding MinitMethod enum value
35    ///
36    /// # Errors
37    ///
38    /// * Returns an error if the string is not recognized
39    pub fn parse_method(s: &str) -> Result<Self> {
40        match s.to_lowercase().as_str() {
41            "random" => Ok(MinitMethod::Random),
42            "points" => Ok(MinitMethod::Points),
43            "k-means++" | "kmeans++" | "plusplus" => Ok(MinitMethod::PlusPlus),
44            _ => Err(ClusteringError::InvalidInput(format!(
45                "Unknown initialization method: '{}'. Valid options are: 'random', 'points', 'k-means++'",
46                s
47            ))),
48        }
49    }
50}
51
52impl FromStr for MinitMethod {
53    type Err = ClusteringError;
54
55    fn from_str(s: &str) -> Result<Self> {
56        Self::parse_method(s)
57    }
58}
59
60/// Methods for handling empty clusters during K-means clustering
61///
62/// When a cluster becomes empty during the K-means iteration process,
63/// this enum determines how the algorithm should respond.
64#[derive(Debug, Clone, Copy, PartialEq, Eq)]
65pub enum MissingMethod {
66    /// Give a warning and continue with the algorithm
67    ///
68    /// The algorithm will continue execution, potentially with fewer
69    /// effective clusters than originally requested.
70    Warn,
71    /// Raise a ClusteringError and terminate the algorithm
72    ///
73    /// The algorithm will stop execution and return an error when
74    /// an empty cluster is encountered.
75    Raise,
76}
77
78/// Enhanced K-means clustering algorithm compatible with SciPy's kmeans2
79///
80/// # Arguments
81///
82/// * `data` - Input data (n_samples × n_features) or pre-computed centroids if minit is None
83/// * `k` - Number of clusters or initial centroids array
84/// * `iter` - Number of iterations
85/// * `thresh` - Convergence threshold (not used yet)
86/// * `minit` - Initialization method (None if k is a centroid array)
87/// * `missing` - Method to handle empty clusters
88/// * `check_finite` - Whether to check input validity
89/// * `randomseed` - Optional random seed
90///
91/// # Returns
92///
93/// * Tuple of (centroids, labels) where:
94///   - centroids: Array of shape (k × n_features)
95///   - labels: Array of shape (n_samples,) with cluster assignments
96#[allow(clippy::too_many_arguments)]
97#[allow(dead_code)]
98pub fn kmeans2<F>(
99    data: ArrayView2<F>,
100    k: usize,
101    iter: Option<usize>,
102    thresh: Option<F>,
103    minit: Option<MinitMethod>,
104    missing: Option<MissingMethod>,
105    check_finite: Option<bool>,
106    randomseed: Option<u64>,
107) -> Result<(Array2<F>, Array1<usize>)>
108where
109    F: Float + FromPrimitive + Debug + std::iter::Sum + std::fmt::Display,
110{
111    let n_samples = data.shape()[0];
112    let n_features = data.shape()[1];
113    let iterations = iter.unwrap_or(10);
114    let threshold = thresh.unwrap_or(F::from(1e-5).unwrap());
115    let missing_method = missing.unwrap_or(MissingMethod::Warn);
116    let check_finite_flag = check_finite.unwrap_or(true);
117
118    // Use unified validation
119    validate_clustering_data(&data, "K-means", check_finite_flag, Some(k))
120        .map_err(|e| ClusteringError::InvalidInput(format!("K-means: {}", e)))?;
121
122    check_n_clusters_bounds(&data, k, "K-means")
123        .map_err(|e| ClusteringError::InvalidInput(format!("{}", e)))?;
124
125    check_iteration_params(iterations, threshold, "K-means")
126        .map_err(|e| ClusteringError::InvalidInput(format!("{}", e)))?;
127
128    // Initialize centroids
129    let init_method = minit.unwrap_or(MinitMethod::PlusPlus); // Default to k-means++
130    let mut centroids = match init_method {
131        MinitMethod::Random => krandinit(data, k, randomseed)?,
132        MinitMethod::Points => kpoints(data, k, randomseed)?,
133        MinitMethod::PlusPlus => kmeans_plus_plus(data, k, randomseed)?,
134    };
135
136    let mut labels;
137
138    // Run K-means iterations
139    for _iteration in 0..iterations {
140        // Store previous centroids for convergence check
141        let prev_centroids = centroids.clone();
142
143        // Assign samples to nearest centroid
144        let (new_labels, _distances) = vq(data, centroids.view())?;
145        labels = new_labels;
146
147        // Compute new centroids
148        let mut new_centroids = Array2::zeros((k, n_features));
149        let mut counts = Array1::zeros(k);
150
151        for i in 0..n_samples {
152            let cluster = labels[i];
153            let point = data.slice(s![i, ..]);
154
155            for j in 0..n_features {
156                new_centroids[[cluster, j]] = new_centroids[[cluster, j]] + point[j];
157            }
158
159            counts[cluster] += 1;
160        }
161
162        // Handle empty clusters
163        for i in 0..k {
164            if counts[i] == 0 {
165                match missing_method {
166                    MissingMethod::Warn => {
167                        eprintln!("One of the clusters is empty. Re-run kmeans with a different initialization.");
168                        // Find point furthest from its centroid
169                        let mut max_dist = F::zero();
170                        let mut far_idx = 0;
171
172                        for j in 0..n_samples {
173                            let cluster_j = labels[j];
174                            let dist = euclidean_distance(
175                                data.slice(s![j, ..]),
176                                centroids.slice(s![cluster_j, ..]),
177                            );
178                            if dist > max_dist {
179                                max_dist = dist;
180                                far_idx = j;
181                            }
182                        }
183
184                        // Move this point to the empty cluster
185                        for j in 0..n_features {
186                            new_centroids[[i, j]] = data[[far_idx, j]];
187                        }
188                        counts[i] = 1;
189                    }
190                    MissingMethod::Raise => {
191                        return Err(ClusteringError::EmptyCluster(
192                            "One of the clusters is empty. Re-run kmeans with a different initialization.".to_string()
193                        ));
194                    }
195                }
196            } else {
197                // Normalize by the number of points in the cluster
198                for j in 0..n_features {
199                    new_centroids[[i, j]] = new_centroids[[i, j]] / F::from(counts[i]).unwrap();
200                }
201            }
202        }
203
204        centroids = new_centroids;
205
206        // Check for convergence
207        let mut max_centroid_shift = F::zero();
208        for i in 0..k {
209            for j in 0..n_features {
210                let shift = (centroids[[i, j]] - prev_centroids[[i, j]]).abs();
211                if shift > max_centroid_shift {
212                    max_centroid_shift = shift;
213                }
214            }
215        }
216
217        // If convergence reached, break early
218        if max_centroid_shift < threshold {
219            break;
220        }
221    }
222
223    // Final assignment
224    let (final_labels, _distances) = vq(data, centroids.view())?;
225
226    Ok((centroids, final_labels))
227}
228
229/// SciPy-compatible K-means clustering with string-based parameters
230///
231/// This function provides an interface compatible with SciPy's kmeans2,
232/// accepting string-based initialization methods.
233///
234/// # Arguments
235///
236/// * `data` - Input data (n_samples × n_features)
237/// * `k` - Number of clusters
238/// * `iter` - Number of iterations
239/// * `thresh` - Convergence threshold
240/// * `minit` - Initialization method as string ('random', 'points', 'k-means++')
241/// * `missing` - Method to handle empty clusters ('warn', 'raise')
242/// * `check_finite` - Whether to check input validity
243/// * `randomseed` - Optional random seed
244///
245/// # Returns
246///
247/// * Tuple of (centroids, labels) where:
248///   - centroids: Array of shape (k × n_features)
249///   - labels: Array of shape (n_samples,) with cluster assignments
250///
251/// # Examples
252///
253/// ```
254/// use scirs2_core::ndarray::Array2;
255/// use scirs2_cluster::vq::kmeans2_str;
256///
257/// let data = Array2::from_shape_vec((6, 2), vec![
258///     1.0, 1.0, 1.1, 1.1, 0.9, 0.9,
259///     8.0, 8.0, 8.1, 8.1, 7.9, 7.9,
260/// ]).unwrap();
261///
262/// let (centroids, labels) = kmeans2_str(
263///     data.view(), 2, Some(20), Some(1e-5), Some("k-means++"),
264///     Some("warn"), Some(true), Some(42)
265/// ).unwrap();
266/// ```
267#[allow(clippy::too_many_arguments)]
268#[allow(dead_code)]
269pub fn kmeans2_str<F>(
270    data: ArrayView2<F>,
271    k: usize,
272    iter: Option<usize>,
273    thresh: Option<F>,
274    minit: Option<&str>,
275    missing: Option<&str>,
276    check_finite: Option<bool>,
277    randomseed: Option<u64>,
278) -> Result<(Array2<F>, Array1<usize>)>
279where
280    F: Float + FromPrimitive + Debug + std::iter::Sum + std::fmt::Display,
281{
282    // Parse string parameters
283    let minit_method = if let Some(method_str) = minit {
284        Some(MinitMethod::from_str(method_str)?)
285    } else {
286        Some(MinitMethod::PlusPlus) // Default to k-means++ like SciPy
287    };
288
289    let missing_method = if let Some(missing_str) = missing {
290        match missing_str.to_lowercase().as_str() {
291            "warn" => Some(MissingMethod::Warn),
292            "raise" => Some(MissingMethod::Raise),
293            _ => {
294                return Err(ClusteringError::InvalidInput(format!(
295                    "Unknown missing method: '{}'. Valid options are: 'warn', 'raise'",
296                    missing_str
297                )))
298            }
299        }
300    } else {
301        Some(MissingMethod::Warn) // Default to warn like SciPy
302    };
303
304    // Call the main kmeans2 function
305    kmeans2(
306        data,
307        k,
308        iter,
309        thresh,
310        minit_method,
311        missing_method,
312        check_finite,
313        randomseed,
314    )
315}
316
317/// Random initialization: generate k centroids from a Gaussian with mean and
318/// variance estimated from the data
319#[allow(dead_code)]
320fn krandinit<F>(data: ArrayView2<F>, k: usize, randomseed: Option<u64>) -> Result<Array2<F>>
321where
322    F: Float + FromPrimitive + Debug + std::iter::Sum,
323{
324    let n_samples = data.shape()[0];
325    let n_features = data.shape()[1];
326
327    // Calculate mean and variance for each feature
328    let mut means = Array1::<F>::zeros(n_features);
329    let mut vars = Array1::<F>::zeros(n_features);
330
331    for j in 0..n_features {
332        let mut sum = F::zero();
333        for i in 0..n_samples {
334            sum = sum + data[[i, j]];
335        }
336        means[j] = sum / F::from(n_samples).unwrap();
337
338        let mut var_sum = F::zero();
339        for i in 0..n_samples {
340            let diff = data[[i, j]] - means[j];
341            var_sum = var_sum + diff * diff;
342        }
343        vars[j] = var_sum / F::from(n_samples).unwrap();
344    }
345
346    // Generate random centroids from Gaussian distribution
347    let mut centroids = Array2::<F>::zeros((k, n_features));
348
349    let mut rng: Box<dyn RngCore> = if let Some(_seed) = randomseed {
350        Box::new(StdRng::seed_from_u64(_seed))
351    } else {
352        Box::new(scirs2_core::random::rng())
353    };
354
355    for i in 0..k {
356        for j in 0..n_features {
357            // Convert to f64 for normal distribution
358            let mean = means[j].to_f64().unwrap();
359            let std = vars[j].sqrt().to_f64().unwrap();
360
361            if std > 0.0 {
362                let normal = Normal::new(mean, std).unwrap();
363                let value = normal.sample(&mut rng);
364                centroids[[i, j]] = F::from(value).unwrap();
365            } else {
366                centroids[[i, j]] = means[j];
367            }
368        }
369    }
370
371    Ok(centroids)
372}
373
374/// Points initialization: choose k observations (rows) at random from data
375#[allow(dead_code)]
376fn kpoints<F>(data: ArrayView2<F>, k: usize, randomseed: Option<u64>) -> Result<Array2<F>>
377where
378    F: Float + FromPrimitive + Debug,
379{
380    let n_samples = data.shape()[0];
381    let n_features = data.shape()[1];
382
383    let mut rng: Box<dyn RngCore> = if let Some(_seed) = randomseed {
384        Box::new(StdRng::seed_from_u64(_seed))
385    } else {
386        Box::new(scirs2_core::random::rng())
387    };
388
389    // Choose k random indices without replacement
390    let mut indices: Vec<usize> = (0..n_samples).collect();
391
392    // Shuffle and take first k
393    for i in 0..k {
394        let j = rng.random_range(i..n_samples);
395        indices.swap(i, j);
396    }
397
398    // Extract centroids from _data
399    let mut centroids = Array2::zeros((k, n_features));
400    for i in 0..k {
401        let idx = indices[i];
402        for j in 0..n_features {
403            centroids[[i, j]] = data[[idx, j]];
404        }
405    }
406
407    Ok(centroids)
408}
409
410/// K-means++ initialization
411#[allow(dead_code)]
412fn kmeans_plus_plus<F>(data: ArrayView2<F>, k: usize, randomseed: Option<u64>) -> Result<Array2<F>>
413where
414    F: Float + FromPrimitive + Debug + std::iter::Sum,
415{
416    let n_samples = data.shape()[0];
417    let n_features = data.shape()[1];
418
419    let mut rng: Box<dyn RngCore> = if let Some(_seed) = randomseed {
420        Box::new(StdRng::seed_from_u64(_seed))
421    } else {
422        Box::new(scirs2_core::random::rng())
423    };
424
425    let mut centroids = Array2::zeros((k, n_features));
426
427    // Choose first centroid randomly
428    let first_idx = rng.random_range(0..n_samples);
429    for j in 0..n_features {
430        centroids[[0, j]] = data[[first_idx, j]];
431    }
432
433    // Choose remaining centroids
434    for i in 1..k {
435        // Calculate squared distances to nearest centroid
436        let mut distances = Array1::<F>::zeros(n_samples);
437
438        for j in 0..n_samples {
439            let mut min_dist = F::infinity();
440            for c in 0..i {
441                let dist = euclidean_distance(data.slice(s![j, ..]), centroids.slice(s![c, ..]));
442                if dist < min_dist {
443                    min_dist = dist;
444                }
445            }
446            distances[j] = min_dist * min_dist;
447        }
448
449        // Calculate probabilities
450        let total = distances.iter().fold(F::zero(), |a, &b| a + b);
451        let mut probabilities = Array1::<F>::zeros(n_samples);
452        for j in 0..n_samples {
453            probabilities[j] = distances[j] / total;
454        }
455
456        // Choose next centroid based on weighted probability
457        let mut cumsum = F::zero();
458        let r = F::from(rng.random::<f64>()).unwrap();
459        let mut next_idx = n_samples - 1;
460
461        for j in 0..n_samples {
462            cumsum = cumsum + probabilities[j];
463            if cumsum > r {
464                next_idx = j;
465                break;
466            }
467        }
468
469        // Add chosen centroid
470        for j in 0..n_features {
471            centroids[[i, j]] = data[[next_idx, j]];
472        }
473    }
474
475    Ok(centroids)
476}
477
478#[cfg(test)]
479mod tests {
480    use super::*;
481    use approx::assert_abs_diff_eq;
482    use scirs2_core::ndarray::{array, Array2};
483
484    #[test]
485    fn test_kmeans2_basic_functionality() {
486        let data = array![
487            [1.0, 1.0],
488            [1.5, 1.5],
489            [0.8, 0.9],
490            [8.0, 8.0],
491            [8.2, 8.1],
492            [7.8, 7.9],
493        ];
494
495        let (centroids, labels) = kmeans2(
496            data.view(),
497            2,
498            Some(50),
499            Some(1e-6),
500            Some(MinitMethod::PlusPlus),
501            Some(MissingMethod::Warn),
502            Some(true),
503            Some(42),
504        )
505        .unwrap();
506
507        // Should have 2 centroids
508        assert_eq!(centroids.shape(), [2, 2]);
509
510        // Should have labels for all 6 points
511        assert_eq!(labels.len(), 6);
512
513        // All labels should be 0 or 1
514        assert!(labels.iter().all(|&l| l == 0 || l == 1));
515
516        // Should have points from both clusters
517        let unique_labels: std::collections::HashSet<_> = labels.iter().cloned().collect();
518        assert_eq!(unique_labels.len(), 2);
519    }
520
521    #[test]
522    fn test_kmeans2_parameter_validation() {
523        let data = array![[1.0, 1.0], [2.0, 2.0]];
524
525        // Test k=0 (invalid)
526        let result = kmeans2(
527            data.view(),
528            0,
529            None,
530            None,
531            Some(MinitMethod::Random),
532            None,
533            None,
534            None,
535        );
536        assert!(result.is_err());
537
538        // Test k > n_samples (invalid)
539        let result = kmeans2(
540            data.view(),
541            5,
542            None,
543            None,
544            Some(MinitMethod::Random),
545            None,
546            None,
547            None,
548        );
549        assert!(result.is_err());
550    }
551
552    #[test]
553    fn test_kmeans2_initialization_methods() {
554        let data = array![
555            [1.0, 1.0],
556            [1.5, 1.5],
557            [0.8, 0.9],
558            [8.0, 8.0],
559            [8.2, 8.1],
560            [7.8, 7.9],
561        ];
562
563        let methods = vec![
564            MinitMethod::Random,
565            MinitMethod::Points,
566            MinitMethod::PlusPlus,
567        ];
568
569        for method in methods {
570            let result = kmeans2(
571                data.view(),
572                2,
573                Some(10),
574                None,
575                Some(method),
576                Some(MissingMethod::Warn),
577                None,
578                Some(42),
579            );
580
581            assert!(result.is_ok(), "Failed with method: {:?}", method);
582            let (centroids, labels) = result.unwrap();
583            assert_eq!(centroids.shape(), [2, 2]);
584            assert_eq!(labels.len(), 6);
585        }
586    }
587
588    #[test]
589    fn test_kmeans2_reproducibility_with_seed() {
590        let data = array![
591            [1.0, 1.0],
592            [1.5, 1.5],
593            [0.8, 0.9],
594            [8.0, 8.0],
595            [8.2, 8.1],
596            [7.8, 7.9],
597        ];
598
599        let (centroids1, labels1) = kmeans2(
600            data.view(),
601            2,
602            Some(10),
603            None,
604            Some(MinitMethod::Random),
605            None,
606            None,
607            Some(42),
608        )
609        .unwrap();
610
611        let (centroids2, labels2) = kmeans2(
612            data.view(),
613            2,
614            Some(10),
615            None,
616            Some(MinitMethod::Random),
617            None,
618            None,
619            Some(42),
620        )
621        .unwrap();
622
623        // With same seed, results should be identical
624        assert_eq!(labels1, labels2);
625
626        // Centroids should be very close (allowing for floating point precision)
627        for i in 0..centroids1.shape()[0] {
628            for j in 0..centroids1.shape()[1] {
629                assert_abs_diff_eq!(centroids1[[i, j]], centroids2[[i, j]], epsilon = 1e-10);
630            }
631        }
632    }
633
634    #[test]
635    fn test_kmeans2_single_cluster() {
636        let data = array![[1.0, 1.0], [1.1, 1.1], [0.9, 0.9],];
637
638        let (centroids, labels) = kmeans2(
639            data.view(),
640            1,
641            Some(10),
642            None,
643            Some(MinitMethod::Points),
644            None,
645            None,
646            Some(42),
647        )
648        .unwrap();
649
650        // Should have 1 centroid
651        assert_eq!(centroids.shape(), [1, 2]);
652
653        // All labels should be 0
654        assert!(labels.iter().all(|&l| l == 0));
655    }
656
657    #[test]
658    fn test_kmeans2_identical_points() {
659        let data = array![[1.0, 1.0], [1.0, 1.0], [1.0, 1.0], [1.0, 1.0],];
660
661        let (centroids, labels) = kmeans2(
662            data.view(),
663            2,
664            Some(10),
665            None,
666            Some(MinitMethod::Points),
667            Some(MissingMethod::Warn),
668            None,
669            Some(42),
670        )
671        .unwrap();
672
673        // Should still produce valid results
674        assert_eq!(centroids.shape(), [2, 2]);
675        assert_eq!(labels.len(), 4);
676
677        // All labels should be valid (0 or 1)
678        assert!(labels.iter().all(|&l| l == 0 || l == 1));
679    }
680
681    #[test]
682    fn test_kmeans2_missing_method_warn() {
683        // Create data that might lead to empty clusters
684        let data = array![[0.0, 0.0], [0.1, 0.1], [10.0, 10.0],];
685
686        let result = kmeans2(
687            data.view(),
688            2,
689            Some(5),
690            None,
691            Some(MinitMethod::Random),
692            Some(MissingMethod::Warn),
693            None,
694            Some(123),
695        );
696
697        // Should succeed with warning
698        assert!(result.is_ok());
699    }
700
701    #[test]
702    fn test_kmeans2_convergence_behavior() {
703        // Create well-separated clusters
704        let data = array![
705            [1.0, 1.0],
706            [1.1, 1.1],
707            [0.9, 0.9],
708            [10.0, 10.0],
709            [10.1, 10.1],
710            [9.9, 9.9],
711        ];
712
713        // Test with different iteration counts
714        let (centroids_few_) = kmeans2(
715            data.view(),
716            2,
717            Some(1),
718            None,
719            Some(MinitMethod::PlusPlus),
720            None,
721            None,
722            Some(42),
723        )
724        .unwrap();
725
726        let (centroids_many_) = kmeans2(
727            data.view(),
728            2,
729            Some(100),
730            None,
731            Some(MinitMethod::PlusPlus),
732            None,
733            None,
734            Some(42),
735        )
736        .unwrap();
737
738        // Results should be valid for both
739        assert_eq!(centroids_few_.0.shape(), [2, 2]);
740        assert_eq!(centroids_many_.0.shape(), [2, 2]);
741    }
742
743    #[test]
744    fn test_kmeans2_high_k() {
745        let data = array![[1.0, 1.0], [2.0, 2.0], [3.0, 3.0], [4.0, 4.0], [5.0, 5.0],];
746
747        // Test with k equal to number of points
748        let (centroids, labels) = kmeans2(
749            data.view(),
750            5,
751            Some(10),
752            None,
753            Some(MinitMethod::Points),
754            None,
755            None,
756            Some(42),
757        )
758        .unwrap();
759
760        assert_eq!(centroids.shape(), [5, 2]);
761        assert_eq!(labels.len(), 5);
762
763        // Each point should be its own cluster
764        let unique_labels: std::collections::HashSet<_> = labels.iter().cloned().collect();
765        assert_eq!(unique_labels.len(), 5);
766    }
767
768    #[test]
769    fn test_kmeans2_different_thresholds() {
770        let data = array![[1.0, 1.0], [1.5, 1.5], [8.0, 8.0], [8.5, 8.5],];
771
772        // Test with different convergence thresholds
773        let result1 = kmeans2(
774            data.view(),
775            2,
776            Some(100),
777            Some(1e-10), // Very strict
778            Some(MinitMethod::PlusPlus),
779            None,
780            None,
781            Some(42),
782        );
783
784        let result2 = kmeans2(
785            data.view(),
786            2,
787            Some(100),
788            Some(1e-1), // Very loose
789            Some(MinitMethod::PlusPlus),
790            None,
791            None,
792            Some(42),
793        );
794
795        assert!(result1.is_ok());
796        assert!(result2.is_ok());
797    }
798
799    #[test]
800    fn test_kmeans2_convergence_threshold() {
801        // Test early convergence with well-separated clusters
802        let data = array![
803            [1.0, 1.0],
804            [1.1, 1.1],
805            [0.9, 0.9],
806            [10.0, 10.0],
807            [10.1, 10.1],
808            [9.9, 9.9],
809        ];
810
811        // With a tight threshold, should converge quickly
812        let result1 = kmeans2(
813            data.view(),
814            2,
815            Some(100),   // Allow many iterations
816            Some(1e-10), // Very strict threshold - should converge quickly
817            Some(MinitMethod::PlusPlus),
818            None,
819            None,
820            Some(42),
821        );
822
823        assert!(result1.is_ok());
824        let (centroids1, labels1) = result1.unwrap();
825        assert_eq!(centroids1.shape(), [2, 2]);
826        assert_eq!(labels1.len(), 6);
827
828        // With a loose threshold, should also work
829        let result2 = kmeans2(
830            data.view(),
831            2,
832            Some(100),
833            Some(1e-1), // Very loose threshold
834            Some(MinitMethod::PlusPlus),
835            None,
836            None,
837            Some(42),
838        );
839
840        assert!(result2.is_ok());
841        let (centroids2, labels2) = result2.unwrap();
842        assert_eq!(centroids2.shape(), [2, 2]);
843        assert_eq!(labels2.len(), 6);
844    }
845
846    #[test]
847    fn test_kmeans2_check_finite() {
848        // Test with finite data (should work)
849        let data = array![[1.0, 2.0], [1.5, 1.5], [8.0, 8.0],];
850
851        let result = kmeans2(
852            data.view(),
853            2,
854            Some(10),
855            None,
856            Some(MinitMethod::Random),
857            None,
858            Some(true), // check_finite = true
859            Some(42),
860        );
861        assert!(result.is_ok());
862
863        // Test with check_finite disabled (should also work with finite data)
864        let result = kmeans2(
865            data.view(),
866            2,
867            Some(10),
868            None,
869            Some(MinitMethod::Random),
870            None,
871            Some(false), // check_finite = false
872            Some(42),
873        );
874        assert!(result.is_ok());
875    }
876
877    #[test]
878    fn test_kmeans2_large_dataset() {
879        // Generate a larger dataset for stress testing
880        let mut data = Array2::zeros((100, 3));
881
882        // Create 3 clusters
883        for i in 0..100 {
884            let cluster = i % 3;
885            match cluster {
886                0 => {
887                    data[[i, 0]] = 1.0 + (i as f64) * 0.01;
888                    data[[i, 1]] = 1.0 + (i as f64) * 0.01;
889                    data[[i, 2]] = 1.0 + (i as f64) * 0.01;
890                }
891                1 => {
892                    data[[i, 0]] = 5.0 + (i as f64) * 0.01;
893                    data[[i, 1]] = 5.0 + (i as f64) * 0.01;
894                    data[[i, 2]] = 5.0 + (i as f64) * 0.01;
895                }
896                2 => {
897                    data[[i, 0]] = 10.0 + (i as f64) * 0.01;
898                    data[[i, 1]] = 10.0 + (i as f64) * 0.01;
899                    data[[i, 2]] = 10.0 + (i as f64) * 0.01;
900                }
901                _ => unreachable!(),
902            }
903        }
904
905        let (centroids, labels) = kmeans2(
906            data.view(),
907            3,
908            Some(50),
909            None,
910            Some(MinitMethod::PlusPlus),
911            None,
912            None,
913            Some(42),
914        )
915        .unwrap();
916
917        assert_eq!(centroids.shape(), [3, 3]);
918        assert_eq!(labels.len(), 100);
919
920        // Should find 3 clusters
921        let unique_labels: std::collections::HashSet<_> = labels.iter().cloned().collect();
922        assert_eq!(unique_labels.len(), 3);
923    }
924
925    // Tests for string-based parameter support
926    use super::kmeans2_str;
927
928    #[test]
929    fn test_kmeans2_str_basic_functionality() {
930        let data = array![
931            [1.0, 1.0],
932            [1.5, 1.5],
933            [0.8, 0.9],
934            [8.0, 8.0],
935            [8.2, 8.1],
936            [7.8, 7.9],
937        ];
938
939        let (centroids, labels) = kmeans2_str(
940            data.view(),
941            2,
942            Some(50),
943            Some(1e-6),
944            Some("k-means++"),
945            Some("warn"),
946            Some(true),
947            Some(42),
948        )
949        .unwrap();
950
951        assert_eq!(centroids.shape(), [2, 2]);
952        assert_eq!(labels.len(), 6);
953        assert!(labels.iter().all(|&l| l == 0 || l == 1));
954
955        let unique_labels: std::collections::HashSet<_> = labels.iter().cloned().collect();
956        assert_eq!(unique_labels.len(), 2);
957    }
958
959    #[test]
960    fn test_kmeans2_str_all_init_methods() {
961        let data = array![
962            [1.0, 1.0],
963            [1.5, 1.5],
964            [0.8, 0.9],
965            [8.0, 8.0],
966            [8.2, 8.1],
967            [7.8, 7.9],
968        ];
969
970        let methods = vec!["random", "points", "k-means++", "kmeans++", "plusplus"];
971
972        for method in methods {
973            let result = kmeans2_str(
974                data.view(),
975                2,
976                Some(10),
977                None,
978                Some(method),
979                Some("warn"),
980                None,
981                Some(42),
982            );
983
984            assert!(result.is_ok(), "Failed with method: '{}'", method);
985            let (centroids, labels) = result.unwrap();
986            assert_eq!(centroids.shape(), [2, 2]);
987            assert_eq!(labels.len(), 6);
988        }
989    }
990
991    #[test]
992    fn test_kmeans2_str_case_insensitive() {
993        let data = array![[1.0, 1.0], [2.0, 2.0], [8.0, 8.0], [9.0, 9.0],];
994
995        // Test case insensitive method names
996        let methods = vec![
997            "RANDOM",
998            "Random",
999            "random",
1000            "POINTS",
1001            "Points",
1002            "points",
1003            "K-MEANS++",
1004            "K-Means++",
1005            "k-means++",
1006        ];
1007
1008        for method in methods {
1009            let result = kmeans2_str(
1010                data.view(),
1011                2,
1012                Some(10),
1013                None,
1014                Some(method),
1015                Some("warn"),
1016                None,
1017                Some(42),
1018            );
1019
1020            assert!(result.is_ok(), "Failed with method: '{}'", method);
1021        }
1022    }
1023
1024    #[test]
1025    fn test_kmeans2_str_missing_methods() {
1026        let data = array![[1.0, 1.0], [2.0, 2.0], [8.0, 8.0],];
1027
1028        // Test different missing methods
1029        let missing_methods = vec!["warn", "raise", "WARN", "RAISE"];
1030
1031        for missing_method in missing_methods {
1032            let result = kmeans2_str(
1033                data.view(),
1034                2,
1035                Some(5),
1036                None,
1037                Some("points"),
1038                Some(missing_method),
1039                None,
1040                Some(42),
1041            );
1042
1043            assert!(
1044                result.is_ok(),
1045                "Failed with missing method: '{}'",
1046                missing_method
1047            );
1048        }
1049    }
1050
1051    #[test]
1052    fn test_kmeans2_str_invalid_method() {
1053        let data = array![[1.0, 1.0], [2.0, 2.0]];
1054
1055        // Test invalid initialization method
1056        let result = kmeans2_str(
1057            data.view(),
1058            2,
1059            Some(10),
1060            None,
1061            Some("invalid_method"),
1062            Some("warn"),
1063            None,
1064            None,
1065        );
1066
1067        assert!(result.is_err());
1068        assert!(result
1069            .unwrap_err()
1070            .to_string()
1071            .contains("Unknown initialization method"));
1072    }
1073
1074    #[test]
1075    fn test_kmeans2_str_invalid_missing_method() {
1076        let data = array![[1.0, 1.0], [2.0, 2.0]];
1077
1078        // Test invalid missing method
1079        let result = kmeans2_str(
1080            data.view(),
1081            2,
1082            Some(10),
1083            None,
1084            Some("points"),
1085            Some("invalid_missing"),
1086            None,
1087            None,
1088        );
1089
1090        assert!(result.is_err());
1091        assert!(result
1092            .unwrap_err()
1093            .to_string()
1094            .contains("Unknown missing method"));
1095    }
1096
1097    #[test]
1098    fn test_kmeans2_str_defaults() {
1099        let data = array![[1.0, 1.0], [1.5, 1.5], [8.0, 8.0], [8.5, 8.5],];
1100
1101        // Test with all None parameters (should use defaults)
1102        let result = kmeans2_str(
1103            data.view(),
1104            2,
1105            Some(10),
1106            None,
1107            None, // Should default to k-means++
1108            None, // Should default to warn
1109            None,
1110            Some(42),
1111        );
1112
1113        assert!(result.is_ok());
1114        let (centroids, labels) = result.unwrap();
1115        assert_eq!(centroids.shape(), [2, 2]);
1116        assert_eq!(labels.len(), 4);
1117    }
1118
1119    #[test]
1120    fn test_kmeans2_str_equivalence_with_enum() {
1121        let data = array![
1122            [1.0, 1.0],
1123            [1.5, 1.5],
1124            [0.8, 0.9],
1125            [8.0, 8.0],
1126            [8.2, 8.1],
1127            [7.8, 7.9],
1128        ];
1129
1130        // Test that string version produces same results as enum version
1131        let (centroids_enum, labels_enum) = kmeans2(
1132            data.view(),
1133            2,
1134            Some(50),
1135            Some(1e-6),
1136            Some(MinitMethod::PlusPlus),
1137            Some(MissingMethod::Warn),
1138            Some(true),
1139            Some(42),
1140        )
1141        .unwrap();
1142
1143        let (centroids_str, labels_str) = kmeans2_str(
1144            data.view(),
1145            2,
1146            Some(50),
1147            Some(1e-6),
1148            Some("k-means++"),
1149            Some("warn"),
1150            Some(true),
1151            Some(42),
1152        )
1153        .unwrap();
1154
1155        // Results should be identical
1156        assert_eq!(labels_enum, labels_str);
1157
1158        for i in 0..centroids_enum.shape()[0] {
1159            for j in 0..centroids_enum.shape()[1] {
1160                assert_abs_diff_eq!(
1161                    centroids_enum[[i, j]],
1162                    centroids_str[[i, j]],
1163                    epsilon = 1e-10
1164                );
1165            }
1166        }
1167    }
1168
1169    #[test]
1170    fn test_minit_method_from_str() {
1171        // Test MinitMethod::from_str function directly
1172        assert_eq!(
1173            MinitMethod::from_str("random").unwrap(),
1174            MinitMethod::Random
1175        );
1176        assert_eq!(
1177            MinitMethod::from_str("RANDOM").unwrap(),
1178            MinitMethod::Random
1179        );
1180        assert_eq!(
1181            MinitMethod::from_str("points").unwrap(),
1182            MinitMethod::Points
1183        );
1184        assert_eq!(
1185            MinitMethod::from_str("POINTS").unwrap(),
1186            MinitMethod::Points
1187        );
1188        assert_eq!(
1189            MinitMethod::from_str("k-means++").unwrap(),
1190            MinitMethod::PlusPlus
1191        );
1192        assert_eq!(
1193            MinitMethod::from_str("kmeans++").unwrap(),
1194            MinitMethod::PlusPlus
1195        );
1196        assert_eq!(
1197            MinitMethod::from_str("plusplus").unwrap(),
1198            MinitMethod::PlusPlus
1199        );
1200        assert_eq!(
1201            MinitMethod::from_str("K-MEANS++").unwrap(),
1202            MinitMethod::PlusPlus
1203        );
1204
1205        // Test invalid method
1206        assert!(MinitMethod::from_str("invalid").is_err());
1207    }
1208}