1use yscv_tensor::Tensor;
2
3use super::super::ImgProcError;
4use super::super::shape::hwc_shape;
5
6#[derive(Debug, Clone, Copy, PartialEq)]
8pub struct BBox {
9 pub x1: f32,
10 pub y1: f32,
11 pub x2: f32,
12 pub y2: f32,
13 pub score: f32,
14}
15
16pub fn nms(boxes: &[BBox], iou_threshold: f32) -> Vec<usize> {
20 let mut indices: Vec<usize> = (0..boxes.len()).collect();
21 indices.sort_by(|&a, &b| {
22 boxes[b]
23 .score
24 .partial_cmp(&boxes[a].score)
25 .unwrap_or(std::cmp::Ordering::Equal)
26 });
27
28 let mut keep = Vec::new();
29 let mut suppressed = vec![false; boxes.len()];
30
31 for &i in &indices {
32 if suppressed[i] {
33 continue;
34 }
35 keep.push(i);
36 for &j in &indices {
37 if suppressed[j] || j == i {
38 continue;
39 }
40 if iou(&boxes[i], &boxes[j]) > iou_threshold {
41 suppressed[j] = true;
42 }
43 }
44 }
45
46 keep
47}
48
49fn iou(a: &BBox, b: &BBox) -> f32 {
50 let x1 = a.x1.max(b.x1);
51 let y1 = a.y1.max(b.y1);
52 let x2 = a.x2.min(b.x2);
53 let y2 = a.y2.min(b.y2);
54 let inter = (x2 - x1).max(0.0) * (y2 - y1).max(0.0);
55 let area_a = (a.x2 - a.x1) * (a.y2 - a.y1);
56 let area_b = (b.x2 - b.x1) * (b.y2 - b.y1);
57 let union = area_a + area_b - inter;
58 if union <= 0.0 { 0.0 } else { inter / union }
59}
60
61#[derive(Debug, Clone, Copy, PartialEq, Eq)]
63pub enum TemplateMatchMethod {
64 Ssd,
66 Ncc,
68}
69
70#[derive(Debug, Clone, Copy, PartialEq)]
72pub struct TemplateMatchResult {
73 pub x: usize,
74 pub y: usize,
75 pub score: f32,
76}
77
78pub fn template_match(
82 image: &Tensor,
83 template: &Tensor,
84 method: TemplateMatchMethod,
85) -> Result<TemplateMatchResult, ImgProcError> {
86 let (ih, iw, ic) = hwc_shape(image)?;
87 let (th, tw, tc) = hwc_shape(template)?;
88 if ic != 1 || tc != 1 {
89 return Err(ImgProcError::InvalidChannelCount {
90 expected: 1,
91 got: if ic != 1 { ic } else { tc },
92 });
93 }
94 if th > ih || tw > iw {
95 return Err(ImgProcError::InvalidSize {
96 height: th,
97 width: tw,
98 });
99 }
100 let img = image.data();
101 let tmpl = template.data();
102 let rh = ih - th + 1;
103 let rw = iw - tw + 1;
104
105 let mut best = TemplateMatchResult {
106 x: 0,
107 y: 0,
108 score: match method {
109 TemplateMatchMethod::Ssd => f32::MAX,
110 TemplateMatchMethod::Ncc => f32::NEG_INFINITY,
111 },
112 };
113
114 let tmpl_mean: f32 = tmpl.iter().sum::<f32>() / tmpl.len() as f32;
116 let tmpl_std: f32 = {
117 let var: f32 = tmpl
118 .iter()
119 .map(|&v| (v - tmpl_mean) * (v - tmpl_mean))
120 .sum::<f32>()
121 / tmpl.len() as f32;
122 var.sqrt()
123 };
124
125 for y in 0..rh {
126 for x in 0..rw {
127 let score = match method {
128 TemplateMatchMethod::Ssd => {
129 let mut sum = 0.0f32;
130 for ty in 0..th {
131 for tx in 0..tw {
132 let diff = img[(y + ty) * iw + x + tx] - tmpl[ty * tw + tx];
133 sum += diff * diff;
134 }
135 }
136 sum
137 }
138 TemplateMatchMethod::Ncc => {
139 let patch_size = (th * tw) as f32;
140 let mut patch_mean = 0.0f32;
141 for ty in 0..th {
142 for tx in 0..tw {
143 patch_mean += img[(y + ty) * iw + x + tx];
144 }
145 }
146 patch_mean /= patch_size;
147 let mut num = 0.0f32;
148 let mut den_patch = 0.0f32;
149 for ty in 0..th {
150 for tx in 0..tw {
151 let pi = img[(y + ty) * iw + x + tx] - patch_mean;
152 let ti = tmpl[ty * tw + tx] - tmpl_mean;
153 num += pi * ti;
154 den_patch += pi * pi;
155 }
156 }
157 let den = (den_patch.sqrt()) * (tmpl_std * patch_size.sqrt());
158 if den.abs() < 1e-10 { 0.0 } else { num / den }
159 }
160 };
161
162 let is_better = match method {
163 TemplateMatchMethod::Ssd => score < best.score,
164 TemplateMatchMethod::Ncc => score > best.score,
165 };
166 if is_better {
167 best = TemplateMatchResult { x, y, score };
168 }
169 }
170 }
171
172 Ok(best)
173}