1use std::fmt::Debug;
56use std::marker::PhantomData;
57
58use rand::Rng;
59#[cfg(feature = "serde")]
60use serde::{Deserialize, Serialize};
61
62use crate::algorithm::neighbour::bbd_tree::BBDTree;
63use crate::api::{Predictor, UnsupervisedEstimator};
64use crate::error::Failed;
65use crate::linalg::basic::arrays::{Array1, Array2};
66use crate::metrics::distance::euclidian::*;
67use crate::numbers::basenum::Number;
68use crate::rand_custom::get_rng_impl;
69
70#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
72#[derive(Debug)]
73pub struct KMeans<TX: Number, TY: Number, X: Array2<TX>, Y: Array1<TY>> {
74 k: usize,
75 _y: Vec<usize>,
76 size: Vec<usize>,
77 _distortion: f64,
78 centroids: Vec<Vec<f64>>,
79 _phantom_tx: PhantomData<TX>,
80 _phantom_ty: PhantomData<TY>,
81 _phantom_x: PhantomData<X>,
82 _phantom_y: PhantomData<Y>,
83}
84
85impl<TX: Number, TY: Number, X: Array2<TX>, Y: Array1<TY>> PartialEq for KMeans<TX, TY, X, Y> {
86 fn eq(&self, other: &Self) -> bool {
87 if self.k != other.k
88 || self.size != other.size
89 || self.centroids.len() != other.centroids.len()
90 {
91 false
92 } else {
93 let n_centroids = self.centroids.len();
94 for i in 0..n_centroids {
95 if self.centroids[i].len() != other.centroids[i].len() {
96 return false;
97 }
98 for j in 0..self.centroids[i].len() {
99 if (self.centroids[i][j] - other.centroids[i][j]).abs() > f64::EPSILON {
100 return false;
101 }
102 }
103 }
104 true
105 }
106 }
107}
108
109#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
110#[derive(Debug, Clone)]
111pub struct KMeansParameters {
113 #[cfg_attr(feature = "serde", serde(default))]
114 pub k: usize,
116 #[cfg_attr(feature = "serde", serde(default))]
117 pub max_iter: usize,
119 #[cfg_attr(feature = "serde", serde(default))]
120 pub seed: Option<u64>,
123}
124
125impl KMeansParameters {
126 pub fn with_k(mut self, k: usize) -> Self {
128 self.k = k;
129 self
130 }
131 pub fn with_max_iter(mut self, max_iter: usize) -> Self {
133 self.max_iter = max_iter;
134 self
135 }
136}
137
138impl Default for KMeansParameters {
139 fn default() -> Self {
140 KMeansParameters {
141 k: 2,
142 max_iter: 100,
143 seed: Option::None,
144 }
145 }
146}
147
148#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
150#[derive(Debug, Clone)]
151pub struct KMeansSearchParameters {
152 #[cfg_attr(feature = "serde", serde(default))]
153 pub k: Vec<usize>,
155 #[cfg_attr(feature = "serde", serde(default))]
156 pub max_iter: Vec<usize>,
158 #[cfg_attr(feature = "serde", serde(default))]
159 pub seed: Vec<Option<u64>>,
162}
163
164pub struct KMeansSearchParametersIterator {
166 kmeans_search_parameters: KMeansSearchParameters,
167 current_k: usize,
168 current_max_iter: usize,
169 current_seed: usize,
170}
171
172impl IntoIterator for KMeansSearchParameters {
173 type Item = KMeansParameters;
174 type IntoIter = KMeansSearchParametersIterator;
175
176 fn into_iter(self) -> Self::IntoIter {
177 KMeansSearchParametersIterator {
178 kmeans_search_parameters: self,
179 current_k: 0,
180 current_max_iter: 0,
181 current_seed: 0,
182 }
183 }
184}
185
186impl Iterator for KMeansSearchParametersIterator {
187 type Item = KMeansParameters;
188
189 fn next(&mut self) -> Option<Self::Item> {
190 if self.current_k == self.kmeans_search_parameters.k.len()
191 && self.current_max_iter == self.kmeans_search_parameters.max_iter.len()
192 && self.current_seed == self.kmeans_search_parameters.seed.len()
193 {
194 return None;
195 }
196
197 let next = KMeansParameters {
198 k: self.kmeans_search_parameters.k[self.current_k],
199 max_iter: self.kmeans_search_parameters.max_iter[self.current_max_iter],
200 seed: self.kmeans_search_parameters.seed[self.current_seed],
201 };
202
203 if self.current_k + 1 < self.kmeans_search_parameters.k.len() {
204 self.current_k += 1;
205 } else if self.current_max_iter + 1 < self.kmeans_search_parameters.max_iter.len() {
206 self.current_k = 0;
207 self.current_max_iter += 1;
208 } else if self.current_seed + 1 < self.kmeans_search_parameters.seed.len() {
209 self.current_k = 0;
210 self.current_max_iter = 0;
211 self.current_seed += 1;
212 } else {
213 self.current_k += 1;
214 self.current_max_iter += 1;
215 self.current_seed += 1;
216 }
217
218 Some(next)
219 }
220}
221
222impl Default for KMeansSearchParameters {
223 fn default() -> Self {
224 let default_params = KMeansParameters::default();
225
226 KMeansSearchParameters {
227 k: vec![default_params.k],
228 max_iter: vec![default_params.max_iter],
229 seed: vec![default_params.seed],
230 }
231 }
232}
233
234impl<TX: Number, TY: Number, X: Array2<TX>, Y: Array1<TY>>
235 UnsupervisedEstimator<X, KMeansParameters> for KMeans<TX, TY, X, Y>
236{
237 fn fit(x: &X, parameters: KMeansParameters) -> Result<Self, Failed> {
238 KMeans::fit(x, parameters)
239 }
240}
241
242impl<TX: Number, TY: Number, X: Array2<TX>, Y: Array1<TY>> Predictor<X, Y>
243 for KMeans<TX, TY, X, Y>
244{
245 fn predict(&self, x: &X) -> Result<Y, Failed> {
246 self.predict(x)
247 }
248}
249
250impl<TX: Number, TY: Number, X: Array2<TX>, Y: Array1<TY>> KMeans<TX, TY, X, Y> {
251 pub fn fit(data: &X, parameters: KMeansParameters) -> Result<KMeans<TX, TY, X, Y>, Failed> {
255 let bbd = BBDTree::new(data);
256
257 if parameters.k < 2 {
258 return Err(Failed::fit(&format!(
259 "invalid number of clusters: {}",
260 parameters.k
261 )));
262 }
263
264 if parameters.max_iter == 0 {
265 return Err(Failed::fit(&format!(
266 "invalid maximum number of iterations: {}",
267 parameters.max_iter
268 )));
269 }
270
271 let (n, d) = data.shape();
272
273 let mut distortion = f64::MAX;
274 let mut y = KMeans::<TX, TY, X, Y>::kmeans_plus_plus(data, parameters.k, parameters.seed);
275 let mut size = vec![0; parameters.k];
276 let mut centroids = vec![vec![0f64; d]; parameters.k];
277
278 for i in 0..n {
279 size[y[i]] += 1;
280 }
281
282 for i in 0..n {
283 for j in 0..d {
284 centroids[y[i]][j] += data.get((i, j)).to_f64().unwrap();
285 }
286 }
287
288 for i in 0..parameters.k {
289 for j in 0..d {
290 centroids[i][j] /= size[i] as f64;
291 }
292 }
293
294 let mut sums = vec![vec![0f64; d]; parameters.k];
295 for _ in 1..=parameters.max_iter {
296 let dist = bbd.clustering(¢roids, &mut sums, &mut size, &mut y);
297 for i in 0..parameters.k {
298 if size[i] > 0 {
299 for j in 0..d {
300 centroids[i][j] = sums[i][j] / size[i] as f64;
301 }
302 }
303 }
304
305 if distortion <= dist {
306 break;
307 } else {
308 distortion = dist;
309 }
310 }
311
312 Ok(KMeans {
313 k: parameters.k,
314 _y: y,
315 size,
316 _distortion: distortion,
317 centroids,
318 _phantom_tx: PhantomData,
319 _phantom_ty: PhantomData,
320 _phantom_x: PhantomData,
321 _phantom_y: PhantomData,
322 })
323 }
324
325 pub fn predict(&self, x: &X) -> Result<Y, Failed> {
328 let (n, _) = x.shape();
329 let mut result = Y::zeros(n);
330
331 let mut row = vec![0f64; x.shape().1];
332
333 for i in 0..n {
334 let mut min_dist = f64::MAX;
335 let mut best_cluster = 0;
336
337 for j in 0..self.k {
338 x.get_row(i)
339 .iterator(0)
340 .zip(row.iter_mut())
341 .for_each(|(&x, r)| *r = x.to_f64().unwrap());
342 let dist = Euclidian::squared_distance(&row, &self.centroids[j]);
343 if dist < min_dist {
344 min_dist = dist;
345 best_cluster = j;
346 }
347 }
348 result.set(i, TY::from_usize(best_cluster).unwrap());
349 }
350
351 Ok(result)
352 }
353
354 fn kmeans_plus_plus(data: &X, k: usize, seed: Option<u64>) -> Vec<usize> {
355 let mut rng = get_rng_impl(seed);
356 let (n, _) = data.shape();
357 let mut y = vec![0; n];
358 let mut centroid: Vec<TX> = data
359 .get_row(rng.gen_range(0..n))
360 .iterator(0)
361 .cloned()
362 .collect();
363
364 let mut d = vec![f64::MAX; n];
365 let mut row = vec![TX::zero(); data.shape().1];
366
367 for j in 1..k {
368 for i in 0..n {
369 data.get_row(i)
370 .iterator(0)
371 .zip(row.iter_mut())
372 .for_each(|(&x, r)| *r = x);
373 let dist = Euclidian::squared_distance(&row, ¢roid);
374
375 if dist < d[i] {
376 d[i] = dist;
377 y[i] = j - 1;
378 }
379 }
380
381 let mut sum = 0f64;
382 for i in d.iter() {
383 sum += *i;
384 }
385 let cutoff = rng.gen::<f64>() * sum;
386 let mut cost = 0f64;
387 let mut index = 0;
388 while index < n {
389 cost += d[index];
390 if cost >= cutoff {
391 break;
392 }
393 index += 1;
394 }
395
396 centroid = data.get_row(index).iterator(0).cloned().collect();
397 }
398
399 for i in 0..n {
400 data.get_row(i)
401 .iterator(0)
402 .zip(row.iter_mut())
403 .for_each(|(&x, r)| *r = x);
404 let dist = Euclidian::squared_distance(&row, ¢roid);
405
406 if dist < d[i] {
407 d[i] = dist;
408 y[i] = k - 1;
409 }
410 }
411
412 y
413 }
414}
415
416#[cfg(test)]
417mod tests {
418 use super::*;
419 use crate::linalg::basic::matrix::DenseMatrix;
420
421 #[cfg_attr(
422 all(target_arch = "wasm32", not(target_os = "wasi")),
423 wasm_bindgen_test::wasm_bindgen_test
424 )]
425 #[test]
426 fn invalid_k() {
427 let x = DenseMatrix::from_2d_array(&[&[1, 2, 3], &[4, 5, 6]]).unwrap();
428
429 assert!(KMeans::<i32, i32, DenseMatrix<i32>, Vec<i32>>::fit(
430 &x,
431 KMeansParameters::default().with_k(0)
432 )
433 .is_err());
434 assert_eq!(
435 "Fit failed: invalid number of clusters: 1",
436 KMeans::<i32, i32, DenseMatrix<i32>, Vec<i32>>::fit(
437 &x,
438 KMeansParameters::default().with_k(1)
439 )
440 .unwrap_err()
441 .to_string()
442 );
443 }
444
445 #[test]
446 fn search_parameters() {
447 let parameters = KMeansSearchParameters {
448 k: vec![2, 4],
449 max_iter: vec![10, 100],
450 ..Default::default()
451 };
452 let mut iter = parameters.into_iter();
453 let next = iter.next().unwrap();
454 assert_eq!(next.k, 2);
455 assert_eq!(next.max_iter, 10);
456 let next = iter.next().unwrap();
457 assert_eq!(next.k, 4);
458 assert_eq!(next.max_iter, 10);
459 let next = iter.next().unwrap();
460 assert_eq!(next.k, 2);
461 assert_eq!(next.max_iter, 100);
462 let next = iter.next().unwrap();
463 assert_eq!(next.k, 4);
464 assert_eq!(next.max_iter, 100);
465 assert!(iter.next().is_none());
466 }
467
468 #[cfg_attr(
469 all(target_arch = "wasm32", not(target_os = "wasi")),
470 wasm_bindgen_test::wasm_bindgen_test
471 )]
472 #[test]
473 fn fit_predict() {
474 let x = DenseMatrix::from_2d_array(&[
475 &[5.1, 3.5, 1.4, 0.2],
476 &[4.9, 3.0, 1.4, 0.2],
477 &[4.7, 3.2, 1.3, 0.2],
478 &[4.6, 3.1, 1.5, 0.2],
479 &[5.0, 3.6, 1.4, 0.2],
480 &[5.4, 3.9, 1.7, 0.4],
481 &[4.6, 3.4, 1.4, 0.3],
482 &[5.0, 3.4, 1.5, 0.2],
483 &[4.4, 2.9, 1.4, 0.2],
484 &[4.9, 3.1, 1.5, 0.1],
485 &[7.0, 3.2, 4.7, 1.4],
486 &[6.4, 3.2, 4.5, 1.5],
487 &[6.9, 3.1, 4.9, 1.5],
488 &[5.5, 2.3, 4.0, 1.3],
489 &[6.5, 2.8, 4.6, 1.5],
490 &[5.7, 2.8, 4.5, 1.3],
491 &[6.3, 3.3, 4.7, 1.6],
492 &[4.9, 2.4, 3.3, 1.0],
493 &[6.6, 2.9, 4.6, 1.3],
494 &[5.2, 2.7, 3.9, 1.4],
495 ])
496 .unwrap();
497
498 let kmeans = KMeans::fit(&x, Default::default()).unwrap();
499
500 let y: Vec<usize> = kmeans.predict(&x).unwrap();
501
502 for (i, _y_i) in y.iter().enumerate() {
503 assert_eq!({ y[i] }, kmeans._y[i]);
504 }
505 }
506
507 #[cfg_attr(
508 all(target_arch = "wasm32", not(target_os = "wasi")),
509 wasm_bindgen_test::wasm_bindgen_test
510 )]
511 #[test]
512 #[cfg(feature = "serde")]
513 fn serde() {
514 let x = DenseMatrix::from_2d_array(&[
515 &[5.1, 3.5, 1.4, 0.2],
516 &[4.9, 3.0, 1.4, 0.2],
517 &[4.7, 3.2, 1.3, 0.2],
518 &[4.6, 3.1, 1.5, 0.2],
519 &[5.0, 3.6, 1.4, 0.2],
520 &[5.4, 3.9, 1.7, 0.4],
521 &[4.6, 3.4, 1.4, 0.3],
522 &[5.0, 3.4, 1.5, 0.2],
523 &[4.4, 2.9, 1.4, 0.2],
524 &[4.9, 3.1, 1.5, 0.1],
525 &[7.0, 3.2, 4.7, 1.4],
526 &[6.4, 3.2, 4.5, 1.5],
527 &[6.9, 3.1, 4.9, 1.5],
528 &[5.5, 2.3, 4.0, 1.3],
529 &[6.5, 2.8, 4.6, 1.5],
530 &[5.7, 2.8, 4.5, 1.3],
531 &[6.3, 3.3, 4.7, 1.6],
532 &[4.9, 2.4, 3.3, 1.0],
533 &[6.6, 2.9, 4.6, 1.3],
534 &[5.2, 2.7, 3.9, 1.4],
535 ])
536 .unwrap();
537
538 let kmeans: KMeans<f32, f32, DenseMatrix<f32>, Vec<f32>> =
539 KMeans::fit(&x, Default::default()).unwrap();
540
541 let deserialized_kmeans: KMeans<f32, f32, DenseMatrix<f32>, Vec<f32>> =
542 serde_json::from_str(&serde_json::to_string(&kmeans).unwrap()).unwrap();
543
544 assert_eq!(kmeans, deserialized_kmeans);
545 }
546}