1use crate::{TrainError, TrainResult};
4use scirs2_core::ndarray::{ArrayView, Ix2};
5
6use super::Metric;
7
8#[derive(Debug, Clone)]
17pub struct IoU {
18 pub threshold: f64,
20 pub epsilon: f64,
22}
23
24impl Default for IoU {
25 fn default() -> Self {
26 Self {
27 threshold: 0.5,
28 epsilon: 1e-7,
29 }
30 }
31}
32
33impl IoU {
34 pub fn new(threshold: f64) -> Self {
36 Self {
37 threshold,
38 epsilon: 1e-7,
39 }
40 }
41}
42
43impl Metric for IoU {
44 fn compute(
45 &self,
46 predictions: &ArrayView<f64, Ix2>,
47 targets: &ArrayView<f64, Ix2>,
48 ) -> TrainResult<f64> {
49 if predictions.shape() != targets.shape() {
50 return Err(TrainError::MetricsError(format!(
51 "Shape mismatch: predictions {:?} vs targets {:?}",
52 predictions.shape(),
53 targets.shape()
54 )));
55 }
56
57 let mut intersection = 0.0;
58 let mut union = 0.0;
59
60 for i in 0..predictions.nrows() {
61 for j in 0..predictions.ncols() {
62 let pred = if predictions[[i, j]] >= self.threshold {
63 1.0
64 } else {
65 0.0
66 };
67 let target = targets[[i, j]];
68
69 intersection += pred * target;
70 union += (pred + target - pred * target).max(0.0);
71 }
72 }
73
74 Ok(intersection / (union + self.epsilon))
75 }
76
77 fn name(&self) -> &str {
78 "iou"
79 }
80}
81
82#[derive(Debug, Clone)]
87pub struct MeanIoU {
88 pub threshold: f64,
90 pub epsilon: f64,
92}
93
94impl Default for MeanIoU {
95 fn default() -> Self {
96 Self {
97 threshold: 0.5,
98 epsilon: 1e-7,
99 }
100 }
101}
102
103impl Metric for MeanIoU {
104 fn compute(
105 &self,
106 predictions: &ArrayView<f64, Ix2>,
107 targets: &ArrayView<f64, Ix2>,
108 ) -> TrainResult<f64> {
109 if predictions.shape() != targets.shape() {
110 return Err(TrainError::MetricsError(format!(
111 "Shape mismatch: predictions {:?} vs targets {:?}",
112 predictions.shape(),
113 targets.shape()
114 )));
115 }
116
117 let num_classes = predictions.ncols();
118 let mut class_ious = Vec::new();
119
120 for class_idx in 0..num_classes {
122 let mut intersection = 0.0;
123 let mut union = 0.0;
124
125 for i in 0..predictions.nrows() {
126 let pred = if predictions[[i, class_idx]] >= self.threshold {
127 1.0
128 } else {
129 0.0
130 };
131 let target = targets[[i, class_idx]];
132
133 intersection += pred * target;
134 union += (pred + target - pred * target).max(0.0);
135 }
136
137 if union > self.epsilon {
138 class_ious.push(intersection / union);
139 }
140 }
141
142 if class_ious.is_empty() {
143 return Ok(0.0);
144 }
145
146 Ok(class_ious.iter().sum::<f64>() / class_ious.len() as f64)
147 }
148
149 fn name(&self) -> &str {
150 "mean_iou"
151 }
152}
153
154#[derive(Debug, Clone)]
161pub struct DiceCoefficient {
162 pub threshold: f64,
164 pub epsilon: f64,
166}
167
168impl Default for DiceCoefficient {
169 fn default() -> Self {
170 Self {
171 threshold: 0.5,
172 epsilon: 1e-7,
173 }
174 }
175}
176
177impl Metric for DiceCoefficient {
178 fn compute(
179 &self,
180 predictions: &ArrayView<f64, Ix2>,
181 targets: &ArrayView<f64, Ix2>,
182 ) -> TrainResult<f64> {
183 if predictions.shape() != targets.shape() {
184 return Err(TrainError::MetricsError(format!(
185 "Shape mismatch: predictions {:?} vs targets {:?}",
186 predictions.shape(),
187 targets.shape()
188 )));
189 }
190
191 let mut intersection = 0.0;
192 let mut pred_sum = 0.0;
193 let mut target_sum = 0.0;
194
195 for i in 0..predictions.nrows() {
196 for j in 0..predictions.ncols() {
197 let pred = if predictions[[i, j]] >= self.threshold {
198 1.0
199 } else {
200 0.0
201 };
202 let target = targets[[i, j]];
203
204 intersection += pred * target;
205 pred_sum += pred;
206 target_sum += target;
207 }
208 }
209
210 Ok((2.0 * intersection) / (pred_sum + target_sum + self.epsilon))
211 }
212
213 fn name(&self) -> &str {
214 "dice_coefficient"
215 }
216}
217
218#[derive(Debug, Clone)]
225pub struct MeanAveragePrecision {
226 pub num_recall_points: usize,
228}
229
230impl Default for MeanAveragePrecision {
231 fn default() -> Self {
232 Self {
233 num_recall_points: 11, }
235 }
236}
237
238impl MeanAveragePrecision {
239 pub fn new(num_recall_points: usize) -> Self {
241 Self { num_recall_points }
242 }
243
244 fn compute_ap(&self, predictions: &[f64], targets: &[bool]) -> f64 {
246 if predictions.is_empty() || targets.is_empty() {
247 return 0.0;
248 }
249
250 let mut paired: Vec<(f64, bool)> = predictions
252 .iter()
253 .copied()
254 .zip(targets.iter().copied())
255 .collect();
256 paired.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
257
258 let total_positives = targets.iter().filter(|&&t| t).count() as f64;
259 if total_positives == 0.0 {
260 return 0.0;
261 }
262
263 let mut true_positives = 0.0;
264 let mut false_positives = 0.0;
265 let mut precisions = Vec::new();
266 let mut recalls = Vec::new();
267
268 for (_, target) in paired {
269 if target {
270 true_positives += 1.0;
271 } else {
272 false_positives += 1.0;
273 }
274
275 let precision = true_positives / (true_positives + false_positives);
276 let recall = true_positives / total_positives;
277
278 precisions.push(precision);
279 recalls.push(recall);
280 }
281
282 let mut ap = 0.0;
284 for i in 0..self.num_recall_points {
285 let recall_level = i as f64 / (self.num_recall_points - 1) as f64;
286
287 let max_precision = recalls
289 .iter()
290 .enumerate()
291 .filter(|(_, &r)| r >= recall_level)
292 .map(|(i, _)| precisions[i])
293 .fold(0.0, f64::max);
294
295 ap += max_precision;
296 }
297
298 ap / self.num_recall_points as f64
299 }
300}
301
302impl Metric for MeanAveragePrecision {
303 fn compute(
304 &self,
305 predictions: &ArrayView<f64, Ix2>,
306 targets: &ArrayView<f64, Ix2>,
307 ) -> TrainResult<f64> {
308 if predictions.shape() != targets.shape() {
309 return Err(TrainError::MetricsError(format!(
310 "Shape mismatch: predictions {:?} vs targets {:?}",
311 predictions.shape(),
312 targets.shape()
313 )));
314 }
315
316 let num_classes = predictions.ncols();
317 let mut aps = Vec::new();
318
319 for class_idx in 0..num_classes {
321 let mut class_preds = Vec::new();
322 let mut class_targets = Vec::new();
323
324 for i in 0..predictions.nrows() {
325 class_preds.push(predictions[[i, class_idx]]);
326 class_targets.push(targets[[i, class_idx]] > 0.5);
327 }
328
329 let ap = self.compute_ap(&class_preds, &class_targets);
330 aps.push(ap);
331 }
332
333 if aps.is_empty() {
334 return Ok(0.0);
335 }
336
337 Ok(aps.iter().sum::<f64>() / aps.len() as f64)
338 }
339
340 fn name(&self) -> &str {
341 "mean_average_precision"
342 }
343}
344
345#[cfg(test)]
346mod tests {
347 use super::*;
348 use scirs2_core::ndarray::array;
349
350 #[test]
351 fn test_iou() {
352 let metric = IoU::default();
353
354 let predictions = array![[0.9, 0.1], [0.8, 0.2], [0.9, 0.1]];
356 let targets = array![[1.0, 0.0], [1.0, 0.0], [1.0, 0.0]];
357
358 let iou = metric
359 .compute(&predictions.view(), &targets.view())
360 .unwrap();
361 assert!((iou - 1.0).abs() < 1e-6);
362
363 let predictions = array![[0.9, 0.1], [0.2, 0.8], [0.6, 0.4]];
365 let targets = array![[1.0, 0.0], [1.0, 0.0], [1.0, 0.0]];
366
367 let iou = metric
368 .compute(&predictions.view(), &targets.view())
369 .unwrap();
370 assert!((0.0..=1.0).contains(&iou));
371 assert!(iou < 1.0);
372 }
373
374 #[test]
375 fn test_mean_iou() {
376 let metric = MeanIoU::default();
377
378 let predictions = array![[0.9, 0.1, 0.0], [0.1, 0.8, 0.1], [0.0, 0.1, 0.9]];
380 let targets = array![[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]];
381
382 let miou = metric
383 .compute(&predictions.view(), &targets.view())
384 .unwrap();
385 assert!((miou - 1.0).abs() < 1e-6);
386 }
387
388 #[test]
389 fn test_dice_coefficient() {
390 let metric = DiceCoefficient::default();
391
392 let predictions = array![[0.9, 0.1], [0.8, 0.2], [0.9, 0.1]];
394 let targets = array![[1.0, 0.0], [1.0, 0.0], [1.0, 0.0]];
395
396 let dice = metric
397 .compute(&predictions.view(), &targets.view())
398 .unwrap();
399 assert!((dice - 1.0).abs() < 1e-6);
400
401 let predictions = array![[0.1, 0.9], [0.2, 0.8]];
403 let targets = array![[1.0, 0.0], [1.0, 0.0]];
404
405 let dice = metric
406 .compute(&predictions.view(), &targets.view())
407 .unwrap();
408 assert!(dice < 0.1);
409 }
410
411 #[test]
412 fn test_mean_average_precision() {
413 let metric = MeanAveragePrecision::default();
414
415 let predictions = array![[0.9, 0.8], [0.8, 0.7], [0.3, 0.2], [0.2, 0.1]];
417 let targets = array![[1.0, 1.0], [1.0, 1.0], [0.0, 0.0], [0.0, 0.0]];
418
419 let map = metric
420 .compute(&predictions.view(), &targets.view())
421 .unwrap();
422 assert!((map - 1.0).abs() < 1e-6);
423
424 let predictions = array![[0.5, 0.5], [0.5, 0.5], [0.5, 0.5], [0.5, 0.5]];
426 let targets = array![[1.0, 0.0], [0.0, 1.0], [1.0, 0.0], [0.0, 1.0]];
427
428 let map = metric
429 .compute(&predictions.view(), &targets.view())
430 .unwrap();
431 assert!((0.0..=1.0).contains(&map));
432 }
433
434 #[test]
435 fn test_iou_custom_threshold() {
436 let metric = IoU::new(0.7);
437
438 let predictions = array![[0.8, 0.2], [0.6, 0.4]]; let targets = array![[1.0, 0.0], [1.0, 0.0]];
440
441 let iou = metric
442 .compute(&predictions.view(), &targets.view())
443 .unwrap();
444 assert!((0.0..=1.0).contains(&iou));
445 assert!(iou < 1.0); }
447
448 #[test]
449 fn test_mean_average_precision_custom_points() {
450 let metric = MeanAveragePrecision::new(5); let predictions = array![[0.9], [0.8], [0.3], [0.2]];
453 let targets = array![[1.0], [1.0], [0.0], [0.0]];
454
455 let map = metric
456 .compute(&predictions.view(), &targets.view())
457 .unwrap();
458 assert!((0.0..=1.0).contains(&map));
459 }
460}