rusty_machine/learning/
k_means.rs

1//! K-means Classification
2//!
3//! Provides implementation of K-Means classification.
4//!
5//! # Usage
6//!
7//! ```
8//! use rusty_machine::linalg::Matrix;
9//! use rusty_machine::learning::k_means::KMeansClassifier;
10//! use rusty_machine::learning::UnSupModel;
11//!
12//! let inputs = Matrix::new(3, 2, vec![1.0, 2.0, 1.0, 3.0, 1.0, 4.0]);
13//! let test_inputs = Matrix::new(1, 2, vec![1.0, 3.5]);
14//!
15//! // Create model with k(=2) classes.
16//! let mut model = KMeansClassifier::new(2);
17//!
18//! // Where inputs is a Matrix with features in columns.
19//! model.train(&inputs).unwrap();
20//!
21//! // Where test_inputs is a Matrix with features in columns.
22//! let a = model.predict(&test_inputs).unwrap();
23//! ```
24//!
25//! Additionally you can control the initialization
26//! algorithm and max number of iterations.
27//!
28//! # Initializations
29//!
30//! Three initialization algorithms are supported.
31//!
32//! ## Forgy initialization
33//!
34//! Choose initial centroids randomly from the data.
35//!
36//! ## Random Partition initialization
37//!
38//! Randomly assign each data point to one of k clusters.
39//! The initial centroids are the mean of the data in their class.
40//!
41//! ## K-means++ initialization
42//!
43//! The [k-means++](https://en.wikipedia.org/wiki/K-means%2B%2B) scheme.
44
45use linalg::{Matrix, MatrixSlice, Axes, Vector, BaseMatrix};
46use learning::{LearningResult, UnSupModel};
47use learning::error::{Error, ErrorKind};
48
49use rand::{Rng, thread_rng};
50use libnum::abs;
51
52use std::fmt::Debug;
53
54/// K-Means Classification model.
55///
56/// Contains option for centroids.
57/// Specifies iterations and number of classes.
58///
59/// # Usage
60///
61/// This model is used through the `UnSupModel` trait. The model is
62/// trained via the `train` function with a matrix containing rows of
63/// feature vectors.
64///
65/// The model will not check to ensure the data coming in is all valid.
66/// This responsibility lies with the user (for now).
67#[derive(Debug)]
68pub struct KMeansClassifier<InitAlg: Initializer> {
69    /// Max iterations of algorithm to run.
70    iters: usize,
71    /// The number of classes.
72    k: usize,
73    /// The fitted centroids .
74    centroids: Option<Matrix<f64>>,
75    /// The initial algorithm to use.
76    init_algorithm: InitAlg,
77}
78
79impl<InitAlg: Initializer> UnSupModel<Matrix<f64>, Vector<usize>> for KMeansClassifier<InitAlg> {
80    /// Predict classes from data.
81    ///
82    /// Model must be trained.
83    fn predict(&self, inputs: &Matrix<f64>) -> LearningResult<Vector<usize>> {
84        if let Some(ref centroids) = self.centroids {
85            Ok(KMeansClassifier::<InitAlg>::find_closest_centroids(centroids.as_slice(), inputs).0)
86        } else {
87            Err(Error::new_untrained())
88        }
89    }
90
91    /// Train the classifier using input data.
92    fn train(&mut self, inputs: &Matrix<f64>) -> LearningResult<()> {
93        try!(self.init_centroids(inputs));
94        let mut cost = 0.0;
95        let eps = 1e-14;
96
97        for _i in 0..self.iters {
98            let (idx, distances) = try!(self.get_closest_centroids(inputs));
99            self.update_centroids(inputs, idx);
100
101            let cost_i = distances.sum();
102            if abs(cost - cost_i) < eps {
103                break;
104            }
105
106            cost = cost_i;
107        }
108
109        Ok(())
110    }
111}
112
113impl KMeansClassifier<KPlusPlus> {
114    /// Constructs untrained k-means classifier model.
115    ///
116    /// Requires number of classes to be specified.
117    /// Defaults to 100 iterations and kmeans++ initialization.
118    ///
119    /// # Examples
120    ///
121    /// ```
122    /// use rusty_machine::learning::k_means::KMeansClassifier;
123    ///
124    /// let model = KMeansClassifier::new(5);
125    /// ```
126    pub fn new(k: usize) -> KMeansClassifier<KPlusPlus> {
127        KMeansClassifier {
128            iters: 100,
129            k: k,
130            centroids: None,
131            init_algorithm: KPlusPlus,
132        }
133    }
134}
135
136impl<InitAlg: Initializer> KMeansClassifier<InitAlg> {
137    /// Constructs untrained k-means classifier model.
138    ///
139    /// Requires number of classes, number of iterations, and
140    /// the initialization algorithm to use.
141    ///
142    /// # Examples
143    ///
144    /// ```
145    /// use rusty_machine::learning::k_means::{KMeansClassifier, Forgy};
146    ///
147    /// let model = KMeansClassifier::new_specified(5, 42, Forgy);
148    /// ```
149    pub fn new_specified(k: usize, iters: usize, algo: InitAlg) -> KMeansClassifier<InitAlg> {
150        KMeansClassifier {
151            iters: iters,
152            k: k,
153            centroids: None,
154            init_algorithm: algo,
155        }
156    }
157
158    /// Get the number of classes.
159    pub fn k(&self) -> usize {
160        self.k
161    }
162
163    /// Get the number of iterations.
164    pub fn iters(&self) -> usize {
165        self.iters
166    }
167
168    /// Get the initialization algorithm.
169    pub fn init_algorithm(&self) -> &InitAlg {
170        &self.init_algorithm
171    }
172
173    /// Get the centroids `Option<Matrix<f64>>`.
174    pub fn centroids(&self) -> &Option<Matrix<f64>> {
175        &self.centroids
176    }
177
178    /// Set the number of iterations.
179    pub fn set_iters(&mut self, iters: usize) {
180        self.iters = iters;
181    }
182
183    /// Initialize the centroids.
184    ///
185    /// Used internally within model.
186    fn init_centroids(&mut self, inputs: &Matrix<f64>) -> LearningResult<()> {
187        if self.k > inputs.rows() {
188            Err(Error::new(ErrorKind::InvalidData,
189                           format!("Number of clusters ({0}) exceeds number of data points \
190                                    ({1}).",
191                                   self.k,
192                                   inputs.rows())))
193        } else {
194            let centroids = try!(self.init_algorithm.init_centroids(self.k, inputs));
195
196            if centroids.rows() != self.k {
197                Err(Error::new(ErrorKind::InvalidState,
198                                    "Initial centroids must have exactly k rows."))
199            } else if centroids.cols() != inputs.cols() {
200                Err(Error::new(ErrorKind::InvalidState,
201                                    "Initial centroids must have the same column count as inputs."))
202            } else {
203                self.centroids = Some(centroids);
204                Ok(())
205            }
206        }
207
208    }
209
210    /// Updated the centroids by computing means of assigned classes.
211    ///
212    /// Used internally within model.
213    fn update_centroids(&mut self, inputs: &Matrix<f64>, classes: Vector<usize>) {
214        let mut new_centroids = Vec::with_capacity(self.k * inputs.cols());
215
216        let mut row_indexes = vec![Vec::new(); self.k];
217        for (i, c) in classes.into_vec().into_iter().enumerate() {
218            row_indexes.get_mut(c as usize).map(|v| v.push(i));
219        }
220
221        for vec_i in row_indexes {
222            let mat_i = inputs.select_rows(&vec_i);
223            new_centroids.extend(mat_i.mean(Axes::Row).into_vec());
224        }
225
226        self.centroids = Some(Matrix::new(self.k, inputs.cols(), new_centroids));
227    }
228
229    fn get_closest_centroids(&self,
230                             inputs: &Matrix<f64>)
231                             -> LearningResult<(Vector<usize>, Vector<f64>)> {
232        if let Some(ref c) = self.centroids {
233            Ok(KMeansClassifier::<InitAlg>::find_closest_centroids(c.as_slice(), inputs))
234        } else {
235            Err(Error::new(ErrorKind::InvalidState,
236                           "Centroids not correctly initialized."))
237        }
238    }
239
240    /// Find the centroid closest to each data point.
241    ///
242    /// Used internally within model.
243    /// Returns the index of the closest centroid and the distance to it.
244    fn find_closest_centroids(centroids: MatrixSlice<f64>,
245                              inputs: &Matrix<f64>)
246                              -> (Vector<usize>, Vector<f64>) {
247        let mut idx = Vec::with_capacity(inputs.rows());
248        let mut distances = Vec::with_capacity(inputs.rows());
249
250        for i in 0..inputs.rows() {
251            // This works like repmat pulling out row i repeatedly.
252            let centroid_diff = centroids - inputs.select_rows(&vec![i; centroids.rows()]);
253            let dist = &centroid_diff.elemul(&centroid_diff).sum_cols();
254
255            // Now take argmin and this is the centroid.
256            let (min_idx, min_dist) = dist.argmin();
257            idx.push(min_idx);
258            distances.push(min_dist);
259        }
260
261        (Vector::new(idx), Vector::new(distances))
262    }
263}
264
265/// Trait for algorithms initializing the K-means centroids.
266pub trait Initializer: Debug {
267    /// Initialize the centroids for the initial state of the K-Means model.
268    ///
269    /// The `Matrix` returned must have `k` rows and the same column count as `inputs`.
270    fn init_centroids(&self, k: usize, inputs: &Matrix<f64>) -> LearningResult<Matrix<f64>>;
271}
272
273/// The Forgy initialization scheme.
274#[derive(Debug)]
275pub struct Forgy;
276
277impl Initializer for Forgy {
278    fn init_centroids(&self, k: usize, inputs: &Matrix<f64>) -> LearningResult<Matrix<f64>> {
279        let mut random_choices = Vec::with_capacity(k);
280        let mut rng = thread_rng();
281        while random_choices.len() < k {
282            let r = rng.gen_range(0, inputs.rows());
283
284            if !random_choices.contains(&r) {
285                random_choices.push(r);
286            }
287        }
288
289        Ok(inputs.select_rows(&random_choices))
290    }
291}
292
293/// The Random Partition initialization scheme.
294#[derive(Debug)]
295pub struct RandomPartition;
296
297impl Initializer for RandomPartition {
298    fn init_centroids(&self, k: usize, inputs: &Matrix<f64>) -> LearningResult<Matrix<f64>> {
299
300        // Populate so we have something in each class.
301        let mut random_assignments = (0..k).map(|i| vec![i]).collect::<Vec<Vec<usize>>>();
302        let mut rng = thread_rng();
303        for i in k..inputs.rows() {
304            let idx = rng.gen_range(0, k);
305            unsafe {
306                random_assignments.get_unchecked_mut(idx).push(i);
307            }
308        }
309
310        let mut init_centroids = Vec::with_capacity(k * inputs.cols());
311
312        for vec_i in random_assignments {
313            let mat_i = inputs.select_rows(&vec_i);
314            init_centroids.extend_from_slice(&*mat_i.mean(Axes::Row).into_vec());
315        }
316
317        Ok(Matrix::new(k, inputs.cols(), init_centroids))
318    }
319}
320
321/// The K-means ++ initialization scheme.
322#[derive(Debug)]
323pub struct KPlusPlus;
324
325impl Initializer for KPlusPlus {
326    fn init_centroids(&self, k: usize, inputs: &Matrix<f64>) -> LearningResult<Matrix<f64>> {
327        let mut rng = thread_rng();
328
329        let mut init_centroids = Vec::with_capacity(k * inputs.cols());
330        let first_cen = rng.gen_range(0usize, inputs.rows());
331
332        unsafe {
333            init_centroids.extend_from_slice(inputs.get_row_unchecked(first_cen));
334        }
335
336        for i in 1..k {
337            unsafe {
338                let temp_centroids = MatrixSlice::from_raw_parts(init_centroids.as_ptr(),
339                                                                 i,
340                                                                 inputs.cols(),
341                                                                 inputs.cols());
342                let (_, dist) =
343                    KMeansClassifier::<KPlusPlus>::find_closest_centroids(temp_centroids, &inputs);
344
345                // A relatively cheap way to validate our input data
346                if !dist.data().iter().all(|x| x.is_finite()) {
347                    return Err(Error::new(ErrorKind::InvalidData,
348                                          "Input data led to invalid centroid distances during \
349                                           initialization."));
350                }
351
352                let next_cen = sample_discretely(dist);
353                init_centroids.extend_from_slice(inputs.get_row_unchecked(next_cen));
354            }
355        }
356
357        Ok(Matrix::new(k, inputs.cols(), init_centroids))
358    }
359}
360
361/// Sample from an unnormalized distribution.
362///
363/// The input to this function is assumed to have all positive entries.
364fn sample_discretely(unnorm_dist: Vector<f64>) -> usize {
365    assert!(unnorm_dist.size() > 0, "No entries in distribution vector.");
366
367    let sum = unnorm_dist.sum();
368
369    let rand = thread_rng().gen_range(0.0f64, sum);
370
371    let mut tempsum = 0.0;
372    for (i, p) in unnorm_dist.data().iter().enumerate() {
373        tempsum += *p;
374
375        if rand < tempsum {
376            return i;
377        }
378    }
379
380    panic!("No random value was sampled! There may be more clusters than unique data points.");
381}