1use libnum::{Zero, One};
6
7use linalg::{BaseMatrix, Matrix};
8use learning::toolkit::cost_fn::{CostFunc, MeanSqError};
9
10pub fn accuracy<I>(outputs: I, targets: I) -> f64
35 where I: ExactSizeIterator,
36 I::Item: PartialEq
37{
38 assert!(outputs.len() == targets.len(), "outputs and targets must have the same length");
39 let len = outputs.len() as f64;
40 let correct = outputs
41 .zip(targets)
42 .filter(|&(ref x, ref y)| x == y)
43 .count();
44 correct as f64 / len
45}
46
47pub fn row_accuracy(outputs: &Matrix<f64>, targets: &Matrix<f64>) -> f64 {
49 accuracy(outputs.iter_rows(), targets.iter_rows())
50}
51
52pub fn precision<'a, I, T>(outputs: I, targets: I) -> f64
77 where I: ExactSizeIterator<Item=&'a T>,
78 T: 'a + PartialEq + Zero + One
79{
80 assert!(outputs.len() == targets.len(), "outputs and targets must have the same length");
81
82 let mut tpfp = 0.0f64;
83 let mut tp = 0.0f64;
84
85 for (ref o, ref t) in outputs.zip(targets) {
86 if *o == &T::one() {
87 tpfp += 1.0f64;
88 if *t == &T::one() {
89 tp += 1.0f64;
90 }
91 }
92 if ((*t != &T::zero()) & (*t != &T::one())) |
93 ((*o != &T::zero()) & (*o != &T::one())) {
94 panic!("precision must be used for 2 class classification")
95 }
96 }
97 tp / tpfp
98}
99
100pub fn recall<'a, I, T>(outputs: I, targets: I) -> f64
125 where I: ExactSizeIterator<Item=&'a T>,
126 T: 'a + PartialEq + Zero + One
127{
128 assert!(outputs.len() == targets.len(), "outputs and targets must have the same length");
129
130 let mut tpfn = 0.0f64;
131 let mut tp = 0.0f64;
132
133 for (ref o, ref t) in outputs.zip(targets) {
134 if *t == &T::one() {
135 tpfn += 1.0f64;
136 if *o == &T::one() {
137 tp += 1.0f64;
138 }
139 }
140 if ((*t != &T::zero()) & (*t != &T::one())) |
141 ((*o != &T::zero()) & (*o != &T::one())) {
142 panic!("recall must be used for 2 class classification")
143 }
144 }
145 tp / tpfn
146}
147
148pub fn f1<'a, I, T>(outputs: I, targets: I) -> f64
173 where I: ExactSizeIterator<Item=&'a T>,
174 T: 'a + PartialEq + Zero + One
175{
176 assert!(outputs.len() == targets.len(), "outputs and targets must have the same length");
177
178 let mut tpos = 0.0f64;
179 let mut fpos = 0.0f64;
180 let mut fneg = 0.0f64;
181
182 for (ref o, ref t) in outputs.zip(targets) {
183 if (*o == &T::one()) & (*t == &T::one()) {
184 tpos += 1.0f64;
185 } else if *t == &T::one() {
186 fpos += 1.0f64;
187 } else if *o == &T::one() {
188 fneg += 1.0f64;
189 }
190 if ((*t != &T::zero()) & (*t != &T::one())) |
191 ((*o != &T::zero()) & (*o != &T::one())) {
192 panic!("f1-score must be used for 2 class classification")
193 }
194 }
195 2.0f64 * tpos / (2.0f64 * tpos + fneg + fpos)
196}
197
198pub fn neg_mean_squared_error(outputs: &Matrix<f64>, targets: &Matrix<f64>) -> f64
207{
208 -2f64 * MeanSqError::cost(outputs, targets)
210}
211
212#[cfg(test)]
213mod tests {
214 use linalg::Matrix;
215 use super::{accuracy, precision, recall, f1, neg_mean_squared_error};
216
217 #[test]
218 fn test_accuracy() {
219 let outputs = [1, 2, 3, 4, 5, 6];
220 let targets = [1, 2, 3, 3, 5, 1];
221 assert_eq!(accuracy(outputs.iter(), targets.iter()), 2f64/3f64);
222
223 let outputs = [1, 1, 1, 0, 0, 0];
224 let targets = [1, 1, 1, 0, 0, 1];
225 assert_eq!(accuracy(outputs.iter(), targets.iter()), 5.0f64 / 6.0f64);
226 }
227
228 #[test]
229 fn test_precision() {
230 let outputs = [1, 1, 1, 0, 0, 0];
231 let targets = [1, 1, 0, 0, 1, 1];
232 assert_eq!(precision(outputs.iter(), targets.iter()), 2.0f64 / 3.0f64);
233
234 let outputs = [1, 1, 1, 0, 1, 1];
235 let targets = [1, 1, 0, 0, 1, 1];
236 assert_eq!(precision(outputs.iter(), targets.iter()), 0.8);
237
238 let outputs = [0, 0, 0, 1, 1, 1];
239 let targets = [1, 1, 1, 1, 1, 0];
240 assert_eq!(precision(outputs.iter(), targets.iter()), 2.0f64 / 3.0f64);
241
242 let outputs = [1, 1, 1, 1, 1, 0];
243 let targets = [0, 0, 0, 1, 1, 1];
244 assert_eq!(precision(outputs.iter(), targets.iter()), 0.4);
245 }
246
247 #[test]
248 #[should_panic]
249 fn test_precision_outputs_not_2class() {
250 let outputs = [1, 2, 1, 0, 0, 0];
251 let targets = [1, 1, 0, 0, 1, 1];
252 precision(outputs.iter(), targets.iter());
253 }
254
255 #[test]
256 #[should_panic]
257 fn test_precision_targets_not_2class() {
258 let outputs = [1, 0, 1, 0, 0, 0];
259 let targets = [1, 2, 0, 0, 1, 1];
260 precision(outputs.iter(), targets.iter());
261 }
262
263 #[test]
264 fn test_recall() {
265 let outputs = [1, 1, 1, 0, 0, 0];
266 let targets = [1, 1, 0, 0, 1, 1];
267 assert_eq!(recall(outputs.iter(), targets.iter()), 0.5);
268
269 let outputs = [1, 1, 1, 0, 1, 1];
270 let targets = [1, 1, 0, 0, 1, 1];
271 assert_eq!(recall(outputs.iter(), targets.iter()), 1.0);
272
273 let outputs = [0, 0, 0, 1, 1, 1];
274 let targets = [1, 1, 1, 1, 1, 0];
275 assert_eq!(recall(outputs.iter(), targets.iter()), 0.4);
276
277 let outputs = [1, 1, 1, 1, 1, 0];
278 let targets = [0, 0, 0, 1, 1, 1];
279 assert_eq!(recall(outputs.iter(), targets.iter()), 2.0f64 / 3.0f64);
280 }
281
282 #[test]
283 #[should_panic]
284 fn test_recall_outputs_not_2class() {
285 let outputs = [1, 2, 1, 0, 0, 0];
286 let targets = [1, 1, 0, 0, 1, 1];
287 recall(outputs.iter(), targets.iter());
288 }
289
290 #[test]
291 #[should_panic]
292 fn test_recall_targets_not_2class() {
293 let outputs = [1, 0, 1, 0, 0, 0];
294 let targets = [1, 2, 0, 0, 1, 1];
295 recall(outputs.iter(), targets.iter());
296 }
297
298 #[test]
299 fn test_f1() {
300 let outputs = [1, 1, 1, 0, 0, 0];
301 let targets = [1, 1, 0, 0, 1, 1];
302 assert_eq!(f1(outputs.iter(), targets.iter()), 0.5714285714285714);
303
304 let outputs = [1, 1, 1, 0, 1, 1];
305 let targets = [1, 1, 0, 0, 1, 1];
306 assert_eq!(f1(outputs.iter(), targets.iter()), 0.8888888888888888);
307
308 let outputs = [0, 0, 0, 1, 1, 1];
309 let targets = [1, 1, 1, 1, 1, 0];
310 assert_eq!(f1(outputs.iter(), targets.iter()), 0.5);
311
312 let outputs = [1, 1, 1, 1, 1, 0];
313 let targets = [0, 0, 0, 1, 1, 1];
314 assert_eq!(f1(outputs.iter(), targets.iter()), 0.5);
315 }
316
317
318 #[test]
319 #[should_panic]
320 fn test_f1_outputs_not_2class() {
321 let outputs = [1, 2, 1, 0, 0, 0];
322 let targets = [1, 1, 0, 0, 1, 1];
323 f1(outputs.iter(), targets.iter());
324 }
325
326 #[test]
327 #[should_panic]
328 fn test_f1_targets_not_2class() {
329 let outputs = [1, 0, 1, 0, 0, 0];
330 let targets = [1, 2, 0, 0, 1, 1];
331 f1(outputs.iter(), targets.iter());
332 }
333
334 #[test]
335 fn test_neg_mean_squared_error_1d() {
336 let outputs = Matrix::new(3, 1, vec![1f64, 2f64, 3f64]);
337 let targets = Matrix::new(3, 1, vec![2f64, 4f64, 3f64]);
338 assert_eq!(neg_mean_squared_error(&outputs, &targets), -5f64/3f64);
339 }
340
341 #[test]
342 fn test_neg_mean_squared_error_2d() {
343 let outputs = Matrix::new(3, 2, vec![
344 1f64, 2f64,
345 3f64, 4f64,
346 5f64, 6f64
347 ]);
348 let targets = Matrix::new(3, 2, vec![
349 1.5f64, 2.5f64,
350 5f64, 6f64,
351 5.5f64, 6.5f64
352 ]);
353 assert_eq!(neg_mean_squared_error(&outputs, &targets), -3f64);
354 }
355}