1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
use std::cmp::Ordering;

use tract_data::itertools::Itertools;
use tract_ndarray::{ArrayViewMutD, Axis, Dimension};

use crate::internal::*;

#[derive(Debug, Clone, new, Default, Hash)]
pub struct Topk {
    pub axis: usize,
    pub largest: bool,
    pub fallback_k: TDim,
}

impl Op for Topk {
    fn name(&self) -> Cow<str> {
        "Topk".into()
    }

    op_as_typed_op!();
}

impl EvalOp for Topk {
    fn is_stateless(&self) -> bool {
        true
    }

    fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
        let (input, k) = args_2!(inputs);
        let mut output_shape: TVec<usize> = input.shape().into();
        let k = k.cast_to_scalar::<i64>()? as usize;
        output_shape[self.axis] = k;
        let dt = input.datum_type();
        let mut output_values = Tensor::zero_dt(dt, &output_shape)?;
        let mut output_indices = Tensor::zero::<i64>(&output_shape)?;
        let mut iterating_shape = output_shape.clone();
        iterating_shape[self.axis] = 1;
        let mut output_indices_view = output_indices.to_array_view_mut::<i64>()?;
        for coords in tract_ndarray::indices(&*iterating_shape) {
            let mut coords: TVec<usize> = coords.as_array_view().as_slice().unwrap().into();
            dispatch_numbers!(Self::inner_loop_t(dt)(
                self,
                &mut coords,
                &input,
                &mut output_values,
                &mut output_indices_view,
                k
            ))?;
        }
        Ok(tvec!(output_values.into_tvalue(), output_indices.into_tvalue()))
    }
}

impl Topk {
    fn inner_loop_t<T: Datum + PartialOrd>(
        &self,
        coords: &mut [usize],
        input: &Tensor,
        output_values: &mut Tensor,
        output_indices_view: &mut ArrayViewMutD<i64>,
        k: usize,
    ) -> TractResult<()> {
        let mut output_values_view = output_values.to_array_view_mut::<T>()?;
        let mut view = input.to_array_view::<T>()?;
        for (ix, x) in coords.iter().enumerate() {
            if ix != self.axis {
                view.index_axis_inplace(Axis(ix), *x);
            }
        }
        for (ix, (argmax, max)) in view
            .iter()
            .cloned()
            .enumerate()
            .sorted_by(|a, b| {
                let ord = { a.1.partial_cmp(&b.1).unwrap_or(Ordering::Less) };
                if self.largest {
                    ord.reverse()
                } else {
                    ord
                }
            })
            .take(k)
            .enumerate()
        {
            coords[self.axis] = ix;
            output_values_view[&*coords] = max;
            output_indices_view[&*coords] = argmax as i64;
        }
        Ok(())
    }
}

impl TypedOp for Topk {
    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
        let mut fact_values = inputs[0].without_value();
        let mut fact_indices = inputs[0].without_value();
        let k: TDim = if let Some(k) = &inputs[1].konst {
            k.cast_to_scalar::<i64>()?.into()
        } else {
            self.fallback_k.clone()
        };
        fact_values.shape.set(self.axis, k.clone());
        fact_indices.shape.set(self.axis, k);
        fact_indices.datum_type = i64::datum_type();
        Ok(tvec!(fact_values, fact_indices))
    }

    as_op!();
}