rusty_machine/analysis/
score.rs

1//! Functions for scoring a set of predictions, i.e. evaluating
2//! how close predictions and truth are. All functions in this
3//! module obey the convention that higher is better.
4
5use libnum::{Zero, One};
6
7use linalg::{BaseMatrix, Matrix};
8use learning::toolkit::cost_fn::{CostFunc, MeanSqError};
9
10// ************************************
11// Classification Scores
12// ************************************
13
14/// Returns the fraction of outputs which match their target.
15///
16/// # Arguments
17///
18/// * `outputs` - Iterator of output (predicted) labels.
19/// * `targets` - Iterator of expected (actual) labels.
20///
21/// # Examples
22///
23/// ```
24/// use rusty_machine::analysis::score::accuracy;
25/// let outputs = [1, 1, 1, 0, 0, 0];
26/// let targets = [1, 1, 0, 0, 1, 1];
27///
28/// assert_eq!(accuracy(outputs.iter(), targets.iter()), 0.5);
29/// ```
30///
31/// # Panics
32///
33/// - outputs and targets have different length
34pub fn accuracy<I>(outputs: I, targets: I) -> f64
35    where I: ExactSizeIterator,
36          I::Item: PartialEq
37{
38    assert!(outputs.len() == targets.len(), "outputs and targets must have the same length");
39    let len = outputs.len() as f64;
40    let correct = outputs
41        .zip(targets)
42        .filter(|&(ref x, ref y)| x == y)
43        .count();
44    correct as f64 / len
45}
46
47/// Returns the fraction of outputs rows which match their target.
48pub fn row_accuracy(outputs: &Matrix<f64>, targets: &Matrix<f64>) -> f64 {
49    accuracy(outputs.iter_rows(), targets.iter_rows())
50}
51
52/// Returns the precision score for 2 class classification.
53///
54/// Precision is calculated with true-positive / (true-positive + false-positive),
55/// see [Precision and Recall](https://en.wikipedia.org/wiki/Precision_and_recall) for details.
56///
57/// # Arguments
58///
59/// * `outputs` - Iterator of output (predicted) labels which only contains 0 or 1.
60/// * `targets` - Iterator of expected (actual) labels which only contains 0 or 1.
61///
62/// # Examples
63///
64/// ```
65/// use rusty_machine::analysis::score::precision;
66/// let outputs = [1, 1, 1, 0, 0, 0];
67/// let targets = [1, 1, 0, 0, 1, 1];
68///
69/// assert_eq!(precision(outputs.iter(), targets.iter()), 2.0f64 / 3.0f64);
70/// ```
71///
72/// # Panics
73///
74/// - outputs and targets have different length
75/// - outputs or targets contains a value which is not 0 or 1
76pub fn precision<'a, I, T>(outputs: I, targets: I) -> f64
77    where I: ExactSizeIterator<Item=&'a T>,
78          T: 'a + PartialEq + Zero + One
79{
80    assert!(outputs.len() == targets.len(), "outputs and targets must have the same length");
81
82    let mut tpfp = 0.0f64;
83    let mut tp = 0.0f64;
84
85    for (ref o, ref t) in outputs.zip(targets) {
86        if *o == &T::one() {
87            tpfp += 1.0f64;
88            if *t == &T::one() {
89                tp += 1.0f64;
90            }
91        }
92        if ((*t != &T::zero()) & (*t != &T::one())) |
93           ((*o != &T::zero()) & (*o != &T::one())) {
94            panic!("precision must be used for 2 class classification")
95        }
96    }
97    tp / tpfp
98}
99
100/// Returns the recall score for 2 class classification.
101///
102/// Recall is calculated with true-positive / (true-positive + false-negative),
103/// see [Precision and Recall](https://en.wikipedia.org/wiki/Precision_and_recall) for details.
104///
105/// # Arguments
106///
107/// * `outputs` - Iterator of output (predicted) labels which only contains 0 or 1.
108/// * `targets` - Iterator of expected (actual) labels which only contains 0 or 1.
109///
110/// # Examples
111///
112/// ```
113/// use rusty_machine::analysis::score::recall;
114/// let outputs = [1, 1, 1, 0, 0, 0];
115/// let targets = [1, 1, 0, 0, 1, 1];
116///
117/// assert_eq!(recall(outputs.iter(), targets.iter()), 0.5);
118/// ```
119///
120/// # Panics
121///
122/// - outputs and targets have different length
123/// - outputs or targets contains a value which is not 0 or 1
124pub fn recall<'a, I, T>(outputs: I, targets: I) -> f64
125    where I: ExactSizeIterator<Item=&'a T>,
126          T: 'a + PartialEq + Zero + One
127{
128    assert!(outputs.len() == targets.len(), "outputs and targets must have the same length");
129
130    let mut tpfn = 0.0f64;
131    let mut tp = 0.0f64;
132
133    for (ref o, ref t) in outputs.zip(targets) {
134        if *t == &T::one() {
135            tpfn += 1.0f64;
136            if *o == &T::one() {
137                tp += 1.0f64;
138            }
139        }
140        if ((*t != &T::zero()) & (*t != &T::one())) |
141           ((*o != &T::zero()) & (*o != &T::one())) {
142            panic!("recall must be used for 2 class classification")
143        }
144    }
145    tp / tpfn
146}
147
148/// Returns the f1 score for 2 class classification.
149///
150/// F1-score is calculated with 2 * precision * recall / (precision + recall),
151/// see [F1 score](https://en.wikipedia.org/wiki/F1_score) for details.
152///
153/// # Arguments
154///
155/// * `outputs` - Iterator of output (predicted) labels which only contains 0 or 1.
156/// * `targets` - Iterator of expected (actual) labels which only contains 0 or 1.
157///
158/// # Examples
159///
160/// ```
161/// use rusty_machine::analysis::score::f1;
162/// let outputs = [1, 1, 1, 0, 0, 0];
163/// let targets = [1, 1, 0, 0, 1, 1];
164///
165/// assert_eq!(f1(outputs.iter(), targets.iter()), 0.5714285714285714);
166/// ```
167///
168/// # Panics
169///
170/// - outputs and targets have different length
171/// - outputs or targets contains a value which is not 0 or 1
172pub fn f1<'a, I, T>(outputs: I, targets: I) -> f64
173    where I: ExactSizeIterator<Item=&'a T>,
174          T: 'a + PartialEq + Zero + One
175{
176    assert!(outputs.len() == targets.len(), "outputs and targets must have the same length");
177
178    let mut tpos = 0.0f64;
179    let mut fpos = 0.0f64;
180    let mut fneg = 0.0f64;
181
182    for (ref o, ref t) in outputs.zip(targets) {
183        if (*o == &T::one()) & (*t == &T::one()) {
184            tpos += 1.0f64;
185        } else if *t == &T::one() {
186            fpos += 1.0f64;
187        } else if *o == &T::one() {
188            fneg += 1.0f64;
189        }
190        if ((*t != &T::zero()) & (*t != &T::one())) |
191           ((*o != &T::zero()) & (*o != &T::one())) {
192            panic!("f1-score must be used for 2 class classification")
193        }
194    }
195    2.0f64 * tpos / (2.0f64 * tpos + fneg + fpos)
196}
197
198// ************************************
199// Regression Scores
200// ************************************
201
202// TODO: generalise to accept arbitrary iterators of diff-able things
203/// Returns the additive inverse of the mean-squared-error of the
204/// outputs. So higher is better, and the returned value is always
205/// negative.
206pub fn neg_mean_squared_error(outputs: &Matrix<f64>, targets: &Matrix<f64>) -> f64
207{
208    // MeanSqError divides the actual mean squared error by two.
209    -2f64 * MeanSqError::cost(outputs, targets)
210}
211
212#[cfg(test)]
213mod tests {
214    use linalg::Matrix;
215    use super::{accuracy, precision, recall, f1, neg_mean_squared_error};
216
217    #[test]
218    fn test_accuracy() {
219        let outputs = [1, 2, 3, 4, 5, 6];
220        let targets = [1, 2, 3, 3, 5, 1];
221        assert_eq!(accuracy(outputs.iter(), targets.iter()), 2f64/3f64);
222
223        let outputs = [1, 1, 1, 0, 0, 0];
224        let targets = [1, 1, 1, 0, 0, 1];
225        assert_eq!(accuracy(outputs.iter(), targets.iter()), 5.0f64 / 6.0f64);
226    }
227
228    #[test]
229    fn test_precision() {
230        let outputs = [1, 1, 1, 0, 0, 0];
231        let targets = [1, 1, 0, 0, 1, 1];
232        assert_eq!(precision(outputs.iter(), targets.iter()), 2.0f64 / 3.0f64);
233
234        let outputs = [1, 1, 1, 0, 1, 1];
235        let targets = [1, 1, 0, 0, 1, 1];
236        assert_eq!(precision(outputs.iter(), targets.iter()), 0.8);
237
238        let outputs = [0, 0, 0, 1, 1, 1];
239        let targets = [1, 1, 1, 1, 1, 0];
240        assert_eq!(precision(outputs.iter(), targets.iter()), 2.0f64 / 3.0f64);
241
242        let outputs = [1, 1, 1, 1, 1, 0];
243        let targets = [0, 0, 0, 1, 1, 1];
244        assert_eq!(precision(outputs.iter(), targets.iter()), 0.4);
245    }
246
247    #[test]
248    #[should_panic]
249    fn test_precision_outputs_not_2class() {
250        let outputs = [1, 2, 1, 0, 0, 0];
251        let targets = [1, 1, 0, 0, 1, 1];
252        precision(outputs.iter(), targets.iter());
253    }
254
255    #[test]
256    #[should_panic]
257    fn test_precision_targets_not_2class() {
258        let outputs = [1, 0, 1, 0, 0, 0];
259        let targets = [1, 2, 0, 0, 1, 1];
260        precision(outputs.iter(), targets.iter());
261    }
262
263    #[test]
264    fn test_recall() {
265        let outputs = [1, 1, 1, 0, 0, 0];
266        let targets = [1, 1, 0, 0, 1, 1];
267        assert_eq!(recall(outputs.iter(), targets.iter()), 0.5);
268
269        let outputs = [1, 1, 1, 0, 1, 1];
270        let targets = [1, 1, 0, 0, 1, 1];
271        assert_eq!(recall(outputs.iter(), targets.iter()), 1.0);
272
273        let outputs = [0, 0, 0, 1, 1, 1];
274        let targets = [1, 1, 1, 1, 1, 0];
275        assert_eq!(recall(outputs.iter(), targets.iter()), 0.4);
276
277        let outputs = [1, 1, 1, 1, 1, 0];
278        let targets = [0, 0, 0, 1, 1, 1];
279        assert_eq!(recall(outputs.iter(), targets.iter()), 2.0f64 / 3.0f64);
280    }
281
282    #[test]
283    #[should_panic]
284    fn test_recall_outputs_not_2class() {
285        let outputs = [1, 2, 1, 0, 0, 0];
286        let targets = [1, 1, 0, 0, 1, 1];
287        recall(outputs.iter(), targets.iter());
288    }
289
290    #[test]
291    #[should_panic]
292    fn test_recall_targets_not_2class() {
293        let outputs = [1, 0, 1, 0, 0, 0];
294        let targets = [1, 2, 0, 0, 1, 1];
295        recall(outputs.iter(), targets.iter());
296    }
297
298    #[test]
299    fn test_f1() {
300        let outputs = [1, 1, 1, 0, 0, 0];
301        let targets = [1, 1, 0, 0, 1, 1];
302        assert_eq!(f1(outputs.iter(), targets.iter()), 0.5714285714285714);
303
304        let outputs = [1, 1, 1, 0, 1, 1];
305        let targets = [1, 1, 0, 0, 1, 1];
306        assert_eq!(f1(outputs.iter(), targets.iter()), 0.8888888888888888);
307
308        let outputs = [0, 0, 0, 1, 1, 1];
309        let targets = [1, 1, 1, 1, 1, 0];
310        assert_eq!(f1(outputs.iter(), targets.iter()), 0.5);
311
312        let outputs = [1, 1, 1, 1, 1, 0];
313        let targets = [0, 0, 0, 1, 1, 1];
314        assert_eq!(f1(outputs.iter(), targets.iter()), 0.5);
315    }
316
317
318    #[test]
319    #[should_panic]
320    fn test_f1_outputs_not_2class() {
321        let outputs = [1, 2, 1, 0, 0, 0];
322        let targets = [1, 1, 0, 0, 1, 1];
323        f1(outputs.iter(), targets.iter());
324    }
325
326    #[test]
327    #[should_panic]
328    fn test_f1_targets_not_2class() {
329        let outputs = [1, 0, 1, 0, 0, 0];
330        let targets = [1, 2, 0, 0, 1, 1];
331        f1(outputs.iter(), targets.iter());
332    }
333
334    #[test]
335    fn test_neg_mean_squared_error_1d() {
336        let outputs = Matrix::new(3, 1, vec![1f64, 2f64, 3f64]);
337        let targets = Matrix::new(3, 1, vec![2f64, 4f64, 3f64]);
338        assert_eq!(neg_mean_squared_error(&outputs, &targets), -5f64/3f64);
339    }
340
341    #[test]
342    fn test_neg_mean_squared_error_2d() {
343        let outputs = Matrix::new(3, 2, vec![
344            1f64, 2f64,
345            3f64, 4f64,
346            5f64, 6f64
347            ]);
348        let targets = Matrix::new(3, 2, vec![
349            1.5f64, 2.5f64,
350            5f64,   6f64,
351            5.5f64, 6.5f64
352            ]);
353        assert_eq!(neg_mean_squared_error(&outputs, &targets), -3f64);
354    }
355}