use crate::internal::*;
use ndarray::*;
#[derive(Debug, Clone, new, Hash)]
pub struct Gather {
pub axis: usize,
}
impl Op for Gather {
fn name(&self) -> Cow<str> {
"Gather".into()
}
op_as_typed_op!();
}
impl Gather {
pub fn compute_output_shape<D: DimLike>(
&self,
input_shape: &[D],
indices_shape: &[D],
) -> TractResult<TVec<D>> {
let mut output_shape: TVec<D> = input_shape[..self.axis].into();
output_shape.extend(indices_shape.iter().cloned());
output_shape.extend(input_shape[self.axis + 1..].iter().cloned());
Ok(output_shape)
}
unsafe fn eval_t<T: Datum>(&self, data: TValue, indices: &TValue) -> TractResult<TValue> {
let data_view = data.to_array_view_unchecked::<T>();
let indices = indices.to_array_view::<i64>()?;
let output_shape = &*self.compute_output_shape(data.shape(), indices.shape())?;
let mut output = Tensor::uninitialized::<T>(output_shape)?;
let mut output_view = output.to_array_view_mut::<T>()?;
for coords in tract_ndarray::indices(output_shape) {
let ocoords = coords.as_array_view();
let ocoords = ocoords.as_slice().unwrap();
let mut icoords: TVec<usize> = ocoords[0..self.axis].into();
let kcoords = &ocoords[self.axis..][..indices.ndim()];
let k = indices[kcoords];
let k = if k < 0 { k + data_view.shape()[self.axis] as i64 } else { k } as usize;
icoords.push(k);
icoords.extend(ocoords[self.axis + indices.ndim()..].iter().copied());
output_view[ocoords] = data_view.get(&*icoords).context("Invalid gather")?.clone();
}
unsafe { output.set_datum_type(data.datum_type()) };
Ok(output.into_tvalue())
}
}
impl TypedOp for Gather {
as_op!();
fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
ensure!(inputs[1].datum_type == i64::datum_type());
Ok(tvec!(inputs[0].datum_type.fact(
&*self.compute_output_shape(&inputs[0].shape.to_tvec(), &inputs[1].shape.to_tvec())?
)))
}
fn declutter(
&self,
model: &TypedModel,
node: &TypedNode,
) -> TractResult<Option<TypedModelPatch>> {
let indices_fact = model.outlet_fact(node.inputs[1])?;
if let Some(indices) = indices_fact.konst.as_ref() {
if indices.rank() == 1 && indices.len() == 1 {
let mut patch = TypedModelPatch::default();
let mut wire = patch.tap_model(model, node.inputs[0])?;
let index = indices.cast_to_scalar::<i64>()?;
let index = if index < 0 {
let data_fact = model.outlet_fact(node.inputs[0])?;
data_fact.shape[self.axis].clone() + index.to_dim()
} else {
index.to_dim()
};
wire = patch.wire_node(
format!("{}.slice", node.name),
crate::ops::array::Slice {
axis: self.axis,
start: index.clone(),
end: index + 1,
},
&[wire],
)?[0];
patch.shunt_outside(model, node.id.into(), wire)?;
return Ok(Some(patch));
}
}
Ok(None)
}
}
impl EvalOp for Gather {
fn is_stateless(&self) -> bool {
true
}
fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
let (data, indices) = args_2!(inputs);
unsafe {
Ok(tvec!(dispatch_datum_by_size!(Self::eval_t(data.datum_type())(
self, data, &indices
))?))
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_should_gather_scalar_index() {
let data = Tensor::from(arr1(&[1i64, 2, 3]));
let gatherer = Gather::new(0);
for idx in 2..3 {
let index = Tensor::from(arr0(idx));
let outputs =
gatherer.eval(tvec![data.clone().into_tvalue(), index.into_tvalue()]).unwrap();
let output = &outputs[0];
assert_eq!(output.shape().len(), 0);
assert_eq!(*output.to_scalar::<i64>().unwrap(), idx + 1);
}
}
}