tract_hir/ops/array/
rm_dims.rs

1use crate::infer::*;
2use crate::internal::*;
3use tract_itertools::Itertools;
4
5#[derive(Debug, Clone, new, Hash)]
6pub struct RmDims {
7    pub axes: Vec<isize>,
8}
9
10
11
12impl RmDims {
13    fn compute_shape<D: DimLike>(&self, input: &[D]) -> TVec<D> {
14        let axes = self
15            .axes
16            .iter()
17            .map(|&a| if a < 0 { a + input.len() as isize } else { a } as usize)
18            .collect::<Vec<_>>();
19        input
20            .iter()
21            .enumerate()
22            .filter(|(ix, _d)| !axes.contains(ix))
23            .map(|(_ix, d)| d.clone())
24            .collect()
25    }
26}
27
28impl Expansion for RmDims {
29    fn name(&self) -> StaticName {
30        "RmDims".into()
31    }
32
33
34    fn rules<'r, 'p: 'r, 's: 'r>(
35        &'s self,
36        s: &mut Solver<'r>,
37        inputs: &'p [TensorProxy],
38        outputs: &'p [TensorProxy],
39    ) -> InferenceResult {
40        check_output_arity(outputs, 1)?;
41        s.equals(&outputs[0].datum_type, &inputs[0].datum_type)?;
42        s.equals(&outputs[0].rank, (&inputs[0].rank).bex() - self.axes.len() as i64)?;
43        s.given(&inputs[0].rank, move |s, rank| {
44            for axis in &self.axes {
45                let axis = if *axis < 0 { axis + rank as isize } else { *axis } as usize;
46                s.equals(&inputs[0].shape[axis], 1.to_dim())?;
47            }
48            Ok(())
49        })?;
50        s.given(&inputs[0].shape, move |s, shape| {
51            let output_shape = self.compute_shape(&shape);
52            s.equals(&outputs[0].shape, output_shape)
53        })
54    }
55
56    fn wire(
57        &self,
58        prefix: &str,
59        target: &mut TypedModel,
60        inputs: &[OutletId],
61    ) -> TractResult<TVec<OutletId>> {
62        let mut wire = inputs[0];
63        let rank = target.outlet_fact(inputs[0])?.rank();
64        let axes = self
65            .axes
66            .iter()
67            .map(|&a| if a < 0 { a + rank as isize } else { a } as usize)
68            .sorted()
69            .rev();
70        for axis in axes {
71            wire =
72                target.wire_node(format!("{prefix}.axis-{axis}"), AxisOp::Rm(axis), &[wire])?
73                    [0];
74        }
75        Ok(tvec!(wire))
76    }
77}