scirs2_cluster/vq/
kmeans.rs

1//! K-means clustering implementation
2
3use scirs2_core::ndarray::{s, Array1, Array2, ArrayView1, ArrayView2};
4use scirs2_core::numeric::{Float, FromPrimitive};
5use scirs2_core::random::Rng;
6use std::fmt::Debug;
7
8use super::{euclidean_distance, vq};
9use crate::error::{ClusteringError, Result};
10// use scirs2_core::validation::{clustering::*, parameters::*};
11
12// Re-export kmeans2 related types and functions
13
14/// Options for K-means clustering
15#[derive(Debug, Clone)]
16pub struct KMeansOptions<F: Float> {
17    /// Maximum number of iterations
18    pub max_iter: usize,
19    /// Convergence threshold for centroid movement
20    pub tol: F,
21    /// Random seed for initialization
22    pub random_seed: Option<u64>,
23    /// Number of different initializations to try
24    pub n_init: usize,
25    /// Method to use for centroid initialization
26    pub init_method: KMeansInit,
27}
28
29impl<F: Float + FromPrimitive> Default for KMeansOptions<F> {
30    fn default() -> Self {
31        Self {
32            max_iter: 300,
33            tol: F::from(1e-4).unwrap(),
34            random_seed: None,
35            n_init: 10,
36            init_method: KMeansInit::KMeansPlusPlus,
37        }
38    }
39}
40
41/// K-means clustering algorithm (SciPy-compatible version)
42///
43/// # Arguments
44///
45/// * `obs` - Input data (n_samples × n_features)
46/// * `k_or_guess` - Number of clusters or initial guess for centroids
47/// * `iter` - Maximum number of iterations (default: 20)
48/// * `thresh` - Convergence threshold (default: 1e-5)
49/// * `check_finite` - Whether to check for finite values (default: true)
50/// * `seed` - Random seed for initialization (optional)
51///
52/// # Returns
53///
54/// * Tuple of (centroids, distortion) where:
55///   - centroids: Array of shape (k × n_features)
56///   - distortion: Sum of squared distances to centroids
57///
58/// # Examples
59///
60/// ```
61/// use scirs2_core::ndarray::{ArrayView1, Array2, ArrayView2};
62/// use scirs2_cluster::vq::kmeans;
63///
64/// let data = Array2::from_shape_vec((6, 2), vec![
65///     1.0, 2.0,
66///     1.2, 1.8,
67///     0.8, 1.9,
68///     3.7, 4.2,
69///     3.9, 3.9,
70///     4.2, 4.1,
71/// ]).unwrap();
72///
73/// let (centroids, distortion) = kmeans(data.view(), 2, Some(20), Some(1e-5), Some(true), Some(42)).unwrap();
74/// ```
75#[allow(clippy::too_many_arguments)]
76#[allow(dead_code)]
77pub fn kmeans<F>(
78    obs: ArrayView2<F>,
79    k_or_guess: usize,
80    iter: Option<usize>,
81    thresh: Option<F>,
82    check_finite: Option<bool>,
83    seed: Option<u64>,
84) -> Result<(Array2<F>, F)>
85where
86    F: Float + FromPrimitive + Debug + std::iter::Sum + std::fmt::Display,
87{
88    let k = k_or_guess; // For now, just treat as number of clusters
89    let max_iter = iter.unwrap_or(20);
90    let tol = thresh.unwrap_or(F::from(1e-5).unwrap());
91    let _check_finite_flag = check_finite.unwrap_or(true);
92
93    // Basic validation
94    if obs.is_empty() {
95        return Err(ClusteringError::InvalidInput(
96            "Input data is empty".to_string(),
97        ));
98    }
99    if k == 0 {
100        return Err(ClusteringError::InvalidInput(
101            "Number of clusters must be greater than 0".to_string(),
102        ));
103    }
104    if k > obs.nrows() {
105        return Err(ClusteringError::InvalidInput(format!(
106            "Number of clusters ({}) cannot be greater than number of data points ({})",
107            k,
108            obs.nrows()
109        )));
110    }
111
112    // Create options struct for internal use
113    let options = KMeansOptions {
114        max_iter,
115        tol,
116        random_seed: seed,
117        n_init: 1, // SciPy's kmeans does single initialization
118        init_method: KMeansInit::KMeansPlusPlus,
119    };
120
121    // Use the options-based version internally
122    let (centroids, labels) = kmeans_with_options(obs, k, Some(options))?;
123
124    // Calculate distortion (sum of squared distances to centroids)
125    let distortion = calculate_distortion(obs, centroids.view(), &labels);
126
127    Ok((centroids, distortion))
128}
129
130/// K-means clustering algorithm (options-based version)
131///
132/// This is the original implementation that uses the options struct.
133/// The SciPy-compatible version above is a wrapper around this function.
134///
135/// # Arguments
136///
137/// * `data` - Input data (n_samples × n_features)
138/// * `k` - Number of clusters
139/// * `options` - Optional parameters
140///
141/// # Returns
142///
143/// * Tuple of (centroids, labels) where:
144///   - centroids: Array of shape (k × n_features)
145///   - labels: Array of shape (n_samples,) with cluster assignments
146///
147/// # Examples
148///
149/// ```
150/// use scirs2_core::ndarray::{ArrayView1, Array2, ArrayView2};
151/// use scirs2_cluster::vq::kmeans_with_options;
152///
153/// let data = Array2::from_shape_vec((6, 2), vec![
154///     1.0, 2.0,
155///     1.2, 1.8,
156///     0.8, 1.9,
157///     3.7, 4.2,
158///     3.9, 3.9,
159///     4.2, 4.1,
160/// ]).unwrap();
161///
162/// let (centroids, labels) = kmeans_with_options(data.view(), 2, None).unwrap();
163/// ```
164#[allow(dead_code)]
165pub fn kmeans_with_options<F>(
166    data: ArrayView2<F>,
167    k: usize,
168    options: Option<KMeansOptions<F>>,
169) -> Result<(Array2<F>, Array1<usize>)>
170where
171    F: Float + FromPrimitive + Debug + std::iter::Sum,
172{
173    if k == 0 {
174        return Err(ClusteringError::InvalidInput(
175            "Number of clusters must be greater than 0".to_string(),
176        ));
177    }
178
179    let n_samples = data.shape()[0];
180    if n_samples == 0 {
181        return Err(ClusteringError::InvalidInput(
182            "Input data is empty".to_string(),
183        ));
184    }
185
186    if k > n_samples {
187        return Err(ClusteringError::InvalidInput(format!(
188            "Number of clusters ({}) cannot be greater than number of data points ({})",
189            k, n_samples
190        )));
191    }
192
193    let opts = options.unwrap_or_default();
194    // Random seed is handled in kmeans_init function
195
196    let mut bestcentroids = None;
197    let mut best_labels = None;
198    let mut best_inertia = F::infinity();
199
200    // If we're using K-means|| initialization, we only need to run once
201    let n_init = if opts.init_method == KMeansInit::KMeansParallel {
202        1
203    } else {
204        opts.n_init
205    };
206
207    for _ in 0..n_init {
208        // Initialize centroids using the specified method
209        let centroids = kmeans_init(data, k, Some(opts.init_method), opts.random_seed)?;
210
211        // Run k-means
212        let (centroids, labels, inertia) = _kmeans_single(data, centroids.view(), &opts)?;
213
214        if inertia < best_inertia {
215            bestcentroids = Some(centroids);
216            best_labels = Some(labels);
217            best_inertia = inertia;
218        }
219    }
220
221    Ok((bestcentroids.unwrap(), best_labels.unwrap()))
222}
223
224/// Calculate distortion (sum of squared distances to centroids)
225#[allow(dead_code)]
226fn calculate_distortion<F>(
227    data: ArrayView2<F>,
228    centroids: ArrayView2<F>,
229    labels: &Array1<usize>,
230) -> F
231where
232    F: Float + FromPrimitive + Debug + std::iter::Sum,
233{
234    let n_samples = data.shape()[0];
235    let mut total_distortion = F::zero();
236
237    for i in 0..n_samples {
238        let cluster = labels[i];
239        let point = data.slice(s![i, ..]);
240        let centroid = centroids.slice(s![cluster, ..]);
241
242        let squared_distance = euclidean_distance(point, centroid).powi(2);
243        total_distortion = total_distortion + squared_distance;
244    }
245
246    total_distortion
247}
248
249/// Run a single k-means clustering iteration
250#[allow(dead_code)]
251fn _kmeans_single<F>(
252    data: ArrayView2<F>,
253    initcentroids: ArrayView2<F>,
254    opts: &KMeansOptions<F>,
255) -> Result<(Array2<F>, Array1<usize>, F)>
256where
257    F: Float + FromPrimitive + Debug + std::iter::Sum,
258{
259    let n_samples = data.shape()[0];
260    let n_features = data.shape()[1];
261    let k = initcentroids.shape()[0];
262
263    let mut centroids = initcentroids.to_owned();
264    let mut labels = Array1::zeros(n_samples);
265    let mut prev_centroid_diff = F::infinity();
266
267    for _iter in 0..opts.max_iter {
268        // Assign samples to nearest centroid
269        let (new_labels, distances) = vq(data, centroids.view())?;
270        labels = new_labels;
271
272        // Compute new centroids
273        let mut newcentroids = Array2::zeros((k, n_features));
274        let mut counts = Array1::zeros(k);
275
276        for i in 0..n_samples {
277            let cluster = labels[i];
278            let point = data.slice(s![i, ..]);
279
280            for j in 0..n_features {
281                newcentroids[[cluster, j]] = newcentroids[[cluster, j]] + point[j];
282            }
283
284            counts[cluster] += 1;
285        }
286
287        // If a cluster is empty, reinitialize it
288        for i in 0..k {
289            if counts[i] == 0 {
290                // Find the point furthest from its centroid
291                let mut max_dist = F::zero();
292                let mut far_idx = 0;
293
294                for j in 0..n_samples {
295                    let dist = distances[j];
296                    if dist > max_dist {
297                        max_dist = dist;
298                        far_idx = j;
299                    }
300                }
301
302                // Move this point to the empty cluster
303                for j in 0..n_features {
304                    newcentroids[[i, j]] = data[[far_idx, j]];
305                }
306
307                counts[i] = 1;
308            } else {
309                // Normalize by the number of points in the cluster
310                for j in 0..n_features {
311                    newcentroids[[i, j]] = newcentroids[[i, j]] / F::from(counts[i]).unwrap();
312                }
313            }
314        }
315
316        // Check for convergence
317        let mut centroid_diff = F::zero();
318        for i in 0..k {
319            let dist =
320                euclidean_distance(centroids.slice(s![i, ..]), newcentroids.slice(s![i, ..]));
321            centroid_diff = centroid_diff + dist;
322        }
323
324        centroids = newcentroids;
325
326        if centroid_diff <= opts.tol || centroid_diff >= prev_centroid_diff {
327            break;
328        }
329
330        prev_centroid_diff = centroid_diff;
331    }
332
333    // Calculate inertia (sum of squared distances to nearest centroid)
334    let mut inertia = F::zero();
335    for i in 0..n_samples {
336        let cluster = labels[i];
337        let dist = euclidean_distance(data.slice(s![i, ..]), centroids.slice(s![cluster, ..]));
338        inertia = inertia + dist * dist;
339    }
340
341    Ok((centroids, labels, inertia))
342}
343
344/// Initialization methods for K-means
345#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
346pub enum KMeansInit {
347    /// Random initialization
348    Random,
349    /// K-means++ initialization
350    #[default]
351    KMeansPlusPlus,
352    /// K-means|| initialization (parallel version of K-means++)
353    KMeansParallel,
354}
355
356/// K-means initialization algorithm
357///
358/// # Arguments
359///
360/// * `data` - Input data (n_samples × n_features)
361/// * `k` - Number of clusters
362/// * `init_method` - Initialization method (default: K-means++)
363/// * `random_seed` - Optional random seed
364///
365/// # Returns
366///
367/// * Array of shape (k × n_features) with initial centroids
368#[allow(dead_code)]
369pub fn kmeans_init<F>(
370    data: ArrayView2<F>,
371    k: usize,
372    init_method: Option<KMeansInit>,
373    random_seed: Option<u64>,
374) -> Result<Array2<F>>
375where
376    F: Float + FromPrimitive + Debug + std::iter::Sum,
377{
378    match init_method.unwrap_or_default() {
379        KMeansInit::Random => random_init(data, k, random_seed),
380        KMeansInit::KMeansPlusPlus => kmeans_plus_plus(data, k, random_seed),
381        KMeansInit::KMeansParallel => kmeans_parallel(data, k, random_seed),
382    }
383}
384
385/// Random initialization algorithm for K-means
386///
387/// # Arguments
388///
389/// * `data` - Input data (n_samples × n_features)
390/// * `k` - Number of clusters
391/// * `random_seed` - Optional random seed
392///
393/// # Returns
394///
395/// * Array of shape (k × n_features) with initial centroids
396#[allow(dead_code)]
397pub fn random_init<F>(data: ArrayView2<F>, k: usize, random_seed: Option<u64>) -> Result<Array2<F>>
398where
399    F: Float + FromPrimitive + Debug + std::iter::Sum,
400{
401    let n_samples = data.shape()[0];
402    let n_features = data.shape()[1];
403
404    if k == 0 || k > n_samples {
405        return Err(ClusteringError::InvalidInput(format!(
406            "Number of clusters ({}) must be between 1 and number of samples ({})",
407            k, n_samples
408        )));
409    }
410
411    let mut rng = scirs2_core::random::rng();
412    let mut centroids = Array2::zeros((k, n_features));
413    let mut selected_indices = Vec::with_capacity(k);
414
415    // Select k unique random points from the _data
416    while selected_indices.len() < k {
417        let idx = rng.random_range(0..n_samples);
418        if !selected_indices.contains(&idx) {
419            selected_indices.push(idx);
420        }
421    }
422
423    // Copy the selected points to the centroids
424    for (i, &idx) in selected_indices.iter().enumerate() {
425        for j in 0..n_features {
426            centroids[[i, j]] = data[[idx, j]];
427        }
428    }
429
430    Ok(centroids)
431}
432
433/// K-means++ initialization algorithm
434///
435/// # Arguments
436///
437/// * `data` - Input data (n_samples × n_features)
438/// * `k` - Number of clusters
439/// * `_random_seed` - Optional random seed
440///
441/// # Returns
442///
443/// * Array of shape (k × n_features) with initial centroids
444#[allow(dead_code)]
445pub fn kmeans_plus_plus<F>(
446    data: ArrayView2<F>,
447    k: usize,
448    random_seed: Option<u64>,
449) -> Result<Array2<F>>
450where
451    F: Float + FromPrimitive + Debug + std::iter::Sum,
452{
453    let n_samples = data.shape()[0];
454    let n_features = data.shape()[1];
455
456    if k == 0 || k > n_samples {
457        return Err(ClusteringError::InvalidInput(format!(
458            "Number of clusters ({}) must be between 1 and number of samples ({})",
459            k, n_samples
460        )));
461    }
462
463    let mut rng = scirs2_core::random::rng();
464
465    let mut centroids = Array2::zeros((k, n_features));
466
467    // Choose the first centroid randomly
468    let first_idx = rng.random_range(0..n_samples);
469    for j in 0..n_features {
470        centroids[[0, j]] = data[[first_idx, j]];
471    }
472
473    if k == 1 {
474        return Ok(centroids);
475    }
476
477    // Choose remaining centroids using the k-means++ algorithm
478    for i in 1..k {
479        // Compute distances to closest centroid for each point
480        let mut min_distances = Array1::from_elem(n_samples, F::infinity());
481
482        for sample_idx in 0..n_samples {
483            let sample = data.slice(s![sample_idx, ..]);
484
485            for centroid_idx in 0..i {
486                let centroid = centroids.slice(s![centroid_idx, ..]);
487                let dist = euclidean_distance(sample, centroid);
488
489                if dist < min_distances[sample_idx] {
490                    min_distances[sample_idx] = dist;
491                }
492            }
493        }
494
495        // Square the distances to get the probability distribution
496        let mut weights = min_distances.mapv(|d| d * d);
497
498        // Normalize the weights to create a probability distribution
499        let sum_weights = weights.sum();
500        if sum_weights > F::zero() {
501            weights.mapv_inplace(|w| w / sum_weights);
502        } else {
503            // If all weights are zero, use uniform distribution
504            weights.fill(F::from(1.0 / n_samples as f64).unwrap());
505        }
506
507        // Convert weights to cumulative distribution
508        let mut cum_weights = weights.clone();
509        for j in 1..n_samples {
510            cum_weights[j] = cum_weights[j] + cum_weights[j - 1];
511        }
512
513        // Sample the next centroid based on the probability distribution
514        let rand_val = F::from(rng.random_range(0.0..1.0)).unwrap();
515        let mut next_idx = 0;
516
517        for j in 0..n_samples {
518            if rand_val <= cum_weights[j] {
519                next_idx = j;
520                break;
521            }
522        }
523
524        // Add the new centroid
525        for j in 0..n_features {
526            centroids[[i, j]] = data[[next_idx, j]];
527        }
528    }
529
530    Ok(centroids)
531}
532
533/// K-means|| initialization algorithm (parallel version of K-means++)
534///
535/// This algorithm samples more than one center at each step, which makes it
536/// suitable for parallel or distributed implementations.
537///
538/// # Arguments
539///
540/// * `data` - Input data (n_samples × n_features)
541/// * `k` - Number of clusters
542/// * `_random_seed` - Optional random seed
543///
544/// # Returns
545///
546/// * Array of shape (k × n_features) with initial centroids
547///
548/// # References
549///
550/// * [Scalable K-means++ by Bahmani et al.](https://theory.stanford.edu/~sergei/papers/vldb12-kmpar.pdf)
551#[allow(dead_code)]
552pub fn kmeans_parallel<F>(
553    data: ArrayView2<F>,
554    k: usize,
555    random_seed: Option<u64>,
556) -> Result<Array2<F>>
557where
558    F: Float + FromPrimitive + Debug + std::iter::Sum,
559{
560    let n_samples = data.shape()[0];
561    let n_features = data.shape()[1];
562
563    if k == 0 || k > n_samples {
564        return Err(ClusteringError::InvalidInput(format!(
565            "Number of clusters ({}) must be between 1 and number of samples ({})",
566            k, n_samples
567        )));
568    }
569
570    let mut rng = scirs2_core::random::rng();
571
572    // Hyperparameters for K-means||
573    let l = F::from(5.0).unwrap(); // Multiplication factor for oversampling
574    let n_rounds = 8; // Number of rounds for parallel sampling
575
576    // Centers is a weighted set of candidate centers
577    let mut centers = Vec::new();
578    let mut weights = Vec::new();
579
580    // Choose the first center randomly
581    let first_idx = rng.random_range(0..n_samples);
582    let mut first_center = Vec::with_capacity(n_features);
583    for j in 0..n_features {
584        first_center.push(data[[first_idx, j]]);
585    }
586    centers.push(first_center);
587    weights.push(F::one()); // Initial weight is 1
588
589    // Perform parallel sampling rounds
590    for _ in 0..n_rounds {
591        // Compute distances to the closest center for each point
592        let mut min_distances = Array1::from_elem(n_samples, F::infinity());
593
594        for sample_idx in 0..n_samples {
595            let sample = data.slice(s![sample_idx, ..]);
596
597            for center in centers.iter() {
598                let mut dist_sq = F::zero();
599                for j in 0..n_features {
600                    let diff = sample[j] - center[j];
601                    dist_sq = dist_sq + diff * diff;
602                }
603                let dist = dist_sq.sqrt();
604
605                if dist < min_distances[sample_idx] {
606                    min_distances[sample_idx] = dist;
607                }
608            }
609        }
610
611        // Compute the sum of squared minimum distances (a.k.a. potential)
612        let potential: F = min_distances.iter().map(|&d| d * d).sum();
613        if potential <= F::epsilon() {
614            break; // Already covered all points well
615        }
616
617        // Sample new centers proportional to their squared distance
618        let expected_new_centers = l * F::from(k).unwrap();
619        let oversampling = F::min(expected_new_centers / potential, F::one());
620
621        for sample_idx in 0..n_samples {
622            let probability = min_distances[sample_idx] * min_distances[sample_idx] * oversampling;
623
624            // Sample with probability proportional to distance^2
625            if F::from(rng.random_range(0.0..1.0)).unwrap() < probability {
626                let mut new_center = Vec::with_capacity(n_features);
627                for j in 0..n_features {
628                    new_center.push(data[[sample_idx, j]]);
629                }
630                centers.push(new_center);
631                weights.push(F::one()); // Initial weight is 1
632            }
633        }
634    }
635
636    // If we have too many candidate centers, cluster them using weighted k-means
637    match centers.len().cmp(&k) {
638        std::cmp::Ordering::Greater => {
639            // Convert centers and weights to arrays for clustering
640            let n_centers = centers.len();
641            let mut centers_array = Array2::zeros((n_centers, n_features));
642            let mut weights_array = Array1::zeros(n_centers);
643
644            for i in 0..n_centers {
645                for j in 0..n_features {
646                    centers_array[[i, j]] = centers[i][j];
647                }
648                weights_array[i] = weights[i];
649            }
650
651            // Use regular k-means with weights to reduce to k centers
652            let options = KMeansOptions {
653                max_iter: 100,
654                tol: F::from(1e-4).unwrap(),
655                random_seed,
656                n_init: 1,
657                init_method: KMeansInit::KMeansPlusPlus,
658            };
659
660            // Initialize with random k centers from the candidate centers
661            let init_indices: Vec<usize> = (0..n_centers)
662            .filter(|_| rng.random_range(0.0..1.0) < 0.5) // Randomly select some centers
663            .take(k) // Take at most k centers
664            .collect();
665
666            // If we didn't get k centers..just take the first k
667            let actual_indices = if init_indices.len() < k {
668                (0..k.min(n_centers)).collect::<Vec<usize>>()
669            } else {
670                init_indices
671            };
672
673            let mut initcentroids = Array2::zeros((actual_indices.len(), n_features));
674            for (i, &idx) in actual_indices.iter().enumerate() {
675                for j in 0..n_features {
676                    initcentroids[[i, j]] = centers_array[[idx, j]];
677                }
678            }
679
680            // Run weighted k-means to get final centroids
681            let (finalcentroids_, _) = _weighted_kmeans_single(
682                centers_array.view(),
683                weights_array.view(),
684                initcentroids.view(),
685                &options,
686            )?;
687
688            Ok(finalcentroids_)
689        }
690        std::cmp::Ordering::Less => {
691            // If we have too few centers, add random points
692            let mut centroids = Array2::zeros((k, n_features));
693
694            // Copy existing centers
695            for i in 0..centers.len() {
696                for j in 0..n_features {
697                    centroids[[i, j]] = centers[i][j];
698                }
699            }
700
701            // Add random points to reach k centers
702            let mut selected_indices = Vec::with_capacity(k - centers.len());
703            while selected_indices.len() < k - centers.len() {
704                let idx = rng.random_range(0..n_samples);
705                if !selected_indices.contains(&idx) {
706                    selected_indices.push(idx);
707                }
708            }
709
710            for (i, &idx) in selected_indices.iter().enumerate() {
711                for j in 0..n_features {
712                    centroids[[centers.len() + i, j]] = data[[idx, j]];
713                }
714            }
715
716            Ok(centroids)
717        }
718        std::cmp::Ordering::Equal => {
719            // We have exactly k centers
720            let mut centroids = Array2::zeros((k, n_features));
721            for i in 0..k {
722                for j in 0..n_features {
723                    centroids[[i, j]] = centers[i][j];
724                }
725            }
726            Ok(centroids)
727        }
728    }
729}
730
731/// Run a single weighted k-means clustering iteration
732#[allow(dead_code)]
733fn _weighted_kmeans_single<F>(
734    data: ArrayView2<F>,
735    weights: ArrayView1<F>,
736    initcentroids: ArrayView2<F>,
737    opts: &KMeansOptions<F>,
738) -> Result<(Array2<F>, Array1<usize>)>
739where
740    F: Float + FromPrimitive + Debug + std::iter::Sum,
741{
742    let n_samples = data.shape()[0];
743    let n_features = data.shape()[1];
744    let k = initcentroids.shape()[0];
745
746    let mut centroids = initcentroids.to_owned();
747    let mut labels = Array1::zeros(n_samples);
748    let mut prev_centroid_diff = F::infinity();
749
750    for _iter in 0..opts.max_iter {
751        // Assign samples to nearest centroid
752        let (new_labels_, _) = vq(data, centroids.view())?;
753        labels = new_labels_;
754
755        // Compute new centroids using weights
756        let mut newcentroids = Array2::zeros((k, n_features));
757        let mut total_weights = Array1::zeros(k);
758
759        for i in 0..n_samples {
760            let cluster = labels[i];
761            let point = data.slice(s![i, ..]);
762            let weight = weights[i];
763
764            for j in 0..n_features {
765                newcentroids[[cluster, j]] = newcentroids[[cluster, j]] + point[j] * weight;
766            }
767
768            total_weights[cluster] = total_weights[cluster] + weight;
769        }
770
771        // If a cluster is empty, reinitialize it
772        for i in 0..k {
773            if total_weights[i] <= F::epsilon() {
774                // Find the point furthest from its centroid
775                let mut max_dist = F::zero();
776                let mut far_idx = 0;
777
778                for j in 0..n_samples {
779                    let dist = euclidean_distance(
780                        data.slice(s![j, ..]),
781                        centroids.slice(s![labels[j], ..]),
782                    );
783                    if dist > max_dist {
784                        max_dist = dist;
785                        far_idx = j;
786                    }
787                }
788
789                // Move this point to the empty cluster
790                for j in 0..n_features {
791                    newcentroids[[i, j]] = data[[far_idx, j]];
792                }
793
794                total_weights[i] = weights[far_idx];
795            } else {
796                // Normalize by the total weight in the cluster
797                for j in 0..n_features {
798                    newcentroids[[i, j]] = newcentroids[[i, j]] / total_weights[i];
799                }
800            }
801        }
802
803        // Check for convergence
804        let mut centroid_diff = F::zero();
805        for i in 0..k {
806            let dist =
807                euclidean_distance(centroids.slice(s![i, ..]), newcentroids.slice(s![i, ..]));
808            centroid_diff = centroid_diff + dist;
809        }
810
811        centroids = newcentroids;
812
813        if centroid_diff <= opts.tol || centroid_diff >= prev_centroid_diff {
814            break;
815        }
816
817        prev_centroid_diff = centroid_diff;
818    }
819
820    Ok((centroids, labels))
821}
822
823/// Enhanced K-means clustering with custom distance metrics
824///
825/// This function extends the standard K-means algorithm to support various distance
826/// metrics including Euclidean, Manhattan, Chebyshev, Mahalanobis, and more.
827///
828/// # Arguments
829///
830/// * `data` - Input data (n_samples × n_features)
831/// * `k` - Number of clusters
832/// * `metric` - Distance metric to use for clustering
833/// * `options` - Optional parameters
834///
835/// # Returns
836///
837/// * Tuple of (centroids, labels) where:
838///   - centroids: Array of shape (k × n_features)
839///   - labels: Array of shape (n_samples,) with cluster assignments
840///
841/// # Examples
842///
843/// ```
844/// use scirs2_core::ndarray::Array2;
845/// use scirs2_cluster::vq::{kmeans_with_metric, EuclideanDistance, KMeansOptions};
846///
847/// let data = Array2::from_shape_vec((6, 2), vec![
848///     1.0, 2.0,
849///     1.2, 1.8,
850///     0.8, 1.9,
851///     3.7, 4.2,
852///     3.9, 3.9,
853///     4.2, 4.1,
854/// ]).unwrap();
855///
856/// let metric = Box::new(EuclideanDistance);
857/// let (centroids, labels) = kmeans_with_metric(data.view(), 2, metric, None).unwrap();
858/// ```
859#[allow(dead_code)]
860pub fn kmeans_with_metric<F>(
861    data: ArrayView2<F>,
862    k: usize,
863    metric: Box<dyn crate::vq::VQDistanceMetric<F>>,
864    options: Option<KMeansOptions<F>>,
865) -> Result<(Array2<F>, Array1<usize>)>
866where
867    F: Float + FromPrimitive + Debug + std::iter::Sum + Send + Sync + 'static,
868{
869    if k == 0 {
870        return Err(ClusteringError::InvalidInput(
871            "Number of clusters must be greater than 0".to_string(),
872        ));
873    }
874
875    let n_samples = data.shape()[0];
876    if n_samples == 0 {
877        return Err(ClusteringError::InvalidInput(
878            "Input data is empty".to_string(),
879        ));
880    }
881
882    if k > n_samples {
883        return Err(ClusteringError::InvalidInput(format!(
884            "Number of clusters ({}) cannot be greater than number of data points ({})",
885            k, n_samples
886        )));
887    }
888
889    let opts = options.unwrap_or_default();
890
891    let mut bestcentroids = None;
892    let mut best_labels = None;
893    let mut best_inertia = F::infinity();
894
895    // If we're using K-means|| initialization, we only need to run once
896    let n_init = if opts.init_method == KMeansInit::KMeansParallel {
897        1
898    } else {
899        opts.n_init
900    };
901
902    for _ in 0..n_init {
903        // Initialize centroids using the specified method
904        let centroids = kmeans_init(data, k, Some(opts.init_method), opts.random_seed)?;
905
906        // Run k-means with custom distance metric
907        let (centroids, labels, inertia) =
908            _kmeans_single_with_metric(data, centroids.view(), metric.as_ref(), &opts)?;
909
910        if inertia < best_inertia {
911            bestcentroids = Some(centroids);
912            best_labels = Some(labels);
913            best_inertia = inertia;
914        }
915    }
916
917    Ok((bestcentroids.unwrap(), best_labels.unwrap()))
918}
919
920/// Run a single k-means clustering iteration with custom distance metric
921#[allow(dead_code)]
922fn _kmeans_single_with_metric<F>(
923    data: ArrayView2<F>,
924    initcentroids: ArrayView2<F>,
925    metric: &dyn crate::vq::VQDistanceMetric<F>,
926    opts: &KMeansOptions<F>,
927) -> Result<(Array2<F>, Array1<usize>, F)>
928where
929    F: Float + FromPrimitive + Debug + std::iter::Sum + Send + Sync,
930{
931    let n_samples = data.shape()[0];
932    let n_features = data.shape()[1];
933    let k = initcentroids.shape()[0];
934
935    let mut centroids = initcentroids.to_owned();
936    let mut labels = Array1::zeros(n_samples);
937    let mut prev_centroid_diff = F::infinity();
938
939    for _iter in 0..opts.max_iter {
940        // Assign samples to nearest centroid using custom metric
941        let (new_labels, distances) = _vq_with_metric(data, centroids.view(), metric)?;
942        labels = new_labels;
943
944        // Compute new centroids
945        let mut newcentroids = Array2::zeros((k, n_features));
946        let mut counts = Array1::zeros(k);
947
948        for i in 0..n_samples {
949            let cluster = labels[i];
950            let point = data.slice(s![i, ..]);
951
952            for j in 0..n_features {
953                newcentroids[[cluster, j]] = newcentroids[[cluster, j]] + point[j];
954            }
955
956            counts[cluster] += 1;
957        }
958
959        // If a cluster is empty, reinitialize it
960        for i in 0..k {
961            if counts[i] == 0 {
962                // Find the point furthest from its centroid
963                let mut max_dist = F::zero();
964                let mut far_idx = 0;
965
966                for j in 0..n_samples {
967                    let dist = distances[j];
968                    if dist > max_dist {
969                        max_dist = dist;
970                        far_idx = j;
971                    }
972                }
973
974                // Move this point to the empty cluster
975                for j in 0..n_features {
976                    newcentroids[[i, j]] = data[[far_idx, j]];
977                }
978
979                counts[i] = 1;
980            } else {
981                // Normalize by the number of points in the cluster
982                for j in 0..n_features {
983                    newcentroids[[i, j]] = newcentroids[[i, j]] / F::from(counts[i]).unwrap();
984                }
985            }
986        }
987
988        // Check for convergence using custom metric
989        let mut centroid_diff = F::zero();
990        for i in 0..k {
991            let dist = metric.distance(centroids.slice(s![i, ..]), newcentroids.slice(s![i, ..]));
992            centroid_diff = centroid_diff + dist;
993        }
994
995        centroids = newcentroids;
996
997        if centroid_diff <= opts.tol || centroid_diff >= prev_centroid_diff {
998            break;
999        }
1000
1001        prev_centroid_diff = centroid_diff;
1002    }
1003
1004    // Calculate inertia (sum of squared distances to nearest centroid)
1005    let mut inertia = F::zero();
1006    for i in 0..n_samples {
1007        let cluster = labels[i];
1008        let dist = metric.distance(data.slice(s![i, ..]), centroids.slice(s![cluster, ..]));
1009        inertia = inertia + dist * dist;
1010    }
1011
1012    Ok((centroids, labels, inertia))
1013}
1014
1015/// Vector quantization with custom distance metric
1016#[allow(dead_code)]
1017fn _vq_with_metric<F>(
1018    data: ArrayView2<F>,
1019    centroids: ArrayView2<F>,
1020    metric: &dyn crate::vq::VQDistanceMetric<F>,
1021) -> Result<(Array1<usize>, Array1<F>)>
1022where
1023    F: Float + FromPrimitive + Debug + Send + Sync,
1024{
1025    let n_samples = data.shape()[0];
1026    let ncentroids = centroids.shape()[0];
1027
1028    let mut labels = Array1::zeros(n_samples);
1029    let mut distances = Array1::zeros(n_samples);
1030
1031    for i in 0..n_samples {
1032        let point = data.slice(s![i, ..]);
1033        let mut min_dist = F::infinity();
1034        let mut closest_centroid = 0;
1035
1036        for j in 0..ncentroids {
1037            let centroid = centroids.slice(s![j, ..]);
1038            let dist = metric.distance(point, centroid);
1039
1040            if dist < min_dist {
1041                min_dist = dist;
1042                closest_centroid = j;
1043            }
1044        }
1045
1046        labels[i] = closest_centroid;
1047        distances[i] = min_dist;
1048    }
1049
1050    Ok((labels, distances))
1051}
1052
1053#[cfg(test)]
1054mod tests {
1055    use super::*;
1056    use scirs2_core::ndarray::Array2;
1057
1058    #[test]
1059    fn test_kmeans_random_init() {
1060        // Create a sample dataset
1061        let data = Array2::from_shape_vec(
1062            (6, 2),
1063            vec![1.0, 2.0, 1.2, 1.8, 0.8, 1.9, 4.0, 5.0, 4.2, 4.8, 3.9, 5.1],
1064        )
1065        .unwrap();
1066
1067        // Run k-means with random initialization
1068        let options = KMeansOptions {
1069            init_method: KMeansInit::Random,
1070            ..Default::default()
1071        };
1072
1073        let result = kmeans_with_options(data.view(), 2, Some(options));
1074        assert!(result.is_ok());
1075
1076        let (centroids, labels) = result.unwrap();
1077
1078        // Check dimensions
1079        assert_eq!(centroids.shape(), &[2, 2]);
1080        assert_eq!(labels.len(), 6);
1081
1082        // Check that we have exactly 2 clusters
1083        let unique_labels: std::collections::HashSet<_> = labels.iter().cloned().collect();
1084        assert_eq!(unique_labels.len(), 2);
1085    }
1086
1087    #[test]
1088    fn test_kmeans_plusplus_init() {
1089        // Create a sample dataset
1090        let data = Array2::from_shape_vec(
1091            (6, 2),
1092            vec![1.0, 2.0, 1.2, 1.8, 0.8, 1.9, 4.0, 5.0, 4.2, 4.8, 3.9, 5.1],
1093        )
1094        .unwrap();
1095
1096        // Run k-means with k-means++ initialization
1097        let options = KMeansOptions {
1098            init_method: KMeansInit::KMeansPlusPlus,
1099            ..Default::default()
1100        };
1101
1102        let result = kmeans_with_options(data.view(), 2, Some(options));
1103        assert!(result.is_ok());
1104
1105        let (centroids, labels) = result.unwrap();
1106
1107        // Check dimensions
1108        assert_eq!(centroids.shape(), &[2, 2]);
1109        assert_eq!(labels.len(), 6);
1110
1111        // Check that we have exactly 2 clusters
1112        let unique_labels: std::collections::HashSet<_> = labels.iter().cloned().collect();
1113        assert_eq!(unique_labels.len(), 2);
1114    }
1115
1116    #[test]
1117    fn test_kmeans_parallel_init() {
1118        // Create a sample dataset
1119        let data = Array2::from_shape_vec(
1120            (20, 2),
1121            vec![
1122                1.0, 2.0, 1.2, 1.8, 0.8, 1.9, 1.1, 2.2, 0.9, 1.7, 1.3, 2.1, 1.0, 1.9, 0.7, 2.0,
1123                1.2, 2.3, 1.5, 1.8, 5.0, 6.0, 5.2, 5.8, 4.8, 6.2, 5.1, 5.9, 5.3, 6.1, 4.9, 5.7,
1124                5.0, 6.3, 5.4, 5.6, 4.7, 5.9, 5.2, 6.2,
1125            ],
1126        )
1127        .unwrap();
1128
1129        // Run k-means with k-means|| initialization
1130        let options = KMeansOptions {
1131            init_method: KMeansInit::KMeansParallel,
1132            ..Default::default()
1133        };
1134
1135        let result = kmeans_with_options(data.view(), 2, Some(options));
1136        assert!(result.is_ok());
1137
1138        let (centroids, labels) = result.unwrap();
1139
1140        // Check dimensions
1141        assert_eq!(centroids.shape(), &[2, 2]);
1142        assert_eq!(labels.len(), 20);
1143
1144        // Check that we have exactly 2 clusters
1145        let unique_labels: std::collections::HashSet<_> = labels.iter().cloned().collect();
1146        assert_eq!(unique_labels.len(), 2);
1147
1148        // Check that the clusters are sensible (first 10 points should be in one cluster, last 10 in another)
1149        let first_cluster = labels[0];
1150        for i in 0..10 {
1151            assert_eq!(labels[i], first_cluster);
1152        }
1153
1154        let second_cluster = labels[10];
1155        assert_ne!(first_cluster, second_cluster);
1156        for i in 10..20 {
1157            assert_eq!(labels[i], second_cluster);
1158        }
1159    }
1160
1161    #[test]
1162    fn test_scipy_compatible_kmeans() {
1163        // Test the new SciPy-compatible kmeans function
1164        let data = Array2::from_shape_vec(
1165            (6, 2),
1166            vec![1.0, 2.0, 1.2, 1.8, 0.8, 1.9, 4.0, 5.0, 4.2, 4.8, 3.9, 5.1],
1167        )
1168        .unwrap();
1169
1170        // Test with all parameters
1171        let result = kmeans(
1172            data.view(),
1173            2,          // k_or_guess
1174            Some(20),   // iter
1175            Some(1e-5), // thresh
1176            Some(true), // check_finite
1177            Some(42),   // seed
1178        );
1179        assert!(result.is_ok());
1180
1181        let (centroids, distortion) = result.unwrap();
1182
1183        // Check dimensions
1184        assert_eq!(centroids.shape(), &[2, 2]);
1185
1186        // Distortion should be positive
1187        assert!(distortion > 0.0);
1188
1189        // Test with default parameters (None values)
1190        let result = kmeans(
1191            data.view(),
1192            2,    // k_or_guess
1193            None, // iter (default: 20)
1194            None, // thresh (default: 1e-5)
1195            None, // check_finite (default: true)
1196            None, // seed (random)
1197        );
1198        assert!(result.is_ok());
1199
1200        let (centroids2, distortion2) = result.unwrap();
1201        assert_eq!(centroids2.shape(), &[2, 2]);
1202        assert!(distortion2 > 0.0);
1203    }
1204
1205    #[test]
1206    fn test_scipy_kmeans_check_finite() {
1207        let data =
1208            Array2::from_shape_vec((4, 2), vec![1.0, 2.0, 1.5, 1.5, 8.0, 8.0, 8.5, 8.5]).unwrap();
1209
1210        // Test with check_finite = true (should work with finite data)
1211        let result = kmeans(
1212            data.view(),
1213            2,
1214            Some(10),
1215            Some(1e-5),
1216            Some(true), // check_finite = true
1217            Some(42),
1218        );
1219        assert!(result.is_ok());
1220
1221        // Test with check_finite = false (should also work with finite data)
1222        let result = kmeans(
1223            data.view(),
1224            2,
1225            Some(10),
1226            Some(1e-5),
1227            Some(false), // check_finite = false
1228            Some(42),
1229        );
1230        assert!(result.is_ok());
1231    }
1232}