torsh_cluster/initialization/
kmeans_plus_plus.rs1use super::InitializationStrategy;
4use crate::error::{ClusterError, ClusterResult};
5use scirs2_core::random::Random;
6use scirs2_core::RngExt;
7use torsh_tensor::Tensor;
8
9#[derive(Debug, Default)]
11pub struct KMeansPlusPlus;
12
13impl InitializationStrategy for KMeansPlusPlus {
14 fn initialize(
15 &self,
16 data: &Tensor,
17 n_clusters: usize,
18 seed: Option<u64>,
19 ) -> ClusterResult<Tensor> {
20 let n_samples = data.shape().dims()[0];
21 let n_features = data.shape().dims()[1];
22
23 if n_clusters > n_samples {
24 return Err(ClusterError::InvalidClusters(n_clusters));
25 }
26
27 let mut rng = Random::seed(seed.unwrap_or_else(|| {
28 use std::time::{SystemTime, UNIX_EPOCH};
29 SystemTime::now()
30 .duration_since(UNIX_EPOCH)
31 .expect("system time should be after UNIX_EPOCH")
32 .as_secs()
33 }));
34
35 let data_vec = data.to_vec().map_err(ClusterError::TensorError)?;
36 let mut centroids_data = Vec::with_capacity(n_clusters * n_features);
37
38 let first_idx = rng.gen_range(0..n_samples);
40 for j in 0..n_features {
41 centroids_data.push(data_vec[first_idx * n_features + j]);
42 }
43
44 for k in 1..n_clusters {
46 let mut distances = vec![f32::INFINITY; n_samples];
47
48 for i in 0..n_samples {
50 for c in 0..k {
51 let mut dist = 0.0;
52 for j in 0..n_features {
53 let diff =
54 data_vec[i * n_features + j] - centroids_data[c * n_features + j];
55 dist += diff * diff;
56 }
57 distances[i] = distances[i].min(dist);
58 }
59 }
60
61 let total_dist: f32 = distances.iter().sum();
63 if total_dist <= 0.0 {
64 let idx = rng.gen_range(0..n_samples);
66 for j in 0..n_features {
67 centroids_data.push(data_vec[idx * n_features + j]);
68 }
69 } else {
70 let threshold = rng.random::<f32>() * total_dist;
71 let mut cumsum = 0.0;
72 let mut selected_idx = 0;
73
74 for (i, &distance) in distances.iter().enumerate() {
75 cumsum += distance;
76 if cumsum >= threshold {
77 selected_idx = i;
78 break;
79 }
80 }
81
82 for j in 0..n_features {
83 centroids_data.push(data_vec[selected_idx * n_features + j]);
84 }
85 }
86 }
87
88 Tensor::from_vec(centroids_data, &[n_clusters, n_features])
89 .map_err(ClusterError::TensorError)
90 }
91
92 fn name(&self) -> &str {
93 "K-means++"
94 }
95}