tract_core/ops/array/
topk.rs1use std::cmp::Ordering;
2
3use tract_data::itertools::Itertools;
4use tract_ndarray::{ArrayViewMutD, Axis, Dimension};
5
6use crate::internal::*;
7
8#[derive(Debug, Clone, new, Default, Hash)]
9pub struct Topk {
10 pub axis: usize,
11 pub largest: bool,
12 pub fallback_k: TDim,
13}
14
15impl Op for Topk {
16 fn name(&self) -> Cow<str> {
17 "Topk".into()
18 }
19
20 op_as_typed_op!();
21}
22
23impl EvalOp for Topk {
24 fn is_stateless(&self) -> bool {
25 true
26 }
27
28 fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
29 let (input, k) = args_2!(inputs);
30 let mut output_shape: TVec<usize> = input.shape().into();
31 let k = k.cast_to_scalar::<i64>()? as usize;
32 output_shape[self.axis] = k;
33 let dt = input.datum_type();
34 let mut output_values = Tensor::zero_dt(dt, &output_shape)?;
35 let mut output_indices = Tensor::zero::<i64>(&output_shape)?;
36 let mut iterating_shape = output_shape.clone();
37 iterating_shape[self.axis] = 1;
38 let mut output_indices_view = output_indices.to_array_view_mut::<i64>()?;
39 for coords in tract_ndarray::indices(&*iterating_shape) {
40 let mut coords: TVec<usize> = coords.as_array_view().as_slice().unwrap().into();
41 dispatch_numbers!(Self::inner_loop_t(dt)(
42 self,
43 &mut coords,
44 &input,
45 &mut output_values,
46 &mut output_indices_view,
47 k
48 ))?;
49 }
50 Ok(tvec!(output_values.into_tvalue(), output_indices.into_tvalue()))
51 }
52}
53
54impl Topk {
55 fn inner_loop_t<T: Datum + PartialOrd>(
56 &self,
57 coords: &mut [usize],
58 input: &Tensor,
59 output_values: &mut Tensor,
60 output_indices_view: &mut ArrayViewMutD<i64>,
61 k: usize,
62 ) -> TractResult<()> {
63 let mut output_values_view = output_values.to_array_view_mut::<T>()?;
64 let mut view = input.to_array_view::<T>()?;
65 for (ix, x) in coords.iter().enumerate() {
66 if ix != self.axis {
67 view.collapse_axis(Axis(ix), *x);
68 }
69 }
70 for (ix, (argmax, max)) in view
71 .iter()
72 .cloned()
73 .enumerate()
74 .sorted_by(|a, b| {
75 let ord = { a.1.partial_cmp(&b.1).unwrap_or(Ordering::Less) };
76 if self.largest {
77 ord.reverse()
78 } else {
79 ord
80 }
81 })
82 .take(k)
83 .enumerate()
84 {
85 coords[self.axis] = ix;
86 output_values_view[&*coords] = max;
87 output_indices_view[&*coords] = argmax as i64;
88 }
89 Ok(())
90 }
91}
92
93impl TypedOp for Topk {
94 fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
95 let mut fact_values = inputs[0].without_value();
96 let mut fact_indices = inputs[0].without_value();
97 let k: TDim = if let Some(k) = &inputs[1].konst {
98 k.cast_to_scalar::<i64>()?.into()
99 } else {
100 self.fallback_k.clone()
101 };
102 fact_values.shape.set(self.axis, k.clone());
103 fact_indices.shape.set(self.axis, k);
104 fact_indices.datum_type = i64::datum_type();
105 Ok(tvec!(fact_values, fact_indices))
106 }
107
108 as_op!();
109}