tract_core/ops/array/
scatter_elements.rs1use super::scatter_nd::ScatterReduction;
2use crate::internal::*;
3use ndarray::*;
4
5#[derive(Debug, Clone, new, Hash, PartialEq, Eq)]
6pub struct ScatterElements {
7 pub axis: usize,
8 pub reduction: ScatterReduction,
9}
10
11impl Op for ScatterElements {
12 fn name(&self) -> StaticName {
13 "ScatterElements".into()
14 }
15
16 op_as_typed_op!();
17}
18
19impl ScatterElements {
20 unsafe fn eval_t<T: Datum>(
21 data: TValue,
22 indices: &ArrayViewD<i64>,
23 updates: TValue,
24 axis: usize,
25 ) -> TractResult<TValue> {
26 let mut data = unsafe { data.into_tensor().into_array_unchecked::<T>() };
27 let updates_plain = updates.try_as_plain()?;
28 let updates_view = unsafe { updates_plain.to_array_view_unchecked::<T>() };
29 for (mut coords, value) in updates_view.indexed_iter() {
30 let index = indices[&coords];
31 coords[axis] =
32 if index < 0 { index + data.shape()[axis] as i64 } else { index } as usize;
33 data[coords] = value.clone()
34 }
35 let mut tensor = data.into_tensor();
36 unsafe { tensor.set_datum_type(updates.datum_type()) };
37 Ok(tensor.into_tvalue())
38 }
39
40 unsafe fn eval_t_reduce<T: Datum + PartialOrd + std::ops::AddAssign + std::ops::MulAssign>(
41 data: TValue,
42 indices: &ArrayViewD<i64>,
43 updates: TValue,
44 axis: usize,
45 reduction: ScatterReduction,
46 ) -> TractResult<TValue> {
47 let mut data = unsafe { data.into_tensor().into_array_unchecked::<T>() };
48 let updates_plain = updates.try_as_plain()?;
49 let updates_view = unsafe { updates_plain.to_array_view_unchecked::<T>() };
50 for (mut coords, value) in updates_view.indexed_iter() {
51 let index = indices[&coords];
52 coords[axis] =
53 if index < 0 { index + data.shape()[axis] as i64 } else { index } as usize;
54 let d = &mut data[coords];
55 match reduction {
56 ScatterReduction::Add => *d += value.clone(),
57 ScatterReduction::Mul => *d *= value.clone(),
58 ScatterReduction::Min => {
59 if value < d {
60 *d = value.clone()
61 }
62 }
63 ScatterReduction::Max => {
64 if value > d {
65 *d = value.clone()
66 }
67 }
68 ScatterReduction::None => unreachable!(),
69 }
70 }
71 let mut tensor = data.into_tensor();
72 unsafe { tensor.set_datum_type(updates.datum_type()) };
73 Ok(tensor.into_tvalue())
74 }
75}
76
77impl TypedOp for ScatterElements {
78 as_op!();
79
80 fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
81 Ok(tvec!(inputs[0].datum_type.fact(inputs[0].shape.clone())))
82 }
83}
84
85impl EvalOp for ScatterElements {
86 fn is_stateless(&self) -> bool {
87 true
88 }
89
90 fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
91 let (data, indices, updates) = args_3!(inputs);
92 let indices = indices.cast_to::<i64>()?;
93 let indices = indices.to_plain_array_view::<i64>()?;
94 if data.datum_type() != updates.datum_type() {
95 bail!(
96 "Data and update must be of the same type, got {:?} and {:?}",
97 data.datum_type(),
98 updates.datum_type()
99 );
100 }
101 unsafe {
102 match self.reduction {
103 ScatterReduction::None => {
104 Ok(tvec!(dispatch_datum_by_size!(Self::eval_t(data.datum_type())(
105 data, &indices, updates, self.axis
106 ))?))
107 }
108 reduction => Ok(tvec!(dispatch_numbers!(Self::eval_t_reduce(data.datum_type())(
109 data, &indices, updates, self.axis, reduction
110 ))?)),
111 }
112 }
113 }
114}