tract_hir/ops/array/
scatter_elements.rs1use tract_core::ops::cast::wire_cast;
2
3use crate::infer::*;
4use crate::internal::*;
5
6#[derive(Debug, Clone, new, Default, Hash)]
7pub struct ScatterElements {
8 axis: i64,
9}
10
11impl Expansion for ScatterElements {
12 fn name(&self) -> StaticName {
13 "ScatterElements".into()
14 }
15
16 fn rules<'r, 'p: 'r, 's: 'r>(
17 &'s self,
18 s: &mut Solver<'r>,
19 inputs: &'p [TensorProxy],
20 outputs: &'p [TensorProxy],
21 ) -> InferenceResult {
22 check_input_arity(inputs, 3)?;
23 check_output_arity(outputs, 1)?;
24
25 s.given_2(&inputs[0].datum_type, &inputs[2].datum_type, move |s, input, updates| {
26 let super_type: DatumType = DatumType::super_type_for([input, updates])
27 .with_context(|| format!("No supertype found for {input:?} and {updates:?}"))?;
28 s.equals(&outputs[0].datum_type, super_type)
29 })?;
30 s.equals(&inputs[0].rank, &inputs[1].rank)?;
31 s.equals(&inputs[1].shape, &inputs[2].shape)?;
32 s.equals(&outputs[0].shape, &inputs[0].shape)?;
33 Ok(())
34 }
35
36 fn wire(
37 &self,
38 prefix: &str,
39 model: &mut TypedModel,
40 inputs: &[OutletId],
41 ) -> TractResult<TVec<OutletId>> {
42 let input_rank = model.outlet_fact(inputs[0])?.rank();
43 let axis = if self.axis < 0 { self.axis + input_rank as i64 } else { self.axis } as usize;
44 let super_type = if let Some(super_type) = DatumType::super_type_for([
45 model.outlet_fact(inputs[0])?.datum_type,
46 model.outlet_fact(inputs[2])?.datum_type,
47 ]) {
48 super_type
49 } else {
50 bail!("Can not type op");
51 };
52 let casted = wire_cast(prefix, model, &[inputs[0], inputs[2]], super_type)?;
53 model.wire_node(
54 prefix,
55 tract_core::ops::array::ScatterElements { axis },
56 &[casted[0], inputs[1], casted[1]],
57 )
58 }
59}