Skip to main content

tensorlogic_train/metrics/
vision.rs

1//! Computer vision metrics.
2
3use crate::{TrainError, TrainResult};
4use scirs2_core::ndarray::{ArrayView, Ix2};
5
6use super::Metric;
7
8/// Intersection over Union (IoU) metric for segmentation tasks.
9///
10/// IoU = (Intersection) / (Union) = TP / (TP + FP + FN)
11///
12/// Also known as Jaccard Index, this is a key metric for:
13/// - Semantic segmentation
14/// - Instance segmentation
15/// - Object detection (bounding box overlap)
16#[derive(Debug, Clone)]
17pub struct IoU {
18    /// Threshold for converting predictions to binary
19    pub threshold: f64,
20    /// Small epsilon to avoid division by zero
21    pub epsilon: f64,
22}
23
24impl Default for IoU {
25    fn default() -> Self {
26        Self {
27            threshold: 0.5,
28            epsilon: 1e-7,
29        }
30    }
31}
32
33impl IoU {
34    /// Create a new IoU metric with custom threshold.
35    pub fn new(threshold: f64) -> Self {
36        Self {
37            threshold,
38            epsilon: 1e-7,
39        }
40    }
41}
42
43impl Metric for IoU {
44    fn compute(
45        &self,
46        predictions: &ArrayView<f64, Ix2>,
47        targets: &ArrayView<f64, Ix2>,
48    ) -> TrainResult<f64> {
49        if predictions.shape() != targets.shape() {
50            return Err(TrainError::MetricsError(format!(
51                "Shape mismatch: predictions {:?} vs targets {:?}",
52                predictions.shape(),
53                targets.shape()
54            )));
55        }
56
57        let mut intersection = 0.0;
58        let mut union = 0.0;
59
60        for i in 0..predictions.nrows() {
61            for j in 0..predictions.ncols() {
62                let pred = if predictions[[i, j]] >= self.threshold {
63                    1.0
64                } else {
65                    0.0
66                };
67                let target = targets[[i, j]];
68
69                intersection += pred * target;
70                union += (pred + target - pred * target).max(0.0);
71            }
72        }
73
74        Ok(intersection / (union + self.epsilon))
75    }
76
77    fn name(&self) -> &str {
78        "iou"
79    }
80}
81
82/// Mean Intersection over Union (mIoU) metric for multi-class segmentation.
83///
84/// Computes IoU for each class separately and returns the mean.
85/// This is the standard evaluation metric for semantic segmentation.
86#[derive(Debug, Clone)]
87pub struct MeanIoU {
88    /// Threshold for converting predictions to binary
89    pub threshold: f64,
90    /// Small epsilon to avoid division by zero
91    pub epsilon: f64,
92}
93
94impl Default for MeanIoU {
95    fn default() -> Self {
96        Self {
97            threshold: 0.5,
98            epsilon: 1e-7,
99        }
100    }
101}
102
103impl Metric for MeanIoU {
104    fn compute(
105        &self,
106        predictions: &ArrayView<f64, Ix2>,
107        targets: &ArrayView<f64, Ix2>,
108    ) -> TrainResult<f64> {
109        if predictions.shape() != targets.shape() {
110            return Err(TrainError::MetricsError(format!(
111                "Shape mismatch: predictions {:?} vs targets {:?}",
112                predictions.shape(),
113                targets.shape()
114            )));
115        }
116
117        let num_classes = predictions.ncols();
118        let mut class_ious = Vec::new();
119
120        // Compute IoU for each class
121        for class_idx in 0..num_classes {
122            let mut intersection = 0.0;
123            let mut union = 0.0;
124
125            for i in 0..predictions.nrows() {
126                let pred = if predictions[[i, class_idx]] >= self.threshold {
127                    1.0
128                } else {
129                    0.0
130                };
131                let target = targets[[i, class_idx]];
132
133                intersection += pred * target;
134                union += (pred + target - pred * target).max(0.0);
135            }
136
137            if union > self.epsilon {
138                class_ious.push(intersection / union);
139            }
140        }
141
142        if class_ious.is_empty() {
143            return Ok(0.0);
144        }
145
146        Ok(class_ious.iter().sum::<f64>() / class_ious.len() as f64)
147    }
148
149    fn name(&self) -> &str {
150        "mean_iou"
151    }
152}
153
154/// Dice Coefficient metric (F1 Score variant for segmentation).
155///
156/// Dice = 2 * (Intersection) / (|A| + |B|) = 2TP / (2TP + FP + FN)
157///
158/// Often used in medical image segmentation.
159/// Range: [0, 1] where 1 is perfect overlap.
160#[derive(Debug, Clone)]
161pub struct DiceCoefficient {
162    /// Threshold for converting predictions to binary
163    pub threshold: f64,
164    /// Small epsilon to avoid division by zero
165    pub epsilon: f64,
166}
167
168impl Default for DiceCoefficient {
169    fn default() -> Self {
170        Self {
171            threshold: 0.5,
172            epsilon: 1e-7,
173        }
174    }
175}
176
177impl Metric for DiceCoefficient {
178    fn compute(
179        &self,
180        predictions: &ArrayView<f64, Ix2>,
181        targets: &ArrayView<f64, Ix2>,
182    ) -> TrainResult<f64> {
183        if predictions.shape() != targets.shape() {
184            return Err(TrainError::MetricsError(format!(
185                "Shape mismatch: predictions {:?} vs targets {:?}",
186                predictions.shape(),
187                targets.shape()
188            )));
189        }
190
191        let mut intersection = 0.0;
192        let mut pred_sum = 0.0;
193        let mut target_sum = 0.0;
194
195        for i in 0..predictions.nrows() {
196            for j in 0..predictions.ncols() {
197                let pred = if predictions[[i, j]] >= self.threshold {
198                    1.0
199                } else {
200                    0.0
201                };
202                let target = targets[[i, j]];
203
204                intersection += pred * target;
205                pred_sum += pred;
206                target_sum += target;
207            }
208        }
209
210        Ok((2.0 * intersection) / (pred_sum + target_sum + self.epsilon))
211    }
212
213    fn name(&self) -> &str {
214        "dice_coefficient"
215    }
216}
217
218/// Mean Average Precision (mAP) metric for object detection and retrieval.
219///
220/// Computes the average precision (AP) for each class and returns the mean.
221/// This is a simplified version for multi-label classification scenarios.
222///
223/// For true object detection mAP with IoU thresholds, use specialized computer vision libraries.
224#[derive(Debug, Clone)]
225pub struct MeanAveragePrecision {
226    /// Number of recall points to sample for AP calculation
227    pub num_recall_points: usize,
228}
229
230impl Default for MeanAveragePrecision {
231    fn default() -> Self {
232        Self {
233            num_recall_points: 11, // Standard 11-point interpolation
234        }
235    }
236}
237
238impl MeanAveragePrecision {
239    /// Create with custom number of recall points.
240    pub fn new(num_recall_points: usize) -> Self {
241        Self { num_recall_points }
242    }
243
244    /// Compute Average Precision for a single class.
245    fn compute_ap(&self, predictions: &[f64], targets: &[bool]) -> f64 {
246        if predictions.is_empty() || targets.is_empty() {
247            return 0.0;
248        }
249
250        // Sort by prediction scores (descending)
251        let mut paired: Vec<(f64, bool)> = predictions
252            .iter()
253            .copied()
254            .zip(targets.iter().copied())
255            .collect();
256        paired.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
257
258        let total_positives = targets.iter().filter(|&&t| t).count() as f64;
259        if total_positives == 0.0 {
260            return 0.0;
261        }
262
263        let mut true_positives = 0.0;
264        let mut false_positives = 0.0;
265        let mut precisions = Vec::new();
266        let mut recalls = Vec::new();
267
268        for (_, target) in paired {
269            if target {
270                true_positives += 1.0;
271            } else {
272                false_positives += 1.0;
273            }
274
275            let precision = true_positives / (true_positives + false_positives);
276            let recall = true_positives / total_positives;
277
278            precisions.push(precision);
279            recalls.push(recall);
280        }
281
282        // Interpolate precision at standard recall levels
283        let mut ap = 0.0;
284        for i in 0..self.num_recall_points {
285            let recall_level = i as f64 / (self.num_recall_points - 1) as f64;
286
287            // Find max precision at recall >= recall_level
288            let max_precision = recalls
289                .iter()
290                .enumerate()
291                .filter(|(_, &r)| r >= recall_level)
292                .map(|(i, _)| precisions[i])
293                .fold(0.0, f64::max);
294
295            ap += max_precision;
296        }
297
298        ap / self.num_recall_points as f64
299    }
300}
301
302impl Metric for MeanAveragePrecision {
303    fn compute(
304        &self,
305        predictions: &ArrayView<f64, Ix2>,
306        targets: &ArrayView<f64, Ix2>,
307    ) -> TrainResult<f64> {
308        if predictions.shape() != targets.shape() {
309            return Err(TrainError::MetricsError(format!(
310                "Shape mismatch: predictions {:?} vs targets {:?}",
311                predictions.shape(),
312                targets.shape()
313            )));
314        }
315
316        let num_classes = predictions.ncols();
317        let mut aps = Vec::new();
318
319        // Compute AP for each class
320        for class_idx in 0..num_classes {
321            let mut class_preds = Vec::new();
322            let mut class_targets = Vec::new();
323
324            for i in 0..predictions.nrows() {
325                class_preds.push(predictions[[i, class_idx]]);
326                class_targets.push(targets[[i, class_idx]] > 0.5);
327            }
328
329            let ap = self.compute_ap(&class_preds, &class_targets);
330            aps.push(ap);
331        }
332
333        if aps.is_empty() {
334            return Ok(0.0);
335        }
336
337        Ok(aps.iter().sum::<f64>() / aps.len() as f64)
338    }
339
340    fn name(&self) -> &str {
341        "mean_average_precision"
342    }
343}
344
345#[cfg(test)]
346mod tests {
347    use super::*;
348    use scirs2_core::ndarray::array;
349
350    #[test]
351    fn test_iou() {
352        let metric = IoU::default();
353
354        // Perfect overlap
355        let predictions = array![[0.9, 0.1], [0.8, 0.2], [0.9, 0.1]];
356        let targets = array![[1.0, 0.0], [1.0, 0.0], [1.0, 0.0]];
357
358        let iou = metric
359            .compute(&predictions.view(), &targets.view())
360            .unwrap();
361        assert!((iou - 1.0).abs() < 1e-6);
362
363        // Partial overlap
364        let predictions = array![[0.9, 0.1], [0.2, 0.8], [0.6, 0.4]];
365        let targets = array![[1.0, 0.0], [1.0, 0.0], [1.0, 0.0]];
366
367        let iou = metric
368            .compute(&predictions.view(), &targets.view())
369            .unwrap();
370        assert!((0.0..=1.0).contains(&iou));
371        assert!(iou < 1.0);
372    }
373
374    #[test]
375    fn test_mean_iou() {
376        let metric = MeanIoU::default();
377
378        // Perfect multi-class segmentation
379        let predictions = array![[0.9, 0.1, 0.0], [0.1, 0.8, 0.1], [0.0, 0.1, 0.9]];
380        let targets = array![[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]];
381
382        let miou = metric
383            .compute(&predictions.view(), &targets.view())
384            .unwrap();
385        assert!((miou - 1.0).abs() < 1e-6);
386    }
387
388    #[test]
389    fn test_dice_coefficient() {
390        let metric = DiceCoefficient::default();
391
392        // Perfect overlap
393        let predictions = array![[0.9, 0.1], [0.8, 0.2], [0.9, 0.1]];
394        let targets = array![[1.0, 0.0], [1.0, 0.0], [1.0, 0.0]];
395
396        let dice = metric
397            .compute(&predictions.view(), &targets.view())
398            .unwrap();
399        assert!((dice - 1.0).abs() < 1e-6);
400
401        // No overlap
402        let predictions = array![[0.1, 0.9], [0.2, 0.8]];
403        let targets = array![[1.0, 0.0], [1.0, 0.0]];
404
405        let dice = metric
406            .compute(&predictions.view(), &targets.view())
407            .unwrap();
408        assert!(dice < 0.1);
409    }
410
411    #[test]
412    fn test_mean_average_precision() {
413        let metric = MeanAveragePrecision::default();
414
415        // Perfect ranking
416        let predictions = array![[0.9, 0.8], [0.8, 0.7], [0.3, 0.2], [0.2, 0.1]];
417        let targets = array![[1.0, 1.0], [1.0, 1.0], [0.0, 0.0], [0.0, 0.0]];
418
419        let map = metric
420            .compute(&predictions.view(), &targets.view())
421            .unwrap();
422        assert!((map - 1.0).abs() < 1e-6);
423
424        // Random ranking
425        let predictions = array![[0.5, 0.5], [0.5, 0.5], [0.5, 0.5], [0.5, 0.5]];
426        let targets = array![[1.0, 0.0], [0.0, 1.0], [1.0, 0.0], [0.0, 1.0]];
427
428        let map = metric
429            .compute(&predictions.view(), &targets.view())
430            .unwrap();
431        assert!((0.0..=1.0).contains(&map));
432    }
433
434    #[test]
435    fn test_iou_custom_threshold() {
436        let metric = IoU::new(0.7);
437
438        let predictions = array![[0.8, 0.2], [0.6, 0.4]]; // Second one below threshold
439        let targets = array![[1.0, 0.0], [1.0, 0.0]];
440
441        let iou = metric
442            .compute(&predictions.view(), &targets.view())
443            .unwrap();
444        assert!((0.0..=1.0).contains(&iou));
445        assert!(iou < 1.0); // Should be less than 1 due to threshold
446    }
447
448    #[test]
449    fn test_mean_average_precision_custom_points() {
450        let metric = MeanAveragePrecision::new(5); // 5-point interpolation
451
452        let predictions = array![[0.9], [0.8], [0.3], [0.2]];
453        let targets = array![[1.0], [1.0], [0.0], [0.0]];
454
455        let map = metric
456            .compute(&predictions.view(), &targets.view())
457            .unwrap();
458        assert!((0.0..=1.0).contains(&map));
459    }
460}