tract_hir/ops/cnn/
pools.rs1use crate::infer::*;
2use crate::internal::*;
3
4use tract_core::ops::cnn::MaxPool;
5use tract_core::ops::cnn::PoolSpec;
6use tract_core::ops::cnn::SumPool;
7
8#[derive(Debug, Clone, new, Hash)]
9pub struct HirSumPool {
10 pub pool_spec: PoolSpec,
11 pub count_include_pad: bool,
12 pub normalize: bool,
13}
14
15impl Expansion for HirSumPool {
16 fn name(&self) -> StaticName {
17 "SumPool".into()
18 }
19
20 fn rules<'r, 'p: 'r, 's: 'r>(
21 &'s self,
22 s: &mut Solver<'r>,
23 inputs: &'p [TensorProxy],
24 outputs: &'p [TensorProxy],
25 ) -> InferenceResult {
26 check_input_arity(inputs, 1)?;
27 check_output_arity(outputs, 1)?;
28 s.equals(&outputs[0].datum_type, &inputs[0].datum_type)?;
29 rules_for_shape(&self.pool_spec, s, inputs, outputs)
30 }
31
32 fn wire(
33 &self,
34 prefix: &str,
35 model: &mut TypedModel,
36 inputs: &[OutletId],
37 ) -> TractResult<TVec<OutletId>> {
38 let c = self
39 .pool_spec
40 .data_format
41 .shape(&model.outlet_fact(inputs[0])?.shape)?
42 .c()
43 .to_usize()
44 .context("Expect constant integer depth")?;
45 let pool_spec =
46 PoolSpec { input_channels: c, output_channels: c, ..self.pool_spec.clone() };
47 model.wire_node(
48 prefix,
49 SumPool {
50 pool_spec,
51 count_include_pad: self.count_include_pad,
52 normalize: self.normalize,
53 },
54 inputs,
55 )
56 }
57}
58
59#[derive(Debug, Clone, new, Hash)]
60pub struct HirMaxPool {
61 pub pool_spec: PoolSpec,
62 pub with_index_outputs: Option<DatumType>,
63}
64
65impl Expansion for HirMaxPool {
66 fn name(&self) -> StaticName {
67 "MaxPool".into()
68 }
69
70 fn rules<'r, 'p: 'r, 's: 'r>(
71 &'s self,
72 s: &mut Solver<'r>,
73 inputs: &'p [TensorProxy],
74 outputs: &'p [TensorProxy],
75 ) -> InferenceResult {
76 check_output_arity(outputs, 1 + self.with_index_outputs.is_some() as usize)?;
77 s.equals(&outputs[0].rank, &inputs[0].rank)?;
78 s.equals(&outputs[0].datum_type, &inputs[0].datum_type)?;
79 if let Some(idt) = self.with_index_outputs {
80 s.equals(&outputs[1].datum_type, idt)?;
81 s.equals(&outputs[1].shape, &outputs[0].shape)?;
82 }
83 rules_for_shape(&self.pool_spec, s, inputs, outputs)
84 }
85
86 fn nboutputs(&self) -> TractResult<usize> {
87 Ok(1 + self.with_index_outputs.is_some() as usize)
88 }
89
90 fn wire(
91 &self,
92 prefix: &str,
93 model: &mut TypedModel,
94 inputs: &[OutletId],
95 ) -> TractResult<TVec<OutletId>> {
96 let c = self
97 .pool_spec
98 .data_format
99 .shape(&model.outlet_fact(inputs[0])?.shape)?
100 .c()
101 .to_usize()
102 .context("Expect constant integer depth")?;
103 let pool_spec =
104 PoolSpec { input_channels: c, output_channels: c, ..self.pool_spec.clone() };
105 model.wire_node(
106 prefix,
107 MaxPool { pool_spec, with_index_outputs: self.with_index_outputs },
108 inputs,
109 )
110 }
111}
112
113pub fn rules_for_shape<'r, 'p: 'r, 's: 'r>(
114 pool_spec: &'s PoolSpec,
115 s: &mut Solver<'r>,
116 inputs: &'p [TensorProxy],
117 outputs: &'p [TensorProxy],
118) -> InferenceResult {
119 s.equals(&outputs[0].rank, &inputs[0].rank)?;
120 s.given(&inputs[0].shape, move |s, ishape| {
121 let ishape = pool_spec.data_format.shape(ishape)?;
122 let ones = tvec![1; ishape.hw_rank()];
123 let computed = pool_spec.padding.compute(
124 ishape.hw_dims(),
125 &pool_spec.kernel_shape,
126 pool_spec.dilations.as_ref().unwrap_or(&ones),
127 pool_spec.strides.as_ref().unwrap_or(&ones),
128 );
129 for o in outputs {
130 for (ix, d) in computed.iter().enumerate() {
131 s.equals(&o.shape[ix + ishape.h_axis()], &d.convoluted)?;
132 }
133 if ishape.n_axis().is_some() {
134 s.equals(&o.shape[ishape.n_axis().unwrap()], ishape.n_dim().unwrap())?;
135 }
136 if pool_spec.input_channels == 0 && pool_spec.output_channels == 0 {
138 s.equals(&o.shape[ishape.c_axis()], ishape.c_dim())?;
139 }
140 }
141 Ok(())
142 })
143}