Skip to main content

tract_onnx/ops/array/
topk.rs

1use crate::model::ParsingContext;
2use crate::pb::NodeProto;
3use tract_hir::internal::*;
4
5pub fn topk(
6    _ctx: &ParsingContext,
7    node: &NodeProto,
8) -> TractResult<(Box<dyn InferenceOp>, Vec<String>)> {
9    let axis = node.get_attr_opt("axis")?.unwrap_or(-1i64);
10    let largest = node.get_attr_opt("largest")?.unwrap_or(1i64) == 1;
11    Ok((expand(Topk { axis, largest }), vec![]))
12}
13
14#[derive(Debug, Clone, new, Default, Hash)]
15struct Topk {
16    axis: i64,
17    largest: bool,
18}
19
20impl Expansion for Topk {
21    fn name(&self) -> StaticName {
22        "Topk".into()
23    }
24
25    fn rules<'r, 'p: 'r, 's: 'r>(
26        &'s self,
27        solver: &mut Solver<'r>,
28        inputs: &'p [TensorProxy],
29        outputs: &'p [TensorProxy],
30    ) -> InferenceResult {
31        check_input_arity(inputs, 2)?;
32        check_input_arity(outputs, 2)?;
33
34        solver.equals(&inputs[0].datum_type, &outputs[0].datum_type)?;
35        solver.equals(&outputs[1].datum_type, i64::datum_type())?;
36
37        solver.equals(&inputs[0].rank, &outputs[0].rank)?;
38        solver.equals(&inputs[0].rank, &outputs[1].rank)?;
39        solver.equals(&inputs[1].rank, 1)?;
40
41        solver.equals(&inputs[1].shape[0], 1.to_dim())?;
42
43        solver.given(&inputs[0].rank, move |s, rank| {
44            let axis = if self.axis >= 0 { self.axis } else { self.axis + rank } as usize;
45            for ix in 0..rank as usize {
46                if ix != axis {
47                    s.equals(&inputs[0].shape[ix], &outputs[0].shape[ix])?;
48                    s.equals(&inputs[0].shape[ix], &outputs[1].shape[ix])?;
49                } else {
50                    s.given(&inputs[1].value, move |s, k| {
51                        if let Ok(k) =
52                            k.cast_to::<TDim>().and_then(|t| t.to_scalar::<TDim>().cloned())
53                        {
54                            s.equals(&outputs[0].shape[ix], k.clone())?;
55                            s.equals(&outputs[1].shape[ix], k)?;
56                        }
57                        Ok(())
58                    })?;
59                }
60            }
61            Ok(())
62        })
63    }
64
65    fn wire(
66        &self,
67        prefix: &str,
68        model: &mut TypedModel,
69        inputs: &[OutletId],
70    ) -> TractResult<TVec<OutletId>> {
71        let input = model.outlet_fact(inputs[0])?;
72        let rank = input.rank();
73        let axis = if self.axis >= 0 { self.axis } else { self.axis + rank as i64 } as usize;
74        let fallback_k = model.symbols.new_with_prefix("k").into();
75        model.wire_node(
76            prefix,
77            tract_core::ops::array::Topk { axis, fallback_k, largest: self.largest },
78            &[inputs[0], inputs[1]],
79        )
80    }
81
82    fn nboutputs(&self) -> TractResult<usize> {
83        Ok(2)
84    }
85}