tract_core/ops/array/
scatter_nd.rs

1use crate::internal::*;
2use ndarray::*;
3
4#[derive(Debug, Clone, new, Hash)]
5pub struct ScatterNd;
6
7
8
9impl Op for ScatterNd {
10    fn name(&self) -> Cow<str> {
11        "ScatterNd".into()
12    }
13
14    op_as_typed_op!();
15}
16
17impl ScatterNd {
18    unsafe fn eval_t<T: Datum>(
19        &self,
20        data: TValue,
21        indices: &ArrayViewD<i64>,
22        updates: TValue,
23    ) -> TractResult<TValue> {
24        let mut data = data.into_tensor().into_array_unchecked::<T>();
25        let updates_view = updates.to_array_view_unchecked::<T>();
26        for coords in tract_ndarray::indices(&indices.shape()[..indices.ndim() - 1]) {
27            let mut indices_into_data = indices.view();
28            let mut updates = updates_view.view();
29            for x in coords.slice() {
30                indices_into_data.index_axis_inplace(Axis(0), *x);
31                updates.index_axis_inplace(Axis(0), *x);
32            }
33            let mut data = data.view_mut();
34            for x in indices_into_data {
35                data.index_axis_inplace(Axis(0), *x as usize);
36            }
37
38            data.assign(&updates)
39        }
40        let mut tensor = data.into_tensor();
41        tensor.set_datum_type(updates.datum_type());
42        Ok(tensor.into_tvalue())
43    }
44}
45
46impl TypedOp for ScatterNd {
47    as_op!();
48
49    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
50        Ok(tvec!(inputs[0].datum_type.fact(inputs[0].shape.to_tvec())))
51    }
52}
53
54impl EvalOp for ScatterNd {
55    fn is_stateless(&self) -> bool {
56        true
57    }
58
59    fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
60        let (data, indices, updates) = args_3!(inputs);
61        let indices = indices.cast_to::<i64>()?;
62        let indices = indices.to_array_view::<i64>()?;
63        if data.datum_type() != updates.datum_type() {
64            bail!(
65                "Data and update must be of the same type, got {:?} and {:?}",
66                data.datum_type(),
67                updates.datum_type()
68            );
69        }
70        unsafe {
71            Ok(tvec!(dispatch_datum_by_size!(Self::eval_t(data.datum_type())(
72                self, data, &indices, updates
73            ))?))
74        }
75    }
76}