tract_hir/ops/array/
squeeze.rs

1use crate::infer::*;
2use crate::internal::*;
3
4use super::RmDims;
5
6#[derive(Debug, Clone, new, Default, Hash)]
7pub struct Squeeze {
8    axes: Option<Vec<isize>>,
9}
10
11impl Squeeze {
12    pub fn output_shape<D: DimLike>(&self, input: &[D]) -> TractResult<TVec<D>> {
13        if let Some(ref axes) = self.axes {
14            let axes = axes
15                .iter()
16                .map(|&a| if a < 0 { a + input.len() as isize } else { a } as usize)
17                .collect::<Vec<_>>();
18            let mut shape: TVec<D> = input.iter().cloned().collect();
19            for &axis in axes.iter().rev() {
20                if shape.remove(axis) != D::one() {
21                    bail!(
22                        "Attempt to squeeze an axis which dimension is not one {:?}, {:?}",
23                        self,
24                        input
25                    );
26                }
27            }
28            Ok(shape)
29        } else {
30            Ok(input.iter().filter(|&d| d != &D::one()).cloned().collect())
31        }
32    }
33}
34
35impl Expansion for Squeeze {
36    fn name(&self) -> StaticName {
37        "Squeeze".into()
38    }
39
40    fn rules<'r, 'p: 'r, 's: 'r>(
41        &'s self,
42        s: &mut Solver<'r>,
43        inputs: &'p [TensorProxy],
44        outputs: &'p [TensorProxy],
45    ) -> InferenceResult {
46        check_output_arity(outputs, 1)?;
47        s.equals(&outputs[0].datum_type, &inputs[0].datum_type)?;
48        if let Some(ref axes) = self.axes {
49            s.equals(&outputs[0].rank, (&inputs[0].rank).bex() - axes.len() as i64)?;
50        }
51        s.given(&inputs[0].shape, move |s, shape| {
52            let output_shape = self.output_shape(&shape)?;
53            s.equals(&outputs[0].shape, output_shape)
54        })
55    }
56
57    fn wire(
58        &self,
59        prefix: &str,
60        target: &mut TypedModel,
61        inputs: &[OutletId],
62    ) -> TractResult<TVec<OutletId>> {
63        let input = inputs[0];
64        let axes = if let Some(axes) = &self.axes {
65            axes.clone()
66        } else {
67            let input_fact = target.outlet_fact(input)?;
68            input_fact
69                .shape
70                .iter()
71                .enumerate()
72                .filter(|(_ix, d)| d.is_one())
73                .map(|(ix, _d)| ix as isize)
74                .collect()
75        };
76        RmDims::new(axes).wire(prefix, target, inputs)
77    }
78}