rusty_ai/metrics/
confusion.rs

1use std::{collections::HashSet, error::Error};
2
3use nalgebra::{DMatrix, DVector};
4
5use crate::data::dataset::WholeNumber;
6
7type ConfusionMatrix = DMatrix<usize>;
8
9pub trait ClassificationMetrics<T: WholeNumber> {
10    /// Computes the confusion matrix based on the true labels and predicted labels.
11    ///
12    /// # Arguments
13    ///
14    /// * `y_true` - The true labels.
15    /// * `y_pred` - The predicted labels.
16    ///
17    /// # Returns
18    ///
19    /// The confusion matrix as a `Result` containing a `ConfusionMatrix` or an error message.
20    fn confusion_matrix(
21        &self,
22        y_true: &DVector<T>,
23        y_pred: &DVector<T>,
24    ) -> Result<ConfusionMatrix, Box<dyn Error>> {
25        if y_true.len() != y_pred.len() {
26            return Err("Predictions and labels are of different sizes.".into());
27        }
28
29        let mut classes_set = HashSet::<T>::new();
30        classes_set.extend(y_true);
31        classes_set.extend(y_pred);
32
33        let mut classes = Vec::from_iter(classes_set.iter().cloned());
34        classes.sort_by(|a, b| a.partial_cmp(b).unwrap());
35
36        let mut matrix = DMatrix::zeros(classes_set.len(), classes_set.len());
37
38        for (y_t, y_p) in y_true.iter().zip(y_pred.iter()) {
39            let matrix_row = classes.iter().position(|&c| c == *y_t).unwrap();
40            let matrix_col = classes.iter().position(|&c| c == *y_p).unwrap();
41            matrix[(matrix_row, matrix_col)] += 1;
42        }
43
44        Ok(matrix)
45    }
46
47    /// Computes the accuracy based on the true labels and predicted labels.
48    ///
49    /// # Arguments
50    ///
51    /// * `y_true` - The true labels.
52    /// * `y_pred` - The predicted labels.
53    ///
54    /// # Returns
55    ///
56    /// The accuracy as a `Result` containing a `f64` value or an error message.
57    fn accuracy(&self, y_true: &DVector<T>, y_pred: &DVector<T>) -> Result<f64, Box<dyn Error>> {
58        let matrix = self.confusion_matrix(y_true, y_pred)?;
59
60        let mut correct = 0;
61
62        matrix.diagonal().iter().for_each(|e| correct += e);
63
64        Ok(correct as f64 / y_true.len() as f64)
65    }
66
67    /// Computes the precision based on the true labels and predicted labels.
68    ///
69    /// # Arguments
70    ///
71    /// * `y_true` - The true labels.
72    /// * `y_pred` - The predicted labels.
73    ///
74    /// # Returns
75    ///
76    /// The precision as a `Result` containing a `f64` value or an error message.
77    fn precision(&self, y_true: &DVector<T>, y_pred: &DVector<T>) -> Result<f64, Box<dyn Error>> {
78        let matrix = self.confusion_matrix(y_true, y_pred)?;
79
80        let num_classes = matrix.nrows();
81
82        if num_classes == 2 {
83            let tp = matrix[(1, 1)];
84            let fp = matrix[(0, 1)];
85
86            if tp + fp > 0 {
87                return Ok(tp as f64 / (tp + fp) as f64);
88            }
89        }
90
91        let mut precision_total = 0.0;
92        for class in 0..num_classes {
93            let tp = matrix[(class, class)];
94            let fp = matrix.column(class).sum() - tp;
95
96            if tp + fp > 0 {
97                let precision = tp as f64 / (tp + fp) as f64;
98                precision_total += precision;
99            }
100        }
101
102        Ok(precision_total / num_classes as f64)
103    }
104
105    /// Computes the recall based on the true labels and predicted labels.
106    ///
107    /// # Arguments
108    ///
109    /// * `y_true` - The true labels.
110    /// * `y_pred` - The predicted labels.
111    ///
112    /// # Returns
113    ///
114    /// The recall as a `Result` containing a `f64` value or an error message.
115    fn recall(&self, y_true: &DVector<T>, y_pred: &DVector<T>) -> Result<f64, Box<dyn Error>> {
116        let matrix = self.confusion_matrix(y_true, y_pred)?;
117
118        let num_classes = matrix.nrows();
119
120        if num_classes == 2 {
121            let tp = matrix[(1, 1)];
122            let fn_ = matrix[(1, 0)];
123
124            if tp + fn_ > 0 {
125                return Ok(tp as f64 / (tp + fn_) as f64);
126            }
127        }
128
129        let mut recall_total = 0.0;
130
131        for class in 0..num_classes {
132            let tp = matrix[(class, class)];
133            let fn_ = matrix.row(class).sum() - tp;
134
135            if tp + fn_ > 0 {
136                let recall = tp as f64 / (tp + fn_) as f64;
137                recall_total += recall;
138            }
139        }
140
141        Ok(recall_total / num_classes as f64)
142    }
143
144    /// Computes the F1 score based on the true labels and predicted labels.
145    ///
146    /// # Arguments
147    ///
148    /// * `y_true` - The true labels.
149    /// * `y_pred` - The predicted labels.
150    ///
151    /// # Returns
152    ///
153    /// The F1 score as a `Result` containing a `f64` value or an error message.
154    fn f1_score(&self, y_true: &DVector<T>, y_pred: &DVector<T>) -> Result<f64, Box<dyn Error>> {
155        let precision = self.precision(y_true, y_pred)?;
156        let recall = self.recall(y_true, y_pred)?;
157
158        match (precision + recall).abs() < std::f64::EPSILON {
159            true => Err("Precision and recall are both 0, F1 score undefined.".into()),
160            false => Ok(2.0 * (precision * recall) / (precision + recall)),
161        }
162    }
163}
164
165#[cfg(test)]
166mod tests {
167    use super::*;
168    use nalgebra::DVector;
169
170    struct MockClassifier;
171
172    impl ClassificationMetrics<u8> for MockClassifier {}
173
174    #[test]
175    fn test_confusion_matrix() {
176        let classifier = MockClassifier;
177
178        let y_true = DVector::from_vec(vec![1, 0, 1, 0, 1]);
179        let y_pred = DVector::from_vec(vec![1, 1, 0, 0, 1]);
180
181        let result = classifier.confusion_matrix(&y_true, &y_pred).unwrap();
182
183        let expected = DMatrix::from_vec(2, 2, vec![1, 1, 1, 2]);
184
185        assert_eq!(result, expected);
186    }
187
188    #[test]
189    fn test_confusion_matrix_unequal() {
190        let classifier = MockClassifier;
191
192        let y_true = DVector::from_vec(vec![1, 0, 1, 0, 1, 0]);
193        let y_pred = DVector::from_vec(vec![1, 1, 0, 0, 1]);
194
195        let result = classifier.confusion_matrix(&y_true, &y_pred);
196
197        assert!(result.is_err());
198    }
199
200    #[test]
201    fn test_confusion_matrix_multiclass() {
202        let classifier = MockClassifier;
203
204        let y_true = DVector::from_vec(vec![0, 1, 2, 1, 0, 2]);
205        let y_pred = DVector::from_vec(vec![0, 2, 1, 1, 0, 2]);
206
207        let result = classifier.confusion_matrix(&y_true, &y_pred).unwrap();
208        let expected = DMatrix::from_vec(3, 3, vec![2, 0, 0, 0, 1, 1, 0, 1, 1]);
209
210        assert_eq!(result, expected);
211    }
212
213    #[test]
214    fn test_accuracy() {
215        let classifier = MockClassifier;
216
217        let y_true = DVector::from_vec(vec![1, 0, 1, 0, 1]);
218        let y_pred = DVector::from_vec(vec![1, 1, 0, 0, 1]);
219
220        let result = classifier.accuracy(&y_true, &y_pred).unwrap();
221
222        let expected = 0.6;
223
224        assert_eq!(result, expected);
225    }
226
227    #[test]
228    fn test_accuracy_perfect_classification() {
229        let classifier = MockClassifier;
230
231        let y_true = DVector::from_vec(vec![1, 0, 1, 0, 1]);
232        let y_pred = DVector::from_vec(vec![1, 0, 1, 0, 1]);
233
234        let result = classifier.accuracy(&y_true, &y_pred).unwrap();
235        let expected = 1.0;
236
237        assert_eq!(result, expected);
238    }
239
240    #[test]
241    fn test_precision() {
242        let classifier = MockClassifier;
243
244        let y_true = DVector::from_vec(vec![1, 0, 1, 0, 1]);
245        let y_pred = DVector::from_vec(vec![1, 1, 0, 0, 1]);
246
247        let conf = classifier.confusion_matrix(&y_true, &y_pred).unwrap();
248        println!("conf: {}", conf);
249        let result = classifier.precision(&y_true, &y_pred).unwrap();
250
251        let expected = 2.0 / 3.0;
252
253        assert_eq!(result, expected);
254    }
255
256    #[test]
257    fn test_precision_no_positive_predictions() {
258        let classifier = MockClassifier;
259
260        let y_true = DVector::from_vec(vec![1, 1, 1, 1, 1]);
261        let y_pred = DVector::from_vec(vec![0, 0, 0, 0, 0]);
262
263        let result = classifier.precision(&y_true, &y_pred).unwrap();
264
265        assert_eq!(result, 0.0);
266    }
267
268    #[test]
269    fn test_precision_multiclass() {
270        let classifier = MockClassifier;
271
272        let y_true = DVector::from_vec(vec![0, 1, 2, 1, 0, 2]);
273        let y_pred = DVector::from_vec(vec![0, 2, 1, 1, 0, 2]);
274
275        let result = classifier.precision(&y_true, &y_pred).unwrap();
276        let expected = (2.0 / 2.0 + 1.0 / 2.0 + 1.0 / 2.0) / 3.0;
277
278        assert!((result - expected).abs() < std::f64::EPSILON);
279    }
280
281    #[test]
282    fn test_recall() {
283        let classifier = MockClassifier;
284
285        let y_true = DVector::from_vec(vec![1, 0, 1, 0, 1]);
286        let y_pred = DVector::from_vec(vec![1, 1, 0, 0, 1]);
287
288        let result = classifier.recall(&y_true, &y_pred).unwrap();
289
290        let expected = 2.0 / 3.0;
291
292        assert_eq!(result, expected);
293    }
294
295    #[test]
296    fn test_recall_no_true_positives() {
297        let classifier = MockClassifier;
298
299        let y_true = DVector::from_vec(vec![1, 1, 1, 1, 1]);
300        let y_pred = DVector::from_vec(vec![0, 0, 0, 0, 0]);
301
302        let result = classifier.recall(&y_true, &y_pred).unwrap();
303        let expected = 0.0;
304
305        assert_eq!(result, expected);
306    }
307
308    #[test]
309    fn test_recall_multiclass() {
310        let classifier = MockClassifier;
311
312        let y_true = DVector::from_vec(vec![0, 1, 2, 1, 0, 2]);
313        let y_pred = DVector::from_vec(vec![0, 2, 1, 1, 0, 2]);
314
315        let result = classifier.recall(&y_true, &y_pred).unwrap();
316        let expected = (2.0 / 2.0 + 1.0 / 2.0 + 1.0 / 2.0) / 3.0;
317
318        assert!((result - expected).abs() < std::f64::EPSILON);
319    }
320
321    #[test]
322    fn test_f1_score() {
323        let classifier = MockClassifier;
324
325        let y_true = DVector::from_vec(vec![1, 0, 1, 0, 1]);
326        let y_pred = DVector::from_vec(vec![1, 1, 0, 0, 1]);
327
328        let result = classifier.f1_score(&y_true, &y_pred).unwrap();
329
330        let expected = 2.0 / 3.0;
331
332        assert_eq!(result, expected);
333    }
334
335    #[test]
336    fn test_f1_score_perfect_classification() {
337        let classifier = MockClassifier;
338
339        let y_true = DVector::from_vec(vec![1, 0, 1, 0, 1]);
340        let y_pred = DVector::from_vec(vec![1, 0, 1, 0, 1]);
341
342        let result = classifier.f1_score(&y_true, &y_pred).unwrap();
343        let expected = 1.0;
344
345        assert_eq!(result, expected);
346    }
347
348    #[test]
349    fn test_f1_score_error() {
350        let classifier = MockClassifier;
351
352        let y_true = DVector::from_vec(vec![1, 1, 1, 1, 1]);
353        let y_pred = DVector::from_vec(vec![0, 0, 0, 0, 0]);
354
355        let result = classifier.f1_score(&y_true, &y_pred);
356
357        assert!(result.is_err());
358    }
359
360    #[test]
361    fn test_f1_score_multiclass() {
362        let classifier = MockClassifier;
363
364        let y_true = DVector::from_vec(vec![0, 1, 2, 1, 0, 2]);
365        let y_pred = DVector::from_vec(vec![0, 2, 1, 1, 0, 2]);
366
367        let result = classifier.f1_score(&y_true, &y_pred).unwrap();
368        let precision = classifier.precision(&y_true, &y_pred).unwrap();
369        let recall = classifier.recall(&y_true, &y_pred).unwrap();
370        let expected = 2.0 * (precision * recall) / (precision + recall); // Harmonic mean of precision and recall
371
372        assert!((result - expected).abs() < std::f64::EPSILON);
373    }
374}