tract_hir/ops/array/
array_feature_extractor.rs

1use tract_core::ops::array::Gather;
2
3use crate::infer::*;
4use crate::internal::*;
5
6#[derive(Debug, Clone, new, Default, Hash)]
7pub struct ArrayFeatureExtractor;
8
9impl Expansion for ArrayFeatureExtractor {
10    fn name(&self) -> StaticName {
11        "ArrayFeatureExtractor".into()
12    }
13
14    fn wire(
15        &self,
16        prefix: &str,
17        model: &mut TypedModel,
18        inputs: &[OutletId],
19    ) -> TractResult<TVec<OutletId>> {
20        let last_axis = model.outlet_fact(inputs[0])?.rank() - 1;
21        let gather_op = Gather { axis: last_axis, output_type: None };
22
23        model.wire_node(prefix, gather_op, inputs)
24    }
25
26    fn rules<'r, 'p: 'r, 's: 'r>(
27        &'s self,
28        s: &mut Solver<'r>,
29        inputs: &'p [TensorProxy],
30        outputs: &'p [TensorProxy],
31    ) -> InferenceResult {
32        // Expect two inputs:
33        // - X: data to be selected
34        // - Y: the indices that'll be applied to the last axis
35        check_input_arity(inputs, 2)?;
36
37        // We return one tensor containing the selection
38        check_output_arity(outputs, 1)?;
39
40        // Check types
41        s.equals(&inputs[0].datum_type, &outputs[0].datum_type)?;
42        s.equals(&inputs[1].datum_type, i64::datum_type())?;
43
44        // Check ranks
45        s.equals(inputs[0].rank.bex() - 1 + inputs[1].rank.bex(), outputs[0].rank.bex())?;
46
47        // Check shapes
48        s.given_2(&inputs[0].shape, &inputs[1].shape, move |s, input_shape, indices_shape| {
49            let input_rank = input_shape.len();
50            let mut output_shape = tvec![];
51            output_shape.extend(input_shape.iter().take(input_rank - 1).cloned());
52            output_shape.extend(indices_shape.iter().cloned());
53            s.equals(&outputs[0].shape, output_shape)?;
54            Ok(())
55        })?;
56        Ok(())
57    }
58}