1use yscv_tensor::Tensor;
2
3use crate::nms::validate_nms_args;
4use crate::{BoundingBox, CLASS_ID_PERSON, DetectError, Detection, non_max_suppression};
5
6#[derive(Debug, Default, Clone, PartialEq)]
11pub struct HeatmapDetectScratch {
12 active: Vec<bool>,
13 visited: Vec<bool>,
14 stack: Vec<usize>,
15 detections: Vec<Detection>,
16}
17
18pub fn detect_from_heatmap(
20 heatmap: &Tensor,
21 score_threshold: f32,
22 min_area: usize,
23 iou_threshold: f32,
24 max_detections: usize,
25) -> Result<Vec<Detection>, DetectError> {
26 let mut scratch = HeatmapDetectScratch::default();
27 detect_from_heatmap_with_scratch(
28 heatmap,
29 score_threshold,
30 min_area,
31 iou_threshold,
32 max_detections,
33 &mut scratch,
34 )
35}
36
37pub fn detect_from_heatmap_with_scratch(
39 heatmap: &Tensor,
40 score_threshold: f32,
41 min_area: usize,
42 iou_threshold: f32,
43 max_detections: usize,
44 scratch: &mut HeatmapDetectScratch,
45) -> Result<Vec<Detection>, DetectError> {
46 let (h, w, c) = map_shape(heatmap)?;
47 if c != 1 {
48 return Err(DetectError::InvalidChannelCount {
49 expected: 1,
50 got: c,
51 });
52 }
53 detect_from_heatmap_data_with_scratch(
54 (h, w),
55 heatmap.data(),
56 score_threshold,
57 min_area,
58 iou_threshold,
59 max_detections,
60 scratch,
61 )
62}
63
64pub(crate) fn detect_from_heatmap_data_with_scratch(
65 shape: (usize, usize),
66 data: &[f32],
67 score_threshold: f32,
68 min_area: usize,
69 iou_threshold: f32,
70 max_detections: usize,
71 scratch: &mut HeatmapDetectScratch,
72) -> Result<Vec<Detection>, DetectError> {
73 let (h, w) = shape;
74 if !score_threshold.is_finite() || !(0.0..=1.0).contains(&score_threshold) {
75 return Err(DetectError::InvalidThreshold {
76 threshold: score_threshold,
77 });
78 }
79 if min_area == 0 {
80 return Err(DetectError::InvalidMinArea { min_area });
81 }
82 validate_nms_args(iou_threshold, max_detections)?;
83 let pixel_count = h.saturating_mul(w);
84 debug_assert_eq!(data.len(), pixel_count);
85
86 if scratch.active.len() != pixel_count {
87 scratch.active.resize(pixel_count, false);
88 }
89 if scratch.visited.len() != pixel_count {
90 scratch.visited.resize(pixel_count, false);
91 }
92
93 for ((active, visited), value) in scratch
94 .active
95 .iter_mut()
96 .zip(scratch.visited.iter_mut())
97 .zip(data.iter().copied())
98 {
99 *active = is_active_score(value, score_threshold);
100 *visited = false;
101 }
102
103 scratch.stack.clear();
104 scratch.detections.clear();
105 for start in 0..pixel_count {
106 if scratch.visited[start] || !scratch.active[start] {
107 continue;
108 }
109
110 scratch.visited[start] = true;
111 scratch.stack.clear();
112 scratch.stack.push(start);
113
114 let start_y = start / w;
115 let start_x = start - start_y * w;
116 let mut min_x = start_x;
117 let mut max_x = start_x;
118 let mut min_y = start_y;
119 let mut max_y = start_y;
120 let mut area = 0usize;
121 let mut score_sum = 0.0f32;
122 let mut score_max = 0.0f32;
123
124 while let Some(current) = scratch.stack.pop() {
125 let cy = current / w;
126 let cx = current - cy * w;
127 let current_score = data[current];
128
129 area += 1;
130 score_sum += current_score;
131 score_max = score_max.max(current_score);
132 min_x = min_x.min(cx);
133 max_x = max_x.max(cx);
134 min_y = min_y.min(cy);
135 max_y = max_y.max(cy);
136
137 if cx > 0 {
138 visit_neighbor(
139 current - 1,
140 &scratch.active,
141 &mut scratch.visited,
142 &mut scratch.stack,
143 );
144 }
145 if cx + 1 < w {
146 visit_neighbor(
147 current + 1,
148 &scratch.active,
149 &mut scratch.visited,
150 &mut scratch.stack,
151 );
152 }
153 if cy > 0 {
154 visit_neighbor(
155 current - w,
156 &scratch.active,
157 &mut scratch.visited,
158 &mut scratch.stack,
159 );
160 }
161 if cy + 1 < h {
162 visit_neighbor(
163 current + w,
164 &scratch.active,
165 &mut scratch.visited,
166 &mut scratch.stack,
167 );
168 }
169 }
170
171 if area >= min_area {
172 let avg_score = score_sum / area as f32;
173 scratch.detections.push(Detection {
174 bbox: BoundingBox {
175 x1: min_x as f32,
176 y1: min_y as f32,
177 x2: (max_x + 1) as f32,
178 y2: (max_y + 1) as f32,
179 },
180 score: (avg_score + score_max) * 0.5,
181 class_id: CLASS_ID_PERSON,
182 });
183 }
184 }
185
186 Ok(non_max_suppression(
187 &scratch.detections,
188 iou_threshold,
189 max_detections,
190 ))
191}
192
193pub(crate) fn map_shape(input: &Tensor) -> Result<(usize, usize, usize), DetectError> {
194 if input.rank() != 3 {
195 return Err(DetectError::InvalidMapShape {
196 expected_rank: 3,
197 got: input.shape().to_vec(),
198 });
199 }
200 Ok((input.shape()[0], input.shape()[1], input.shape()[2]))
201}
202
203fn is_active_score(value: f32, threshold: f32) -> bool {
204 value.is_finite() && value >= threshold
205}
206
207fn visit_neighbor(index: usize, active: &[bool], visited: &mut [bool], stack: &mut Vec<usize>) {
208 if visited[index] {
209 return;
210 }
211 visited[index] = true;
212 if active[index] {
213 stack.push(index);
214 }
215}