tract_tensorflow/ops/nn/s2b/
unary.rs

1use tract_hir::internal::*;
2use tract_ndarray::prelude::*;
3
4use tract_hir::tract_core::ops::cnn::{Conv, PoolSpec};
5
6#[derive(Debug, Copy, Clone, Hash)]
7pub enum PaddingStrat {
8    FlexFixed(usize),
9    FixedFlex(usize),
10    FixedFixed(usize, usize),
11}
12
13#[derive(Debug, Clone, new, Hash)]
14pub struct SpaceToBatchUnary {
15    pub datum_type: DatumType,
16    pub space_shape: TVec<TDim>,
17    pub batch_shape: TVec<TDim>,
18    pub block_shape: Array1<i32>,
19    pub pad: TVec<PaddingStrat>,
20}
21
22impl Op for SpaceToBatchUnary {
23    fn name(&self) -> StaticName {
24        "SpaceToBatchUnary".into()
25    }
26
27    op_as_typed_op!();
28}
29
30impl EvalOp for SpaceToBatchUnary {
31    fn is_stateless(&self) -> bool {
32        true
33    }
34
35    fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
36        let input = args_1!(inputs);
37        let mut paddings = Array2::zeros((self.block_shape.len(), 2));
38        for (ax, &strat) in self.pad.iter().enumerate() {
39            let spread = (self.batch_shape[2 + ax].clone() * self.block_shape[ax]
40                - &self.space_shape[2 + ax])
41                .to_usize()?;
42            let (bef, aft) = match strat {
43                PaddingStrat::FlexFixed(f) => (spread - f, f),
44                PaddingStrat::FixedFlex(f) => (f, spread - f),
45                PaddingStrat::FixedFixed(a, b) => (a, b),
46            };
47            paddings[(ax, 0)] = bef as i32;
48            paddings[(ax, 1)] = aft as i32;
49        }
50        let r = dispatch_numbers!(super::space_to_batch(input.datum_type())(
51            input,
52            &self.block_shape.view(),
53            &paddings.view()
54        ))?;
55        Ok(tvec!(r))
56    }
57}
58
59impl TypedOp for SpaceToBatchUnary {
60    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
61        Ok(tvec!(inputs[0].datum_type.fact(&self.batch_shape)))
62    }
63
64    fn declutter(
65        &self,
66        model: &TypedModel,
67        node: &TypedNode,
68    ) -> TractResult<Option<TypedModelPatch>> {
69        let [succ] = &*model.node(node.id).outputs[0].successors else { return Ok(None) };
70        let conv_node = model.node(succ.node);
71        let Some(conv_op) = conv_node.op_as::<Conv>() else { return Ok(None) };
72        let [succ] = &*conv_node.outputs[0].successors else { return Ok(None) };
73        let b2s_node = model.node(succ.node);
74        let Some(_bs2_op) = b2s_node.op_as::<BatchToSpaceUnary>() else { return Ok(None) };
75        let op = Conv {
76            pool_spec: PoolSpec {
77                dilations: Some(self.block_shape.iter().map(|&i| i as usize).collect()),
78                ..conv_op.pool_spec.clone()
79            },
80            ..conv_op.clone()
81        };
82        let mut patch = TypedModelPatch::default();
83        let taps_s2b = patch.taps(model, &node.inputs)?;
84        let mut taps_conv = patch.taps(model, &conv_node.inputs)?;
85        taps_conv[0] = taps_s2b[0];
86        let out = patch.model.wire_node(&*conv_node.name, op, &taps_conv)?[0];
87        patch.shunt_outside(model, OutletId::new(b2s_node.id, 0), out)?;
88        Ok(Some(patch))
89    }
90
91    as_op!();
92}
93
94#[derive(Debug, Clone, new, Hash)]
95pub struct BatchToSpaceUnary {
96    datum_type: DatumType,
97    batch_shape: TVec<TDim>,
98    space_shape: TVec<TDim>,
99    block_shape: Array1<i32>,
100    pad: Vec<PaddingStrat>,
101}
102
103impl Op for BatchToSpaceUnary {
104    fn name(&self) -> StaticName {
105        "BatchToSpaceUnary".into()
106    }
107
108    op_as_typed_op!();
109}
110
111impl EvalOp for BatchToSpaceUnary {
112    fn is_stateless(&self) -> bool {
113        true
114    }
115
116    fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
117        let input = args_1!(inputs);
118        let mut paddings = Array2::zeros((self.block_shape.len(), 2));
119        for (ax, &strat) in self.pad.iter().enumerate() {
120            let spread = (self.batch_shape[2 + ax].clone() * self.block_shape[ax]
121                - &self.space_shape[2 + ax])
122                .to_usize()?;
123            let (bef, aft) = match strat {
124                PaddingStrat::FlexFixed(f) => (spread - f, f),
125                PaddingStrat::FixedFlex(f) => (f, spread - f),
126                PaddingStrat::FixedFixed(a, b) => (a, b),
127            };
128            paddings[(ax, 0)] = bef as i32;
129            paddings[(ax, 1)] = aft as i32;
130        }
131        let r = dispatch_numbers!(super::batch_to_space(input.datum_type())(
132            input,
133            &self.block_shape.view(),
134            &paddings.view()
135        ))?;
136        Ok(tvec!(r))
137    }
138}
139
140impl TypedOp for BatchToSpaceUnary {
141    as_op!();
142
143    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
144        Ok(tvec!(inputs[0].datum_type.fact(&self.space_shape)))
145    }
146}