1use std::cmp::Ordering;
2
3use rustfft::num_traits::Float;
4use tract_nnef::{
5 internal::*,
6 tract_ndarray::{s, ArrayView1},
7};
8
9pub fn register(registry: &mut Registry) {
10 registry.register_primitive(
11 "tract_onnx_non_max_suppression",
12 ¶meters(),
13 &[("output", TypeName::Integer.tensor())],
14 load,
15 );
16 registry.register_dumper(dump);
17}
18
19#[derive(Copy, Clone, Debug, Hash)]
20pub enum BoxRepr {
21 TwoPoints,
23 CenterWidthHeight,
25}
26
27fn get_min_max<T: Float>(lhs: T, rhs: T) -> (T, T) {
28 if lhs >= rhs {
29 (rhs, lhs)
30 } else {
31 (lhs, rhs)
32 }
33}
34
35impl BoxRepr {
36 pub fn from_i64(val: i64) -> TractResult<BoxRepr> {
37 Ok(match val {
38 0 => BoxRepr::TwoPoints,
39 1 => BoxRepr::CenterWidthHeight,
40 other => bail!("unsupported center_point_box argument value: {}", other),
41 })
42 }
43
44 pub fn into_i64(self) -> i64 {
45 match self {
46 BoxRepr::TwoPoints => 0,
47 BoxRepr::CenterWidthHeight => 1,
48 }
49 }
50
51 fn should_suppress_by_iou<T: Datum + Float>(
53 &self,
54 box1: ArrayView1<T>,
55 box2: ArrayView1<T>,
56 iou_threshold: T,
57 ) -> bool {
58 let two = T::one() + T::one();
59 let (x1_min, x1_max, x2_min, x2_max, y1_min, y1_max, y2_min, y2_max) = match self {
60 BoxRepr::TwoPoints => {
61 let (x1_min, x1_max) = get_min_max(box1[[1]], box1[[3]]);
62 let (x2_min, x2_max) = get_min_max(box2[[1]], box2[[3]]);
63
64 let (y1_min, y1_max) = get_min_max(box1[[0]], box1[[2]]);
65 let (y2_min, y2_max) = get_min_max(box2[[0]], box2[[2]]);
66
67 (x1_min, x1_max, x2_min, x2_max, y1_min, y1_max, y2_min, y2_max)
68 }
69 BoxRepr::CenterWidthHeight => {
70 let (box1_width_half, box1_height_half) = (box1[[2]] / two, box1[[3]] / two);
71 let (box2_width_half, box2_height_half) = (box2[[2]] / two, box2[[3]] / two);
72
73 let (x1_min, x1_max) = (box1[[0]] - box1_width_half, box1[[0]] + box1_width_half);
74 let (x2_min, x2_max) = (box2[[0]] - box2_width_half, box2[[0]] + box2_width_half);
75
76 let (y1_min, y1_max) = (box1[[1]] - box1_height_half, box1[[1]] + box1_height_half);
77 let (y2_min, y2_max) = (box2[[1]] - box2_height_half, box2[[1]] + box2_height_half);
78
79 (x1_min, x1_max, x2_min, x2_max, y1_min, y1_max, y2_min, y2_max)
80 }
81 };
82
83 let intersection_y_min = T::max(y1_min, y2_min);
84 let intersection_y_max = T::min(y1_max, y2_max);
85 if intersection_y_max <= intersection_y_min {
86 return false;
87 }
88
89 let intersection_x_min = T::max(x1_min, x2_min);
90 let intersection_x_max = T::min(x1_max, x2_max);
91 if intersection_x_max <= intersection_x_min {
92 return false;
93 }
94
95 let intersection_area =
96 (intersection_x_max - intersection_x_min) * (intersection_y_max - intersection_y_min);
97
98 if intersection_area.is_sign_negative() {
99 return false;
100 }
101
102 let area1 = (x1_max - x1_min) * (y1_max - y1_min);
103 let area2 = (x2_max - x2_min) * (y2_max - y2_min);
104
105 let union_area = area1 + area2 - intersection_area;
106
107 if area1.is_sign_negative() || area2.is_sign_negative() || union_area.is_sign_negative() {
108 return false;
109 }
110
111 let intersection_over_union = intersection_area / union_area;
112
113 intersection_over_union > iou_threshold
114 }
115}
116
117#[derive(Debug, Clone, Hash)]
118pub struct NonMaxSuppression {
119 pub center_point_box: BoxRepr,
120 pub num_selected_indices_symbol: Symbol,
121 pub has_score_threshold: bool,
122}
123
124impl NonMaxSuppression {
125 fn eval_t<T: Datum + Float>(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
126 let (boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold) =
127 if self.has_score_threshold {
128 let (t1, t2, t3, t4, t5) = args_5!(inputs);
129 (t1, t2, t3, t4, Some(t5))
130 } else {
131 let (t1, t2, t3, t4) = args_4!(inputs);
132 (t1, t2, t3, t4, None)
133 };
134
135 let mut max_output_boxes_per_class = *max_output_boxes_per_class.to_scalar::<i64>()?;
136 let iou_threshold = *iou_threshold.to_scalar::<T>()?;
137 let score_threshold = score_threshold
138 .map_or(Ok::<_, TractError>(None), |val| Ok(Some(*val.to_scalar::<T>()?)))?;
139
140 if max_output_boxes_per_class == 0 {
141 max_output_boxes_per_class = i64::MAX;
142 }
143 let num_batches = scores.shape()[0];
146 let num_classes = scores.shape()[1];
147 let num_dim = scores.shape()[2];
148
149 let boxes = boxes.to_array_view::<T>()?;
150 let scores = scores.to_array_view::<T>()?;
151
152 let mut selected_global: TVec<(usize, usize, usize)> = tvec![];
154
155 for batch in 0..num_batches {
156 for class in 0..num_classes {
157 let mut candidates: TVec<(T, usize)> =
159 if let Some(score_threshold) = score_threshold {
160 (0..num_dim)
161 .map(|i| (scores[[batch, class, i]], i))
162 .filter(|(score, _)| *score > score_threshold)
163 .collect()
164 } else {
165 (0..num_dim).map(|i| (scores[[batch, class, i]], i)).collect()
166 };
167
168 candidates.sort_by(|(a, _), (b, _)| b.partial_cmp(a).unwrap_or(Ordering::Equal));
169
170 let mut selected_in_class: TVec<(T, usize)> = tvec![];
172
173 for (score, index) in candidates {
174 if selected_in_class.len() as i64 >= max_output_boxes_per_class {
175 break;
176 }
177
178 let box1 = boxes.slice(s![batch, index, ..]);
179 let suppr = selected_in_class.iter().any(|(_, index)| {
180 let box2 = boxes.slice(s![batch, *index, ..]);
181 self.center_point_box.should_suppress_by_iou(box1, box2, iou_threshold)
182 });
183 if !suppr {
184 selected_in_class.push((score, index));
185 selected_global.push((batch, class, index));
186 }
187 }
188 }
189 }
190
191 let num_selected = selected_global.len();
193 let v = selected_global
194 .into_iter()
195 .flat_map(|(batch, class, index)| [batch as i64, class as i64, index as i64])
196 .collect();
197 let res = tract_ndarray::ArrayD::from_shape_vec(&*tvec![num_selected, 3], v)?;
198
199 Ok(tvec![res.into_tvalue()])
200 }
201}
202
203impl Op for NonMaxSuppression {
204 fn name(&self) -> StaticName {
205 "NonMaxSuppression".into()
206 }
207
208 op_as_typed_op!();
209}
210
211impl EvalOp for NonMaxSuppression {
212 fn is_stateless(&self) -> bool {
213 true
214 }
215
216 fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
217 let dt = inputs[0].datum_type();
218 dispatch_floatlike!(Self::eval_t(dt)(self, inputs))
219 }
220}
221
222impl TypedOp for NonMaxSuppression {
223 fn output_facts(&self, _inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
224 Ok(tvec![i64::fact([self.num_selected_indices_symbol.to_dim(), 3usize.to_dim()])])
225 }
226
227 as_op!();
228}
229
230fn parameters() -> Vec<Parameter> {
231 vec![
232 TypeName::Integer.tensor().named("boxes"),
233 TypeName::Scalar.tensor().named("scores"),
234 TypeName::Integer.named("max_output_boxes_per_class").default(0),
235 TypeName::Scalar.named("iou_threshold").default(0.0),
236 TypeName::Scalar.named("score_threshold"),
237 TypeName::Integer.named("center_point_box").default(0),
238 ]
239}
240
241fn dump(
242 ast: &mut IntoAst,
243 node: &TypedNode,
244 op: &NonMaxSuppression,
245) -> TractResult<Option<Arc<RValue>>> {
246 let boxes = ast.mapping[&node.inputs[0]].clone();
247 let scores = ast.mapping[&node.inputs[1]].clone();
248 let max_output_boxes_per_class = ast.mapping[&node.inputs[2]].clone();
249 let iou_threshold = ast.mapping[&node.inputs[3]].clone();
250 let score_threshold = node.inputs.get(4).map(|v| ast.mapping[v].clone());
251
252 let inv = if let Some(score_threshold) = score_threshold {
253 invocation(
254 "tract_onnx_non_max_suppression",
255 &[boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold],
256 &[("center_point_box", numeric(op.center_point_box.into_i64()))],
257 )
258 } else {
259 invocation(
260 "tract_onnx_non_max_suppression",
261 &[boxes, scores, max_output_boxes_per_class, iou_threshold],
262 &[("center_point_box", numeric(op.center_point_box.into_i64()))],
263 )
264 };
265
266 Ok(Some(inv))
267}
268
269fn load(builder: &mut ModelBuilder, invocation: &ResolvedInvocation) -> TractResult<Value> {
270 let boxes = invocation.named_arg_as(builder, "boxes")?;
271 let scores = invocation.named_arg_as(builder, "scores")?;
272 let max_output_boxes_per_class =
273 invocation.named_arg_as(builder, "max_output_boxes_per_class")?;
274 let iou_threshold = invocation.named_arg_as(builder, "iou_threshold")?;
275 let score_threshold = invocation.named_arg_as(builder, "score_threshold").ok();
276
277 let center_point_box =
278 BoxRepr::from_i64(invocation.named_arg_as(builder, "center_point_box")?)?;
279
280 let n = builder.model.symbols.sym("n");
281 let op = NonMaxSuppression {
282 center_point_box,
283 num_selected_indices_symbol: n,
284 has_score_threshold: score_threshold.is_some(),
285 };
286 if let Some(score_threshold) = score_threshold {
287 builder
288 .wire(op, &[boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold])
289 } else {
290 builder.wire(op, &[boxes, scores, max_output_boxes_per_class, iou_threshold])
291 }
292}