tract_core/ops/array/
scatter_elements.rs

1use crate::internal::*;
2use ndarray::*;
3
4#[derive(Debug, Clone, new, Hash)]
5pub struct ScatterElements {
6    pub axis: usize,
7}
8
9impl Op for ScatterElements {
10    fn name(&self) -> StaticName {
11        "ScatterElements".into()
12    }
13
14    op_as_typed_op!();
15}
16
17impl ScatterElements {
18    unsafe fn eval_t<T: Datum>(
19        &self,
20        data: TValue,
21        indices: &ArrayViewD<i64>,
22        updates: TValue,
23    ) -> TractResult<TValue> {
24        let mut data = unsafe { data.into_tensor().into_array_unchecked::<T>() };
25        let updates_view = unsafe { updates.to_array_view_unchecked::<T>() };
26        for (mut coords, value) in updates_view.indexed_iter() {
27            let index = indices[&coords];
28            coords[self.axis] =
29                if index < 0 { index + data.shape()[self.axis] as i64 } else { index } as usize;
30            data[coords] = value.clone()
31        }
32        let mut tensor = data.into_tensor();
33        unsafe { tensor.set_datum_type(updates.datum_type()) };
34        Ok(tensor.into_tvalue())
35    }
36}
37
38impl TypedOp for ScatterElements {
39    as_op!();
40
41    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
42        Ok(tvec!(inputs[0].datum_type.fact(inputs[0].shape.clone())))
43    }
44}
45
46impl EvalOp for ScatterElements {
47    fn is_stateless(&self) -> bool {
48        true
49    }
50
51    fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
52        let (data, indices, updates) = args_3!(inputs);
53        let indices = indices.cast_to::<i64>()?;
54        let indices = indices.to_array_view::<i64>()?;
55        if data.datum_type() != updates.datum_type() {
56            bail!(
57                "Data and update must be of the same type, got {:?} and {:?}",
58                data.datum_type(),
59                updates.datum_type()
60            );
61        }
62        unsafe {
63            Ok(tvec!(dispatch_datum_by_size!(Self::eval_t(data.datum_type())(
64                self, data, &indices, updates
65            ))?))
66        }
67    }
68}