tract_core/ops/array/
topk.rs1use std::cmp::Ordering;
2
3use tract_data::itertools::Itertools;
4use tract_ndarray::{ArrayViewMutD, Axis, Dimension};
5
6use crate::internal::*;
7
8#[derive(Debug, Clone, new, Default, Hash)]
9pub struct Topk {
10 pub axis: usize,
11 pub largest: bool,
12 pub fallback_k: TDim,
13}
14
15impl Op for Topk {
16 fn name(&self) -> StaticName {
17 "Topk".into()
18 }
19
20 op_as_typed_op!();
21}
22
23impl EvalOp for Topk {
24 fn is_stateless(&self) -> bool {
25 true
26 }
27
28 fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
29 let (input, k) = args_2!(inputs);
30 let mut output_shape: TVec<usize> = input.shape().into();
31 let k = k.cast_to_scalar::<i64>()? as usize;
32 output_shape[self.axis] = k;
33 let dt = input.datum_type();
34 let mut output_values = Tensor::zero_dt(dt, &output_shape)?;
35 let mut output_indices = Tensor::zero::<i64>(&output_shape)?;
36 let mut iterating_shape = output_shape.clone();
37 iterating_shape[self.axis] = 1;
38 let mut output_indices_dense = output_indices.try_as_dense_mut()?;
39 let mut output_indices_view = output_indices_dense.to_array_view_mut::<i64>()?;
40 for coords in tract_ndarray::indices(&*iterating_shape) {
41 let mut coords: TVec<usize> = coords.as_array_view().as_slice().unwrap().into();
42 dispatch_numbers!(Self::inner_loop_t(dt)(
43 self,
44 &mut coords,
45 &input,
46 &mut output_values,
47 &mut output_indices_view,
48 k
49 ))?;
50 }
51 Ok(tvec!(output_values.into_tvalue(), output_indices.into_tvalue()))
52 }
53}
54
55impl Topk {
56 fn inner_loop_t<T: Datum + PartialOrd>(
57 &self,
58 coords: &mut [usize],
59 input: &Tensor,
60 output_values: &mut Tensor,
61 output_indices_view: &mut ArrayViewMutD<i64>,
62 k: usize,
63 ) -> TractResult<()> {
64 let mut output_values_dense = output_values.try_as_dense_mut()?;
65 let mut output_values_view = output_values_dense.to_array_view_mut::<T>()?;
66 let mut view = input.to_dense_array_view::<T>()?;
67 for (ix, x) in coords.iter().enumerate() {
68 if ix != self.axis {
69 view.collapse_axis(Axis(ix), *x);
70 }
71 }
72 for (ix, (argmax, max)) in view
73 .iter()
74 .cloned()
75 .enumerate()
76 .sorted_by(|a, b| {
77 let ord = { a.1.partial_cmp(&b.1).unwrap_or(Ordering::Less) };
78 if self.largest { ord.reverse() } else { ord }
79 })
80 .take(k)
81 .enumerate()
82 {
83 coords[self.axis] = ix;
84 output_values_view[&*coords] = max;
85 output_indices_view[&*coords] = argmax as i64;
86 }
87 Ok(())
88 }
89}
90
91impl TypedOp for Topk {
92 fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
93 let mut fact_values = inputs[0].without_value();
94 let mut fact_indices = inputs[0].without_value();
95 let k: TDim = if let Some(k) = &inputs[1].konst {
96 k.cast_to::<TDim>()?.try_as_dense()?.to_scalar::<TDim>()?.clone()
97 } else {
98 self.fallback_k.clone()
99 };
100 fact_values.shape.set(self.axis, k.clone());
101 fact_indices.shape.set(self.axis, k);
102 fact_indices.datum_type = i64::datum_type();
103 Ok(tvec!(fact_values, fact_indices))
104 }
105
106 as_op!();
107}