tract_onnx/ops/array/
topk.rs1use crate::model::ParsingContext;
2use crate::pb::NodeProto;
3use tract_hir::internal::*;
4
5pub fn topk(
6 _ctx: &ParsingContext,
7 node: &NodeProto,
8) -> TractResult<(Box<dyn InferenceOp>, Vec<String>)> {
9 let axis = node.get_attr_opt("axis")?.unwrap_or(-1i64);
10 let largest = node.get_attr_opt("largest")?.unwrap_or(1i64) == 1;
11 Ok((expand(Topk { axis, largest }), vec![]))
12}
13
14#[derive(Debug, Clone, new, Default, Hash)]
15struct Topk {
16 axis: i64,
17 largest: bool,
18}
19
20impl Expansion for Topk {
21 fn name(&self) -> StaticName {
22 "Topk".into()
23 }
24
25 fn rules<'r, 'p: 'r, 's: 'r>(
26 &'s self,
27 solver: &mut Solver<'r>,
28 inputs: &'p [TensorProxy],
29 outputs: &'p [TensorProxy],
30 ) -> InferenceResult {
31 check_input_arity(inputs, 2)?;
32 check_input_arity(outputs, 2)?;
33
34 solver.equals(&inputs[0].datum_type, &outputs[0].datum_type)?;
35 solver.equals(&outputs[1].datum_type, i64::datum_type())?;
36
37 solver.equals(&inputs[0].rank, &outputs[0].rank)?;
38 solver.equals(&inputs[0].rank, &outputs[1].rank)?;
39 solver.equals(&inputs[1].rank, 1)?;
40
41 solver.equals(&inputs[1].shape[0], 1.to_dim())?;
42
43 solver.given(&inputs[0].rank, move |s, rank| {
44 let axis = if self.axis >= 0 { self.axis } else { self.axis + rank } as usize;
45 for ix in 0..rank as usize {
46 if ix != axis {
47 s.equals(&inputs[0].shape[ix], &outputs[0].shape[ix])?;
48 s.equals(&inputs[0].shape[ix], &outputs[1].shape[ix])?;
49 } else {
50 s.given(&inputs[1].value, move |s, k| {
51 if let Ok(k) =
52 k.cast_to::<TDim>().and_then(|t| t.to_scalar::<TDim>().cloned())
53 {
54 s.equals(&outputs[0].shape[ix], k.clone())?;
55 s.equals(&outputs[1].shape[ix], k)?;
56 }
57 Ok(())
58 })?;
59 }
60 }
61 Ok(())
62 })
63 }
64
65 fn wire(
66 &self,
67 prefix: &str,
68 model: &mut TypedModel,
69 inputs: &[OutletId],
70 ) -> TractResult<TVec<OutletId>> {
71 let input = model.outlet_fact(inputs[0])?;
72 let rank = input.rank();
73 let axis = if self.axis >= 0 { self.axis } else { self.axis + rank as i64 } as usize;
74 let fallback_k = model.symbols.new_with_prefix("k").into();
75 model.wire_node(
76 prefix,
77 tract_core::ops::array::Topk { axis, fallback_k, largest: self.largest },
78 &[inputs[0], inputs[1]],
79 )
80 }
81
82 fn nboutputs(&self) -> TractResult<usize> {
83 Ok(2)
84 }
85}