1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
use crate::internal::*;
use ndarray::*;

#[derive(Debug, Clone, new, Hash)]
pub struct Gather {
    axis: i64,
}
tract_linalg::impl_dyn_hash!(Gather);

impl Op for Gather {
    fn name(&self) -> Cow<str> {
        "Gather".into()
    }

    op_core_mir!();
    op_as_typed_op!();
    not_a_pulsed_op!();
}

impl Gather {
    fn resolved_axis(&self, rank: usize) -> TractResult<usize> {
        if 0 <= self.axis && self.axis <= rank as i64 - 1 {
            Ok(self.axis as usize)
        } else if -(rank as i64) <= self.axis && self.axis < 0 {
            Ok((self.axis + rank as i64) as usize)
        } else {
            bail!("Illegal combination of values for rank and axis")
        }
    }

    pub fn compute_output_shape<D: DimLike>(
        &self,
        input_shape: &[D],
        indices_shape: &[D],
    ) -> TractResult<TVec<D>> {
        let axis = self.resolved_axis(input_shape.len())?;
        let mut output_shape = tvec![];
        for (idx, dim) in input_shape.iter().enumerate() {
            if idx != axis {
                output_shape.push(dim.clone());
            } else {
                for idx2 in indices_shape {
                    output_shape.push(idx2.clone());
                }
            }
        }
        Ok(output_shape)
    }

    fn eval_t<T: Datum>(
        &self,
        data: Arc<Tensor>,
        indices: &Arc<Tensor>,
    ) -> TractResult<Arc<Tensor>> {
        let data_view = data.to_array_view::<T>()?;
        let axis = self.resolved_axis(data.shape().len())?;
        let indices = indices.cast_to::<i64>()?;
        if indices.shape().len() == 0 {
            let mut index = *indices.to_scalar::<i64>()?;
            if index < 0 {
                index += data_view.shape()[0] as i64;
            }
            return Ok(data_view
                .index_axis(Axis(axis), index as usize)
                .to_owned()
                .into_arc_tensor());
        }

        let mut output = unsafe {
            Tensor::uninitialized::<T>(&*self.compute_output_shape(data.shape(), indices.shape())?)?
        };
        {
            let mut output = output.to_array_view_mut::<T>()?;
            for (pattern, index) in indices.to_array_view::<i64>()?.indexed_iter() {
                let mut to_update = output.index_axis_mut(Axis(axis), pattern[0]);
                for idx in 1..pattern.ndim() {
                    to_update = to_update.index_axis_move(Axis(0), pattern[idx]);
                }

                to_update.assign(&data_view.index_axis(Axis(axis), *index as usize));
            }
        }
        Ok(output.into_arc_tensor())
    }
}

impl TypedOp for Gather {
    as_op!();

    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
        Ok(tvec!(TypedFact::dt_shape(
            inputs[0].datum_type,
            &*self
                .compute_output_shape(&*inputs[0].shape.to_tvec(), &*inputs[1].shape.to_tvec())?
        )?))
    }
}

impl StatelessOp for Gather {
    /// Evaluates the operation given the input tensors.
    fn eval(&self, mut inputs: TVec<Arc<Tensor>>) -> TractResult<TVec<Arc<Tensor>>> {
        let (data, indices) = args_2!(inputs);
        Ok(tvec!(dispatch_datum!(Self::eval_t(data.datum_type())(&self, data, &indices))?))
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_should_gather_scalar_index() {
        let data = Tensor::from(arr1(&[1i64, 2, 3]));
        let gatherer = Gather::new(0);
        for idx in 2..3 {
            let index = Tensor::from(arr0(idx as i64));
            let outputs = gatherer.eval(tvec![data.clone().into(), index.into()]).unwrap();
            let output = &outputs[0];
            assert_eq!(output.shape().len(), 0);
            assert_eq!(*output.to_scalar::<i64>().unwrap(), idx + 1);
        }
    }
}