1use scirs2_core::ndarray::{s, Array1, Array2, ArrayView2, Axis};
7use scirs2_core::numeric::{Float, FromPrimitive};
8use scirs2_core::parallel_ops::*;
9use std::fmt::Debug;
10use std::sync::Mutex;
11
12use super::{euclidean_distance, kmeans_init, KMeansInit};
13use crate::error::{ClusteringError, Result};
14
15#[derive(Debug, Clone)]
17pub struct ParallelKMeansOptions<F: Float> {
18 pub max_iter: usize,
20 pub tol: F,
22 pub random_seed: Option<u64>,
24 pub n_init: usize,
26 pub init_method: KMeansInit,
28 pub n_threads: Option<usize>,
30}
31
32impl<F: Float + FromPrimitive> Default for ParallelKMeansOptions<F> {
33 fn default() -> Self {
34 Self {
35 max_iter: 300,
36 tol: F::from(1e-4).unwrap(),
37 random_seed: None,
38 n_init: 10,
39 init_method: KMeansInit::KMeansPlusPlus,
40 n_threads: None,
41 }
42 }
43}
44
45#[allow(dead_code)]
75pub fn parallel_kmeans<F>(
76 data: ArrayView2<F>,
77 k: usize,
78 options: Option<ParallelKMeansOptions<F>>,
79) -> Result<(Array2<F>, Array1<usize>)>
80where
81 F: Float + FromPrimitive + Debug + std::iter::Sum + Send + Sync,
82{
83 if k == 0 {
84 return Err(ClusteringError::InvalidInput(
85 "Number of clusters must be greater than 0".to_string(),
86 ));
87 }
88
89 let n_samples = data.shape()[0];
90 if n_samples == 0 {
91 return Err(ClusteringError::InvalidInput(
92 "Input data is empty".to_string(),
93 ));
94 }
95
96 if k > n_samples {
97 return Err(ClusteringError::InvalidInput(format!(
98 "Number of clusters ({}) cannot be greater than number of data points ({})",
99 k, n_samples
100 )));
101 }
102
103 let opts = options.unwrap_or_default();
104
105 if let Some(_n_threads) = opts.n_threads {
107 }
110
111 let mut bestcentroids = None;
112 let mut best_labels = None;
113 let mut best_inertia = F::infinity();
114
115 for _ in 0..opts.n_init {
117 let centroids = kmeans_init(data, k, Some(opts.init_method), opts.random_seed)?;
119
120 let (centroids, labels, inertia) = parallel_kmeans_single(data, centroids.view(), &opts)?;
122
123 if inertia < best_inertia {
124 bestcentroids = Some(centroids);
125 best_labels = Some(labels);
126 best_inertia = inertia;
127 }
128 }
129
130 Ok((bestcentroids.unwrap(), best_labels.unwrap()))
131}
132
133#[allow(dead_code)]
135fn parallel_kmeans_single<F>(
136 data: ArrayView2<F>,
137 initcentroids: ArrayView2<F>,
138 opts: &ParallelKMeansOptions<F>,
139) -> Result<(Array2<F>, Array1<usize>, F)>
140where
141 F: Float + FromPrimitive + Debug + std::iter::Sum + Send + Sync,
142{
143 let n_samples = data.shape()[0];
144 let _n_features = data.shape()[1];
145 let k = initcentroids.shape()[0];
146
147 let mut centroids = initcentroids.to_owned();
148 let mut labels = Array1::zeros(n_samples);
149 let mut prev_inertia = F::infinity();
150
151 for _iter in 0..opts.max_iter {
152 let (new_labels, distances) = parallel_assign_labels(data, centroids.view())?;
154 labels = new_labels;
155
156 let newcentroids = parallel_updatecentroids(data, &labels, k)?;
158
159 let cluster_counts = count_clusters(&labels, k);
161
162 let mut finalcentroids = newcentroids;
164 for (i, &count) in cluster_counts.iter().enumerate() {
165 if count == 0 {
166 let (far_idx, _) = distances
168 .iter()
169 .enumerate()
170 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
171 .unwrap();
172
173 finalcentroids
175 .slice_mut(s![i, ..])
176 .assign(&data.slice(s![far_idx, ..]));
177 }
178 }
179
180 let inertia = parallel_compute_inertia(data, &labels, finalcentroids.view())?;
182
183 if (prev_inertia - inertia).abs() <= opts.tol {
185 return Ok((finalcentroids, labels, inertia));
186 }
187
188 centroids = finalcentroids;
189 prev_inertia = inertia;
190 }
191
192 let final_inertia = parallel_compute_inertia(data, &labels, centroids.view())?;
194
195 Ok((centroids, labels, final_inertia))
196}
197
198#[allow(dead_code)]
200fn parallel_assign_labels<F>(
201 data: ArrayView2<F>,
202 centroids: ArrayView2<F>,
203) -> Result<(Array1<usize>, Array1<F>)>
204where
205 F: Float + FromPrimitive + Debug + Send + Sync,
206{
207 let n_samples = data.shape()[0];
208 let k = centroids.shape()[0];
209
210 let results: Vec<(usize, F)> = (0..n_samples)
212 .into_par_iter()
213 .map(|i| {
214 let sample = data.slice(s![i, ..]);
215 let mut min_dist = F::infinity();
216 let mut best_label = 0;
217
218 for j in 0..k {
219 let centroid = centroids.slice(s![j, ..]);
220 let dist = euclidean_distance(sample, centroid);
221
222 if dist < min_dist {
223 min_dist = dist;
224 best_label = j;
225 }
226 }
227
228 (best_label, min_dist)
229 })
230 .collect();
231
232 let mut labels = Array1::zeros(n_samples);
234 let mut distances = Array1::zeros(n_samples);
235
236 for (i, (label, dist)) in results.into_iter().enumerate() {
237 labels[i] = label;
238 distances[i] = dist;
239 }
240
241 Ok((labels, distances))
242}
243
244#[allow(dead_code)]
246fn parallel_updatecentroids<F>(
247 data: ArrayView2<F>,
248 labels: &Array1<usize>,
249 k: usize,
250) -> Result<Array2<F>>
251where
252 F: Float + FromPrimitive + Debug + Send + Sync + std::iter::Sum,
253{
254 let n_features = data.shape()[1];
255
256 let sums: Vec<Mutex<Array1<F>>> = (0..k)
258 .map(|_| Mutex::new(Array1::zeros(n_features)))
259 .collect();
260
261 let counts: Vec<Mutex<usize>> = (0..k).map(|_| Mutex::new(0)).collect();
262
263 data.axis_iter(Axis(0))
265 .zip(labels.iter())
266 .par_bridge()
267 .for_each(|(sample, &label)| {
268 let mut sum = sums[label].lock().unwrap();
269 for i in 0..n_features {
270 sum[i] = sum[i] + sample[i];
271 }
272
273 let mut count = counts[label].lock().unwrap();
274 *count += 1;
275 });
276
277 let mut newcentroids = Array2::zeros((k, n_features));
279
280 for i in 0..k {
281 let sum = sums[i].lock().unwrap();
282 let count = *counts[i].lock().unwrap();
283
284 if count > 0 {
285 for j in 0..n_features {
286 newcentroids[[i, j]] = sum[j] / F::from(count).unwrap();
287 }
288 }
289 }
290
291 Ok(newcentroids)
292}
293
294#[allow(dead_code)]
296fn count_clusters(labels: &Array1<usize>, k: usize) -> Vec<usize> {
297 let mut counts = vec![0; k];
298 for &label in labels.iter() {
299 counts[label] += 1;
300 }
301 counts
302}
303
304#[allow(dead_code)]
306fn parallel_compute_inertia<F>(
307 data: ArrayView2<F>,
308 labels: &Array1<usize>,
309 centroids: ArrayView2<F>,
310) -> Result<F>
311where
312 F: Float + FromPrimitive + Debug + Send + Sync + std::iter::Sum,
313{
314 let inertia: F = data
315 .axis_iter(Axis(0))
316 .zip(labels.iter())
317 .par_bridge()
318 .map(|(sample, &label)| {
319 let centroid = centroids.slice(s![label, ..]);
320 let dist = euclidean_distance(sample.view(), centroid);
321 dist * dist
322 })
323 .sum();
324
325 Ok(inertia)
326}
327
328#[cfg(test)]
329mod tests {
330 use super::*;
331 use scirs2_core::ndarray::Array2;
332
333 #[test]
334 fn test_parallel_kmeans_simple() {
335 let data = Array2::from_shape_vec(
337 (6, 2),
338 vec![1.0, 2.0, 1.2, 1.8, 0.8, 1.9, 4.0, 5.0, 4.2, 4.8, 3.9, 5.1],
339 )
340 .unwrap();
341
342 let options = ParallelKMeansOptions {
344 n_init: 1,
345 random_seed: Some(42),
346 ..Default::default()
347 };
348
349 let (centroids, labels) = parallel_kmeans(data.view(), 2, Some(options)).unwrap();
350
351 assert_eq!(centroids.shape(), &[2, 2]);
353 assert_eq!(labels.len(), 6);
354
355 let unique_labels: Vec<_> = labels
357 .iter()
358 .copied()
359 .collect::<std::collections::HashSet<_>>()
360 .into_iter()
361 .collect();
362 assert_eq!(unique_labels.len(), 2);
363 }
364
365 #[test]
366 fn test_parallel_kmeans_large_dataset() {
367 let n_samples = 1000;
369 let n_features = 10;
370
371 let mut data_vec = Vec::with_capacity(n_samples * n_features);
372 for i in 0..n_samples {
373 for j in 0..n_features {
374 let cluster = i / (n_samples / 3);
376 let value = (cluster * 10) as f64 + (j as f64 + i as f64 * 0.01);
377 data_vec.push(value);
378 }
379 }
380
381 let data = Array2::from_shape_vec((n_samples, n_features), data_vec).unwrap();
382
383 let options = ParallelKMeansOptions {
385 n_init: 3,
386 max_iter: 50,
387 random_seed: Some(42),
388 ..Default::default()
389 };
390
391 let start_time = std::time::Instant::now();
392 let (centroids, labels) = parallel_kmeans(data.view(), 3, Some(options)).unwrap();
393 let duration = start_time.elapsed();
394
395 println!("Parallel K-means took: {duration:?}");
396
397 assert_eq!(centroids.shape(), &[3, n_features]);
399 assert_eq!(labels.len(), n_samples);
400
401 for &label in labels.iter() {
403 assert!(label < 3);
404 }
405 }
406}