Skip to main content

tract_core/ops/array/
scatter_nd.rs

1use crate::internal::*;
2use ndarray::*;
3
4#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq, Default)]
5pub enum ScatterReduction {
6    #[default]
7    None,
8    Add,
9    Mul,
10    Min,
11    Max,
12}
13
14impl ScatterReduction {
15    pub fn as_str(&self) -> &'static str {
16        match self {
17            ScatterReduction::None => "none",
18            ScatterReduction::Add => "add",
19            ScatterReduction::Mul => "mul",
20            ScatterReduction::Min => "min",
21            ScatterReduction::Max => "max",
22        }
23    }
24
25    pub fn parse(s: &str) -> TractResult<Self> {
26        Ok(match s {
27            "none" => ScatterReduction::None,
28            "add" => ScatterReduction::Add,
29            "mul" => ScatterReduction::Mul,
30            "min" => ScatterReduction::Min,
31            "max" => ScatterReduction::Max,
32            s => bail!("Unknown scatter reduction: {s}"),
33        })
34    }
35}
36
37#[derive(Debug, Clone, new, Hash, PartialEq, Eq)]
38pub struct ScatterNd {
39    pub reduction: ScatterReduction,
40}
41
42impl Op for ScatterNd {
43    fn name(&self) -> StaticName {
44        "ScatterNd".into()
45    }
46
47    op_as_typed_op!();
48}
49
50impl ScatterNd {
51    unsafe fn eval_t<T: Datum>(
52        data: TValue,
53        indices: &ArrayViewD<i64>,
54        updates: TValue,
55    ) -> TractResult<TValue> {
56        let mut data = unsafe { data.into_tensor().into_array_unchecked::<T>() };
57        let updates_plain = updates.try_as_plain()?;
58        let updates_view = unsafe { updates_plain.to_array_view_unchecked::<T>() };
59        for coords in tract_ndarray::indices(&indices.shape()[..indices.ndim() - 1]) {
60            let mut indices_into_data = indices.view();
61            let mut updates = updates_view.view();
62            for x in coords.slice() {
63                indices_into_data.index_axis_inplace(Axis(0), *x);
64                updates.index_axis_inplace(Axis(0), *x);
65            }
66            let mut data = data.view_mut();
67            for x in indices_into_data {
68                data.index_axis_inplace(Axis(0), *x as usize);
69            }
70            data.assign(&updates)
71        }
72        let mut tensor = data.into_tensor();
73        unsafe { tensor.set_datum_type(updates.datum_type()) };
74        Ok(tensor.into_tvalue())
75    }
76
77    unsafe fn eval_t_reduce<T: Datum + PartialOrd + std::ops::AddAssign + std::ops::MulAssign>(
78        data: TValue,
79        indices: &ArrayViewD<i64>,
80        updates: TValue,
81        reduction: ScatterReduction,
82    ) -> TractResult<TValue> {
83        let mut data = unsafe { data.into_tensor().into_array_unchecked::<T>() };
84        let updates_plain = updates.try_as_plain()?;
85        let updates_view = unsafe { updates_plain.to_array_view_unchecked::<T>() };
86        for coords in tract_ndarray::indices(&indices.shape()[..indices.ndim() - 1]) {
87            let mut indices_into_data = indices.view();
88            let mut updates = updates_view.view();
89            for x in coords.slice() {
90                indices_into_data.index_axis_inplace(Axis(0), *x);
91                updates.index_axis_inplace(Axis(0), *x);
92            }
93            let mut data = data.view_mut();
94            for x in indices_into_data {
95                data.index_axis_inplace(Axis(0), *x as usize);
96            }
97            Zip::from(&mut data).and(&updates).for_each(|d, u| match reduction {
98                ScatterReduction::Add => *d += u.clone(),
99                ScatterReduction::Mul => *d *= u.clone(),
100                ScatterReduction::Min => {
101                    if u < d {
102                        *d = u.clone()
103                    }
104                }
105                ScatterReduction::Max => {
106                    if u > d {
107                        *d = u.clone()
108                    }
109                }
110                ScatterReduction::None => unreachable!(),
111            });
112        }
113        let mut tensor = data.into_tensor();
114        unsafe { tensor.set_datum_type(updates.datum_type()) };
115        Ok(tensor.into_tvalue())
116    }
117}
118
119impl TypedOp for ScatterNd {
120    as_op!();
121
122    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
123        Ok(tvec!(inputs[0].datum_type.fact(inputs[0].shape.to_tvec())))
124    }
125}
126
127impl EvalOp for ScatterNd {
128    fn is_stateless(&self) -> bool {
129        true
130    }
131
132    fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
133        let (data, indices, updates) = args_3!(inputs);
134        let indices = indices.cast_to::<i64>()?;
135        let indices = indices.to_plain_array_view::<i64>()?;
136        if data.datum_type() != updates.datum_type() {
137            bail!(
138                "Data and update must be of the same type, got {:?} and {:?}",
139                data.datum_type(),
140                updates.datum_type()
141            );
142        }
143        unsafe {
144            match self.reduction {
145                ScatterReduction::None => {
146                    Ok(tvec!(dispatch_datum_by_size!(Self::eval_t(data.datum_type())(
147                        data, &indices, updates
148                    ))?))
149                }
150                reduction => Ok(tvec!(dispatch_numbers!(Self::eval_t_reduce(data.datum_type())(
151                    data, &indices, updates, reduction
152                ))?)),
153            }
154        }
155    }
156}