tract_tensorflow/ops/nn/s2b/
raw.rs

1use tract_hir::internal::*;
2use tract_ndarray::prelude::*;
3
4#[derive(Debug, Clone, new, Hash)]
5pub struct SpaceToBatch {
6    datum_type: DatumType,
7}
8
9
10
11impl Op for SpaceToBatch {
12    fn name(&self) -> StaticName {
13        "SpaceToBatch".into()
14    }
15
16    not_a_typed_op!();
17}
18
19impl EvalOp for SpaceToBatch {
20    fn is_stateless(&self) -> bool {
21        true
22    }
23
24    fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
25        let (input, block_shape, paddings) = args_3!(inputs);
26        let block_shape = block_shape.cast_to::<i32>()?;
27        let block_shape = block_shape.to_array_view::<i32>()?.into_dimensionality()?;
28        let paddings = paddings.cast_to::<i32>()?;
29        let paddings = paddings.to_array_view::<i32>()?.into_dimensionality()?;
30        let r = dispatch_numbers!(super::space_to_batch(input.datum_type())(
31            input,
32            &block_shape.view(),
33            &paddings.view()
34        ))?;
35        Ok(tvec!(r))
36    }
37}
38
39impl InferenceRulesOp for SpaceToBatch {
40    /// Registers the inference rules of the operator.
41    fn rules<'r, 'p: 'r, 's: 'r>(
42        &'s self,
43        s: &mut Solver<'r>,
44        inputs: &'p [TensorProxy],
45        outputs: &'p [TensorProxy],
46    ) -> InferenceResult {
47        check_input_arity(inputs, 3)?;
48        check_output_arity(outputs, 1)?;
49        rules(s, self.datum_type, &outputs[0], &inputs[0], &inputs[1], &inputs[2])
50    }
51
52    as_op!();
53
54    fn to_typed(
55        &self,
56        _source: &InferenceModel,
57        node: &InferenceNode,
58        target: &mut TypedModel,
59        mapping: &HashMap<OutletId, OutletId>,
60    ) -> TractResult<TVec<OutletId>> {
61        if let (Some(block_shape), Some(paddings)) = (
62            target.outlet_fact(mapping[&node.inputs[1]])?.konst.clone(),
63            target.outlet_fact(mapping[&node.inputs[2]])?.konst.clone(),
64        ) {
65            let paddings = paddings.cast_to::<TDim>()?;
66            let paddings_view = paddings.to_array_view::<TDim>()?.into_dimensionality::<Ix2>()?;
67            let mut paddings = tvec![];
68            for p in paddings_view.outer_iter() {
69                let pad = match (p[0].to_usize(), p[1].to_usize()) {
70                    (Ok(bef), Ok(aft)) => super::unary::PaddingStrat::FixedFixed(bef, aft),
71                    (_, Ok(aft)) => super::unary::PaddingStrat::FlexFixed(aft),
72                    (Ok(bef), _) => super::unary::PaddingStrat::FixedFlex(bef),
73                    _ => bail!("Failed to unarize SpaceToBatch because of padding"),
74                };
75                paddings.push(pad);
76            }
77            let op = super::unary::SpaceToBatchUnary::new(
78                self.datum_type,
79                target.outlet_fact(mapping[&node.inputs[0]])?.shape.to_tvec(),
80                node.outputs[0]
81                    .fact
82                    .shape
83                    .concretize()
84                    .unwrap()
85                    .iter()
86                    .cloned()
87                    .collect::<TVec<_>>(),
88                block_shape.into_tensor().into_array::<i32>()?.into_dimensionality()?,
89                paddings,
90            );
91            target.wire_node(&*node.name, op, [mapping[&node.inputs[0]]].as_ref())
92        } else {
93            bail!("Need fixed block shape and padding")
94        }
95    }
96}
97
98#[derive(Debug, Clone, new, Hash)]
99pub struct BatchToSpace {
100    datum_type: DatumType,
101}
102
103
104
105impl Op for BatchToSpace {
106    fn name(&self) -> StaticName {
107        "BatchToSpace".into()
108    }
109
110    not_a_typed_op!();
111}
112
113impl EvalOp for BatchToSpace {
114    fn is_stateless(&self) -> bool {
115        true
116    }
117
118    fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
119        let (input, block_shape, crops) = args_3!(inputs);
120        let block_shape = block_shape.cast_to::<i32>()?;
121        let block_shape = block_shape.to_array_view::<i32>()?.into_dimensionality()?;
122        let crops = crops.cast_to::<i32>()?;
123        let crops = crops.to_array_view::<i32>()?.into_dimensionality()?;
124        let r = dispatch_numbers!(super::batch_to_space(input.datum_type())(
125            input,
126            &block_shape.view(),
127            &crops.view()
128        ))?;
129        Ok(tvec!(r))
130    }
131}
132
133impl InferenceRulesOp for BatchToSpace {
134    /// Registers the inference rules of the operator.
135    fn rules<'r, 'p: 'r, 's: 'r>(
136        &'s self,
137        s: &mut Solver<'r>,
138        inputs: &'p [TensorProxy],
139        outputs: &'p [TensorProxy],
140    ) -> InferenceResult {
141        check_input_arity(inputs, 3)?;
142        check_output_arity(outputs, 1)?;
143        rules(s, self.datum_type, &inputs[0], &outputs[0], &inputs[1], &inputs[2])
144    }
145
146    fn to_typed(
147        &self,
148        _source: &InferenceModel,
149        node: &InferenceNode,
150        target: &mut TypedModel,
151        mapping: &HashMap<OutletId, OutletId>,
152    ) -> TractResult<TVec<OutletId>> {
153        if let (Some(block_shape), Some(paddings)) = (
154            target.outlet_fact(mapping[&node.inputs[1]])?.konst.clone(),
155            target.outlet_fact(mapping[&node.inputs[2]])?.konst.clone(),
156        ) {
157            let paddings = paddings.cast_to::<TDim>()?;
158            let paddings = paddings.to_array_view::<TDim>()?.into_dimensionality::<Ix2>()?;
159            let paddings = paddings
160                .outer_iter()
161                .map(|p| {
162                    Ok(match (p[0].to_usize(), p[1].to_usize()) {
163                        (Ok(bef), Ok(aft)) => super::unary::PaddingStrat::FixedFixed(bef, aft),
164                        (_, Ok(aft)) => super::unary::PaddingStrat::FlexFixed(aft),
165                        (Ok(bef), _) => super::unary::PaddingStrat::FixedFlex(bef),
166                        _ => bail!("Failed to unarize SpaceToBatch because of padding"),
167                    })
168                })
169                .collect::<TractResult<_>>()?;
170            let op = super::unary::BatchToSpaceUnary::new(
171                self.datum_type,
172                target.outlet_fact(mapping[&node.inputs[0]])?.shape.to_tvec(),
173                node.outputs[0]
174                    .fact
175                    .shape
176                    .concretize()
177                    .unwrap()
178                    .iter()
179                    .cloned()
180                    .collect::<TVec<_>>(),
181                block_shape.into_tensor().into_array::<i32>()?.into_dimensionality()?,
182                paddings,
183            );
184            target.wire_node(&*node.name, op, [mapping[&node.inputs[0]]].as_ref())
185        } else {
186            bail!("Need fixed block shape and padding")
187        }
188    }
189    as_op!();
190}
191
192fn rules<'r, 'p: 'r>(
193    s: &mut Solver<'r>,
194    datum_type: DatumType,
195    batch: &'p TensorProxy,
196    space: &'p TensorProxy,
197    block_shape: &'p TensorProxy,
198    paddings: &'p TensorProxy,
199) -> InferenceResult {
200    s.equals(&batch.datum_type, datum_type)?;
201    s.equals(&batch.datum_type, &space.datum_type)?;
202    s.equals(&block_shape.datum_type, DatumType::I32)?;
203    s.equals(&batch.rank, &space.rank)?;
204    s.equals(&block_shape.rank, 1)?;
205    s.equals(&paddings.rank, 2)?;
206    s.equals(&block_shape.shape[0], &paddings.shape[0])?;
207    s.given(&block_shape.value, move |s, block_shape| {
208        let block_shape = block_shape.into_tensor().into_array::<i32>()?;
209        let block_shape_prod = block_shape.iter().map(|s| *s as usize).product::<usize>();
210        s.equals(&batch.shape[0], (block_shape_prod as i64) * space.shape[0].bex())?;
211        s.given(&paddings.value, move |s, paddings| {
212            let paddings = paddings.cast_to::<TDim>()?;
213            let paddings = paddings.to_array_view::<TDim>()?.into_dimensionality()?;
214            for d in 0..block_shape.len() {
215                s.equals(
216                    space.shape[1 + d].bex() + &paddings[(d, 0)] + &paddings[(d, 1)],
217                    (block_shape[d] as i64) * batch.shape[1 + d].bex(),
218                )?;
219            }
220            Ok(())
221        })
222    })?;
223    s.given(&block_shape.value, move |s, block_shape| {
224        let block_shape = block_shape.into_tensor().into_array::<i32>()?;
225        s.given(&space.rank, move |s, rank: i64| {
226            for d in block_shape.len() + 1..(rank as usize) {
227                s.equals(&space.shape[d], &batch.shape[d])?
228            }
229            Ok(())
230        })
231    })
232}