tract_core/ops/array/
topk.rs

1use std::cmp::Ordering;
2
3use tract_data::itertools::Itertools;
4use tract_ndarray::{ArrayViewMutD, Axis, Dimension};
5
6use crate::internal::*;
7
8#[derive(Debug, Clone, new, Default, Hash)]
9pub struct Topk {
10    pub axis: usize,
11    pub largest: bool,
12    pub fallback_k: TDim,
13}
14
15impl Op for Topk {
16    fn name(&self) -> Cow<str> {
17        "Topk".into()
18    }
19
20    op_as_typed_op!();
21}
22
23impl EvalOp for Topk {
24    fn is_stateless(&self) -> bool {
25        true
26    }
27
28    fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
29        let (input, k) = args_2!(inputs);
30        let mut output_shape: TVec<usize> = input.shape().into();
31        let k = k.cast_to_scalar::<i64>()? as usize;
32        output_shape[self.axis] = k;
33        let dt = input.datum_type();
34        let mut output_values = Tensor::zero_dt(dt, &output_shape)?;
35        let mut output_indices = Tensor::zero::<i64>(&output_shape)?;
36        let mut iterating_shape = output_shape.clone();
37        iterating_shape[self.axis] = 1;
38        let mut output_indices_view = output_indices.to_array_view_mut::<i64>()?;
39        for coords in tract_ndarray::indices(&*iterating_shape) {
40            let mut coords: TVec<usize> = coords.as_array_view().as_slice().unwrap().into();
41            dispatch_numbers!(Self::inner_loop_t(dt)(
42                self,
43                &mut coords,
44                &input,
45                &mut output_values,
46                &mut output_indices_view,
47                k
48            ))?;
49        }
50        Ok(tvec!(output_values.into_tvalue(), output_indices.into_tvalue()))
51    }
52}
53
54impl Topk {
55    fn inner_loop_t<T: Datum + PartialOrd>(
56        &self,
57        coords: &mut [usize],
58        input: &Tensor,
59        output_values: &mut Tensor,
60        output_indices_view: &mut ArrayViewMutD<i64>,
61        k: usize,
62    ) -> TractResult<()> {
63        let mut output_values_view = output_values.to_array_view_mut::<T>()?;
64        let mut view = input.to_array_view::<T>()?;
65        for (ix, x) in coords.iter().enumerate() {
66            if ix != self.axis {
67                view.collapse_axis(Axis(ix), *x);
68            }
69        }
70        for (ix, (argmax, max)) in view
71            .iter()
72            .cloned()
73            .enumerate()
74            .sorted_by(|a, b| {
75                let ord = { a.1.partial_cmp(&b.1).unwrap_or(Ordering::Less) };
76                if self.largest {
77                    ord.reverse()
78                } else {
79                    ord
80                }
81            })
82            .take(k)
83            .enumerate()
84        {
85            coords[self.axis] = ix;
86            output_values_view[&*coords] = max;
87            output_indices_view[&*coords] = argmax as i64;
88        }
89        Ok(())
90    }
91}
92
93impl TypedOp for Topk {
94    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
95        let mut fact_values = inputs[0].without_value();
96        let mut fact_indices = inputs[0].without_value();
97        let k: TDim = if let Some(k) = &inputs[1].konst {
98            k.cast_to_scalar::<i64>()?.into()
99        } else {
100            self.fallback_k.clone()
101        };
102        fact_values.shape.set(self.axis, k.clone());
103        fact_indices.shape.set(self.axis, k);
104        fact_indices.datum_type = i64::datum_type();
105        Ok(tvec!(fact_values, fact_indices))
106    }
107
108    as_op!();
109}