rusty_machine/learning/
k_means.rs1use linalg::{Matrix, MatrixSlice, Axes, Vector, BaseMatrix};
46use learning::{LearningResult, UnSupModel};
47use learning::error::{Error, ErrorKind};
48
49use rand::{Rng, thread_rng};
50use libnum::abs;
51
52use std::fmt::Debug;
53
54#[derive(Debug)]
68pub struct KMeansClassifier<InitAlg: Initializer> {
69 iters: usize,
71 k: usize,
73 centroids: Option<Matrix<f64>>,
75 init_algorithm: InitAlg,
77}
78
79impl<InitAlg: Initializer> UnSupModel<Matrix<f64>, Vector<usize>> for KMeansClassifier<InitAlg> {
80 fn predict(&self, inputs: &Matrix<f64>) -> LearningResult<Vector<usize>> {
84 if let Some(ref centroids) = self.centroids {
85 Ok(KMeansClassifier::<InitAlg>::find_closest_centroids(centroids.as_slice(), inputs).0)
86 } else {
87 Err(Error::new_untrained())
88 }
89 }
90
91 fn train(&mut self, inputs: &Matrix<f64>) -> LearningResult<()> {
93 try!(self.init_centroids(inputs));
94 let mut cost = 0.0;
95 let eps = 1e-14;
96
97 for _i in 0..self.iters {
98 let (idx, distances) = try!(self.get_closest_centroids(inputs));
99 self.update_centroids(inputs, idx);
100
101 let cost_i = distances.sum();
102 if abs(cost - cost_i) < eps {
103 break;
104 }
105
106 cost = cost_i;
107 }
108
109 Ok(())
110 }
111}
112
113impl KMeansClassifier<KPlusPlus> {
114 pub fn new(k: usize) -> KMeansClassifier<KPlusPlus> {
127 KMeansClassifier {
128 iters: 100,
129 k: k,
130 centroids: None,
131 init_algorithm: KPlusPlus,
132 }
133 }
134}
135
136impl<InitAlg: Initializer> KMeansClassifier<InitAlg> {
137 pub fn new_specified(k: usize, iters: usize, algo: InitAlg) -> KMeansClassifier<InitAlg> {
150 KMeansClassifier {
151 iters: iters,
152 k: k,
153 centroids: None,
154 init_algorithm: algo,
155 }
156 }
157
158 pub fn k(&self) -> usize {
160 self.k
161 }
162
163 pub fn iters(&self) -> usize {
165 self.iters
166 }
167
168 pub fn init_algorithm(&self) -> &InitAlg {
170 &self.init_algorithm
171 }
172
173 pub fn centroids(&self) -> &Option<Matrix<f64>> {
175 &self.centroids
176 }
177
178 pub fn set_iters(&mut self, iters: usize) {
180 self.iters = iters;
181 }
182
183 fn init_centroids(&mut self, inputs: &Matrix<f64>) -> LearningResult<()> {
187 if self.k > inputs.rows() {
188 Err(Error::new(ErrorKind::InvalidData,
189 format!("Number of clusters ({0}) exceeds number of data points \
190 ({1}).",
191 self.k,
192 inputs.rows())))
193 } else {
194 let centroids = try!(self.init_algorithm.init_centroids(self.k, inputs));
195
196 if centroids.rows() != self.k {
197 Err(Error::new(ErrorKind::InvalidState,
198 "Initial centroids must have exactly k rows."))
199 } else if centroids.cols() != inputs.cols() {
200 Err(Error::new(ErrorKind::InvalidState,
201 "Initial centroids must have the same column count as inputs."))
202 } else {
203 self.centroids = Some(centroids);
204 Ok(())
205 }
206 }
207
208 }
209
210 fn update_centroids(&mut self, inputs: &Matrix<f64>, classes: Vector<usize>) {
214 let mut new_centroids = Vec::with_capacity(self.k * inputs.cols());
215
216 let mut row_indexes = vec![Vec::new(); self.k];
217 for (i, c) in classes.into_vec().into_iter().enumerate() {
218 row_indexes.get_mut(c as usize).map(|v| v.push(i));
219 }
220
221 for vec_i in row_indexes {
222 let mat_i = inputs.select_rows(&vec_i);
223 new_centroids.extend(mat_i.mean(Axes::Row).into_vec());
224 }
225
226 self.centroids = Some(Matrix::new(self.k, inputs.cols(), new_centroids));
227 }
228
229 fn get_closest_centroids(&self,
230 inputs: &Matrix<f64>)
231 -> LearningResult<(Vector<usize>, Vector<f64>)> {
232 if let Some(ref c) = self.centroids {
233 Ok(KMeansClassifier::<InitAlg>::find_closest_centroids(c.as_slice(), inputs))
234 } else {
235 Err(Error::new(ErrorKind::InvalidState,
236 "Centroids not correctly initialized."))
237 }
238 }
239
240 fn find_closest_centroids(centroids: MatrixSlice<f64>,
245 inputs: &Matrix<f64>)
246 -> (Vector<usize>, Vector<f64>) {
247 let mut idx = Vec::with_capacity(inputs.rows());
248 let mut distances = Vec::with_capacity(inputs.rows());
249
250 for i in 0..inputs.rows() {
251 let centroid_diff = centroids - inputs.select_rows(&vec![i; centroids.rows()]);
253 let dist = ¢roid_diff.elemul(¢roid_diff).sum_cols();
254
255 let (min_idx, min_dist) = dist.argmin();
257 idx.push(min_idx);
258 distances.push(min_dist);
259 }
260
261 (Vector::new(idx), Vector::new(distances))
262 }
263}
264
265pub trait Initializer: Debug {
267 fn init_centroids(&self, k: usize, inputs: &Matrix<f64>) -> LearningResult<Matrix<f64>>;
271}
272
273#[derive(Debug)]
275pub struct Forgy;
276
277impl Initializer for Forgy {
278 fn init_centroids(&self, k: usize, inputs: &Matrix<f64>) -> LearningResult<Matrix<f64>> {
279 let mut random_choices = Vec::with_capacity(k);
280 let mut rng = thread_rng();
281 while random_choices.len() < k {
282 let r = rng.gen_range(0, inputs.rows());
283
284 if !random_choices.contains(&r) {
285 random_choices.push(r);
286 }
287 }
288
289 Ok(inputs.select_rows(&random_choices))
290 }
291}
292
293#[derive(Debug)]
295pub struct RandomPartition;
296
297impl Initializer for RandomPartition {
298 fn init_centroids(&self, k: usize, inputs: &Matrix<f64>) -> LearningResult<Matrix<f64>> {
299
300 let mut random_assignments = (0..k).map(|i| vec![i]).collect::<Vec<Vec<usize>>>();
302 let mut rng = thread_rng();
303 for i in k..inputs.rows() {
304 let idx = rng.gen_range(0, k);
305 unsafe {
306 random_assignments.get_unchecked_mut(idx).push(i);
307 }
308 }
309
310 let mut init_centroids = Vec::with_capacity(k * inputs.cols());
311
312 for vec_i in random_assignments {
313 let mat_i = inputs.select_rows(&vec_i);
314 init_centroids.extend_from_slice(&*mat_i.mean(Axes::Row).into_vec());
315 }
316
317 Ok(Matrix::new(k, inputs.cols(), init_centroids))
318 }
319}
320
321#[derive(Debug)]
323pub struct KPlusPlus;
324
325impl Initializer for KPlusPlus {
326 fn init_centroids(&self, k: usize, inputs: &Matrix<f64>) -> LearningResult<Matrix<f64>> {
327 let mut rng = thread_rng();
328
329 let mut init_centroids = Vec::with_capacity(k * inputs.cols());
330 let first_cen = rng.gen_range(0usize, inputs.rows());
331
332 unsafe {
333 init_centroids.extend_from_slice(inputs.get_row_unchecked(first_cen));
334 }
335
336 for i in 1..k {
337 unsafe {
338 let temp_centroids = MatrixSlice::from_raw_parts(init_centroids.as_ptr(),
339 i,
340 inputs.cols(),
341 inputs.cols());
342 let (_, dist) =
343 KMeansClassifier::<KPlusPlus>::find_closest_centroids(temp_centroids, &inputs);
344
345 if !dist.data().iter().all(|x| x.is_finite()) {
347 return Err(Error::new(ErrorKind::InvalidData,
348 "Input data led to invalid centroid distances during \
349 initialization."));
350 }
351
352 let next_cen = sample_discretely(dist);
353 init_centroids.extend_from_slice(inputs.get_row_unchecked(next_cen));
354 }
355 }
356
357 Ok(Matrix::new(k, inputs.cols(), init_centroids))
358 }
359}
360
361fn sample_discretely(unnorm_dist: Vector<f64>) -> usize {
365 assert!(unnorm_dist.size() > 0, "No entries in distribution vector.");
366
367 let sum = unnorm_dist.sum();
368
369 let rand = thread_rng().gen_range(0.0f64, sum);
370
371 let mut tempsum = 0.0;
372 for (i, p) in unnorm_dist.data().iter().enumerate() {
373 tempsum += *p;
374
375 if rand < tempsum {
376 return i;
377 }
378 }
379
380 panic!("No random value was sampled! There may be more clusters than unique data points.");
381}