Skip to main content

tract_core/ops/array/
topk.rs

1use std::cmp::Ordering;
2
3use tract_data::itertools::Itertools;
4use tract_ndarray::{ArrayViewMutD, Axis, Dimension};
5
6use crate::internal::*;
7
8#[derive(Debug, Clone, new, Default, Hash)]
9pub struct Topk {
10    pub axis: usize,
11    pub largest: bool,
12    pub fallback_k: TDim,
13}
14
15impl Op for Topk {
16    fn name(&self) -> StaticName {
17        "Topk".into()
18    }
19
20    op_as_typed_op!();
21}
22
23impl EvalOp for Topk {
24    fn is_stateless(&self) -> bool {
25        true
26    }
27
28    fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
29        let (input, k) = args_2!(inputs);
30        let mut output_shape: TVec<usize> = input.shape().into();
31        let k = k.cast_to_scalar::<i64>()? as usize;
32        output_shape[self.axis] = k;
33        let dt = input.datum_type();
34        let mut output_values = Tensor::zero_dt(dt, &output_shape)?;
35        let mut output_indices = Tensor::zero::<i64>(&output_shape)?;
36        let mut iterating_shape = output_shape.clone();
37        iterating_shape[self.axis] = 1;
38        let mut output_indices_view = output_indices.to_array_view_mut::<i64>()?;
39        for coords in tract_ndarray::indices(&*iterating_shape) {
40            let mut coords: TVec<usize> = coords.as_array_view().as_slice().unwrap().into();
41            dispatch_numbers!(Self::inner_loop_t(dt)(
42                self,
43                &mut coords,
44                &input,
45                &mut output_values,
46                &mut output_indices_view,
47                k
48            ))?;
49        }
50        Ok(tvec!(output_values.into_tvalue(), output_indices.into_tvalue()))
51    }
52}
53
54impl Topk {
55    fn inner_loop_t<T: Datum + PartialOrd>(
56        &self,
57        coords: &mut [usize],
58        input: &Tensor,
59        output_values: &mut Tensor,
60        output_indices_view: &mut ArrayViewMutD<i64>,
61        k: usize,
62    ) -> TractResult<()> {
63        let mut output_values_view = output_values.to_array_view_mut::<T>()?;
64        let mut view = input.to_array_view::<T>()?;
65        for (ix, x) in coords.iter().enumerate() {
66            if ix != self.axis {
67                view.collapse_axis(Axis(ix), *x);
68            }
69        }
70        for (ix, (argmax, max)) in view
71            .iter()
72            .cloned()
73            .enumerate()
74            .sorted_by(|a, b| {
75                let ord = { a.1.partial_cmp(&b.1).unwrap_or(Ordering::Less) };
76                if self.largest { ord.reverse() } else { ord }
77            })
78            .take(k)
79            .enumerate()
80        {
81            coords[self.axis] = ix;
82            output_values_view[&*coords] = max;
83            output_indices_view[&*coords] = argmax as i64;
84        }
85        Ok(())
86    }
87}
88
89impl TypedOp for Topk {
90    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
91        let mut fact_values = inputs[0].without_value();
92        let mut fact_indices = inputs[0].without_value();
93        let k: TDim = if let Some(k) = &inputs[1].konst {
94            k.cast_to::<TDim>()?.to_scalar::<TDim>()?.clone()
95        } else {
96            self.fallback_k.clone()
97        };
98        fact_values.shape.set(self.axis, k.clone());
99        fact_indices.shape.set(self.axis, k);
100        fact_indices.datum_type = i64::datum_type();
101        Ok(tvec!(fact_values, fact_indices))
102    }
103
104    as_op!();
105}