tract_core/ops/array/
scatter_elements.rs1use crate::internal::*;
2use ndarray::*;
3
4#[derive(Debug, Clone, new, Hash)]
5pub struct ScatterElements {
6 pub axis: usize,
7}
8
9impl Op for ScatterElements {
10 fn name(&self) -> StaticName {
11 "ScatterElements".into()
12 }
13
14 op_as_typed_op!();
15}
16
17impl ScatterElements {
18 unsafe fn eval_t<T: Datum>(
19 &self,
20 data: TValue,
21 indices: &ArrayViewD<i64>,
22 updates: TValue,
23 ) -> TractResult<TValue> {
24 let mut data = unsafe { data.into_tensor().into_array_unchecked::<T>() };
25 let updates_view = unsafe { updates.to_array_view_unchecked::<T>() };
26 for (mut coords, value) in updates_view.indexed_iter() {
27 let index = indices[&coords];
28 coords[self.axis] =
29 if index < 0 { index + data.shape()[self.axis] as i64 } else { index } as usize;
30 data[coords] = value.clone()
31 }
32 let mut tensor = data.into_tensor();
33 unsafe { tensor.set_datum_type(updates.datum_type()) };
34 Ok(tensor.into_tvalue())
35 }
36}
37
38impl TypedOp for ScatterElements {
39 as_op!();
40
41 fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
42 Ok(tvec!(inputs[0].datum_type.fact(inputs[0].shape.clone())))
43 }
44}
45
46impl EvalOp for ScatterElements {
47 fn is_stateless(&self) -> bool {
48 true
49 }
50
51 fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
52 let (data, indices, updates) = args_3!(inputs);
53 let indices = indices.cast_to::<i64>()?;
54 let indices = indices.to_array_view::<i64>()?;
55 if data.datum_type() != updates.datum_type() {
56 bail!(
57 "Data and update must be of the same type, got {:?} and {:?}",
58 data.datum_type(),
59 updates.datum_type()
60 );
61 }
62 unsafe {
63 Ok(tvec!(dispatch_datum_by_size!(Self::eval_t(data.datum_type())(
64 self, data, &indices, updates
65 ))?))
66 }
67 }
68}