1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62
use crate::model::ParsingContext; use crate::tfpb::tensorflow::NodeDef; use tract_hir::internal::*; use tract_hir::ops::array::Squeeze; pub fn squeeze(_ctx: &ParsingContext, pb: &NodeDef) -> TractResult<Box<dyn InferenceOp>> { let squeeze_dims = pb.get_attr_opt_list_int("squeeze_dims")?; if let Some(mut squeeze_dims) = squeeze_dims { if squeeze_dims.len() > 0 { squeeze_dims.sort(); return Ok(Box::new(Squeeze::new(Some(squeeze_dims)))); } } Ok(Box::new(Squeeze::default())) } #[cfg(test)] mod tests { #![allow(non_snake_case)] use super::*; use tract_ndarray::Array; fn run<I>(op: Squeeze, input: I) -> Tensor where I: Into<Tensor>, { op.eval(tvec![input.into().into()]).unwrap().pop().unwrap().into_tensor() } #[test] fn squeeze_1() { assert_eq!( run(Squeeze::new(None), Array::from_elem([1, 2, 1, 3, 1, 1], 0)).shape(), &[2, 3] ); } #[test] fn squeeze_2() { assert_eq!( run(Squeeze::new(Some(vec![2, 4])), Array::from_elem([1, 2, 1, 3, 1, 1], 0)).shape(), &[1, 2, 3, 1] ); } #[test] fn squeeze_inference_1() { let input = InferenceFact::default() .with_datum_type(DatumType::TDim) .with_shape(shapefactoid![1, 1, (TDim::stream() - 2), 16]); let any = InferenceFact::default(); let mut op = Squeeze::new(Some(vec![1])); let inferred = op.infer_facts(tvec!(&input), tvec!(&any), tvec!()).unwrap(); let expect: TVec<_> = tvec!(InferenceFact::default() .with_datum_type(DatumType::TDim) .with_shape(shapefactoid![1, (TDim::stream() - 2), 16])); assert_eq!(inferred.1, expect); } }