stonnx_api/operators/
batchnormalization.rs

1use ndarray::{ArrayD, SliceInfoElem};
2
3use crate::common::{ArrayElement, BoxResult, F32IntoType, OperatorResult, TensorType};
4use crate::onnx::NodeProto;
5use crate::utils::pick_opset_version;
6
7const OPSET_VERSIONS: [i64; 6] = [1, 6, 7, 9, 14, 15];
8
9/// ver 1 | ver 6
10/// inputs: X, scale, B, mean , var
11/// outputs: Y, mean (opt), var(opt), saved_mean(opt), saved_var(opt)
12/// attributes: epsilon(default 1e-5), is_test(default 0), momentum (default 0.9), spatial(default 1)
13
14/// ver 7
15/// inputs: X, scale, B, mean , var
16/// outputs: Y, mean (opt), var(opt), saved_mean(opt), saved_var(opt)
17/// attributes: epsilon(default 1e-5), momentum (default 0.9), spatial(default 1)
18
19/// ver 9
20/// inputs: X, scale, B, mean , var
21/// outputs: Y, mean (opt), var(opt), saved_mean(opt), saved_var(opt)
22/// attributes: epsilon(default 1e-5), momentum (default 0.9)
23
24/// ver 14 | 15
25/// inputs: X, scale, B, input_mean, input_var
26/// outputs: Y, running_mean (opt), running_var(opt)
27/// attributes: epsilon(default 1e-5), momentum (default 0.9), training_mode(default 0)
28
29#[derive(Debug)]
30struct BatchNormalizationAttrs {
31    epsilon: f32,
32    momentum: Option<f32>,
33    _spatial: bool, // unused?
34    is_test: bool,
35    training_mode: bool,
36}
37
38impl BatchNormalizationAttrs {
39    fn new(node: &NodeProto) -> Self {
40        Self {
41            epsilon: node
42                .attribute
43                .iter()
44                .find(|a| a.name() == "epsilon")
45                .map_or(1e-5, |a| a.f.unwrap_or(1e-5)),
46            momentum: node
47                .attribute
48                .iter()
49                .find(|a| a.name() == "momentum")
50                .map(|a| a.f.unwrap_or(0.9)),
51            _spatial: node
52                .attribute
53                .iter()
54                .find(|a| a.name() == "spatial")
55                .map_or(true, |a| a.i.unwrap_or(1) == 1),
56            is_test: node
57                .attribute
58                .iter()
59                .find(|a| a.name() == "is_test")
60                .map_or(false, |a| a.i.unwrap_or(0) == 1),
61            training_mode: node
62                .attribute
63                .iter()
64                .find(|a| a.name() == "training_mode")
65                .map_or(false, |a| a.i.unwrap_or(0) == 1),
66        }
67    }
68}
69
70fn _batchnorm_test_mode<A: ArrayElement>(
71    x: &ArrayD<A>,
72    scale: &ArrayD<A>,
73    bias: &ArrayD<A>,
74    mean: &ArrayD<A>,
75    var: &ArrayD<A>,
76    epsilon: f32,
77) -> BoxResult<TensorType>
78where
79    TensorType: From<ArrayD<A>>,
80    f32: F32IntoType<A>,
81{
82    let dims_x = x.ndim();
83    let dim_ones_generator = std::iter::repeat(1).take(dims_x - 2);
84    let sshape = [scale.len()]
85        .into_iter()
86        .chain(dim_ones_generator.clone())
87        .collect::<Vec<_>>();
88    let s = scale.to_shape(sshape.as_slice())?;
89    let bshape = [bias.len()]
90        .into_iter()
91        .chain(dim_ones_generator.clone())
92        .collect::<Vec<_>>();
93    let b = bias.to_shape(bshape.as_slice())?;
94    let mshape = [mean.len()]
95        .into_iter()
96        .chain(dim_ones_generator.clone())
97        .collect::<Vec<_>>();
98    let m = mean.to_shape(mshape.as_slice())?;
99    let vshape = [var.len()]
100        .into_iter()
101        .chain(dim_ones_generator.clone())
102        .collect::<Vec<_>>();
103    let mut v = var.to_shape(vshape.as_slice())?;
104    v.par_mapv_inplace(|x| (x + epsilon.as_()).sqrt());
105    let y = &s * (x - &m) / v + b;
106    Ok(y.into())
107}
108
109// FIXME: remove this num::Float requirement, it is needed for ArrayBase::var
110fn _batchnorm_training_mode<A: ArrayElement + num::Float>(
111    x: &ArrayD<A>,
112    scale: &ArrayD<A>,
113    bias: &ArrayD<A>,
114    mean: &ArrayD<A>,
115    var: &ArrayD<A>,
116    momentum: f32,
117    epsilon: f32,
118) -> BoxResult<Vec<TensorType>>
119where
120    TensorType: From<ArrayD<A>>,
121    f32: F32IntoType<A>,
122{
123    let momentum = momentum.as_();
124    let axis = (0..x.ndim()).skip_while(|&i| i == 1).collect::<Vec<_>>();
125    let mut saved_mean = x.clone();
126    for ax in axis.iter().rev() {
127        saved_mean = saved_mean
128            .mean_axis(ndarray::Axis(*ax))
129            .ok_or(anyhow::anyhow!("BatchNormalization: mean_axis failed"))?;
130    }
131    let saved_var_len = x.shape()[1];
132    let mut saved_var = ArrayD::<A>::zeros(vec![saved_var_len]);
133    for i in 0..saved_var_len {
134        let sliceinfo = [(0..).into(), i.into()]
135            .into_iter()
136            .chain(std::iter::repeat((0..).into()).take(x.ndim() - 2))
137            .collect::<Vec<SliceInfoElem>>();
138        let sliced = x.slice(sliceinfo.as_slice()).to_owned();
139        saved_var[i] = sliced.var((0.0).as_());
140    }
141    let (output_mean, output_var) = {
142        let mut saved_mean = saved_mean.clone();
143        let mut saved_var = saved_var.clone();
144        saved_mean.par_mapv_inplace(|x| x * (1f32.as_() - momentum));
145        saved_var.par_mapv_inplace(|x| x * (1f32.as_() - momentum));
146        let mut vmean = mean.clone();
147        let mut vvar = var.clone();
148        vmean.par_mapv_inplace(|x| x * momentum);
149        vvar.par_mapv_inplace(|x| x * momentum);
150        let output_mean = vmean + saved_mean;
151        let output_var = vvar + saved_var;
152        (output_mean, output_var)
153    };
154    let y = _batchnorm_test_mode(x, scale, bias, &output_mean, &output_var, epsilon)?;
155    Ok(vec![
156        y,
157        saved_mean.into(),
158        saved_var.into(),
159        output_mean.into(),
160        output_var.into(),
161    ])
162}
163
164fn batchnormalization_1_6<A: ArrayElement + num::Float>(
165    x: &ArrayD<A>,
166    scale: &ArrayD<A>,
167    bias: &ArrayD<A>,
168    mean: &ArrayD<A>,
169    var: &ArrayD<A>,
170    attrs: BatchNormalizationAttrs,
171) -> BoxResult<Vec<TensorType>>
172where
173    TensorType: From<ArrayD<A>>,
174    f32: F32IntoType<A>,
175{
176    if attrs.is_test {
177        Ok(vec![_batchnorm_test_mode(
178            x,
179            scale,
180            bias,
181            mean,
182            var,
183            attrs.epsilon,
184        )?])
185    } else {
186        _batchnorm_training_mode(
187            x,
188            scale,
189            bias,
190            mean,
191            var,
192            attrs.momentum.unwrap_or(0.9),
193            attrs.epsilon,
194        )
195    }
196}
197fn batchnormalization_7_9<A: ArrayElement + num::Float>(
198    x: &ArrayD<A>,
199    scale: &ArrayD<A>,
200    bias: &ArrayD<A>,
201    mean: &ArrayD<A>,
202    var: &ArrayD<A>,
203    attrs: BatchNormalizationAttrs,
204) -> BoxResult<Vec<TensorType>>
205where
206    TensorType: From<ArrayD<A>>,
207    f32: F32IntoType<A>,
208{
209    if let Some(momentum) = attrs.momentum {
210        let momentum = momentum.as_();
211        let axis = (0..x.ndim()).filter(|i| *i != 1).collect::<Vec<_>>();
212        let mut saved_mean = x.clone();
213        for ax in axis.iter().rev() {
214            saved_mean = saved_mean
215                .mean_axis(ndarray::Axis(*ax))
216                .ok_or(anyhow::anyhow!("BatchNormalization: mean_axis failed"))?;
217        }
218        let saved_var_len = x.shape()[1];
219        let mut saved_var = ArrayD::<A>::zeros(vec![saved_var_len]);
220        for i in 0..saved_var_len {
221            let sliceinfo = [(0..).into(), i.into()]
222                .into_iter()
223                .chain(std::iter::repeat((0..).into()).take(x.ndim() - 2))
224                .collect::<Vec<SliceInfoElem>>();
225            let sliced = x.slice(sliceinfo.as_slice()).to_owned();
226            saved_var[i] = sliced.var((0.0).as_());
227        }
228        saved_mean.par_mapv_inplace(|x| x * (1f32.as_() - momentum));
229        saved_var.par_mapv_inplace(|x| x * (1f32.as_() - momentum));
230        let mut vmean = mean.clone();
231        let mut vvar = var.clone();
232        vmean.par_mapv_inplace(|x| x * momentum);
233        vvar.par_mapv_inplace(|x| x * momentum);
234        let output_mean = vmean + saved_mean;
235        let output_var = vvar + saved_var;
236        let y = _batchnorm_test_mode(x, scale, bias, &output_mean, &output_var, attrs.epsilon)?;
237        Ok(vec![y])
238    } else {
239        Ok(vec![_batchnorm_test_mode(
240            x,
241            scale,
242            bias,
243            mean,
244            var,
245            attrs.epsilon,
246        )?])
247    }
248}
249fn batchnormalization_14_15<A: ArrayElement + num::Float>(
250    x: &ArrayD<A>,
251    scale: &ArrayD<A>,
252    bias: &ArrayD<A>,
253    mean: &ArrayD<A>,
254    var: &ArrayD<A>,
255    attrs: BatchNormalizationAttrs,
256) -> BoxResult<Vec<TensorType>>
257where
258    TensorType: From<ArrayD<A>>,
259    f32: F32IntoType<A>,
260{
261    if !attrs.training_mode {
262        let res = _batchnorm_test_mode(x, scale, bias, mean, var, attrs.epsilon)?;
263        Ok(vec![res])
264    } else {
265        let outputs = _batchnorm_training_mode(
266            x,
267            scale,
268            bias,
269            mean,
270            var,
271            attrs.momentum.unwrap_or(0.9),
272            attrs.epsilon,
273        )?;
274        Ok(outputs
275            .into_iter()
276            .enumerate()
277            .filter_map(|(i, v)| if i == 1 || i == 2 { None } else { Some(v) })
278            .collect::<Vec<_>>())
279    }
280}
281
282/// Carries out batch normalization as described in the paper <https://arxiv.org/abs/1502.03167>.
283///
284/// Depending on the mode it is being run, There are five required inputs ‘X’, ‘scale’, ‘B’, ‘input_mean’ and ‘input_var’.
285///
286/// [Python reference](https://github.com/onnx/onnx/blob/main/onnx/reference/ops/op_batch_normalization.py)
287///
288/// [ONNX Documentation](https://onnx.ai/onnx/operators/onnx__BatchNormalization.html)
289pub fn batchnormalization(
290    inputs: &[&TensorType],
291    node: &NodeProto,
292    opset_version: i64,
293    _output_len: usize,
294) -> BoxResult<OperatorResult> {
295    let attrs = BatchNormalizationAttrs::new(node);
296    let target_ver = pick_opset_version(opset_version, &OPSET_VERSIONS);
297    if inputs.len() < 5 {
298        return Err(anyhow::anyhow!(
299            "BatchNormalization: inputs must be at least 5"
300        ));
301    }
302    let x = if let TensorType::F32(x) = inputs[0] {
303        x
304    } else {
305        todo!("BatchNormalization for x type {}", inputs[0])
306    };
307    let scale = if let TensorType::F32(scale) = inputs[1] {
308        scale
309    } else {
310        todo!("BatchNormalization for scale type {}", inputs[1])
311    };
312    let bias = if let TensorType::F32(bias) = inputs[2] {
313        bias
314    } else {
315        todo!("BatchNormalization for bias type {}", inputs[2])
316    };
317    let mean = if let TensorType::F32(mean) = inputs[3] {
318        mean
319    } else {
320        todo!("BatchNormalization for mean type {}", inputs[3])
321    };
322    let var = if let TensorType::F32(var) = inputs[4] {
323        var
324    } else {
325        todo!("BatchNormalization for var type {}", inputs[4])
326    };
327
328    if target_ver < 7 {
329        Ok(batchnormalization_1_6(x, scale, bias, mean, var, attrs)?.into())
330    } else if target_ver < 14 {
331        Ok(batchnormalization_7_9(x, scale, bias, mean, var, attrs)?.into())
332    } else {
333        Ok(batchnormalization_14_15(x, scale, bias, mean, var, attrs)?.into())
334    }
335}