tract_core/ops/array/
scatter_nd.rs1use crate::internal::*;
2use ndarray::*;
3
4#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq, Default)]
5pub enum ScatterReduction {
6 #[default]
7 None,
8 Add,
9 Mul,
10 Min,
11 Max,
12}
13
14impl ScatterReduction {
15 pub fn as_str(&self) -> &'static str {
16 match self {
17 ScatterReduction::None => "none",
18 ScatterReduction::Add => "add",
19 ScatterReduction::Mul => "mul",
20 ScatterReduction::Min => "min",
21 ScatterReduction::Max => "max",
22 }
23 }
24
25 pub fn parse(s: &str) -> TractResult<Self> {
26 Ok(match s {
27 "none" => ScatterReduction::None,
28 "add" => ScatterReduction::Add,
29 "mul" => ScatterReduction::Mul,
30 "min" => ScatterReduction::Min,
31 "max" => ScatterReduction::Max,
32 s => bail!("Unknown scatter reduction: {s}"),
33 })
34 }
35}
36
37#[derive(Debug, Clone, new, Hash, PartialEq, Eq)]
38pub struct ScatterNd {
39 pub reduction: ScatterReduction,
40}
41
42impl Op for ScatterNd {
43 fn name(&self) -> StaticName {
44 "ScatterNd".into()
45 }
46
47 op_as_typed_op!();
48}
49
50impl ScatterNd {
51 unsafe fn eval_t<T: Datum>(
52 data: TValue,
53 indices: &ArrayViewD<i64>,
54 updates: TValue,
55 ) -> TractResult<TValue> {
56 let mut data = unsafe { data.into_tensor().into_array_unchecked::<T>() };
57 let updates_plain = updates.try_as_plain()?;
58 let updates_view = unsafe { updates_plain.to_array_view_unchecked::<T>() };
59 for coords in tract_ndarray::indices(&indices.shape()[..indices.ndim() - 1]) {
60 let mut indices_into_data = indices.view();
61 let mut updates = updates_view.view();
62 for x in coords.slice() {
63 indices_into_data.index_axis_inplace(Axis(0), *x);
64 updates.index_axis_inplace(Axis(0), *x);
65 }
66 let mut data = data.view_mut();
67 for x in indices_into_data {
68 data.index_axis_inplace(Axis(0), *x as usize);
69 }
70 data.assign(&updates)
71 }
72 let mut tensor = data.into_tensor();
73 unsafe { tensor.set_datum_type(updates.datum_type()) };
74 Ok(tensor.into_tvalue())
75 }
76
77 unsafe fn eval_t_reduce<T: Datum + PartialOrd + std::ops::AddAssign + std::ops::MulAssign>(
78 data: TValue,
79 indices: &ArrayViewD<i64>,
80 updates: TValue,
81 reduction: ScatterReduction,
82 ) -> TractResult<TValue> {
83 let mut data = unsafe { data.into_tensor().into_array_unchecked::<T>() };
84 let updates_plain = updates.try_as_plain()?;
85 let updates_view = unsafe { updates_plain.to_array_view_unchecked::<T>() };
86 for coords in tract_ndarray::indices(&indices.shape()[..indices.ndim() - 1]) {
87 let mut indices_into_data = indices.view();
88 let mut updates = updates_view.view();
89 for x in coords.slice() {
90 indices_into_data.index_axis_inplace(Axis(0), *x);
91 updates.index_axis_inplace(Axis(0), *x);
92 }
93 let mut data = data.view_mut();
94 for x in indices_into_data {
95 data.index_axis_inplace(Axis(0), *x as usize);
96 }
97 Zip::from(&mut data).and(&updates).for_each(|d, u| match reduction {
98 ScatterReduction::Add => *d += u.clone(),
99 ScatterReduction::Mul => *d *= u.clone(),
100 ScatterReduction::Min => {
101 if u < d {
102 *d = u.clone()
103 }
104 }
105 ScatterReduction::Max => {
106 if u > d {
107 *d = u.clone()
108 }
109 }
110 ScatterReduction::None => unreachable!(),
111 });
112 }
113 let mut tensor = data.into_tensor();
114 unsafe { tensor.set_datum_type(updates.datum_type()) };
115 Ok(tensor.into_tvalue())
116 }
117}
118
119impl TypedOp for ScatterNd {
120 as_op!();
121
122 fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
123 Ok(tvec!(inputs[0].datum_type.fact(inputs[0].shape.to_tvec())))
124 }
125}
126
127impl EvalOp for ScatterNd {
128 fn is_stateless(&self) -> bool {
129 true
130 }
131
132 fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
133 let (data, indices, updates) = args_3!(inputs);
134 let indices = indices.cast_to::<i64>()?;
135 let indices = indices.to_plain_array_view::<i64>()?;
136 if data.datum_type() != updates.datum_type() {
137 bail!(
138 "Data and update must be of the same type, got {:?} and {:?}",
139 data.datum_type(),
140 updates.datum_type()
141 );
142 }
143 unsafe {
144 match self.reduction {
145 ScatterReduction::None => {
146 Ok(tvec!(dispatch_datum_by_size!(Self::eval_t(data.datum_type())(
147 data, &indices, updates
148 ))?))
149 }
150 reduction => Ok(tvec!(dispatch_numbers!(Self::eval_t_reduce(data.datum_type())(
151 data, &indices, updates, reduction
152 ))?)),
153 }
154 }
155 }
156}