1use crate::internal::*;
2
3use crate::ops::cnn::{PaddingSpec, Patch, PatchSpec};
4use crate::ops::nn::{BaseDataShape, DataFormat, DataShape, SymDataShape};
5
6use super::padding::ComputedPaddedDim;
7
8#[derive(Debug, Clone, new, Default, Hash, PartialEq, Eq)]
9pub struct PoolSpec {
10 pub data_format: DataFormat,
11 pub kernel_shape: TVec<usize>,
12 pub padding: PaddingSpec,
13 pub dilations: Option<TVec<usize>>,
14 pub strides: Option<TVec<usize>>,
15 pub input_channels: usize,
16 pub output_channels: usize,
17}
18
19impl PoolSpec {
20 pub fn info(&self) -> Vec<String> {
21 vec![
22 format!("Data format: {:?}", self.data_format),
23 format!(
24 "Kernel shape:{:?} (strides:{:?}, padding:{:?}, dilations:{:?})",
25 self.kernel_shape, self.strides, self.padding, self.dilations,
26 ),
27 ]
28 }
29
30 pub fn rank(&self) -> usize {
31 self.kernel_shape.len()
32 }
33
34 pub fn dilation(&self, geo_axis: usize) -> usize {
35 self.dilations.as_ref().map(|d| d[geo_axis]).unwrap_or(1)
36 }
37
38 pub fn dilations(&self) -> Cow<[usize]> {
39 self.dilations
40 .as_deref()
41 .map_or_else(|| vec![1; self.kernel_shape.len()].into(), |d| d.into())
42 }
43
44 pub fn stride(&self, geo_axis: usize) -> usize {
45 self.strides.as_ref().map(|s| s[geo_axis]).unwrap_or(1)
46 }
47
48 pub fn strides(&self) -> Cow<[usize]> {
49 self.strides
50 .as_deref()
51 .map_or_else(|| vec![1; self.kernel_shape.len()].into(), |d| d.into())
52 }
53
54 pub fn computed_padding<D: DimLike>(&self, input_hw: &[D]) -> TVec<ComputedPaddedDim<D>> {
55 self.padding.compute(input_hw, &self.kernel_shape, &self.dilations(), &self.strides())
56 }
57
58 pub fn output_shape<D: DimLike>(&self, input: &[D]) -> TractResult<BaseDataShape<D, TVec<D>>> {
59 let ishape: BaseDataShape<D, TVec<D>> = self.data_format.shape(input.into())?;
60 ensure!(ishape.c().to_dim() == self.input_channels.to_dim());
61 let computed = self.computed_padding(ishape.hw_dims());
62 let spatial_dims = computed.into_iter().map(|d| d.convoluted).collect::<TVec<D>>();
63 let oshape = self.data_format.from_n_c_hw(
64 ishape.n().cloned().unwrap_or_else(|| 1.into()),
65 self.output_channels.into(),
66 spatial_dims,
67 )?;
68 Ok(oshape)
69 }
70
71 pub fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
72 let oshape = self.output_shape(&inputs[0].shape)?;
73 Ok(tvec!(inputs[0].datum_type.fact(oshape.shape)))
74 }
75
76 pub fn dispose_n_axis(&self) -> PoolSpec {
77 PoolSpec { data_format: self.data_format.dispose_n_axis(), ..self.clone() }
78 }
79
80 pub fn compute_geo(&self, input_full_shape: &[TDim]) -> TractResult<PoolGeometry> {
81 let output_shape = self.output_shape(input_full_shape)?;
82 let input_shape: SymDataShape = self.data_format.shape(input_full_shape.into())?;
83 Ok(PoolGeometry::Symbolic(SymbolicPoolGeometry {
84 pool_spec: self.clone(),
85 input_shape,
86 output_shape,
87 }))
88 }
89
90 pub fn change_geo_axes(&self, op: &AxisOp) -> TractResult<PoolSpec> {
91 let mut dilations = self.dilations().into_owned().into();
92 op.change_shape_array(&mut dilations, false)?;
93 let mut kernel_shape = self.kernel_shape.clone();
94 op.change_shape_array(&mut kernel_shape, false)?;
95 let mut strides = self.strides().into_owned().into();
96 op.change_shape_array(&mut strides, false)?;
97 let padding = self.padding.change_geo_axes(op)?;
98 Ok(PoolSpec {
99 kernel_shape,
100 padding,
101 dilations: Some(dilations),
102 strides: Some(strides),
103 ..self.clone()
104 })
105 }
106
107 pub fn declutter(&self, input: &[TDim]) -> TractResult<Option<PoolSpec>> {
108 if let PaddingSpec::ExplicitOnnxPool(before, after, _) = &self.padding {
109 let input = self.data_format.shape(input)?;
110 let input_hw = input.hw_dims();
111 let reference = self.computed_padding(input_hw);
112 for replacement in [
113 PaddingSpec::Valid,
114 PaddingSpec::SameUpper,
115 PaddingSpec::SameLower,
116 PaddingSpec::Explicit(before.clone(), after.clone()),
117 ] {
118 let new_pool_spec = PoolSpec { padding: replacement, ..self.clone() };
119 if new_pool_spec.computed_padding(input_hw) == reference {
120 return Ok(Some(new_pool_spec));
121 }
122 }
123 }
124 Ok(None)
125 }
126}
127
128pub type PoolGeometry = super::GeometryBound<SymbolicPoolGeometry, ConcretePoolGeometry>;
129
130#[derive(Debug, Clone, Hash, PartialEq, Eq)]
131pub struct SymbolicPoolGeometry {
132 pub pool_spec: PoolSpec,
133 pub input_shape: SymDataShape,
134 pub output_shape: SymDataShape,
135}
136
137#[derive(Debug, Clone, Hash, PartialEq, Eq)]
138pub struct ConcretePoolGeometry {
139 pub input_shape: DataShape,
140 pub patch: Patch,
141 pub output_shape: DataShape,
142}
143
144impl super::ResolveTo<ConcretePoolGeometry> for SymbolicPoolGeometry {
145 type Param = [usize];
146 fn resolve(&self, input_full_shape: &[usize]) -> TractResult<ConcretePoolGeometry> {
147 let input_shape = self.pool_spec.data_format.shape(input_full_shape.into())?;
148 let output_inner_stride = match self.pool_spec.data_format {
149 DataFormat::NCHW | DataFormat::CHW => 1,
150 DataFormat::NHWC | DataFormat::HWC => self.pool_spec.output_channels,
151 };
152 let mut spec = PatchSpec::for_full_shape(self.pool_spec.data_format, input_full_shape)?
153 .with_output_inner_stride(output_inner_stride)
154 .with_kernel_shape(self.pool_spec.kernel_shape.clone())
155 .with_padding(self.pool_spec.padding.clone());
156 if let Some(strides) = self.pool_spec.strides.clone() {
157 spec = spec.with_strides(strides);
158 }
159 if let Some(dilations) = self.pool_spec.dilations.clone() {
160 spec = spec.with_dilations(dilations);
161 }
162 let patch = spec.into_patch();
163 let output_shape = input_shape.fmt.from_n_c_hw(
164 *input_shape.n().unwrap_or(&1),
165 self.pool_spec.output_channels,
166 &*patch.output_shape,
167 )?;
168 Ok(ConcretePoolGeometry { input_shape, patch, output_shape })
169 }
170}