Skip to main content

yscv_imgproc/ops/
nms.rs

1use yscv_tensor::Tensor;
2
3use super::super::ImgProcError;
4use super::super::shape::hwc_shape;
5
6/// Bounding box with confidence for NMS.
7#[derive(Debug, Clone, Copy, PartialEq)]
8pub struct BBox {
9    pub x1: f32,
10    pub y1: f32,
11    pub x2: f32,
12    pub y2: f32,
13    pub score: f32,
14}
15
16/// Greedy non-maximum suppression on bounding boxes.
17///
18/// Returns indices of kept boxes, sorted by descending confidence.
19pub fn nms(boxes: &[BBox], iou_threshold: f32) -> Vec<usize> {
20    let mut indices: Vec<usize> = (0..boxes.len()).collect();
21    indices.sort_by(|&a, &b| {
22        boxes[b]
23            .score
24            .partial_cmp(&boxes[a].score)
25            .unwrap_or(std::cmp::Ordering::Equal)
26    });
27
28    let mut keep = Vec::new();
29    let mut suppressed = vec![false; boxes.len()];
30
31    for &i in &indices {
32        if suppressed[i] {
33            continue;
34        }
35        keep.push(i);
36        for &j in &indices {
37            if suppressed[j] || j == i {
38                continue;
39            }
40            if iou(&boxes[i], &boxes[j]) > iou_threshold {
41                suppressed[j] = true;
42            }
43        }
44    }
45
46    keep
47}
48
49fn iou(a: &BBox, b: &BBox) -> f32 {
50    let x1 = a.x1.max(b.x1);
51    let y1 = a.y1.max(b.y1);
52    let x2 = a.x2.min(b.x2);
53    let y2 = a.y2.min(b.y2);
54    let inter = (x2 - x1).max(0.0) * (y2 - y1).max(0.0);
55    let area_a = (a.x2 - a.x1) * (a.y2 - a.y1);
56    let area_b = (b.x2 - b.x1) * (b.y2 - b.y1);
57    let union = area_a + area_b - inter;
58    if union <= 0.0 { 0.0 } else { inter / union }
59}
60
61/// Template matching method.
62#[derive(Debug, Clone, Copy, PartialEq, Eq)]
63pub enum TemplateMatchMethod {
64    /// Sum of squared differences (lower = better match).
65    Ssd,
66    /// Normalized cross-correlation (higher = better match).
67    Ncc,
68}
69
70/// Template matching result: top-left location and score.
71#[derive(Debug, Clone, Copy, PartialEq)]
72pub struct TemplateMatchResult {
73    pub x: usize,
74    pub y: usize,
75    pub score: f32,
76}
77
78/// Slides `template` over `image` computing a similarity map, returns the best match.
79///
80/// Both must be single-channel `[H, W, 1]`.
81pub fn template_match(
82    image: &Tensor,
83    template: &Tensor,
84    method: TemplateMatchMethod,
85) -> Result<TemplateMatchResult, ImgProcError> {
86    let (ih, iw, ic) = hwc_shape(image)?;
87    let (th, tw, tc) = hwc_shape(template)?;
88    if ic != 1 || tc != 1 {
89        return Err(ImgProcError::InvalidChannelCount {
90            expected: 1,
91            got: if ic != 1 { ic } else { tc },
92        });
93    }
94    if th > ih || tw > iw {
95        return Err(ImgProcError::InvalidSize {
96            height: th,
97            width: tw,
98        });
99    }
100    let img = image.data();
101    let tmpl = template.data();
102    let rh = ih - th + 1;
103    let rw = iw - tw + 1;
104
105    let mut best = TemplateMatchResult {
106        x: 0,
107        y: 0,
108        score: match method {
109            TemplateMatchMethod::Ssd => f32::MAX,
110            TemplateMatchMethod::Ncc => f32::NEG_INFINITY,
111        },
112    };
113
114    // Precompute template stats for NCC
115    let tmpl_mean: f32 = tmpl.iter().sum::<f32>() / tmpl.len() as f32;
116    let tmpl_std: f32 = {
117        let var: f32 = tmpl
118            .iter()
119            .map(|&v| (v - tmpl_mean) * (v - tmpl_mean))
120            .sum::<f32>()
121            / tmpl.len() as f32;
122        var.sqrt()
123    };
124
125    for y in 0..rh {
126        for x in 0..rw {
127            let score = match method {
128                TemplateMatchMethod::Ssd => {
129                    let mut sum = 0.0f32;
130                    for ty in 0..th {
131                        for tx in 0..tw {
132                            let diff = img[(y + ty) * iw + x + tx] - tmpl[ty * tw + tx];
133                            sum += diff * diff;
134                        }
135                    }
136                    sum
137                }
138                TemplateMatchMethod::Ncc => {
139                    let patch_size = (th * tw) as f32;
140                    let mut patch_mean = 0.0f32;
141                    for ty in 0..th {
142                        for tx in 0..tw {
143                            patch_mean += img[(y + ty) * iw + x + tx];
144                        }
145                    }
146                    patch_mean /= patch_size;
147                    let mut num = 0.0f32;
148                    let mut den_patch = 0.0f32;
149                    for ty in 0..th {
150                        for tx in 0..tw {
151                            let pi = img[(y + ty) * iw + x + tx] - patch_mean;
152                            let ti = tmpl[ty * tw + tx] - tmpl_mean;
153                            num += pi * ti;
154                            den_patch += pi * pi;
155                        }
156                    }
157                    let den = (den_patch.sqrt()) * (tmpl_std * patch_size.sqrt());
158                    if den.abs() < 1e-10 { 0.0 } else { num / den }
159                }
160            };
161
162            let is_better = match method {
163                TemplateMatchMethod::Ssd => score < best.score,
164                TemplateMatchMethod::Ncc => score > best.score,
165            };
166            if is_better {
167                best = TemplateMatchResult { x, y, score };
168            }
169        }
170    }
171
172    Ok(best)
173}