rusty_machine/analysis/
cross_validation.rs

1//! Module for performing cross-validation of models.
2
3use std::cmp;
4use std::iter::Chain;
5use std::slice::Iter;
6use linalg::{BaseMatrix, Matrix};
7use learning::{LearningResult, SupModel};
8use learning::toolkit::rand_utils::in_place_fisher_yates;
9
10/// Randomly splits the inputs into k 'folds'. For each fold a model
11/// is trained using all inputs except for that fold, and tested on the
12/// data in the fold. Returns the scores for each fold.
13///
14/// # Arguments
15/// * `model` - Used to train and predict for each fold.
16/// * `inputs` - All input samples.
17/// * `targets` - All targets.
18/// * `k` - Number of folds to use.
19/// * `score` - Used to compare the outputs for each fold to the targets. Higher scores are better. See the `analysis::score` module for examples.
20///
21/// # Examples
22/// ```
23/// use rusty_machine::analysis::cross_validation::k_fold_validate;
24/// use rusty_machine::analysis::score::row_accuracy;
25/// use rusty_machine::learning::naive_bayes::{NaiveBayes, Bernoulli};
26/// use rusty_machine::linalg::{BaseMatrix, Matrix};
27///
28/// let inputs = Matrix::new(3, 2, vec![1.0, 1.1,
29///                                     5.2, 4.3,
30///                                     6.2, 7.3]);
31///
32/// let targets = Matrix::new(3, 3, vec![1.0, 0.0, 0.0,
33///                                      0.0, 0.0, 1.0,
34///                                      0.0, 0.0, 1.0]);
35///
36/// let mut model = NaiveBayes::<Bernoulli>::new();
37///
38/// let accuracy_per_fold: Vec<f64> = k_fold_validate(
39///     &mut model,
40///     &inputs,
41///     &targets,
42///     3,
43///     // Score each fold by the fraction of test samples where
44///     // the model's prediction equals the target.
45///     row_accuracy
46/// ).unwrap();
47/// ```
48pub fn k_fold_validate<M, S>(model: &mut M,
49                             inputs: &Matrix<f64>,
50                             targets: &Matrix<f64>,
51                             k: usize,
52                             score: S) -> LearningResult<Vec<f64>>
53    where S: Fn(&Matrix<f64>, &Matrix<f64>) -> f64,
54          M: SupModel<Matrix<f64>, Matrix<f64>>,
55{
56    assert_eq!(inputs.rows(), targets.rows());
57    let num_samples = inputs.rows();
58    let shuffled_indices = create_shuffled_indices(num_samples);
59    let folds = Folds::new(&shuffled_indices, k);
60
61    let mut costs: Vec<f64> = Vec::new();
62
63    for p in folds {
64        // TODO: don't allocate fresh buffers for every fold
65        let train_inputs = inputs.select_rows(p.train_indices_iter.clone());
66        let train_targets = targets.select_rows(p.train_indices_iter.clone());
67        let test_inputs = inputs.select_rows(p.test_indices_iter.clone());
68        let test_targets = targets.select_rows(p.test_indices_iter.clone());
69
70        let _ = try!(model.train(&train_inputs, &train_targets));
71        let outputs = try!(model.predict(&test_inputs));
72        costs.push(score(&outputs, &test_targets));
73    }
74
75    Ok(costs)
76}
77
78/// A permutation of 0..n.
79struct ShuffledIndices(Vec<usize>);
80
81/// Permute the indices of the inputs samples.
82fn create_shuffled_indices(num_samples: usize) -> ShuffledIndices {
83    let mut indices: Vec<usize> = (0..num_samples).collect();
84    in_place_fisher_yates(&mut indices);
85    ShuffledIndices(indices)
86}
87
88/// A partition of indices of all available samples into
89/// a training set and a test set.
90struct Partition<'a> {
91    train_indices_iter: TrainingIndices<'a>,
92    test_indices_iter: TestIndices<'a>
93}
94
95#[derive(Clone)]
96struct TestIndices<'a>(Iter<'a, usize>);
97
98#[derive(Clone)]
99struct TrainingIndices<'a> {
100    chain: Chain<Iter<'a, usize>, Iter<'a, usize>>,
101    size: usize
102}
103
104impl<'a> TestIndices<'a> {
105    fn new(indices: &'a [usize]) -> TestIndices<'a> {
106        TestIndices(indices.iter())
107    }
108}
109
110impl<'a> Iterator for TestIndices<'a> {
111    type Item = &'a usize;
112
113    fn next(&mut self) -> Option<&'a usize> {
114        self.0.next()
115    }
116}
117
118impl <'a> ExactSizeIterator for TestIndices<'a> {
119    fn len(&self) -> usize {
120        self.0.len()
121    }
122}
123
124impl<'a> TrainingIndices<'a> {
125    fn new(left: &'a [usize], right: &'a [usize]) -> TrainingIndices<'a> {
126        let chain = left.iter().chain(right.iter());
127        TrainingIndices {
128            chain: chain,
129            size: left.len() + right.len()
130        }
131    }
132}
133
134impl<'a> Iterator for TrainingIndices<'a> {
135    type Item = &'a usize;
136
137    fn next(&mut self) -> Option<&'a usize> {
138        self.chain.next()
139    }
140}
141
142impl <'a> ExactSizeIterator for TrainingIndices<'a> {
143    fn len(&self) -> usize {
144        self.size
145    }
146}
147
148/// An iterator over the sets of indices required for k-fold cross validation.
149struct Folds<'a> {
150    num_folds: usize,
151    indices: &'a[usize],
152    count: usize
153}
154
155impl<'a> Folds<'a> {
156    /// Let n = indices.len(), and k = num_folds.
157    /// The first n % k folds have size n / k + 1 and the
158    /// rest have size n / k. (In particular, if n % k == 0 then all
159    /// folds are the same size.)
160    fn new(indices: &'a ShuffledIndices, num_folds: usize) -> Folds<'a> {
161        let num_samples = indices.0.len();
162        assert!(num_folds > 1 && num_samples >= num_folds,
163            "Require num_folds > 1 && num_samples >= num_folds");
164
165        Folds {
166            num_folds: num_folds,
167            indices: &indices.0,
168            count: 0
169        }
170    }
171}
172
173impl<'a> Iterator for Folds<'a> {
174    type Item = Partition<'a>;
175
176    fn next(&mut self) -> Option<Self::Item> {
177        if self.count >= self.num_folds {
178            return None;
179        }
180
181        let num_samples = self.indices.len();
182        let q = num_samples / self.num_folds;
183        let r = num_samples % self.num_folds;
184        let fold_start = self.count * q + cmp::min(self.count, r);
185        let fold_size = if self.count >= r {q} else {q + 1};
186        let fold_end = fold_start + fold_size;
187
188        self.count += 1;
189
190        let prefix = &self.indices[..fold_start];
191        let suffix = &self.indices[fold_end..];
192        let infix = &self.indices[fold_start..fold_end];
193        Some(Partition {
194            train_indices_iter: TrainingIndices::new(prefix, suffix),
195            test_indices_iter: TestIndices::new(infix)
196        })
197    }
198}
199
200#[cfg(test)]
201mod tests {
202    use super::{ShuffledIndices, Folds};
203
204    // k % n == 0
205    #[test]
206    fn test_folds_n6_k3() {
207        let idxs = ShuffledIndices(vec![0, 1, 2, 3, 4, 5]);
208        let folds = collect_folds(Folds::new(&idxs, 3));
209
210        assert_eq!(folds, vec![
211            (vec![2, 3, 4, 5], vec![0, 1]),
212            (vec![0, 1, 4, 5], vec![2, 3]),
213            (vec![0, 1, 2, 3], vec![4, 5])
214            ]);
215    }
216
217    // k % n == 1
218    #[test]
219    fn test_folds_n5_k2() {
220        let idxs = ShuffledIndices(vec![0, 1, 2, 3, 4]);
221        let folds = collect_folds(Folds::new(&idxs, 2));
222
223        assert_eq!(folds, vec![
224            (vec![3, 4], vec![0, 1, 2]),
225            (vec![0, 1, 2], vec![3, 4])
226            ]);
227    }
228
229    // k % n == 2
230    #[test]
231    fn test_folds_n6_k4() {
232        let idxs = ShuffledIndices(vec![0, 1, 2, 3, 4, 5]);
233        let folds = collect_folds(Folds::new(&idxs, 4));
234
235        assert_eq!(folds, vec![
236            (vec![2, 3, 4, 5], vec![0, 1]),
237            (vec![0, 1, 4, 5], vec![2, 3]),
238            (vec![0, 1, 2, 3, 5], vec![4]),
239            (vec![0, 1, 2, 3, 4], vec![5])
240            ]);
241    }
242
243    // k == n
244    #[test]
245    fn test_folds_n4_k4() {
246        let idxs = ShuffledIndices(vec![0, 1, 2, 3]);
247        let folds = collect_folds(Folds::new(&idxs, 4));
248
249        assert_eq!(folds, vec![
250            (vec![1, 2, 3], vec![0]),
251            (vec![0, 2, 3], vec![1]),
252            (vec![0, 1, 3], vec![2]),
253            (vec![0, 1, 2], vec![3])
254            ]);
255    }
256
257    #[test]
258    #[should_panic]
259    fn test_folds_rejects_large_k() {
260        let idxs = ShuffledIndices(vec![0, 1, 2]);
261        let _ = collect_folds(Folds::new(&idxs, 4));
262    }
263
264    // Check we're really returning iterators into the shuffled
265    // indices rather than into (0..n).
266    #[test]
267    fn test_folds_unordered_indices() {
268        let idxs = ShuffledIndices(vec![5, 4, 3, 2, 1, 0]);
269        let folds = collect_folds(Folds::new(&idxs, 3));
270
271        assert_eq!(folds, vec![
272            (vec![3, 2, 1, 0], vec![5, 4]),
273            (vec![5, 4, 1, 0], vec![3, 2]),
274            (vec![5, 4, 3, 2], vec![1, 0])
275            ]);
276    }
277
278    fn collect_folds<'a>(folds: Folds<'a>) -> Vec<(Vec<usize>, Vec<usize>)> {
279        folds
280            .map(|p|
281                (p.train_indices_iter.map(|x| *x).collect::<Vec<_>>(),
282                 p.test_indices_iter.map(|x| *x).collect::<Vec<_>>()))
283            .collect::<Vec<(Vec<usize>, Vec<usize>)>>()
284    }
285}