tract_hir/ops/array/
rm_dims.rs1use crate::infer::*;
2use crate::internal::*;
3use tract_itertools::Itertools;
4
5#[derive(Debug, Clone, new, Hash)]
6pub struct RmDims {
7 pub axes: Vec<isize>,
8}
9
10
11
12impl RmDims {
13 fn compute_shape<D: DimLike>(&self, input: &[D]) -> TVec<D> {
14 let axes = self
15 .axes
16 .iter()
17 .map(|&a| if a < 0 { a + input.len() as isize } else { a } as usize)
18 .collect::<Vec<_>>();
19 input
20 .iter()
21 .enumerate()
22 .filter(|(ix, _d)| !axes.contains(ix))
23 .map(|(_ix, d)| d.clone())
24 .collect()
25 }
26}
27
28impl Expansion for RmDims {
29 fn name(&self) -> StaticName {
30 "RmDims".into()
31 }
32
33
34 fn rules<'r, 'p: 'r, 's: 'r>(
35 &'s self,
36 s: &mut Solver<'r>,
37 inputs: &'p [TensorProxy],
38 outputs: &'p [TensorProxy],
39 ) -> InferenceResult {
40 check_output_arity(outputs, 1)?;
41 s.equals(&outputs[0].datum_type, &inputs[0].datum_type)?;
42 s.equals(&outputs[0].rank, (&inputs[0].rank).bex() - self.axes.len() as i64)?;
43 s.given(&inputs[0].rank, move |s, rank| {
44 for axis in &self.axes {
45 let axis = if *axis < 0 { axis + rank as isize } else { *axis } as usize;
46 s.equals(&inputs[0].shape[axis], 1.to_dim())?;
47 }
48 Ok(())
49 })?;
50 s.given(&inputs[0].shape, move |s, shape| {
51 let output_shape = self.compute_shape(&shape);
52 s.equals(&outputs[0].shape, output_shape)
53 })
54 }
55
56 fn wire(
57 &self,
58 prefix: &str,
59 target: &mut TypedModel,
60 inputs: &[OutletId],
61 ) -> TractResult<TVec<OutletId>> {
62 let mut wire = inputs[0];
63 let rank = target.outlet_fact(inputs[0])?.rank();
64 let axes = self
65 .axes
66 .iter()
67 .map(|&a| if a < 0 { a + rank as isize } else { a } as usize)
68 .sorted()
69 .rev();
70 for axis in axes {
71 wire =
72 target.wire_node(format!("{prefix}.axis-{axis}"), AxisOp::Rm(axis), &[wire])?
73 [0];
74 }
75 Ok(tvec!(wire))
76 }
77}