smartcore/cluster/
kmeans.rs

1//! # K-Means Clustering
2//!
3//! K-means clustering partitions data into k clusters in a way that data points in the same cluster are similar and data points in the different clusters are farther apart.
4//! Similarity of two points is determined by the [Euclidian Distance](../../math/distance/euclidian/index.html) between them.
5//!
6//! K-means algorithm is not capable of determining the number of clusters. You need to choose this number yourself.
7//! One way to choose optimal number of clusters is to use [Elbow Method](https://en.wikipedia.org/wiki/Elbow_method_(clustering)).
8//!
9//! At the high level K-Means algorithm works as follows. K data points are randomly chosen from a given dataset as cluster centers (centroids) and
10//! all training instances are added to the closest cluster. After that the centroids, representing the mean of the instances of each cluster are re-calculated and
11//! these re-calculated centroids becoming the new centers of their respective clusters. Next all instances of the training set are re-assigned to their closest cluster again.
12//! This iterative process continues until convergence is achieved and the clusters are considered settled.
13//!
14//! Initial choice of K data points is very important and has big effect on performance of the algorithm. `smartcore` uses k-means++ algorithm to initialize cluster centers.
15//!
16//! Example:
17//!
18//! ```
19//! use smartcore::linalg::basic::matrix::DenseMatrix;
20//! use smartcore::cluster::kmeans::*;
21//!
22//! // Iris data
23//! let x = DenseMatrix::from_2d_array(&[
24//!            &[5.1, 3.5, 1.4, 0.2],
25//!            &[4.9, 3.0, 1.4, 0.2],
26//!            &[4.7, 3.2, 1.3, 0.2],
27//!            &[4.6, 3.1, 1.5, 0.2],
28//!            &[5.0, 3.6, 1.4, 0.2],
29//!            &[5.4, 3.9, 1.7, 0.4],
30//!            &[4.6, 3.4, 1.4, 0.3],
31//!            &[5.0, 3.4, 1.5, 0.2],
32//!            &[4.4, 2.9, 1.4, 0.2],
33//!            &[4.9, 3.1, 1.5, 0.1],
34//!            &[7.0, 3.2, 4.7, 1.4],
35//!            &[6.4, 3.2, 4.5, 1.5],
36//!            &[6.9, 3.1, 4.9, 1.5],
37//!            &[5.5, 2.3, 4.0, 1.3],
38//!            &[6.5, 2.8, 4.6, 1.5],
39//!            &[5.7, 2.8, 4.5, 1.3],
40//!            &[6.3, 3.3, 4.7, 1.6],
41//!            &[4.9, 2.4, 3.3, 1.0],
42//!            &[6.6, 2.9, 4.6, 1.3],
43//!            &[5.2, 2.7, 3.9, 1.4],
44//!            ]).unwrap();
45//!
46//! let kmeans = KMeans::fit(&x, KMeansParameters::default().with_k(2)).unwrap(); // Fit to data, 2 clusters
47//! let y_hat: Vec<u8> = kmeans.predict(&x).unwrap(); // use the same points for prediction
48//! ```
49//!
50//! ## References:
51//!
52//! * ["An Introduction to Statistical Learning", James G., Witten D., Hastie T., Tibshirani R., 10.3.1 K-Means Clustering](http://faculty.marshall.usc.edu/gareth-james/ISL/)
53//! * ["k-means++: The Advantages of Careful Seeding", Arthur D., Vassilvitskii S.](http://ilpubs.stanford.edu:8090/778/1/2006-13.pdf)
54
55use std::fmt::Debug;
56use std::marker::PhantomData;
57
58use rand::Rng;
59#[cfg(feature = "serde")]
60use serde::{Deserialize, Serialize};
61
62use crate::algorithm::neighbour::bbd_tree::BBDTree;
63use crate::api::{Predictor, UnsupervisedEstimator};
64use crate::error::Failed;
65use crate::linalg::basic::arrays::{Array1, Array2};
66use crate::metrics::distance::euclidian::*;
67use crate::numbers::basenum::Number;
68use crate::rand_custom::get_rng_impl;
69
70/// K-Means clustering algorithm
71#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
72#[derive(Debug)]
73pub struct KMeans<TX: Number, TY: Number, X: Array2<TX>, Y: Array1<TY>> {
74    k: usize,
75    _y: Vec<usize>,
76    size: Vec<usize>,
77    _distortion: f64,
78    centroids: Vec<Vec<f64>>,
79    _phantom_tx: PhantomData<TX>,
80    _phantom_ty: PhantomData<TY>,
81    _phantom_x: PhantomData<X>,
82    _phantom_y: PhantomData<Y>,
83}
84
85impl<TX: Number, TY: Number, X: Array2<TX>, Y: Array1<TY>> PartialEq for KMeans<TX, TY, X, Y> {
86    fn eq(&self, other: &Self) -> bool {
87        if self.k != other.k
88            || self.size != other.size
89            || self.centroids.len() != other.centroids.len()
90        {
91            false
92        } else {
93            let n_centroids = self.centroids.len();
94            for i in 0..n_centroids {
95                if self.centroids[i].len() != other.centroids[i].len() {
96                    return false;
97                }
98                for j in 0..self.centroids[i].len() {
99                    if (self.centroids[i][j] - other.centroids[i][j]).abs() > f64::EPSILON {
100                        return false;
101                    }
102                }
103            }
104            true
105        }
106    }
107}
108
109#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
110#[derive(Debug, Clone)]
111/// K-Means clustering algorithm parameters
112pub struct KMeansParameters {
113    #[cfg_attr(feature = "serde", serde(default))]
114    /// Number of clusters.
115    pub k: usize,
116    #[cfg_attr(feature = "serde", serde(default))]
117    /// Maximum number of iterations of the k-means algorithm for a single run.
118    pub max_iter: usize,
119    #[cfg_attr(feature = "serde", serde(default))]
120    /// Determines random number generation for centroid initialization.
121    /// Use an int to make the randomness deterministic
122    pub seed: Option<u64>,
123}
124
125impl KMeansParameters {
126    /// Number of clusters.
127    pub fn with_k(mut self, k: usize) -> Self {
128        self.k = k;
129        self
130    }
131    /// Maximum number of iterations of the k-means algorithm for a single run.
132    pub fn with_max_iter(mut self, max_iter: usize) -> Self {
133        self.max_iter = max_iter;
134        self
135    }
136}
137
138impl Default for KMeansParameters {
139    fn default() -> Self {
140        KMeansParameters {
141            k: 2,
142            max_iter: 100,
143            seed: Option::None,
144        }
145    }
146}
147
148/// KMeans grid search parameters
149#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
150#[derive(Debug, Clone)]
151pub struct KMeansSearchParameters {
152    #[cfg_attr(feature = "serde", serde(default))]
153    /// Number of clusters.
154    pub k: Vec<usize>,
155    #[cfg_attr(feature = "serde", serde(default))]
156    /// Maximum number of iterations of the k-means algorithm for a single run.
157    pub max_iter: Vec<usize>,
158    #[cfg_attr(feature = "serde", serde(default))]
159    /// Determines random number generation for centroid initialization.
160    /// Use an int to make the randomness deterministic
161    pub seed: Vec<Option<u64>>,
162}
163
164/// KMeans grid search iterator
165pub struct KMeansSearchParametersIterator {
166    kmeans_search_parameters: KMeansSearchParameters,
167    current_k: usize,
168    current_max_iter: usize,
169    current_seed: usize,
170}
171
172impl IntoIterator for KMeansSearchParameters {
173    type Item = KMeansParameters;
174    type IntoIter = KMeansSearchParametersIterator;
175
176    fn into_iter(self) -> Self::IntoIter {
177        KMeansSearchParametersIterator {
178            kmeans_search_parameters: self,
179            current_k: 0,
180            current_max_iter: 0,
181            current_seed: 0,
182        }
183    }
184}
185
186impl Iterator for KMeansSearchParametersIterator {
187    type Item = KMeansParameters;
188
189    fn next(&mut self) -> Option<Self::Item> {
190        if self.current_k == self.kmeans_search_parameters.k.len()
191            && self.current_max_iter == self.kmeans_search_parameters.max_iter.len()
192            && self.current_seed == self.kmeans_search_parameters.seed.len()
193        {
194            return None;
195        }
196
197        let next = KMeansParameters {
198            k: self.kmeans_search_parameters.k[self.current_k],
199            max_iter: self.kmeans_search_parameters.max_iter[self.current_max_iter],
200            seed: self.kmeans_search_parameters.seed[self.current_seed],
201        };
202
203        if self.current_k + 1 < self.kmeans_search_parameters.k.len() {
204            self.current_k += 1;
205        } else if self.current_max_iter + 1 < self.kmeans_search_parameters.max_iter.len() {
206            self.current_k = 0;
207            self.current_max_iter += 1;
208        } else if self.current_seed + 1 < self.kmeans_search_parameters.seed.len() {
209            self.current_k = 0;
210            self.current_max_iter = 0;
211            self.current_seed += 1;
212        } else {
213            self.current_k += 1;
214            self.current_max_iter += 1;
215            self.current_seed += 1;
216        }
217
218        Some(next)
219    }
220}
221
222impl Default for KMeansSearchParameters {
223    fn default() -> Self {
224        let default_params = KMeansParameters::default();
225
226        KMeansSearchParameters {
227            k: vec![default_params.k],
228            max_iter: vec![default_params.max_iter],
229            seed: vec![default_params.seed],
230        }
231    }
232}
233
234impl<TX: Number, TY: Number, X: Array2<TX>, Y: Array1<TY>>
235    UnsupervisedEstimator<X, KMeansParameters> for KMeans<TX, TY, X, Y>
236{
237    fn fit(x: &X, parameters: KMeansParameters) -> Result<Self, Failed> {
238        KMeans::fit(x, parameters)
239    }
240}
241
242impl<TX: Number, TY: Number, X: Array2<TX>, Y: Array1<TY>> Predictor<X, Y>
243    for KMeans<TX, TY, X, Y>
244{
245    fn predict(&self, x: &X) -> Result<Y, Failed> {
246        self.predict(x)
247    }
248}
249
250impl<TX: Number, TY: Number, X: Array2<TX>, Y: Array1<TY>> KMeans<TX, TY, X, Y> {
251    /// Fit algorithm to _NxM_ matrix where _N_ is number of samples and _M_ is number of features.
252    /// * `data` - training instances to cluster
253    /// * `parameters` - cluster parameters
254    pub fn fit(data: &X, parameters: KMeansParameters) -> Result<KMeans<TX, TY, X, Y>, Failed> {
255        let bbd = BBDTree::new(data);
256
257        if parameters.k < 2 {
258            return Err(Failed::fit(&format!(
259                "invalid number of clusters: {}",
260                parameters.k
261            )));
262        }
263
264        if parameters.max_iter == 0 {
265            return Err(Failed::fit(&format!(
266                "invalid maximum number of iterations: {}",
267                parameters.max_iter
268            )));
269        }
270
271        let (n, d) = data.shape();
272
273        let mut distortion = f64::MAX;
274        let mut y = KMeans::<TX, TY, X, Y>::kmeans_plus_plus(data, parameters.k, parameters.seed);
275        let mut size = vec![0; parameters.k];
276        let mut centroids = vec![vec![0f64; d]; parameters.k];
277
278        for i in 0..n {
279            size[y[i]] += 1;
280        }
281
282        for i in 0..n {
283            for j in 0..d {
284                centroids[y[i]][j] += data.get((i, j)).to_f64().unwrap();
285            }
286        }
287
288        for i in 0..parameters.k {
289            for j in 0..d {
290                centroids[i][j] /= size[i] as f64;
291            }
292        }
293
294        let mut sums = vec![vec![0f64; d]; parameters.k];
295        for _ in 1..=parameters.max_iter {
296            let dist = bbd.clustering(&centroids, &mut sums, &mut size, &mut y);
297            for i in 0..parameters.k {
298                if size[i] > 0 {
299                    for j in 0..d {
300                        centroids[i][j] = sums[i][j] / size[i] as f64;
301                    }
302                }
303            }
304
305            if distortion <= dist {
306                break;
307            } else {
308                distortion = dist;
309            }
310        }
311
312        Ok(KMeans {
313            k: parameters.k,
314            _y: y,
315            size,
316            _distortion: distortion,
317            centroids,
318            _phantom_tx: PhantomData,
319            _phantom_ty: PhantomData,
320            _phantom_x: PhantomData,
321            _phantom_y: PhantomData,
322        })
323    }
324
325    /// Predict clusters for `x`
326    /// * `x` - matrix with new data to transform of size _KxM_ , where _K_ is number of new samples and _M_ is number of features.
327    pub fn predict(&self, x: &X) -> Result<Y, Failed> {
328        let (n, _) = x.shape();
329        let mut result = Y::zeros(n);
330
331        let mut row = vec![0f64; x.shape().1];
332
333        for i in 0..n {
334            let mut min_dist = f64::MAX;
335            let mut best_cluster = 0;
336
337            for j in 0..self.k {
338                x.get_row(i)
339                    .iterator(0)
340                    .zip(row.iter_mut())
341                    .for_each(|(&x, r)| *r = x.to_f64().unwrap());
342                let dist = Euclidian::squared_distance(&row, &self.centroids[j]);
343                if dist < min_dist {
344                    min_dist = dist;
345                    best_cluster = j;
346                }
347            }
348            result.set(i, TY::from_usize(best_cluster).unwrap());
349        }
350
351        Ok(result)
352    }
353
354    fn kmeans_plus_plus(data: &X, k: usize, seed: Option<u64>) -> Vec<usize> {
355        let mut rng = get_rng_impl(seed);
356        let (n, _) = data.shape();
357        let mut y = vec![0; n];
358        let mut centroid: Vec<TX> = data
359            .get_row(rng.gen_range(0..n))
360            .iterator(0)
361            .cloned()
362            .collect();
363
364        let mut d = vec![f64::MAX; n];
365        let mut row = vec![TX::zero(); data.shape().1];
366
367        for j in 1..k {
368            for i in 0..n {
369                data.get_row(i)
370                    .iterator(0)
371                    .zip(row.iter_mut())
372                    .for_each(|(&x, r)| *r = x);
373                let dist = Euclidian::squared_distance(&row, &centroid);
374
375                if dist < d[i] {
376                    d[i] = dist;
377                    y[i] = j - 1;
378                }
379            }
380
381            let mut sum = 0f64;
382            for i in d.iter() {
383                sum += *i;
384            }
385            let cutoff = rng.gen::<f64>() * sum;
386            let mut cost = 0f64;
387            let mut index = 0;
388            while index < n {
389                cost += d[index];
390                if cost >= cutoff {
391                    break;
392                }
393                index += 1;
394            }
395
396            centroid = data.get_row(index).iterator(0).cloned().collect();
397        }
398
399        for i in 0..n {
400            data.get_row(i)
401                .iterator(0)
402                .zip(row.iter_mut())
403                .for_each(|(&x, r)| *r = x);
404            let dist = Euclidian::squared_distance(&row, &centroid);
405
406            if dist < d[i] {
407                d[i] = dist;
408                y[i] = k - 1;
409            }
410        }
411
412        y
413    }
414}
415
416#[cfg(test)]
417mod tests {
418    use super::*;
419    use crate::linalg::basic::matrix::DenseMatrix;
420
421    #[cfg_attr(
422        all(target_arch = "wasm32", not(target_os = "wasi")),
423        wasm_bindgen_test::wasm_bindgen_test
424    )]
425    #[test]
426    fn invalid_k() {
427        let x = DenseMatrix::from_2d_array(&[&[1, 2, 3], &[4, 5, 6]]).unwrap();
428
429        assert!(KMeans::<i32, i32, DenseMatrix<i32>, Vec<i32>>::fit(
430            &x,
431            KMeansParameters::default().with_k(0)
432        )
433        .is_err());
434        assert_eq!(
435            "Fit failed: invalid number of clusters: 1",
436            KMeans::<i32, i32, DenseMatrix<i32>, Vec<i32>>::fit(
437                &x,
438                KMeansParameters::default().with_k(1)
439            )
440            .unwrap_err()
441            .to_string()
442        );
443    }
444
445    #[test]
446    fn search_parameters() {
447        let parameters = KMeansSearchParameters {
448            k: vec![2, 4],
449            max_iter: vec![10, 100],
450            ..Default::default()
451        };
452        let mut iter = parameters.into_iter();
453        let next = iter.next().unwrap();
454        assert_eq!(next.k, 2);
455        assert_eq!(next.max_iter, 10);
456        let next = iter.next().unwrap();
457        assert_eq!(next.k, 4);
458        assert_eq!(next.max_iter, 10);
459        let next = iter.next().unwrap();
460        assert_eq!(next.k, 2);
461        assert_eq!(next.max_iter, 100);
462        let next = iter.next().unwrap();
463        assert_eq!(next.k, 4);
464        assert_eq!(next.max_iter, 100);
465        assert!(iter.next().is_none());
466    }
467
468    #[cfg_attr(
469        all(target_arch = "wasm32", not(target_os = "wasi")),
470        wasm_bindgen_test::wasm_bindgen_test
471    )]
472    #[test]
473    fn fit_predict() {
474        let x = DenseMatrix::from_2d_array(&[
475            &[5.1, 3.5, 1.4, 0.2],
476            &[4.9, 3.0, 1.4, 0.2],
477            &[4.7, 3.2, 1.3, 0.2],
478            &[4.6, 3.1, 1.5, 0.2],
479            &[5.0, 3.6, 1.4, 0.2],
480            &[5.4, 3.9, 1.7, 0.4],
481            &[4.6, 3.4, 1.4, 0.3],
482            &[5.0, 3.4, 1.5, 0.2],
483            &[4.4, 2.9, 1.4, 0.2],
484            &[4.9, 3.1, 1.5, 0.1],
485            &[7.0, 3.2, 4.7, 1.4],
486            &[6.4, 3.2, 4.5, 1.5],
487            &[6.9, 3.1, 4.9, 1.5],
488            &[5.5, 2.3, 4.0, 1.3],
489            &[6.5, 2.8, 4.6, 1.5],
490            &[5.7, 2.8, 4.5, 1.3],
491            &[6.3, 3.3, 4.7, 1.6],
492            &[4.9, 2.4, 3.3, 1.0],
493            &[6.6, 2.9, 4.6, 1.3],
494            &[5.2, 2.7, 3.9, 1.4],
495        ])
496        .unwrap();
497
498        let kmeans = KMeans::fit(&x, Default::default()).unwrap();
499
500        let y: Vec<usize> = kmeans.predict(&x).unwrap();
501
502        for (i, _y_i) in y.iter().enumerate() {
503            assert_eq!({ y[i] }, kmeans._y[i]);
504        }
505    }
506
507    #[cfg_attr(
508        all(target_arch = "wasm32", not(target_os = "wasi")),
509        wasm_bindgen_test::wasm_bindgen_test
510    )]
511    #[test]
512    #[cfg(feature = "serde")]
513    fn serde() {
514        let x = DenseMatrix::from_2d_array(&[
515            &[5.1, 3.5, 1.4, 0.2],
516            &[4.9, 3.0, 1.4, 0.2],
517            &[4.7, 3.2, 1.3, 0.2],
518            &[4.6, 3.1, 1.5, 0.2],
519            &[5.0, 3.6, 1.4, 0.2],
520            &[5.4, 3.9, 1.7, 0.4],
521            &[4.6, 3.4, 1.4, 0.3],
522            &[5.0, 3.4, 1.5, 0.2],
523            &[4.4, 2.9, 1.4, 0.2],
524            &[4.9, 3.1, 1.5, 0.1],
525            &[7.0, 3.2, 4.7, 1.4],
526            &[6.4, 3.2, 4.5, 1.5],
527            &[6.9, 3.1, 4.9, 1.5],
528            &[5.5, 2.3, 4.0, 1.3],
529            &[6.5, 2.8, 4.6, 1.5],
530            &[5.7, 2.8, 4.5, 1.3],
531            &[6.3, 3.3, 4.7, 1.6],
532            &[4.9, 2.4, 3.3, 1.0],
533            &[6.6, 2.9, 4.6, 1.3],
534            &[5.2, 2.7, 3.9, 1.4],
535        ])
536        .unwrap();
537
538        let kmeans: KMeans<f32, f32, DenseMatrix<f32>, Vec<f32>> =
539            KMeans::fit(&x, Default::default()).unwrap();
540
541        let deserialized_kmeans: KMeans<f32, f32, DenseMatrix<f32>, Vec<f32>> =
542            serde_json::from_str(&serde_json::to_string(&kmeans).unwrap()).unwrap();
543
544        assert_eq!(kmeans, deserialized_kmeans);
545    }
546}