tract_tensorflow/ops/nn/s2b/
raw.rs1use tract_hir::internal::*;
2use tract_ndarray::prelude::*;
3
4#[derive(Debug, Clone, new, Hash)]
5pub struct SpaceToBatch {
6 datum_type: DatumType,
7}
8
9
10
11impl Op for SpaceToBatch {
12 fn name(&self) -> StaticName {
13 "SpaceToBatch".into()
14 }
15
16 not_a_typed_op!();
17}
18
19impl EvalOp for SpaceToBatch {
20 fn is_stateless(&self) -> bool {
21 true
22 }
23
24 fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
25 let (input, block_shape, paddings) = args_3!(inputs);
26 let block_shape = block_shape.cast_to::<i32>()?;
27 let block_shape = block_shape.to_array_view::<i32>()?.into_dimensionality()?;
28 let paddings = paddings.cast_to::<i32>()?;
29 let paddings = paddings.to_array_view::<i32>()?.into_dimensionality()?;
30 let r = dispatch_numbers!(super::space_to_batch(input.datum_type())(
31 input,
32 &block_shape.view(),
33 &paddings.view()
34 ))?;
35 Ok(tvec!(r))
36 }
37}
38
39impl InferenceRulesOp for SpaceToBatch {
40 fn rules<'r, 'p: 'r, 's: 'r>(
42 &'s self,
43 s: &mut Solver<'r>,
44 inputs: &'p [TensorProxy],
45 outputs: &'p [TensorProxy],
46 ) -> InferenceResult {
47 check_input_arity(inputs, 3)?;
48 check_output_arity(outputs, 1)?;
49 rules(s, self.datum_type, &outputs[0], &inputs[0], &inputs[1], &inputs[2])
50 }
51
52 as_op!();
53
54 fn to_typed(
55 &self,
56 _source: &InferenceModel,
57 node: &InferenceNode,
58 target: &mut TypedModel,
59 mapping: &HashMap<OutletId, OutletId>,
60 ) -> TractResult<TVec<OutletId>> {
61 if let (Some(block_shape), Some(paddings)) = (
62 target.outlet_fact(mapping[&node.inputs[1]])?.konst.clone(),
63 target.outlet_fact(mapping[&node.inputs[2]])?.konst.clone(),
64 ) {
65 let paddings = paddings.cast_to::<TDim>()?;
66 let paddings_view = paddings.to_array_view::<TDim>()?.into_dimensionality::<Ix2>()?;
67 let mut paddings = tvec![];
68 for p in paddings_view.outer_iter() {
69 let pad = match (p[0].to_usize(), p[1].to_usize()) {
70 (Ok(bef), Ok(aft)) => super::unary::PaddingStrat::FixedFixed(bef, aft),
71 (_, Ok(aft)) => super::unary::PaddingStrat::FlexFixed(aft),
72 (Ok(bef), _) => super::unary::PaddingStrat::FixedFlex(bef),
73 _ => bail!("Failed to unarize SpaceToBatch because of padding"),
74 };
75 paddings.push(pad);
76 }
77 let op = super::unary::SpaceToBatchUnary::new(
78 self.datum_type,
79 target.outlet_fact(mapping[&node.inputs[0]])?.shape.to_tvec(),
80 node.outputs[0]
81 .fact
82 .shape
83 .concretize()
84 .unwrap()
85 .iter()
86 .cloned()
87 .collect::<TVec<_>>(),
88 block_shape.into_tensor().into_array::<i32>()?.into_dimensionality()?,
89 paddings,
90 );
91 target.wire_node(&*node.name, op, [mapping[&node.inputs[0]]].as_ref())
92 } else {
93 bail!("Need fixed block shape and padding")
94 }
95 }
96}
97
98#[derive(Debug, Clone, new, Hash)]
99pub struct BatchToSpace {
100 datum_type: DatumType,
101}
102
103
104
105impl Op for BatchToSpace {
106 fn name(&self) -> StaticName {
107 "BatchToSpace".into()
108 }
109
110 not_a_typed_op!();
111}
112
113impl EvalOp for BatchToSpace {
114 fn is_stateless(&self) -> bool {
115 true
116 }
117
118 fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
119 let (input, block_shape, crops) = args_3!(inputs);
120 let block_shape = block_shape.cast_to::<i32>()?;
121 let block_shape = block_shape.to_array_view::<i32>()?.into_dimensionality()?;
122 let crops = crops.cast_to::<i32>()?;
123 let crops = crops.to_array_view::<i32>()?.into_dimensionality()?;
124 let r = dispatch_numbers!(super::batch_to_space(input.datum_type())(
125 input,
126 &block_shape.view(),
127 &crops.view()
128 ))?;
129 Ok(tvec!(r))
130 }
131}
132
133impl InferenceRulesOp for BatchToSpace {
134 fn rules<'r, 'p: 'r, 's: 'r>(
136 &'s self,
137 s: &mut Solver<'r>,
138 inputs: &'p [TensorProxy],
139 outputs: &'p [TensorProxy],
140 ) -> InferenceResult {
141 check_input_arity(inputs, 3)?;
142 check_output_arity(outputs, 1)?;
143 rules(s, self.datum_type, &inputs[0], &outputs[0], &inputs[1], &inputs[2])
144 }
145
146 fn to_typed(
147 &self,
148 _source: &InferenceModel,
149 node: &InferenceNode,
150 target: &mut TypedModel,
151 mapping: &HashMap<OutletId, OutletId>,
152 ) -> TractResult<TVec<OutletId>> {
153 if let (Some(block_shape), Some(paddings)) = (
154 target.outlet_fact(mapping[&node.inputs[1]])?.konst.clone(),
155 target.outlet_fact(mapping[&node.inputs[2]])?.konst.clone(),
156 ) {
157 let paddings = paddings.cast_to::<TDim>()?;
158 let paddings = paddings.to_array_view::<TDim>()?.into_dimensionality::<Ix2>()?;
159 let paddings = paddings
160 .outer_iter()
161 .map(|p| {
162 Ok(match (p[0].to_usize(), p[1].to_usize()) {
163 (Ok(bef), Ok(aft)) => super::unary::PaddingStrat::FixedFixed(bef, aft),
164 (_, Ok(aft)) => super::unary::PaddingStrat::FlexFixed(aft),
165 (Ok(bef), _) => super::unary::PaddingStrat::FixedFlex(bef),
166 _ => bail!("Failed to unarize SpaceToBatch because of padding"),
167 })
168 })
169 .collect::<TractResult<_>>()?;
170 let op = super::unary::BatchToSpaceUnary::new(
171 self.datum_type,
172 target.outlet_fact(mapping[&node.inputs[0]])?.shape.to_tvec(),
173 node.outputs[0]
174 .fact
175 .shape
176 .concretize()
177 .unwrap()
178 .iter()
179 .cloned()
180 .collect::<TVec<_>>(),
181 block_shape.into_tensor().into_array::<i32>()?.into_dimensionality()?,
182 paddings,
183 );
184 target.wire_node(&*node.name, op, [mapping[&node.inputs[0]]].as_ref())
185 } else {
186 bail!("Need fixed block shape and padding")
187 }
188 }
189 as_op!();
190}
191
192fn rules<'r, 'p: 'r>(
193 s: &mut Solver<'r>,
194 datum_type: DatumType,
195 batch: &'p TensorProxy,
196 space: &'p TensorProxy,
197 block_shape: &'p TensorProxy,
198 paddings: &'p TensorProxy,
199) -> InferenceResult {
200 s.equals(&batch.datum_type, datum_type)?;
201 s.equals(&batch.datum_type, &space.datum_type)?;
202 s.equals(&block_shape.datum_type, DatumType::I32)?;
203 s.equals(&batch.rank, &space.rank)?;
204 s.equals(&block_shape.rank, 1)?;
205 s.equals(&paddings.rank, 2)?;
206 s.equals(&block_shape.shape[0], &paddings.shape[0])?;
207 s.given(&block_shape.value, move |s, block_shape| {
208 let block_shape = block_shape.into_tensor().into_array::<i32>()?;
209 let block_shape_prod = block_shape.iter().map(|s| *s as usize).product::<usize>();
210 s.equals(&batch.shape[0], (block_shape_prod as i64) * space.shape[0].bex())?;
211 s.given(&paddings.value, move |s, paddings| {
212 let paddings = paddings.cast_to::<TDim>()?;
213 let paddings = paddings.to_array_view::<TDim>()?.into_dimensionality()?;
214 for d in 0..block_shape.len() {
215 s.equals(
216 space.shape[1 + d].bex() + &paddings[(d, 0)] + &paddings[(d, 1)],
217 (block_shape[d] as i64) * batch.shape[1 + d].bex(),
218 )?;
219 }
220 Ok(())
221 })
222 })?;
223 s.given(&block_shape.value, move |s, block_shape| {
224 let block_shape = block_shape.into_tensor().into_array::<i32>()?;
225 s.given(&space.rank, move |s, rank: i64| {
226 for d in block_shape.len() + 1..(rank as usize) {
227 s.equals(&space.shape[d], &batch.shape[d])?
228 }
229 Ok(())
230 })
231 })
232}