Skip to main content

torsh_cluster/initialization/
kmeans_plus_plus.rs

1//! K-means++ initialization strategy
2
3use super::InitializationStrategy;
4use crate::error::{ClusterError, ClusterResult};
5use scirs2_core::random::Random;
6use scirs2_core::RngExt;
7use torsh_tensor::Tensor;
8
9/// K-means++ initialization for better cluster initialization
10#[derive(Debug, Default)]
11pub struct KMeansPlusPlus;
12
13impl InitializationStrategy for KMeansPlusPlus {
14    fn initialize(
15        &self,
16        data: &Tensor,
17        n_clusters: usize,
18        seed: Option<u64>,
19    ) -> ClusterResult<Tensor> {
20        let n_samples = data.shape().dims()[0];
21        let n_features = data.shape().dims()[1];
22
23        if n_clusters > n_samples {
24            return Err(ClusterError::InvalidClusters(n_clusters));
25        }
26
27        let mut rng = Random::seed(seed.unwrap_or_else(|| {
28            use std::time::{SystemTime, UNIX_EPOCH};
29            SystemTime::now()
30                .duration_since(UNIX_EPOCH)
31                .expect("system time should be after UNIX_EPOCH")
32                .as_secs()
33        }));
34
35        let data_vec = data.to_vec().map_err(ClusterError::TensorError)?;
36        let mut centroids_data = Vec::with_capacity(n_clusters * n_features);
37
38        // Choose first centroid randomly
39        let first_idx = rng.gen_range(0..n_samples);
40        for j in 0..n_features {
41            centroids_data.push(data_vec[first_idx * n_features + j]);
42        }
43
44        // Choose remaining centroids using K-means++ strategy
45        for k in 1..n_clusters {
46            let mut distances = vec![f32::INFINITY; n_samples];
47
48            // Compute minimum distance to existing centroids for each point
49            for i in 0..n_samples {
50                for c in 0..k {
51                    let mut dist = 0.0;
52                    for j in 0..n_features {
53                        let diff =
54                            data_vec[i * n_features + j] - centroids_data[c * n_features + j];
55                        dist += diff * diff;
56                    }
57                    distances[i] = distances[i].min(dist);
58                }
59            }
60
61            // Choose next centroid with probability proportional to squared distance
62            let total_dist: f32 = distances.iter().sum();
63            if total_dist <= 0.0 {
64                // Fallback to random selection
65                let idx = rng.gen_range(0..n_samples);
66                for j in 0..n_features {
67                    centroids_data.push(data_vec[idx * n_features + j]);
68                }
69            } else {
70                let threshold = rng.random::<f32>() * total_dist;
71                let mut cumsum = 0.0;
72                let mut selected_idx = 0;
73
74                for (i, &distance) in distances.iter().enumerate() {
75                    cumsum += distance;
76                    if cumsum >= threshold {
77                        selected_idx = i;
78                        break;
79                    }
80                }
81
82                for j in 0..n_features {
83                    centroids_data.push(data_vec[selected_idx * n_features + j]);
84                }
85            }
86        }
87
88        Tensor::from_vec(centroids_data, &[n_clusters, n_features])
89            .map_err(ClusterError::TensorError)
90    }
91
92    fn name(&self) -> &str {
93        "K-means++"
94    }
95}