tract_core/ops/array/
scatter_nd.rs1use crate::internal::*;
2use ndarray::*;
3
4#[derive(Debug, Clone, new, Hash)]
5pub struct ScatterNd;
6
7
8
9impl Op for ScatterNd {
10 fn name(&self) -> Cow<str> {
11 "ScatterNd".into()
12 }
13
14 op_as_typed_op!();
15}
16
17impl ScatterNd {
18 unsafe fn eval_t<T: Datum>(
19 &self,
20 data: TValue,
21 indices: &ArrayViewD<i64>,
22 updates: TValue,
23 ) -> TractResult<TValue> {
24 let mut data = data.into_tensor().into_array_unchecked::<T>();
25 let updates_view = updates.to_array_view_unchecked::<T>();
26 for coords in tract_ndarray::indices(&indices.shape()[..indices.ndim() - 1]) {
27 let mut indices_into_data = indices.view();
28 let mut updates = updates_view.view();
29 for x in coords.slice() {
30 indices_into_data.index_axis_inplace(Axis(0), *x);
31 updates.index_axis_inplace(Axis(0), *x);
32 }
33 let mut data = data.view_mut();
34 for x in indices_into_data {
35 data.index_axis_inplace(Axis(0), *x as usize);
36 }
37
38 data.assign(&updates)
39 }
40 let mut tensor = data.into_tensor();
41 tensor.set_datum_type(updates.datum_type());
42 Ok(tensor.into_tvalue())
43 }
44}
45
46impl TypedOp for ScatterNd {
47 as_op!();
48
49 fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
50 Ok(tvec!(inputs[0].datum_type.fact(inputs[0].shape.to_tvec())))
51 }
52}
53
54impl EvalOp for ScatterNd {
55 fn is_stateless(&self) -> bool {
56 true
57 }
58
59 fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
60 let (data, indices, updates) = args_3!(inputs);
61 let indices = indices.cast_to::<i64>()?;
62 let indices = indices.to_array_view::<i64>()?;
63 if data.datum_type() != updates.datum_type() {
64 bail!(
65 "Data and update must be of the same type, got {:?} and {:?}",
66 data.datum_type(),
67 updates.datum_type()
68 );
69 }
70 unsafe {
71 Ok(tvec!(dispatch_datum_by_size!(Self::eval_t(data.datum_type())(
72 self, data, &indices, updates
73 ))?))
74 }
75 }
76}