Skip to main content

tract_core/ops/array/
topk.rs

1use std::cmp::Ordering;
2
3use tract_data::itertools::Itertools;
4use tract_ndarray::{ArrayViewMutD, Axis, Dimension};
5
6use crate::internal::*;
7
8#[derive(Debug, Clone, new, Default, Hash)]
9pub struct Topk {
10    pub axis: usize,
11    pub largest: bool,
12    pub fallback_k: TDim,
13}
14
15impl Op for Topk {
16    fn name(&self) -> StaticName {
17        "Topk".into()
18    }
19
20    op_as_typed_op!();
21}
22
23impl EvalOp for Topk {
24    fn is_stateless(&self) -> bool {
25        true
26    }
27
28    fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
29        let (input, k) = args_2!(inputs);
30        let mut output_shape: TVec<usize> = input.shape().into();
31        let k = k.cast_to_scalar::<i64>()? as usize;
32        output_shape[self.axis] = k;
33        let dt = input.datum_type();
34        let mut output_values = Tensor::zero_dt(dt, &output_shape)?;
35        let mut output_indices = Tensor::zero::<i64>(&output_shape)?;
36        let mut iterating_shape = output_shape.clone();
37        iterating_shape[self.axis] = 1;
38        let mut output_indices_dense = output_indices.try_as_dense_mut()?;
39        let mut output_indices_view = output_indices_dense.to_array_view_mut::<i64>()?;
40        for coords in tract_ndarray::indices(&*iterating_shape) {
41            let mut coords: TVec<usize> = coords.as_array_view().as_slice().unwrap().into();
42            dispatch_numbers!(Self::inner_loop_t(dt)(
43                self,
44                &mut coords,
45                &input,
46                &mut output_values,
47                &mut output_indices_view,
48                k
49            ))?;
50        }
51        Ok(tvec!(output_values.into_tvalue(), output_indices.into_tvalue()))
52    }
53}
54
55impl Topk {
56    fn inner_loop_t<T: Datum + PartialOrd>(
57        &self,
58        coords: &mut [usize],
59        input: &Tensor,
60        output_values: &mut Tensor,
61        output_indices_view: &mut ArrayViewMutD<i64>,
62        k: usize,
63    ) -> TractResult<()> {
64        let mut output_values_dense = output_values.try_as_dense_mut()?;
65        let mut output_values_view = output_values_dense.to_array_view_mut::<T>()?;
66        let mut view = input.to_dense_array_view::<T>()?;
67        for (ix, x) in coords.iter().enumerate() {
68            if ix != self.axis {
69                view.collapse_axis(Axis(ix), *x);
70            }
71        }
72        for (ix, (argmax, max)) in view
73            .iter()
74            .cloned()
75            .enumerate()
76            .sorted_by(|a, b| {
77                let ord = { a.1.partial_cmp(&b.1).unwrap_or(Ordering::Less) };
78                if self.largest { ord.reverse() } else { ord }
79            })
80            .take(k)
81            .enumerate()
82        {
83            coords[self.axis] = ix;
84            output_values_view[&*coords] = max;
85            output_indices_view[&*coords] = argmax as i64;
86        }
87        Ok(())
88    }
89}
90
91impl TypedOp for Topk {
92    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
93        let mut fact_values = inputs[0].without_value();
94        let mut fact_indices = inputs[0].without_value();
95        let k: TDim = if let Some(k) = &inputs[1].konst {
96            k.cast_to::<TDim>()?.try_as_dense()?.to_scalar::<TDim>()?.clone()
97        } else {
98            self.fallback_k.clone()
99        };
100        fact_values.shape.set(self.axis, k.clone());
101        fact_indices.shape.set(self.axis, k);
102        fact_indices.datum_type = i64::datum_type();
103        Ok(tvec!(fact_values, fact_indices))
104    }
105
106    as_op!();
107}