tract_core/ops/array/
scatter_elements.rs

1use crate::internal::*;
2use ndarray::*;
3
4#[derive(Debug, Clone, new, Hash)]
5pub struct ScatterElements {
6    pub axis: usize,
7}
8
9
10impl Op for ScatterElements {
11    fn name(&self) -> Cow<str> {
12        "ScatterElements".into()
13    }
14
15    op_as_typed_op!();
16}
17
18impl ScatterElements {
19    unsafe fn eval_t<T: Datum>(
20        &self,
21        data: TValue,
22        indices: &ArrayViewD<i64>,
23        updates: TValue,
24    ) -> TractResult<TValue> {
25        let mut data = data.into_tensor().into_array_unchecked::<T>();
26        let updates_view = updates.to_array_view_unchecked::<T>();
27        for (mut coords, value) in updates_view.indexed_iter() {
28            let index = indices[&coords];
29            coords[self.axis] =
30                if index < 0 { index + data.shape()[self.axis] as i64 } else { index } as usize;
31            data[coords] = value.clone()
32        }
33        let mut tensor = data.into_tensor();
34        tensor.set_datum_type(updates.datum_type());
35        Ok(tensor.into_tvalue())
36    }
37}
38
39impl TypedOp for ScatterElements {
40    as_op!();
41
42    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
43        Ok(tvec!(inputs[0].datum_type.fact(inputs[0].shape.clone())))
44    }
45}
46
47impl EvalOp for ScatterElements {
48    fn is_stateless(&self) -> bool {
49        true
50    }
51
52    fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
53        let (data, indices, updates) = args_3!(inputs);
54        let indices = indices.cast_to::<i64>()?;
55        let indices = indices.to_array_view::<i64>()?;
56        if data.datum_type() != updates.datum_type() {
57            bail!(
58                "Data and update must be of the same type, got {:?} and {:?}",
59                data.datum_type(),
60                updates.datum_type()
61            );
62        }
63        unsafe {
64            Ok(tvec!(dispatch_datum_by_size!(Self::eval_t(data.datum_type())(
65                self, data, &indices, updates
66            ))?))
67        }
68    }
69}