1use std::cmp;
4use std::iter::Chain;
5use std::slice::Iter;
6use linalg::{BaseMatrix, Matrix};
7use learning::{LearningResult, SupModel};
8use learning::toolkit::rand_utils::in_place_fisher_yates;
9
10pub fn k_fold_validate<M, S>(model: &mut M,
49 inputs: &Matrix<f64>,
50 targets: &Matrix<f64>,
51 k: usize,
52 score: S) -> LearningResult<Vec<f64>>
53 where S: Fn(&Matrix<f64>, &Matrix<f64>) -> f64,
54 M: SupModel<Matrix<f64>, Matrix<f64>>,
55{
56 assert_eq!(inputs.rows(), targets.rows());
57 let num_samples = inputs.rows();
58 let shuffled_indices = create_shuffled_indices(num_samples);
59 let folds = Folds::new(&shuffled_indices, k);
60
61 let mut costs: Vec<f64> = Vec::new();
62
63 for p in folds {
64 let train_inputs = inputs.select_rows(p.train_indices_iter.clone());
66 let train_targets = targets.select_rows(p.train_indices_iter.clone());
67 let test_inputs = inputs.select_rows(p.test_indices_iter.clone());
68 let test_targets = targets.select_rows(p.test_indices_iter.clone());
69
70 let _ = try!(model.train(&train_inputs, &train_targets));
71 let outputs = try!(model.predict(&test_inputs));
72 costs.push(score(&outputs, &test_targets));
73 }
74
75 Ok(costs)
76}
77
78struct ShuffledIndices(Vec<usize>);
80
81fn create_shuffled_indices(num_samples: usize) -> ShuffledIndices {
83 let mut indices: Vec<usize> = (0..num_samples).collect();
84 in_place_fisher_yates(&mut indices);
85 ShuffledIndices(indices)
86}
87
88struct Partition<'a> {
91 train_indices_iter: TrainingIndices<'a>,
92 test_indices_iter: TestIndices<'a>
93}
94
95#[derive(Clone)]
96struct TestIndices<'a>(Iter<'a, usize>);
97
98#[derive(Clone)]
99struct TrainingIndices<'a> {
100 chain: Chain<Iter<'a, usize>, Iter<'a, usize>>,
101 size: usize
102}
103
104impl<'a> TestIndices<'a> {
105 fn new(indices: &'a [usize]) -> TestIndices<'a> {
106 TestIndices(indices.iter())
107 }
108}
109
110impl<'a> Iterator for TestIndices<'a> {
111 type Item = &'a usize;
112
113 fn next(&mut self) -> Option<&'a usize> {
114 self.0.next()
115 }
116}
117
118impl <'a> ExactSizeIterator for TestIndices<'a> {
119 fn len(&self) -> usize {
120 self.0.len()
121 }
122}
123
124impl<'a> TrainingIndices<'a> {
125 fn new(left: &'a [usize], right: &'a [usize]) -> TrainingIndices<'a> {
126 let chain = left.iter().chain(right.iter());
127 TrainingIndices {
128 chain: chain,
129 size: left.len() + right.len()
130 }
131 }
132}
133
134impl<'a> Iterator for TrainingIndices<'a> {
135 type Item = &'a usize;
136
137 fn next(&mut self) -> Option<&'a usize> {
138 self.chain.next()
139 }
140}
141
142impl <'a> ExactSizeIterator for TrainingIndices<'a> {
143 fn len(&self) -> usize {
144 self.size
145 }
146}
147
148struct Folds<'a> {
150 num_folds: usize,
151 indices: &'a[usize],
152 count: usize
153}
154
155impl<'a> Folds<'a> {
156 fn new(indices: &'a ShuffledIndices, num_folds: usize) -> Folds<'a> {
161 let num_samples = indices.0.len();
162 assert!(num_folds > 1 && num_samples >= num_folds,
163 "Require num_folds > 1 && num_samples >= num_folds");
164
165 Folds {
166 num_folds: num_folds,
167 indices: &indices.0,
168 count: 0
169 }
170 }
171}
172
173impl<'a> Iterator for Folds<'a> {
174 type Item = Partition<'a>;
175
176 fn next(&mut self) -> Option<Self::Item> {
177 if self.count >= self.num_folds {
178 return None;
179 }
180
181 let num_samples = self.indices.len();
182 let q = num_samples / self.num_folds;
183 let r = num_samples % self.num_folds;
184 let fold_start = self.count * q + cmp::min(self.count, r);
185 let fold_size = if self.count >= r {q} else {q + 1};
186 let fold_end = fold_start + fold_size;
187
188 self.count += 1;
189
190 let prefix = &self.indices[..fold_start];
191 let suffix = &self.indices[fold_end..];
192 let infix = &self.indices[fold_start..fold_end];
193 Some(Partition {
194 train_indices_iter: TrainingIndices::new(prefix, suffix),
195 test_indices_iter: TestIndices::new(infix)
196 })
197 }
198}
199
200#[cfg(test)]
201mod tests {
202 use super::{ShuffledIndices, Folds};
203
204 #[test]
206 fn test_folds_n6_k3() {
207 let idxs = ShuffledIndices(vec![0, 1, 2, 3, 4, 5]);
208 let folds = collect_folds(Folds::new(&idxs, 3));
209
210 assert_eq!(folds, vec![
211 (vec![2, 3, 4, 5], vec![0, 1]),
212 (vec![0, 1, 4, 5], vec![2, 3]),
213 (vec![0, 1, 2, 3], vec![4, 5])
214 ]);
215 }
216
217 #[test]
219 fn test_folds_n5_k2() {
220 let idxs = ShuffledIndices(vec![0, 1, 2, 3, 4]);
221 let folds = collect_folds(Folds::new(&idxs, 2));
222
223 assert_eq!(folds, vec![
224 (vec![3, 4], vec![0, 1, 2]),
225 (vec![0, 1, 2], vec![3, 4])
226 ]);
227 }
228
229 #[test]
231 fn test_folds_n6_k4() {
232 let idxs = ShuffledIndices(vec![0, 1, 2, 3, 4, 5]);
233 let folds = collect_folds(Folds::new(&idxs, 4));
234
235 assert_eq!(folds, vec![
236 (vec![2, 3, 4, 5], vec![0, 1]),
237 (vec![0, 1, 4, 5], vec![2, 3]),
238 (vec![0, 1, 2, 3, 5], vec![4]),
239 (vec![0, 1, 2, 3, 4], vec![5])
240 ]);
241 }
242
243 #[test]
245 fn test_folds_n4_k4() {
246 let idxs = ShuffledIndices(vec![0, 1, 2, 3]);
247 let folds = collect_folds(Folds::new(&idxs, 4));
248
249 assert_eq!(folds, vec![
250 (vec![1, 2, 3], vec![0]),
251 (vec![0, 2, 3], vec![1]),
252 (vec![0, 1, 3], vec![2]),
253 (vec![0, 1, 2], vec![3])
254 ]);
255 }
256
257 #[test]
258 #[should_panic]
259 fn test_folds_rejects_large_k() {
260 let idxs = ShuffledIndices(vec![0, 1, 2]);
261 let _ = collect_folds(Folds::new(&idxs, 4));
262 }
263
264 #[test]
267 fn test_folds_unordered_indices() {
268 let idxs = ShuffledIndices(vec![5, 4, 3, 2, 1, 0]);
269 let folds = collect_folds(Folds::new(&idxs, 3));
270
271 assert_eq!(folds, vec![
272 (vec![3, 2, 1, 0], vec![5, 4]),
273 (vec![5, 4, 1, 0], vec![3, 2]),
274 (vec![5, 4, 3, 2], vec![1, 0])
275 ]);
276 }
277
278 fn collect_folds<'a>(folds: Folds<'a>) -> Vec<(Vec<usize>, Vec<usize>)> {
279 folds
280 .map(|p|
281 (p.train_indices_iter.map(|x| *x).collect::<Vec<_>>(),
282 p.test_indices_iter.map(|x| *x).collect::<Vec<_>>()))
283 .collect::<Vec<(Vec<usize>, Vec<usize>)>>()
284 }
285}