tract_tensorflow/ops/array/
squeeze.rs1use crate::model::ParsingContext;
2use crate::tfpb::tensorflow::NodeDef;
3use tract_hir::internal::*;
4use tract_hir::ops::array::Squeeze;
5
6pub fn squeeze(_ctx: &ParsingContext, pb: &NodeDef) -> TractResult<Box<dyn InferenceOp>> {
7 let squeeze_dims = pb.get_attr_opt_list_int("squeeze_dims")?;
8 if let Some(mut squeeze_dims) = squeeze_dims {
9 if squeeze_dims.len() > 0 {
10 squeeze_dims.sort();
11 return Ok(expand(Squeeze::new(Some(squeeze_dims))));
12 }
13 }
14 Ok(expand(Squeeze::default()))
15}
16
17#[cfg(test)]
18mod tests {
19 #![allow(non_snake_case)]
20 use super::*;
21 use tract_ndarray::Array;
22
23 fn run<I>(op: Squeeze, input: I) -> Tensor
24 where
25 I: Into<Tensor>,
26 {
27 expand(op).eval(tvec![input.into().into()]).unwrap().pop().unwrap().into_tensor()
28 }
29
30 #[test]
31 fn squeeze_1() {
32 assert_eq!(
33 run(Squeeze::new(None), Array::from_elem([1, 2, 1, 3, 1, 1], 0)).shape(),
34 &[2, 3]
35 );
36 }
37
38 #[test]
39 fn squeeze_2() {
40 assert_eq!(
41 run(Squeeze::new(Some(vec![2, 4])), Array::from_elem([1, 2, 1, 3, 1, 1], 0)).shape(),
42 &[1, 2, 3, 1]
43 );
44 }
45}