1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
use tract_data::itertools::Itertools;
use tract_ndarray::{Axis, Dimension};

use crate::internal::*;

#[derive(Debug, Clone, new, Default, Hash)]
pub struct Topk {
    pub axis: usize,
    pub largest: bool,
    pub k: usize,
}

impl Op for Topk {
    fn name(&self) -> Cow<str> {
        "Topk".into()
    }

    op_as_typed_op!();
}

impl EvalOp for Topk {
    fn is_stateless(&self) -> bool {
        true
    }

    fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
        let input = args_1!(inputs);
        let mut output_shape: TVec<usize> = input.shape().into();
        output_shape[self.axis] = self.k;
        let mut output_values = Tensor::zero::<f32>(&output_shape)?;
        let mut output_indices = Tensor::zero::<i64>(&output_shape)?;
        let mut iterating_shape = output_shape.clone();
        iterating_shape[self.axis] = 1;
        let mut output_values_view = output_values.to_array_view_mut::<f32>()?;
        let mut output_indices_view = output_indices.to_array_view_mut::<i64>()?;
        for coords in tract_ndarray::indices(&*iterating_shape) {
            let mut coords: TVec<usize> = coords.as_array_view().as_slice().unwrap().into();
            let mut view = input.to_array_view::<f32>()?;
            for (ix, x) in coords.iter().enumerate() {
                if ix != self.axis {
                    view.index_axis_inplace(Axis(ix), *x);
                }
            }
            for (ix, (argmax, max)) in view
                .iter()
                .cloned()
                .map(|x| if self.largest { -x } else { x })
                .enumerate()
                .sorted_by(|a, b| a.1.total_cmp(&b.1))
                .take(self.k)
                .map(|(pos, val)| if self.largest { (pos, -val) } else { (pos, val) })
                .enumerate()
            {
                coords[self.axis] = ix;
                output_values_view[&*coords] = max;
                output_indices_view[&*coords] = argmax as i64;
            }
        }
        Ok(tvec!(output_values.into_tvalue(), output_indices.into_tvalue()))
    }
}

impl TypedOp for Topk {
    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
        let mut fact_values = inputs[0].without_value();
        let mut fact_indices = inputs[0].without_value();
        fact_values.shape.set(self.axis, self.k.to_dim());
        fact_indices.shape.set(self.axis, self.k.to_dim());
        fact_indices.datum_type = i64::datum_type();
        Ok(tvec!(fact_values, fact_indices))
    }

    as_op!();
}