tract_hir/ops/array/
gather_elements.rs

1use crate::infer::*;
2use crate::internal::*;
3
4#[derive(Debug, Clone, new, Default, Hash)]
5pub struct GatherElements {
6    axis: i64,
7}
8
9
10impl Expansion for GatherElements {
11    fn name(&self) -> StaticName {
12        "GatherElements".into()
13    }
14
15
16    fn rules<'r, 'p: 'r, 's: 'r>(
17        &'s self,
18        s: &mut Solver<'r>,
19        inputs: &'p [TensorProxy],
20        outputs: &'p [TensorProxy],
21    ) -> InferenceResult {
22        check_input_arity(inputs, 2)?;
23        check_output_arity(outputs, 1)?;
24        s.equals(&outputs[0].datum_type, &inputs[0].datum_type)?;
25        s.equals(&inputs[0].rank, &inputs[1].rank)?;
26        s.equals(&outputs[0].shape, &inputs[1].shape)?;
27        Ok(())
28    }
29
30    fn wire(
31        &self,
32        prefix: &str,
33        model: &mut TypedModel,
34        inputs: &[OutletId],
35    ) -> TractResult<TVec<OutletId>> {
36        let input_rank = model.outlet_fact(inputs[0])?.rank();
37        let axis = if self.axis < 0 { self.axis + input_rank as i64 } else { self.axis } as usize;
38        model.wire_node(prefix, tract_core::ops::array::GatherElements { axis }, inputs)
39    }
40}