Skip to main content

tensorlogic_train/metrics/
basic.rs

1//! Basic classification metrics.
2
3use crate::{TrainError, TrainResult};
4use scirs2_core::ndarray::{ArrayView, Ix2};
5
6use super::Metric;
7
8/// Accuracy metric for classification.
9#[derive(Debug, Clone)]
10pub struct Accuracy {
11    /// Threshold for binary classification.
12    pub threshold: f64,
13}
14
15impl Default for Accuracy {
16    fn default() -> Self {
17        Self { threshold: 0.5 }
18    }
19}
20
21impl Metric for Accuracy {
22    fn compute(
23        &self,
24        predictions: &ArrayView<f64, Ix2>,
25        targets: &ArrayView<f64, Ix2>,
26    ) -> TrainResult<f64> {
27        if predictions.shape() != targets.shape() {
28            return Err(TrainError::MetricsError(format!(
29                "Shape mismatch: predictions {:?} vs targets {:?}",
30                predictions.shape(),
31                targets.shape()
32            )));
33        }
34
35        let mut correct = 0;
36        let total = predictions.nrows();
37
38        for i in 0..total {
39            // Find predicted class (argmax)
40            let mut pred_class = 0;
41            let mut max_pred = predictions[[i, 0]];
42            for j in 1..predictions.ncols() {
43                if predictions[[i, j]] > max_pred {
44                    max_pred = predictions[[i, j]];
45                    pred_class = j;
46                }
47            }
48
49            // Find true class (argmax)
50            let mut true_class = 0;
51            let mut max_true = targets[[i, 0]];
52            for j in 1..targets.ncols() {
53                if targets[[i, j]] > max_true {
54                    max_true = targets[[i, j]];
55                    true_class = j;
56                }
57            }
58
59            if pred_class == true_class {
60                correct += 1;
61            }
62        }
63
64        Ok(correct as f64 / total as f64)
65    }
66
67    fn name(&self) -> &str {
68        "accuracy"
69    }
70}
71
72/// Precision metric for classification.
73#[derive(Debug, Clone, Default)]
74pub struct Precision {
75    /// Class to compute precision for (None = macro average).
76    pub class_id: Option<usize>,
77}
78
79impl Metric for Precision {
80    fn compute(
81        &self,
82        predictions: &ArrayView<f64, Ix2>,
83        targets: &ArrayView<f64, Ix2>,
84    ) -> TrainResult<f64> {
85        if predictions.shape() != targets.shape() {
86            return Err(TrainError::MetricsError(format!(
87                "Shape mismatch: predictions {:?} vs targets {:?}",
88                predictions.shape(),
89                targets.shape()
90            )));
91        }
92
93        let num_classes = predictions.ncols();
94        let mut true_positives = vec![0; num_classes];
95        let mut predicted_positives = vec![0; num_classes];
96
97        for i in 0..predictions.nrows() {
98            // Find predicted class
99            let mut pred_class = 0;
100            let mut max_pred = predictions[[i, 0]];
101            for j in 1..num_classes {
102                if predictions[[i, j]] > max_pred {
103                    max_pred = predictions[[i, j]];
104                    pred_class = j;
105                }
106            }
107
108            // Find true class
109            let mut true_class = 0;
110            let mut max_true = targets[[i, 0]];
111            for j in 1..num_classes {
112                if targets[[i, j]] > max_true {
113                    max_true = targets[[i, j]];
114                    true_class = j;
115                }
116            }
117
118            predicted_positives[pred_class] += 1;
119            if pred_class == true_class {
120                true_positives[pred_class] += 1;
121            }
122        }
123
124        if let Some(class_id) = self.class_id {
125            // Precision for specific class
126            if predicted_positives[class_id] == 0 {
127                Ok(0.0)
128            } else {
129                Ok(true_positives[class_id] as f64 / predicted_positives[class_id] as f64)
130            }
131        } else {
132            // Macro-averaged precision
133            let mut total_precision = 0.0;
134            let mut valid_classes = 0;
135
136            for class_id in 0..num_classes {
137                if predicted_positives[class_id] > 0 {
138                    total_precision +=
139                        true_positives[class_id] as f64 / predicted_positives[class_id] as f64;
140                    valid_classes += 1;
141                }
142            }
143
144            if valid_classes == 0 {
145                Ok(0.0)
146            } else {
147                Ok(total_precision / valid_classes as f64)
148            }
149        }
150    }
151
152    fn name(&self) -> &str {
153        "precision"
154    }
155}
156
157/// Recall metric for classification.
158#[derive(Debug, Clone, Default)]
159pub struct Recall {
160    /// Class to compute recall for (None = macro average).
161    pub class_id: Option<usize>,
162}
163
164impl Metric for Recall {
165    fn compute(
166        &self,
167        predictions: &ArrayView<f64, Ix2>,
168        targets: &ArrayView<f64, Ix2>,
169    ) -> TrainResult<f64> {
170        if predictions.shape() != targets.shape() {
171            return Err(TrainError::MetricsError(format!(
172                "Shape mismatch: predictions {:?} vs targets {:?}",
173                predictions.shape(),
174                targets.shape()
175            )));
176        }
177
178        let num_classes = predictions.ncols();
179        let mut true_positives = vec![0; num_classes];
180        let mut actual_positives = vec![0; num_classes];
181
182        for i in 0..predictions.nrows() {
183            // Find predicted class
184            let mut pred_class = 0;
185            let mut max_pred = predictions[[i, 0]];
186            for j in 1..num_classes {
187                if predictions[[i, j]] > max_pred {
188                    max_pred = predictions[[i, j]];
189                    pred_class = j;
190                }
191            }
192
193            // Find true class
194            let mut true_class = 0;
195            let mut max_true = targets[[i, 0]];
196            for j in 1..num_classes {
197                if targets[[i, j]] > max_true {
198                    max_true = targets[[i, j]];
199                    true_class = j;
200                }
201            }
202
203            actual_positives[true_class] += 1;
204            if pred_class == true_class {
205                true_positives[pred_class] += 1;
206            }
207        }
208
209        if let Some(class_id) = self.class_id {
210            // Recall for specific class
211            if actual_positives[class_id] == 0 {
212                Ok(0.0)
213            } else {
214                Ok(true_positives[class_id] as f64 / actual_positives[class_id] as f64)
215            }
216        } else {
217            // Macro-averaged recall
218            let mut total_recall = 0.0;
219            let mut valid_classes = 0;
220
221            for class_id in 0..num_classes {
222                if actual_positives[class_id] > 0 {
223                    total_recall +=
224                        true_positives[class_id] as f64 / actual_positives[class_id] as f64;
225                    valid_classes += 1;
226                }
227            }
228
229            if valid_classes == 0 {
230                Ok(0.0)
231            } else {
232                Ok(total_recall / valid_classes as f64)
233            }
234        }
235    }
236
237    fn name(&self) -> &str {
238        "recall"
239    }
240}
241
242/// F1 score metric for classification.
243#[derive(Debug, Clone, Default)]
244pub struct F1Score {
245    /// Class to compute F1 for (None = macro average).
246    pub class_id: Option<usize>,
247}
248
249impl Metric for F1Score {
250    fn compute(
251        &self,
252        predictions: &ArrayView<f64, Ix2>,
253        targets: &ArrayView<f64, Ix2>,
254    ) -> TrainResult<f64> {
255        let precision = Precision {
256            class_id: self.class_id,
257        }
258        .compute(predictions, targets)?;
259        let recall = Recall {
260            class_id: self.class_id,
261        }
262        .compute(predictions, targets)?;
263
264        if precision + recall == 0.0 {
265            Ok(0.0)
266        } else {
267            Ok(2.0 * precision * recall / (precision + recall))
268        }
269    }
270
271    fn name(&self) -> &str {
272        "f1_score"
273    }
274}
275
276#[cfg(test)]
277mod tests {
278    use super::*;
279    use scirs2_core::ndarray::array;
280
281    #[test]
282    fn test_accuracy() {
283        let metric = Accuracy::default();
284
285        // Perfect predictions
286        let predictions = array![[0.9, 0.1], [0.2, 0.8], [0.8, 0.2]];
287        let targets = array![[1.0, 0.0], [0.0, 1.0], [1.0, 0.0]];
288
289        let accuracy = metric
290            .compute(&predictions.view(), &targets.view())
291            .unwrap();
292        assert_eq!(accuracy, 1.0);
293
294        // Partial correct
295        let predictions = array![[0.9, 0.1], [0.8, 0.2], [0.8, 0.2]];
296        let targets = array![[1.0, 0.0], [0.0, 1.0], [1.0, 0.0]];
297
298        let accuracy = metric
299            .compute(&predictions.view(), &targets.view())
300            .unwrap();
301        assert!((accuracy - 2.0 / 3.0).abs() < 1e-6);
302    }
303
304    #[test]
305    fn test_precision() {
306        let metric = Precision::default();
307
308        let predictions = array![[0.9, 0.1], [0.2, 0.8], [0.7, 0.3]];
309        let targets = array![[1.0, 0.0], [0.0, 1.0], [0.0, 1.0]];
310
311        let precision = metric
312            .compute(&predictions.view(), &targets.view())
313            .unwrap();
314        assert!((0.0..=1.0).contains(&precision));
315    }
316
317    #[test]
318    fn test_recall() {
319        let metric = Recall::default();
320
321        let predictions = array![[0.9, 0.1], [0.2, 0.8], [0.7, 0.3]];
322        let targets = array![[1.0, 0.0], [0.0, 1.0], [0.0, 1.0]];
323
324        let recall = metric
325            .compute(&predictions.view(), &targets.view())
326            .unwrap();
327        assert!((0.0..=1.0).contains(&recall));
328    }
329
330    #[test]
331    fn test_f1_score() {
332        let metric = F1Score::default();
333
334        let predictions = array![[0.9, 0.1], [0.2, 0.8], [0.7, 0.3]];
335        let targets = array![[1.0, 0.0], [0.0, 1.0], [0.0, 1.0]];
336
337        let f1 = metric
338            .compute(&predictions.view(), &targets.view())
339            .unwrap();
340        assert!((0.0..=1.0).contains(&f1));
341    }
342}