tract_tensorflow/ops/nn/
fused_batch_norm.rs

1use tract_hir::internal::*;
2use tract_itertools::izip;
3
4use crate::model::ParsingContext;
5use crate::tfpb::tensorflow::NodeDef;
6
7pub fn fused_batch_norm(_ctx: &ParsingContext, pb: &NodeDef) -> TractResult<Box<dyn InferenceOp>> {
8    let epsilon = pb.get_attr_float::<f32>("epsilon")?;
9    Ok(expand(FusedBatchNorm::new(epsilon)))
10}
11
12#[derive(Debug, Clone, new)]
13struct FusedBatchNorm {
14    epsilon: f32,
15}
16
17impl Expansion for FusedBatchNorm {
18    fn name(&self) -> StaticName {
19        "FusedBatchNorm".into()
20    }
21
22    fn validation(&self) -> Validation {
23        Validation::Rounding
24    }
25
26
27    /// Registers the inference rules of the operator.
28    fn rules<'r, 'p: 'r, 's: 'r>(
29        &'s self,
30        s: &mut Solver<'r>,
31        inputs: &'p [TensorProxy],
32        outputs: &'p [TensorProxy],
33    ) -> InferenceResult {
34        check_input_arity(inputs, 5)?;
35        check_output_arity(outputs, 1)?;
36        s.equals(&inputs[0].datum_type, f32::datum_type())?;
37        s.equals(&inputs[1].datum_type, f32::datum_type())?;
38        s.equals(&inputs[2].datum_type, f32::datum_type())?;
39        s.equals(&inputs[3].datum_type, f32::datum_type())?;
40        s.equals(&inputs[4].datum_type, f32::datum_type())?;
41        s.equals(&outputs[0].datum_type, f32::datum_type())?;
42        s.equals(&outputs[0].shape, &inputs[0].shape)?;
43        s.equals(&inputs[0].rank, 4)?;
44        s.equals(&inputs[1].rank, 1)?;
45        s.equals(&inputs[2].rank, 1)?;
46        s.equals(&inputs[3].rank, 1)?;
47        s.equals(&inputs[4].rank, 1)?;
48        s.equals(&inputs[0].shape, &outputs[0].shape)?;
49        s.equals(&inputs[1].shape[0], &inputs[0].shape[3])?;
50        s.equals(&inputs[2].shape[0], &inputs[0].shape[3])?;
51        s.equals(&inputs[3].shape[0], &inputs[0].shape[3])?;
52        s.equals(&inputs[4].shape[0], &inputs[0].shape[3])?;
53        Ok(())
54    }
55
56    fn wire(
57        &self,
58        prefix: &str,
59        target: &mut TypedModel,
60        inputs: &[OutletId],
61    ) -> TractResult<TVec<OutletId>> {
62        let scale = target.outlet_fact(inputs[1])?;
63        let offset = target.outlet_fact(inputs[2])?;
64        let mean = target.outlet_fact(inputs[3])?;
65        let variance = target.outlet_fact(inputs[4])?;
66        if let (Some(scale), Some(offset), Some(mean), Some(variance)) =
67            (&scale.konst, &offset.konst, &mean.konst, &variance.konst)
68        {
69            let scale = scale.as_slice::<f32>()?;
70            let offset = offset.as_slice::<f32>()?;
71            let mean = mean.as_slice::<f32>()?;
72            let variance = variance.as_slice::<f32>()?;
73            let slope: Vec<f32> =
74                izip!(variance, scale).map(|(v, s)| s / (v + self.epsilon).sqrt()).collect();
75            let inter: Vec<f32> = izip!(offset, mean, &slope).map(|(o, m, s)| o - m * s).collect();
76            let shape = tvec!(1, 1, 1, scale.len());
77            let slope = tensor1(&slope).into_shape(&shape)?;
78            let inter = tensor1(&inter).into_shape(&shape)?;
79            let slope = target.add_const(prefix.to_string() + ".slope", slope)?;
80            let inter = target.add_const(prefix.to_string() + ".inter", inter)?;
81            let wire = target.wire_node(
82                format!("{prefix}.mul"),
83                tract_hir::ops::math::mul(),
84                &[inputs[0], slope],
85            )?;
86            return target.wire_node(
87                format!("{prefix}.add"),
88                tract_hir::ops::math::add(),
89                &[wire[0], inter],
90            );
91        };
92        bail!("Batch norm parameters expected to be known")
93    }
94}