tract_tensorflow/ops/nn/
fused_batch_norm.rs1use tract_hir::internal::*;
2use tract_itertools::izip;
3
4use crate::model::ParsingContext;
5use crate::tfpb::tensorflow::NodeDef;
6
7pub fn fused_batch_norm(_ctx: &ParsingContext, pb: &NodeDef) -> TractResult<Box<dyn InferenceOp>> {
8 let epsilon = pb.get_attr_float::<f32>("epsilon")?;
9 Ok(expand(FusedBatchNorm::new(epsilon)))
10}
11
12#[derive(Debug, Clone, new)]
13struct FusedBatchNorm {
14 epsilon: f32,
15}
16
17impl Expansion for FusedBatchNorm {
18 fn name(&self) -> StaticName {
19 "FusedBatchNorm".into()
20 }
21
22 fn validation(&self) -> Validation {
23 Validation::Rounding
24 }
25
26
27 fn rules<'r, 'p: 'r, 's: 'r>(
29 &'s self,
30 s: &mut Solver<'r>,
31 inputs: &'p [TensorProxy],
32 outputs: &'p [TensorProxy],
33 ) -> InferenceResult {
34 check_input_arity(inputs, 5)?;
35 check_output_arity(outputs, 1)?;
36 s.equals(&inputs[0].datum_type, f32::datum_type())?;
37 s.equals(&inputs[1].datum_type, f32::datum_type())?;
38 s.equals(&inputs[2].datum_type, f32::datum_type())?;
39 s.equals(&inputs[3].datum_type, f32::datum_type())?;
40 s.equals(&inputs[4].datum_type, f32::datum_type())?;
41 s.equals(&outputs[0].datum_type, f32::datum_type())?;
42 s.equals(&outputs[0].shape, &inputs[0].shape)?;
43 s.equals(&inputs[0].rank, 4)?;
44 s.equals(&inputs[1].rank, 1)?;
45 s.equals(&inputs[2].rank, 1)?;
46 s.equals(&inputs[3].rank, 1)?;
47 s.equals(&inputs[4].rank, 1)?;
48 s.equals(&inputs[0].shape, &outputs[0].shape)?;
49 s.equals(&inputs[1].shape[0], &inputs[0].shape[3])?;
50 s.equals(&inputs[2].shape[0], &inputs[0].shape[3])?;
51 s.equals(&inputs[3].shape[0], &inputs[0].shape[3])?;
52 s.equals(&inputs[4].shape[0], &inputs[0].shape[3])?;
53 Ok(())
54 }
55
56 fn wire(
57 &self,
58 prefix: &str,
59 target: &mut TypedModel,
60 inputs: &[OutletId],
61 ) -> TractResult<TVec<OutletId>> {
62 let scale = target.outlet_fact(inputs[1])?;
63 let offset = target.outlet_fact(inputs[2])?;
64 let mean = target.outlet_fact(inputs[3])?;
65 let variance = target.outlet_fact(inputs[4])?;
66 if let (Some(scale), Some(offset), Some(mean), Some(variance)) =
67 (&scale.konst, &offset.konst, &mean.konst, &variance.konst)
68 {
69 let scale = scale.as_slice::<f32>()?;
70 let offset = offset.as_slice::<f32>()?;
71 let mean = mean.as_slice::<f32>()?;
72 let variance = variance.as_slice::<f32>()?;
73 let slope: Vec<f32> =
74 izip!(variance, scale).map(|(v, s)| s / (v + self.epsilon).sqrt()).collect();
75 let inter: Vec<f32> = izip!(offset, mean, &slope).map(|(o, m, s)| o - m * s).collect();
76 let shape = tvec!(1, 1, 1, scale.len());
77 let slope = tensor1(&slope).into_shape(&shape)?;
78 let inter = tensor1(&inter).into_shape(&shape)?;
79 let slope = target.add_const(prefix.to_string() + ".slope", slope)?;
80 let inter = target.add_const(prefix.to_string() + ".inter", inter)?;
81 let wire = target.wire_node(
82 format!("{prefix}.mul"),
83 tract_hir::ops::math::mul(),
84 &[inputs[0], slope],
85 )?;
86 return target.wire_node(
87 format!("{prefix}.add"),
88 tract_hir::ops::math::add(),
89 &[wire[0], inter],
90 );
91 };
92 bail!("Batch norm parameters expected to be known")
93 }
94}