Skip to main content

tract_core/ops/array/
scatter_elements.rs

1use super::scatter_nd::ScatterReduction;
2use crate::internal::*;
3use ndarray::*;
4
5#[derive(Debug, Clone, new, Hash, PartialEq, Eq)]
6pub struct ScatterElements {
7    pub axis: usize,
8    pub reduction: ScatterReduction,
9}
10
11impl Op for ScatterElements {
12    fn name(&self) -> StaticName {
13        "ScatterElements".into()
14    }
15
16    op_as_typed_op!();
17}
18
19impl ScatterElements {
20    unsafe fn eval_t<T: Datum>(
21        data: TValue,
22        indices: &ArrayViewD<i64>,
23        updates: TValue,
24        axis: usize,
25    ) -> TractResult<TValue> {
26        let mut data = unsafe { data.into_tensor().into_array_unchecked::<T>() };
27        let updates_plain = updates.try_as_plain()?;
28        let updates_view = unsafe { updates_plain.to_array_view_unchecked::<T>() };
29        for (mut coords, value) in updates_view.indexed_iter() {
30            let index = indices[&coords];
31            coords[axis] =
32                if index < 0 { index + data.shape()[axis] as i64 } else { index } as usize;
33            data[coords] = value.clone()
34        }
35        let mut tensor = data.into_tensor();
36        unsafe { tensor.set_datum_type(updates.datum_type()) };
37        Ok(tensor.into_tvalue())
38    }
39
40    unsafe fn eval_t_reduce<T: Datum + PartialOrd + std::ops::AddAssign + std::ops::MulAssign>(
41        data: TValue,
42        indices: &ArrayViewD<i64>,
43        updates: TValue,
44        axis: usize,
45        reduction: ScatterReduction,
46    ) -> TractResult<TValue> {
47        let mut data = unsafe { data.into_tensor().into_array_unchecked::<T>() };
48        let updates_plain = updates.try_as_plain()?;
49        let updates_view = unsafe { updates_plain.to_array_view_unchecked::<T>() };
50        for (mut coords, value) in updates_view.indexed_iter() {
51            let index = indices[&coords];
52            coords[axis] =
53                if index < 0 { index + data.shape()[axis] as i64 } else { index } as usize;
54            let d = &mut data[coords];
55            match reduction {
56                ScatterReduction::Add => *d += value.clone(),
57                ScatterReduction::Mul => *d *= value.clone(),
58                ScatterReduction::Min => {
59                    if value < d {
60                        *d = value.clone()
61                    }
62                }
63                ScatterReduction::Max => {
64                    if value > d {
65                        *d = value.clone()
66                    }
67                }
68                ScatterReduction::None => unreachable!(),
69            }
70        }
71        let mut tensor = data.into_tensor();
72        unsafe { tensor.set_datum_type(updates.datum_type()) };
73        Ok(tensor.into_tvalue())
74    }
75}
76
77impl TypedOp for ScatterElements {
78    as_op!();
79
80    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
81        Ok(tvec!(inputs[0].datum_type.fact(inputs[0].shape.clone())))
82    }
83}
84
85impl EvalOp for ScatterElements {
86    fn is_stateless(&self) -> bool {
87        true
88    }
89
90    fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
91        let (data, indices, updates) = args_3!(inputs);
92        let indices = indices.cast_to::<i64>()?;
93        let indices = indices.to_plain_array_view::<i64>()?;
94        if data.datum_type() != updates.datum_type() {
95            bail!(
96                "Data and update must be of the same type, got {:?} and {:?}",
97                data.datum_type(),
98                updates.datum_type()
99            );
100        }
101        unsafe {
102            match self.reduction {
103                ScatterReduction::None => {
104                    Ok(tvec!(dispatch_datum_by_size!(Self::eval_t(data.datum_type())(
105                        data, &indices, updates, self.axis
106                    ))?))
107                }
108                reduction => Ok(tvec!(dispatch_numbers!(Self::eval_t_reduce(data.datum_type())(
109                    data, &indices, updates, self.axis, reduction
110                ))?)),
111            }
112        }
113    }
114}