tract_hir/ops/array/
gather_elements.rs1use crate::infer::*;
2use crate::internal::*;
3
4#[derive(Debug, Clone, new, Default, Hash)]
5pub struct GatherElements {
6 axis: i64,
7}
8
9
10impl Expansion for GatherElements {
11 fn name(&self) -> StaticName {
12 "GatherElements".into()
13 }
14
15
16 fn rules<'r, 'p: 'r, 's: 'r>(
17 &'s self,
18 s: &mut Solver<'r>,
19 inputs: &'p [TensorProxy],
20 outputs: &'p [TensorProxy],
21 ) -> InferenceResult {
22 check_input_arity(inputs, 2)?;
23 check_output_arity(outputs, 1)?;
24 s.equals(&outputs[0].datum_type, &inputs[0].datum_type)?;
25 s.equals(&inputs[0].rank, &inputs[1].rank)?;
26 s.equals(&outputs[0].shape, &inputs[1].shape)?;
27 Ok(())
28 }
29
30 fn wire(
31 &self,
32 prefix: &str,
33 model: &mut TypedModel,
34 inputs: &[OutletId],
35 ) -> TractResult<TVec<OutletId>> {
36 let input_rank = model.outlet_fact(inputs[0])?.rank();
37 let axis = if self.axis < 0 { self.axis + input_rank as i64 } else { self.axis } as usize;
38 model.wire_node(prefix, tract_core::ops::array::GatherElements { axis }, inputs)
39 }
40}