tract_hir/ops/array/
array_feature_extractor.rs1use tract_core::ops::array::Gather;
2
3use crate::infer::*;
4use crate::internal::*;
5
6#[derive(Debug, Clone, new, Default, Hash)]
7pub struct ArrayFeatureExtractor;
8
9impl Expansion for ArrayFeatureExtractor {
10 fn name(&self) -> StaticName {
11 "ArrayFeatureExtractor".into()
12 }
13
14 fn wire(
15 &self,
16 prefix: &str,
17 model: &mut TypedModel,
18 inputs: &[OutletId],
19 ) -> TractResult<TVec<OutletId>> {
20 let last_axis = model.outlet_fact(inputs[0])?.rank() - 1;
21 let gather_op = Gather { axis: last_axis, output_type: None };
22
23 model.wire_node(prefix, gather_op, inputs)
24 }
25
26 fn rules<'r, 'p: 'r, 's: 'r>(
27 &'s self,
28 s: &mut Solver<'r>,
29 inputs: &'p [TensorProxy],
30 outputs: &'p [TensorProxy],
31 ) -> InferenceResult {
32 check_input_arity(inputs, 2)?;
36
37 check_output_arity(outputs, 1)?;
39
40 s.equals(&inputs[0].datum_type, &outputs[0].datum_type)?;
42 s.equals(&inputs[1].datum_type, i64::datum_type())?;
43
44 s.equals(inputs[0].rank.bex() - 1 + inputs[1].rank.bex(), outputs[0].rank.bex())?;
46
47 s.given_2(&inputs[0].shape, &inputs[1].shape, move |s, input_shape, indices_shape| {
49 let input_rank = input_shape.len();
50 let mut output_shape = tvec![];
51 output_shape.extend(input_shape.iter().take(input_rank - 1).cloned());
52 output_shape.extend(indices_shape.iter().cloned());
53 s.equals(&outputs[0].shape, output_shape)?;
54 Ok(())
55 })?;
56 Ok(())
57 }
58}