tract_hir/ops/array/
squeeze.rs1use crate::infer::*;
2use crate::internal::*;
3
4use super::RmDims;
5
6#[derive(Debug, Clone, new, Default, Hash)]
7pub struct Squeeze {
8 axes: Option<Vec<isize>>,
9}
10
11impl Squeeze {
12 pub fn output_shape<D: DimLike>(&self, input: &[D]) -> TractResult<TVec<D>> {
13 if let Some(ref axes) = self.axes {
14 let axes = axes
15 .iter()
16 .map(|&a| if a < 0 { a + input.len() as isize } else { a } as usize)
17 .collect::<Vec<_>>();
18 let mut shape: TVec<D> = input.iter().cloned().collect();
19 for &axis in axes.iter().rev() {
20 if shape.remove(axis) != D::one() {
21 bail!(
22 "Attempt to squeeze an axis which dimension is not one {:?}, {:?}",
23 self,
24 input
25 );
26 }
27 }
28 Ok(shape)
29 } else {
30 Ok(input.iter().filter(|&d| d != &D::one()).cloned().collect())
31 }
32 }
33}
34
35impl Expansion for Squeeze {
36 fn name(&self) -> StaticName {
37 "Squeeze".into()
38 }
39
40 fn rules<'r, 'p: 'r, 's: 'r>(
41 &'s self,
42 s: &mut Solver<'r>,
43 inputs: &'p [TensorProxy],
44 outputs: &'p [TensorProxy],
45 ) -> InferenceResult {
46 check_output_arity(outputs, 1)?;
47 s.equals(&outputs[0].datum_type, &inputs[0].datum_type)?;
48 if let Some(ref axes) = self.axes {
49 s.equals(&outputs[0].rank, (&inputs[0].rank).bex() - axes.len() as i64)?;
50 }
51 s.given(&inputs[0].shape, move |s, shape| {
52 let output_shape = self.output_shape(&shape)?;
53 s.equals(&outputs[0].shape, output_shape)
54 })
55 }
56
57 fn wire(
58 &self,
59 prefix: &str,
60 target: &mut TypedModel,
61 inputs: &[OutletId],
62 ) -> TractResult<TVec<OutletId>> {
63 let input = inputs[0];
64 let axes = if let Some(axes) = &self.axes {
65 axes.clone()
66 } else {
67 let input_fact = target.outlet_fact(input)?;
68 input_fact
69 .shape
70 .iter()
71 .enumerate()
72 .filter(|(_ix, d)| d.is_one())
73 .map(|(ix, _d)| ix as isize)
74 .collect()
75 };
76 RmDims::new(axes).wire(prefix, target, inputs)
77 }
78}