tract_tensorflow/ops/nn/s2b/
unary.rs1use tract_hir::internal::*;
2use tract_ndarray::prelude::*;
3
4use tract_hir::tract_core::ops::cnn::{Conv, PoolSpec};
5
6#[derive(Debug, Copy, Clone, Hash)]
7pub enum PaddingStrat {
8 FlexFixed(usize),
9 FixedFlex(usize),
10 FixedFixed(usize, usize),
11}
12
13#[derive(Debug, Clone, new, Hash)]
14pub struct SpaceToBatchUnary {
15 pub datum_type: DatumType,
16 pub space_shape: TVec<TDim>,
17 pub batch_shape: TVec<TDim>,
18 pub block_shape: Array1<i32>,
19 pub pad: TVec<PaddingStrat>,
20}
21
22impl Op for SpaceToBatchUnary {
23 fn name(&self) -> StaticName {
24 "SpaceToBatchUnary".into()
25 }
26
27 op_as_typed_op!();
28}
29
30impl EvalOp for SpaceToBatchUnary {
31 fn is_stateless(&self) -> bool {
32 true
33 }
34
35 fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
36 let input = args_1!(inputs);
37 let mut paddings = Array2::zeros((self.block_shape.len(), 2));
38 for (ax, &strat) in self.pad.iter().enumerate() {
39 let spread = (self.batch_shape[2 + ax].clone() * self.block_shape[ax]
40 - &self.space_shape[2 + ax])
41 .to_usize()?;
42 let (bef, aft) = match strat {
43 PaddingStrat::FlexFixed(f) => (spread - f, f),
44 PaddingStrat::FixedFlex(f) => (f, spread - f),
45 PaddingStrat::FixedFixed(a, b) => (a, b),
46 };
47 paddings[(ax, 0)] = bef as i32;
48 paddings[(ax, 1)] = aft as i32;
49 }
50 let r = dispatch_numbers!(super::space_to_batch(input.datum_type())(
51 input,
52 &self.block_shape.view(),
53 &paddings.view()
54 ))?;
55 Ok(tvec!(r))
56 }
57}
58
59impl TypedOp for SpaceToBatchUnary {
60 fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
61 Ok(tvec!(inputs[0].datum_type.fact(&self.batch_shape)))
62 }
63
64 fn declutter(
65 &self,
66 model: &TypedModel,
67 node: &TypedNode,
68 ) -> TractResult<Option<TypedModelPatch>> {
69 let [succ] = &*model.node(node.id).outputs[0].successors else { return Ok(None) };
70 let conv_node = model.node(succ.node);
71 let Some(conv_op) = conv_node.op_as::<Conv>() else { return Ok(None) };
72 let [succ] = &*conv_node.outputs[0].successors else { return Ok(None) };
73 let b2s_node = model.node(succ.node);
74 let Some(_bs2_op) = b2s_node.op_as::<BatchToSpaceUnary>() else { return Ok(None) };
75 let op = Conv {
76 pool_spec: PoolSpec {
77 dilations: Some(self.block_shape.iter().map(|&i| i as usize).collect()),
78 ..conv_op.pool_spec.clone()
79 },
80 ..conv_op.clone()
81 };
82 let mut patch = TypedModelPatch::default();
83 let taps_s2b = patch.taps(model, &node.inputs)?;
84 let mut taps_conv = patch.taps(model, &conv_node.inputs)?;
85 taps_conv[0] = taps_s2b[0];
86 let out = patch.model.wire_node(&*conv_node.name, op, &taps_conv)?[0];
87 patch.shunt_outside(model, OutletId::new(b2s_node.id, 0), out)?;
88 Ok(Some(patch))
89 }
90
91 as_op!();
92}
93
94#[derive(Debug, Clone, new, Hash)]
95pub struct BatchToSpaceUnary {
96 datum_type: DatumType,
97 batch_shape: TVec<TDim>,
98 space_shape: TVec<TDim>,
99 block_shape: Array1<i32>,
100 pad: Vec<PaddingStrat>,
101}
102
103impl Op for BatchToSpaceUnary {
104 fn name(&self) -> StaticName {
105 "BatchToSpaceUnary".into()
106 }
107
108 op_as_typed_op!();
109}
110
111impl EvalOp for BatchToSpaceUnary {
112 fn is_stateless(&self) -> bool {
113 true
114 }
115
116 fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
117 let input = args_1!(inputs);
118 let mut paddings = Array2::zeros((self.block_shape.len(), 2));
119 for (ax, &strat) in self.pad.iter().enumerate() {
120 let spread = (self.batch_shape[2 + ax].clone() * self.block_shape[ax]
121 - &self.space_shape[2 + ax])
122 .to_usize()?;
123 let (bef, aft) = match strat {
124 PaddingStrat::FlexFixed(f) => (spread - f, f),
125 PaddingStrat::FixedFlex(f) => (f, spread - f),
126 PaddingStrat::FixedFixed(a, b) => (a, b),
127 };
128 paddings[(ax, 0)] = bef as i32;
129 paddings[(ax, 1)] = aft as i32;
130 }
131 let r = dispatch_numbers!(super::batch_to_space(input.datum_type())(
132 input,
133 &self.block_shape.view(),
134 &paddings.view()
135 ))?;
136 Ok(tvec!(r))
137 }
138}
139
140impl TypedOp for BatchToSpaceUnary {
141 as_op!();
142
143 fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
144 Ok(tvec!(inputs[0].datum_type.fact(&self.space_shape)))
145 }
146}