tract_onnx_opl/
non_max_suppression.rs

1use std::cmp::Ordering;
2
3use rustfft::num_traits::Float;
4use tract_nnef::{
5    internal::*,
6    tract_ndarray::{s, ArrayView1},
7};
8
9pub fn register(registry: &mut Registry) {
10    registry.register_primitive(
11        "tract_onnx_non_max_suppression",
12        &parameters(),
13        &[("output", TypeName::Integer.tensor())],
14        load,
15    );
16    registry.register_dumper(dump);
17}
18
19#[derive(Copy, Clone, Debug, Hash)]
20pub enum BoxRepr {
21    // boxes data format [y1, x1, y2, x2]
22    TwoPoints,
23    // boxes data format [x_center, y_center, width, height]
24    CenterWidthHeight,
25}
26
27fn get_min_max<T: Float>(lhs: T, rhs: T) -> (T, T) {
28    if lhs >= rhs {
29        (rhs, lhs)
30    } else {
31        (lhs, rhs)
32    }
33}
34
35impl BoxRepr {
36    pub fn from_i64(val: i64) -> TractResult<BoxRepr> {
37        Ok(match val {
38            0 => BoxRepr::TwoPoints,
39            1 => BoxRepr::CenterWidthHeight,
40            other => bail!("unsupported center_point_box argument value: {}", other),
41        })
42    }
43
44    pub fn into_i64(self) -> i64 {
45        match self {
46            BoxRepr::TwoPoints => 0,
47            BoxRepr::CenterWidthHeight => 1,
48        }
49    }
50
51    // iou: intersection over union
52    fn should_suppress_by_iou<T: Datum + Float>(
53        &self,
54        box1: ArrayView1<T>,
55        box2: ArrayView1<T>,
56        iou_threshold: T,
57    ) -> bool {
58        let two = T::one() + T::one();
59        let (x1_min, x1_max, x2_min, x2_max, y1_min, y1_max, y2_min, y2_max) = match self {
60            BoxRepr::TwoPoints => {
61                let (x1_min, x1_max) = get_min_max(box1[[1]], box1[[3]]);
62                let (x2_min, x2_max) = get_min_max(box2[[1]], box2[[3]]);
63
64                let (y1_min, y1_max) = get_min_max(box1[[0]], box1[[2]]);
65                let (y2_min, y2_max) = get_min_max(box2[[0]], box2[[2]]);
66
67                (x1_min, x1_max, x2_min, x2_max, y1_min, y1_max, y2_min, y2_max)
68            }
69            BoxRepr::CenterWidthHeight => {
70                let (box1_width_half, box1_height_half) = (box1[[2]] / two, box1[[3]] / two);
71                let (box2_width_half, box2_height_half) = (box2[[2]] / two, box2[[3]] / two);
72
73                let (x1_min, x1_max) = (box1[[0]] - box1_width_half, box1[[0]] + box1_width_half);
74                let (x2_min, x2_max) = (box2[[0]] - box2_width_half, box2[[0]] + box2_width_half);
75
76                let (y1_min, y1_max) = (box1[[1]] - box1_height_half, box1[[1]] + box1_height_half);
77                let (y2_min, y2_max) = (box2[[1]] - box2_height_half, box2[[1]] + box2_height_half);
78
79                (x1_min, x1_max, x2_min, x2_max, y1_min, y1_max, y2_min, y2_max)
80            }
81        };
82
83        let intersection_y_min = T::max(y1_min, y2_min);
84        let intersection_y_max = T::min(y1_max, y2_max);
85        if intersection_y_max <= intersection_y_min {
86            return false;
87        }
88
89        let intersection_x_min = T::max(x1_min, x2_min);
90        let intersection_x_max = T::min(x1_max, x2_max);
91        if intersection_x_max <= intersection_x_min {
92            return false;
93        }
94
95        let intersection_area =
96            (intersection_x_max - intersection_x_min) * (intersection_y_max - intersection_y_min);
97
98        if intersection_area.is_sign_negative() {
99            return false;
100        }
101
102        let area1 = (x1_max - x1_min) * (y1_max - y1_min);
103        let area2 = (x2_max - x2_min) * (y2_max - y2_min);
104
105        let union_area = area1 + area2 - intersection_area;
106
107        if area1.is_sign_negative() || area2.is_sign_negative() || union_area.is_sign_negative() {
108            return false;
109        }
110
111        let intersection_over_union = intersection_area / union_area;
112
113        intersection_over_union > iou_threshold
114    }
115}
116
117#[derive(Debug, Clone, Hash)]
118pub struct NonMaxSuppression {
119    pub center_point_box: BoxRepr,
120    pub num_selected_indices_symbol: Symbol,
121    pub has_score_threshold: bool,
122}
123
124impl NonMaxSuppression {
125    fn eval_t<T: Datum + Float>(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
126        let (boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold) =
127            if self.has_score_threshold {
128                let (t1, t2, t3, t4, t5) = args_5!(inputs);
129                (t1, t2, t3, t4, Some(t5))
130            } else {
131                let (t1, t2, t3, t4) = args_4!(inputs);
132                (t1, t2, t3, t4, None)
133            };
134
135        let mut max_output_boxes_per_class = *max_output_boxes_per_class.to_scalar::<i64>()?;
136        let iou_threshold = *iou_threshold.to_scalar::<T>()?;
137        let score_threshold = score_threshold
138            .map_or(Ok::<_, TractError>(None), |val| Ok(Some(*val.to_scalar::<T>()?)))?;
139
140        if max_output_boxes_per_class == 0 {
141            max_output_boxes_per_class = i64::MAX;
142        }
143        //        ensure!((0.0..=1.0).contains(&iou_threshold), "iou_threshold must be between 0 and 1");
144
145        let num_batches = scores.shape()[0];
146        let num_classes = scores.shape()[1];
147        let num_dim = scores.shape()[2];
148
149        let boxes = boxes.to_array_view::<T>()?;
150        let scores = scores.to_array_view::<T>()?;
151
152        // items: (batch, class, index)
153        let mut selected_global: TVec<(usize, usize, usize)> = tvec![];
154
155        for batch in 0..num_batches {
156            for class in 0..num_classes {
157                // items: (score, index)
158                let mut candidates: TVec<(T, usize)> =
159                    if let Some(score_threshold) = score_threshold {
160                        (0..num_dim)
161                            .map(|i| (scores[[batch, class, i]], i))
162                            .filter(|(score, _)| *score > score_threshold)
163                            .collect()
164                    } else {
165                        (0..num_dim).map(|i| (scores[[batch, class, i]], i)).collect()
166                    };
167
168                candidates.sort_by(|(a, _), (b, _)| b.partial_cmp(a).unwrap_or(Ordering::Equal));
169
170                // items: (score, index)
171                let mut selected_in_class: TVec<(T, usize)> = tvec![];
172
173                for (score, index) in candidates {
174                    if selected_in_class.len() as i64 >= max_output_boxes_per_class {
175                        break;
176                    }
177
178                    let box1 = boxes.slice(s![batch, index, ..]);
179                    let suppr = selected_in_class.iter().any(|(_, index)| {
180                        let box2 = boxes.slice(s![batch, *index, ..]);
181                        self.center_point_box.should_suppress_by_iou(box1, box2, iou_threshold)
182                    });
183                    if !suppr {
184                        selected_in_class.push((score, index));
185                        selected_global.push((batch, class, index));
186                    }
187                }
188            }
189        }
190
191        // output shape is [num_selected_indices, 3]; format is [batch_index, class_index, box_index]
192        let num_selected = selected_global.len();
193        let v = selected_global
194            .into_iter()
195            .flat_map(|(batch, class, index)| [batch as i64, class as i64, index as i64])
196            .collect();
197        let res = tract_ndarray::ArrayD::from_shape_vec(&*tvec![num_selected, 3], v)?;
198
199        Ok(tvec![res.into_tvalue()])
200    }
201}
202
203impl Op for NonMaxSuppression {
204    fn name(&self) -> StaticName {
205        "NonMaxSuppression".into()
206    }
207
208    op_as_typed_op!();
209}
210
211impl EvalOp for NonMaxSuppression {
212    fn is_stateless(&self) -> bool {
213        true
214    }
215
216    fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
217        let dt = inputs[0].datum_type();
218        dispatch_floatlike!(Self::eval_t(dt)(self, inputs))
219    }
220}
221
222impl TypedOp for NonMaxSuppression {
223    fn output_facts(&self, _inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
224        Ok(tvec![i64::fact([self.num_selected_indices_symbol.to_dim(), 3usize.to_dim()])])
225    }
226
227    as_op!();
228}
229
230fn parameters() -> Vec<Parameter> {
231    vec![
232        TypeName::Integer.tensor().named("boxes"),
233        TypeName::Scalar.tensor().named("scores"),
234        TypeName::Integer.named("max_output_boxes_per_class").default(0),
235        TypeName::Scalar.named("iou_threshold").default(0.0),
236        TypeName::Scalar.named("score_threshold"),
237        TypeName::Integer.named("center_point_box").default(0),
238    ]
239}
240
241fn dump(
242    ast: &mut IntoAst,
243    node: &TypedNode,
244    op: &NonMaxSuppression,
245) -> TractResult<Option<Arc<RValue>>> {
246    let boxes = ast.mapping[&node.inputs[0]].clone();
247    let scores = ast.mapping[&node.inputs[1]].clone();
248    let max_output_boxes_per_class = ast.mapping[&node.inputs[2]].clone();
249    let iou_threshold = ast.mapping[&node.inputs[3]].clone();
250    let score_threshold = node.inputs.get(4).map(|v| ast.mapping[v].clone());
251
252    let inv = if let Some(score_threshold) = score_threshold {
253        invocation(
254            "tract_onnx_non_max_suppression",
255            &[boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold],
256            &[("center_point_box", numeric(op.center_point_box.into_i64()))],
257        )
258    } else {
259        invocation(
260            "tract_onnx_non_max_suppression",
261            &[boxes, scores, max_output_boxes_per_class, iou_threshold],
262            &[("center_point_box", numeric(op.center_point_box.into_i64()))],
263        )
264    };
265
266    Ok(Some(inv))
267}
268
269fn load(builder: &mut ModelBuilder, invocation: &ResolvedInvocation) -> TractResult<Value> {
270    let boxes = invocation.named_arg_as(builder, "boxes")?;
271    let scores = invocation.named_arg_as(builder, "scores")?;
272    let max_output_boxes_per_class =
273        invocation.named_arg_as(builder, "max_output_boxes_per_class")?;
274    let iou_threshold = invocation.named_arg_as(builder, "iou_threshold")?;
275    let score_threshold = invocation.named_arg_as(builder, "score_threshold").ok();
276
277    let center_point_box =
278        BoxRepr::from_i64(invocation.named_arg_as(builder, "center_point_box")?)?;
279
280    let n = builder.model.symbols.sym("n");
281    let op = NonMaxSuppression {
282        center_point_box,
283        num_selected_indices_symbol: n,
284        has_score_threshold: score_threshold.is_some(),
285    };
286    if let Some(score_threshold) = score_threshold {
287        builder
288            .wire(op, &[boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold])
289    } else {
290        builder.wire(op, &[boxes, scores, max_output_boxes_per_class, iou_threshold])
291    }
292}