1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
use crate::internal::*;
use ndarray::*;

#[derive(Debug, Clone, new, Hash)]
pub struct Gather {
    pub axis: usize,
}

impl Op for Gather {
    fn name(&self) -> Cow<str> {
        "Gather".into()
    }

    op_as_typed_op!();
}

impl Gather {
    pub fn compute_output_shape<D: DimLike>(
        &self,
        input_shape: &[D],
        indices_shape: &[D],
    ) -> TractResult<TVec<D>> {
        let mut output_shape: TVec<D> = input_shape[..self.axis].into();
        output_shape.extend(indices_shape.iter().cloned());
        output_shape.extend(input_shape[self.axis + 1..].iter().cloned());
        Ok(output_shape)
    }

    unsafe fn eval_t<T: Datum>(&self, data: TValue, indices: &TValue) -> TractResult<TValue> {
        let data_view = data.to_array_view_unchecked::<T>();
        let indices = indices.to_array_view::<i64>()?;
        let output_shape = &*self.compute_output_shape(data.shape(), indices.shape())?;
        let mut output = Tensor::uninitialized::<T>(output_shape)?;
        let mut output_view = output.to_array_view_mut::<T>()?;
        for coords in tract_ndarray::indices(output_shape) {
            let ocoords = coords.as_array_view();
            let ocoords = ocoords.as_slice().unwrap();
            let mut icoords: TVec<usize> = ocoords[0..self.axis].into();
            let kcoords = &ocoords[self.axis..][..indices.ndim()];
            let k = indices[kcoords];
            let k = if k < 0 { k + data_view.shape()[self.axis] as i64 } else { k } as usize;
            icoords.push(k);
            icoords.extend(ocoords[self.axis + indices.ndim()..].iter().copied());
            output_view[ocoords] = data_view.get(&*icoords).context("Invalid gather")?.clone();
        }
        unsafe { output.set_datum_type(data.datum_type()) };
        Ok(output.into_tvalue())
    }
}

impl TypedOp for Gather {
    as_op!();

    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
        ensure!(inputs[1].datum_type == i64::datum_type());
        Ok(tvec!(inputs[0].datum_type.fact(
            &*self.compute_output_shape(&inputs[0].shape.to_tvec(), &inputs[1].shape.to_tvec())?
        )))
    }

    fn declutter(
        &self,
        model: &TypedModel,
        node: &TypedNode,
    ) -> TractResult<Option<TypedModelPatch>> {
        let indices_fact = model.outlet_fact(node.inputs[1])?;
        if let Some(indices) = indices_fact.konst.as_ref() {
            if indices.rank() == 1 && indices.len() == 1 {
                let mut patch = TypedModelPatch::default();
                let mut wire = patch.tap_model(model, node.inputs[0])?;
                let index = indices.cast_to_scalar::<i64>()?;
                let index = if index < 0 {
                    let data_fact = model.outlet_fact(node.inputs[0])?;
                    data_fact.shape[self.axis].clone() + index.to_dim()
                } else {
                    index.to_dim()
                };
                wire = patch.wire_node(
                    format!("{}.slice", node.name),
                    crate::ops::array::Slice {
                        axis: self.axis,
                        start: index.clone(),
                        end: index + 1,
                    },
                    &[wire],
                )?[0];
                patch.shunt_outside(model, node.id.into(), wire)?;
                return Ok(Some(patch));
            }
        }
        Ok(None)
    }
}

impl EvalOp for Gather {
    fn is_stateless(&self) -> bool {
        true
    }

    fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
        let (data, indices) = args_2!(inputs);
        unsafe {
            Ok(tvec!(dispatch_datum_by_size!(Self::eval_t(data.datum_type())(
                self, data, &indices
            ))?))
        }
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_should_gather_scalar_index() {
        let data = Tensor::from(arr1(&[1i64, 2, 3]));
        let gatherer = Gather::new(0);
        for idx in 2..3 {
            let index = Tensor::from(arr0(idx));
            let outputs =
                gatherer.eval(tvec![data.clone().into_tvalue(), index.into_tvalue()]).unwrap();
            let output = &outputs[0];
            assert_eq!(output.shape().len(), 0);
            assert_eq!(*output.to_scalar::<i64>().unwrap(), idx + 1);
        }
    }
}