tract_core/ops/cnn/
pools.rs

1use crate::internal::*;
2
3use crate::ops::cnn::{PaddingSpec, Patch, PatchSpec};
4use crate::ops::nn::{BaseDataShape, DataFormat, DataShape, SymDataShape};
5
6use super::padding::ComputedPaddedDim;
7
8#[derive(Debug, Clone, new, Default, Hash, PartialEq, Eq)]
9pub struct PoolSpec {
10    pub data_format: DataFormat,
11    pub kernel_shape: TVec<usize>,
12    pub padding: PaddingSpec,
13    pub dilations: Option<TVec<usize>>,
14    pub strides: Option<TVec<usize>>,
15    pub input_channels: usize,
16    pub output_channels: usize,
17}
18
19impl PoolSpec {
20    pub fn info(&self) -> Vec<String> {
21        vec![
22            format!("Data format: {:?}", self.data_format),
23            format!(
24                "Kernel shape:{:?} (strides:{:?}, padding:{:?}, dilations:{:?})",
25                self.kernel_shape, self.strides, self.padding, self.dilations,
26            ),
27        ]
28    }
29
30    pub fn rank(&self) -> usize {
31        self.kernel_shape.len()
32    }
33
34    pub fn dilation(&self, geo_axis: usize) -> usize {
35        self.dilations.as_ref().map(|d| d[geo_axis]).unwrap_or(1)
36    }
37
38    pub fn dilations(&self) -> Cow<[usize]> {
39        self.dilations
40            .as_deref()
41            .map_or_else(|| vec![1; self.kernel_shape.len()].into(), |d| d.into())
42    }
43
44    pub fn stride(&self, geo_axis: usize) -> usize {
45        self.strides.as_ref().map(|s| s[geo_axis]).unwrap_or(1)
46    }
47
48    pub fn strides(&self) -> Cow<[usize]> {
49        self.strides
50            .as_deref()
51            .map_or_else(|| vec![1; self.kernel_shape.len()].into(), |d| d.into())
52    }
53
54    pub fn computed_padding<D: DimLike>(&self, input_hw: &[D]) -> TVec<ComputedPaddedDim<D>> {
55        self.padding.compute(input_hw, &self.kernel_shape, &self.dilations(), &self.strides())
56    }
57
58    pub fn output_shape<D: DimLike>(&self, input: &[D]) -> TractResult<BaseDataShape<D, TVec<D>>> {
59        let ishape: BaseDataShape<D, TVec<D>> = self.data_format.shape(input.into())?;
60        ensure!(ishape.c().to_dim() == self.input_channels.to_dim());
61        let computed = self.computed_padding(ishape.hw_dims());
62        let spatial_dims = computed.into_iter().map(|d| d.convoluted).collect::<TVec<D>>();
63        let oshape = self.data_format.from_n_c_hw(
64            ishape.n().cloned().unwrap_or_else(|| 1.into()),
65            self.output_channels.into(),
66            spatial_dims,
67        )?;
68        Ok(oshape)
69    }
70
71    pub fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
72        let oshape = self.output_shape(&inputs[0].shape)?;
73        Ok(tvec!(inputs[0].datum_type.fact(oshape.shape)))
74    }
75
76    pub fn dispose_n_axis(&self) -> PoolSpec {
77        PoolSpec { data_format: self.data_format.dispose_n_axis(), ..self.clone() }
78    }
79
80    pub fn compute_geo(&self, input_full_shape: &[TDim]) -> TractResult<PoolGeometry> {
81        let output_shape = self.output_shape(input_full_shape)?;
82        let input_shape: SymDataShape = self.data_format.shape(input_full_shape.into())?;
83        Ok(PoolGeometry::Symbolic(SymbolicPoolGeometry {
84            pool_spec: self.clone(),
85            input_shape,
86            output_shape,
87        }))
88    }
89
90    pub fn change_geo_axes(&self, op: &AxisOp) -> TractResult<PoolSpec> {
91        let mut dilations = self.dilations().into_owned().into();
92        op.change_shape_array(&mut dilations, false)?;
93        let mut kernel_shape = self.kernel_shape.clone();
94        op.change_shape_array(&mut kernel_shape, false)?;
95        let mut strides = self.strides().into_owned().into();
96        op.change_shape_array(&mut strides, false)?;
97        let padding = self.padding.change_geo_axes(op)?;
98        Ok(PoolSpec {
99            kernel_shape,
100            padding,
101            dilations: Some(dilations),
102            strides: Some(strides),
103            ..self.clone()
104        })
105    }
106
107    pub fn declutter(&self, input: &[TDim]) -> TractResult<Option<PoolSpec>> {
108        if let PaddingSpec::ExplicitOnnxPool(before, after, _) = &self.padding {
109            let input = self.data_format.shape(input)?;
110            let input_hw = input.hw_dims();
111            let reference = self.computed_padding(input_hw);
112            for replacement in [
113                PaddingSpec::Valid,
114                PaddingSpec::SameUpper,
115                PaddingSpec::SameLower,
116                PaddingSpec::Explicit(before.clone(), after.clone()),
117            ] {
118                let new_pool_spec = PoolSpec { padding: replacement, ..self.clone() };
119                if new_pool_spec.computed_padding(input_hw) == reference {
120                    return Ok(Some(new_pool_spec));
121                }
122            }
123        }
124        Ok(None)
125    }
126}
127
128pub type PoolGeometry = super::GeometryBound<SymbolicPoolGeometry, ConcretePoolGeometry>;
129
130#[derive(Debug, Clone, Hash, PartialEq, Eq)]
131pub struct SymbolicPoolGeometry {
132    pub pool_spec: PoolSpec,
133    pub input_shape: SymDataShape,
134    pub output_shape: SymDataShape,
135}
136
137#[derive(Debug, Clone, Hash, PartialEq, Eq)]
138pub struct ConcretePoolGeometry {
139    pub input_shape: DataShape,
140    pub patch: Patch,
141    pub output_shape: DataShape,
142}
143
144impl super::ResolveTo<ConcretePoolGeometry> for SymbolicPoolGeometry {
145    type Param = [usize];
146    fn resolve(&self, input_full_shape: &[usize]) -> TractResult<ConcretePoolGeometry> {
147        let input_shape = self.pool_spec.data_format.shape(input_full_shape.into())?;
148        let output_inner_stride = match self.pool_spec.data_format {
149            DataFormat::NCHW | DataFormat::CHW => 1,
150            DataFormat::NHWC | DataFormat::HWC => self.pool_spec.output_channels,
151        };
152        let mut spec = PatchSpec::for_full_shape(self.pool_spec.data_format, input_full_shape)?
153            .with_output_inner_stride(output_inner_stride)
154            .with_kernel_shape(self.pool_spec.kernel_shape.clone())
155            .with_padding(self.pool_spec.padding.clone());
156        if let Some(strides) = self.pool_spec.strides.clone() {
157            spec = spec.with_strides(strides);
158        }
159        if let Some(dilations) = self.pool_spec.dilations.clone() {
160            spec = spec.with_dilations(dilations);
161        }
162        let patch = spec.into_patch();
163        let output_shape = input_shape.fmt.from_n_c_hw(
164            *input_shape.n().unwrap_or(&1),
165            self.pool_spec.output_channels,
166            &*patch.output_shape,
167        )?;
168        Ok(ConcretePoolGeometry { input_shape, patch, output_shape })
169    }
170}