1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
use crate::internal::*;
use ndarray::*;
#[derive(Debug, Clone, new, Hash)]
pub struct Gather {
axis: i64,
}
tract_linalg::impl_dyn_hash!(Gather);
impl Op for Gather {
fn name(&self) -> Cow<str> {
"Gather".into()
}
op_core_mir!();
op_as_typed_op!();
not_a_pulsed_op!();
}
impl Gather {
fn resolved_axis(&self, rank: usize) -> TractResult<usize> {
if 0 <= self.axis && self.axis <= rank as i64 - 1 {
Ok(self.axis as usize)
} else if -(rank as i64) <= self.axis && self.axis < 0 {
Ok((self.axis + rank as i64) as usize)
} else {
bail!("Illegal combination of values for rank and axis")
}
}
pub fn compute_output_shape<D: DimLike>(
&self,
input_shape: &[D],
indices_shape: &[D],
) -> TractResult<TVec<D>> {
let axis = self.resolved_axis(input_shape.len())?;
let mut output_shape = tvec![];
for (idx, dim) in input_shape.iter().enumerate() {
if idx != axis {
output_shape.push(dim.clone());
} else {
for idx2 in indices_shape {
output_shape.push(idx2.clone());
}
}
}
Ok(output_shape)
}
fn eval_t<T: Datum>(
&self,
data: Arc<Tensor>,
indices: &Arc<Tensor>,
) -> TractResult<Arc<Tensor>> {
let data_view = data.to_array_view::<T>()?;
let axis = self.resolved_axis(data.shape().len())?;
let indices = indices.cast_to::<i64>()?;
if indices.shape().len() == 0 {
let mut index = *indices.to_scalar::<i64>()?;
if index < 0 {
index += data_view.shape()[0] as i64;
}
return Ok(data_view
.index_axis(Axis(axis), index as usize)
.to_owned()
.into_arc_tensor());
}
let mut output = unsafe {
Tensor::uninitialized::<T>(&*self.compute_output_shape(data.shape(), indices.shape())?)?
};
{
let mut output = output.to_array_view_mut::<T>()?;
for (pattern, index) in indices.to_array_view::<i64>()?.indexed_iter() {
let mut to_update = output.index_axis_mut(Axis(axis), pattern[0]);
for idx in 1..pattern.ndim() {
to_update = to_update.index_axis_move(Axis(0), pattern[idx]);
}
to_update.assign(&data_view.index_axis(Axis(axis), *index as usize));
}
}
Ok(output.into_arc_tensor())
}
}
impl TypedOp for Gather {
as_op!();
fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
Ok(tvec!(TypedFact::dt_shape(
inputs[0].datum_type,
&*self
.compute_output_shape(&*inputs[0].shape.to_tvec(), &*inputs[1].shape.to_tvec())?
)?))
}
}
impl StatelessOp for Gather {
fn eval(&self, mut inputs: TVec<Arc<Tensor>>) -> TractResult<TVec<Arc<Tensor>>> {
let (data, indices) = args_2!(inputs);
Ok(tvec!(dispatch_datum!(Self::eval_t(data.datum_type())(&self, data, &indices))?))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_should_gather_scalar_index() {
let data = Tensor::from(arr1(&[1i64, 2, 3]));
let gatherer = Gather::new(0);
for idx in 2..3 {
let index = Tensor::from(arr0(idx as i64));
let outputs = gatherer.eval(tvec![data.clone().into(), index.into()]).unwrap();
let output = &outputs[0];
assert_eq!(output.shape().len(), 0);
assert_eq!(*output.to_scalar::<i64>().unwrap(), idx + 1);
}
}
}