tensorlogic_train/metrics/
basic.rs1use crate::{TrainError, TrainResult};
4use scirs2_core::ndarray::{ArrayView, Ix2};
5
6use super::Metric;
7
8#[derive(Debug, Clone)]
10pub struct Accuracy {
11 pub threshold: f64,
13}
14
15impl Default for Accuracy {
16 fn default() -> Self {
17 Self { threshold: 0.5 }
18 }
19}
20
21impl Metric for Accuracy {
22 fn compute(
23 &self,
24 predictions: &ArrayView<f64, Ix2>,
25 targets: &ArrayView<f64, Ix2>,
26 ) -> TrainResult<f64> {
27 if predictions.shape() != targets.shape() {
28 return Err(TrainError::MetricsError(format!(
29 "Shape mismatch: predictions {:?} vs targets {:?}",
30 predictions.shape(),
31 targets.shape()
32 )));
33 }
34
35 let mut correct = 0;
36 let total = predictions.nrows();
37
38 for i in 0..total {
39 let mut pred_class = 0;
41 let mut max_pred = predictions[[i, 0]];
42 for j in 1..predictions.ncols() {
43 if predictions[[i, j]] > max_pred {
44 max_pred = predictions[[i, j]];
45 pred_class = j;
46 }
47 }
48
49 let mut true_class = 0;
51 let mut max_true = targets[[i, 0]];
52 for j in 1..targets.ncols() {
53 if targets[[i, j]] > max_true {
54 max_true = targets[[i, j]];
55 true_class = j;
56 }
57 }
58
59 if pred_class == true_class {
60 correct += 1;
61 }
62 }
63
64 Ok(correct as f64 / total as f64)
65 }
66
67 fn name(&self) -> &str {
68 "accuracy"
69 }
70}
71
72#[derive(Debug, Clone, Default)]
74pub struct Precision {
75 pub class_id: Option<usize>,
77}
78
79impl Metric for Precision {
80 fn compute(
81 &self,
82 predictions: &ArrayView<f64, Ix2>,
83 targets: &ArrayView<f64, Ix2>,
84 ) -> TrainResult<f64> {
85 if predictions.shape() != targets.shape() {
86 return Err(TrainError::MetricsError(format!(
87 "Shape mismatch: predictions {:?} vs targets {:?}",
88 predictions.shape(),
89 targets.shape()
90 )));
91 }
92
93 let num_classes = predictions.ncols();
94 let mut true_positives = vec![0; num_classes];
95 let mut predicted_positives = vec![0; num_classes];
96
97 for i in 0..predictions.nrows() {
98 let mut pred_class = 0;
100 let mut max_pred = predictions[[i, 0]];
101 for j in 1..num_classes {
102 if predictions[[i, j]] > max_pred {
103 max_pred = predictions[[i, j]];
104 pred_class = j;
105 }
106 }
107
108 let mut true_class = 0;
110 let mut max_true = targets[[i, 0]];
111 for j in 1..num_classes {
112 if targets[[i, j]] > max_true {
113 max_true = targets[[i, j]];
114 true_class = j;
115 }
116 }
117
118 predicted_positives[pred_class] += 1;
119 if pred_class == true_class {
120 true_positives[pred_class] += 1;
121 }
122 }
123
124 if let Some(class_id) = self.class_id {
125 if predicted_positives[class_id] == 0 {
127 Ok(0.0)
128 } else {
129 Ok(true_positives[class_id] as f64 / predicted_positives[class_id] as f64)
130 }
131 } else {
132 let mut total_precision = 0.0;
134 let mut valid_classes = 0;
135
136 for class_id in 0..num_classes {
137 if predicted_positives[class_id] > 0 {
138 total_precision +=
139 true_positives[class_id] as f64 / predicted_positives[class_id] as f64;
140 valid_classes += 1;
141 }
142 }
143
144 if valid_classes == 0 {
145 Ok(0.0)
146 } else {
147 Ok(total_precision / valid_classes as f64)
148 }
149 }
150 }
151
152 fn name(&self) -> &str {
153 "precision"
154 }
155}
156
157#[derive(Debug, Clone, Default)]
159pub struct Recall {
160 pub class_id: Option<usize>,
162}
163
164impl Metric for Recall {
165 fn compute(
166 &self,
167 predictions: &ArrayView<f64, Ix2>,
168 targets: &ArrayView<f64, Ix2>,
169 ) -> TrainResult<f64> {
170 if predictions.shape() != targets.shape() {
171 return Err(TrainError::MetricsError(format!(
172 "Shape mismatch: predictions {:?} vs targets {:?}",
173 predictions.shape(),
174 targets.shape()
175 )));
176 }
177
178 let num_classes = predictions.ncols();
179 let mut true_positives = vec![0; num_classes];
180 let mut actual_positives = vec![0; num_classes];
181
182 for i in 0..predictions.nrows() {
183 let mut pred_class = 0;
185 let mut max_pred = predictions[[i, 0]];
186 for j in 1..num_classes {
187 if predictions[[i, j]] > max_pred {
188 max_pred = predictions[[i, j]];
189 pred_class = j;
190 }
191 }
192
193 let mut true_class = 0;
195 let mut max_true = targets[[i, 0]];
196 for j in 1..num_classes {
197 if targets[[i, j]] > max_true {
198 max_true = targets[[i, j]];
199 true_class = j;
200 }
201 }
202
203 actual_positives[true_class] += 1;
204 if pred_class == true_class {
205 true_positives[pred_class] += 1;
206 }
207 }
208
209 if let Some(class_id) = self.class_id {
210 if actual_positives[class_id] == 0 {
212 Ok(0.0)
213 } else {
214 Ok(true_positives[class_id] as f64 / actual_positives[class_id] as f64)
215 }
216 } else {
217 let mut total_recall = 0.0;
219 let mut valid_classes = 0;
220
221 for class_id in 0..num_classes {
222 if actual_positives[class_id] > 0 {
223 total_recall +=
224 true_positives[class_id] as f64 / actual_positives[class_id] as f64;
225 valid_classes += 1;
226 }
227 }
228
229 if valid_classes == 0 {
230 Ok(0.0)
231 } else {
232 Ok(total_recall / valid_classes as f64)
233 }
234 }
235 }
236
237 fn name(&self) -> &str {
238 "recall"
239 }
240}
241
242#[derive(Debug, Clone, Default)]
244pub struct F1Score {
245 pub class_id: Option<usize>,
247}
248
249impl Metric for F1Score {
250 fn compute(
251 &self,
252 predictions: &ArrayView<f64, Ix2>,
253 targets: &ArrayView<f64, Ix2>,
254 ) -> TrainResult<f64> {
255 let precision = Precision {
256 class_id: self.class_id,
257 }
258 .compute(predictions, targets)?;
259 let recall = Recall {
260 class_id: self.class_id,
261 }
262 .compute(predictions, targets)?;
263
264 if precision + recall == 0.0 {
265 Ok(0.0)
266 } else {
267 Ok(2.0 * precision * recall / (precision + recall))
268 }
269 }
270
271 fn name(&self) -> &str {
272 "f1_score"
273 }
274}
275
276#[cfg(test)]
277mod tests {
278 use super::*;
279 use scirs2_core::ndarray::array;
280
281 #[test]
282 fn test_accuracy() {
283 let metric = Accuracy::default();
284
285 let predictions = array![[0.9, 0.1], [0.2, 0.8], [0.8, 0.2]];
287 let targets = array![[1.0, 0.0], [0.0, 1.0], [1.0, 0.0]];
288
289 let accuracy = metric
290 .compute(&predictions.view(), &targets.view())
291 .unwrap();
292 assert_eq!(accuracy, 1.0);
293
294 let predictions = array![[0.9, 0.1], [0.8, 0.2], [0.8, 0.2]];
296 let targets = array![[1.0, 0.0], [0.0, 1.0], [1.0, 0.0]];
297
298 let accuracy = metric
299 .compute(&predictions.view(), &targets.view())
300 .unwrap();
301 assert!((accuracy - 2.0 / 3.0).abs() < 1e-6);
302 }
303
304 #[test]
305 fn test_precision() {
306 let metric = Precision::default();
307
308 let predictions = array![[0.9, 0.1], [0.2, 0.8], [0.7, 0.3]];
309 let targets = array![[1.0, 0.0], [0.0, 1.0], [0.0, 1.0]];
310
311 let precision = metric
312 .compute(&predictions.view(), &targets.view())
313 .unwrap();
314 assert!((0.0..=1.0).contains(&precision));
315 }
316
317 #[test]
318 fn test_recall() {
319 let metric = Recall::default();
320
321 let predictions = array![[0.9, 0.1], [0.2, 0.8], [0.7, 0.3]];
322 let targets = array![[1.0, 0.0], [0.0, 1.0], [0.0, 1.0]];
323
324 let recall = metric
325 .compute(&predictions.view(), &targets.view())
326 .unwrap();
327 assert!((0.0..=1.0).contains(&recall));
328 }
329
330 #[test]
331 fn test_f1_score() {
332 let metric = F1Score::default();
333
334 let predictions = array![[0.9, 0.1], [0.2, 0.8], [0.7, 0.3]];
335 let targets = array![[1.0, 0.0], [0.0, 1.0], [0.0, 1.0]];
336
337 let f1 = metric
338 .compute(&predictions.view(), &targets.view())
339 .unwrap();
340 assert!((0.0..=1.0).contains(&f1));
341 }
342}