tract_hir/ops/array/
scatter_nd.rs1
2use crate::infer::*;
3use crate::internal::*;
4
5pub use tract_core::ops::array::ScatterNd;
6
7impl InferenceRulesOp for ScatterNd {
8 fn rules<'r, 'p: 'r, 's: 'r>(
9 &'s self,
10 s: &mut Solver<'r>,
11 inputs: &'p [TensorProxy],
12 outputs: &'p [TensorProxy],
13 ) -> InferenceResult {
14 check_input_arity(inputs, 3)?;
15 check_output_arity(outputs, 1)?;
16 s.equals(&outputs[0].datum_type, &inputs[0].datum_type)?;
17 s.equals(&inputs[2].datum_type, &inputs[0].datum_type)?;
18 s.equals(&outputs[0].shape, &inputs[0].shape)?;
19
20 s.given_2(&inputs[0].rank, &inputs[1].rank, move |s, p ,q| {
21 s.given(&inputs[1].shape[q as usize - 1], move |s, r| {
22 if let Ok(r) = r.to_i64() {
23 s.equals(&inputs[2].rank, p + q - r - 1)?;
24 }
25 Ok(())
26 })
27 })?;
28 Ok(())
29 }
30
31 as_op!();
32 to_typed!();
33}