tract_hir/ops/array/
scatter_elements.rs

1use tract_core::ops::cast::wire_cast;
2
3use crate::infer::*;
4use crate::internal::*;
5
6#[derive(Debug, Clone, new, Default, Hash)]
7pub struct ScatterElements {
8    axis: i64,
9}
10
11impl Expansion for ScatterElements {
12    fn name(&self) -> StaticName {
13        "ScatterElements".into()
14    }
15
16    fn rules<'r, 'p: 'r, 's: 'r>(
17        &'s self,
18        s: &mut Solver<'r>,
19        inputs: &'p [TensorProxy],
20        outputs: &'p [TensorProxy],
21    ) -> InferenceResult {
22        check_input_arity(inputs, 3)?;
23        check_output_arity(outputs, 1)?;
24
25        s.given_2(&inputs[0].datum_type, &inputs[2].datum_type, move |s, input, updates| {
26            let super_type: DatumType = DatumType::super_type_for([input, updates])
27                .with_context(|| format!("No supertype found for {input:?} and {updates:?}"))?;
28            s.equals(&outputs[0].datum_type, super_type)
29        })?;
30        s.equals(&inputs[0].rank, &inputs[1].rank)?;
31        s.equals(&inputs[1].shape, &inputs[2].shape)?;
32        s.equals(&outputs[0].shape, &inputs[0].shape)?;
33        Ok(())
34    }
35
36    fn wire(
37        &self,
38        prefix: &str,
39        model: &mut TypedModel,
40        inputs: &[OutletId],
41    ) -> TractResult<TVec<OutletId>> {
42        let input_rank = model.outlet_fact(inputs[0])?.rank();
43        let axis = if self.axis < 0 { self.axis + input_rank as i64 } else { self.axis } as usize;
44        let super_type = if let Some(super_type) = DatumType::super_type_for([
45            model.outlet_fact(inputs[0])?.datum_type,
46            model.outlet_fact(inputs[2])?.datum_type,
47        ]) {
48            super_type
49        } else {
50            bail!("Can not type op");
51        };
52        let casted = wire_cast(prefix, model, &[inputs[0], inputs[2]], super_type)?;
53        model.wire_node(
54            prefix,
55            tract_core::ops::array::ScatterElements { axis },
56            &[casted[0], inputs[1], casted[1]],
57        )
58    }
59}