1use ndarray::{ArrayD, SliceInfoElem};
2
3use crate::common::{ArrayElement, BoxResult, F32IntoType, OperatorResult, TensorType};
4use crate::onnx::NodeProto;
5use crate::utils::pick_opset_version;
6
7const OPSET_VERSIONS: [i64; 6] = [1, 6, 7, 9, 14, 15];
8
9#[derive(Debug)]
30struct BatchNormalizationAttrs {
31 epsilon: f32,
32 momentum: Option<f32>,
33 _spatial: bool, is_test: bool,
35 training_mode: bool,
36}
37
38impl BatchNormalizationAttrs {
39 fn new(node: &NodeProto) -> Self {
40 Self {
41 epsilon: node
42 .attribute
43 .iter()
44 .find(|a| a.name() == "epsilon")
45 .map_or(1e-5, |a| a.f.unwrap_or(1e-5)),
46 momentum: node
47 .attribute
48 .iter()
49 .find(|a| a.name() == "momentum")
50 .map(|a| a.f.unwrap_or(0.9)),
51 _spatial: node
52 .attribute
53 .iter()
54 .find(|a| a.name() == "spatial")
55 .map_or(true, |a| a.i.unwrap_or(1) == 1),
56 is_test: node
57 .attribute
58 .iter()
59 .find(|a| a.name() == "is_test")
60 .map_or(false, |a| a.i.unwrap_or(0) == 1),
61 training_mode: node
62 .attribute
63 .iter()
64 .find(|a| a.name() == "training_mode")
65 .map_or(false, |a| a.i.unwrap_or(0) == 1),
66 }
67 }
68}
69
70fn _batchnorm_test_mode<A: ArrayElement>(
71 x: &ArrayD<A>,
72 scale: &ArrayD<A>,
73 bias: &ArrayD<A>,
74 mean: &ArrayD<A>,
75 var: &ArrayD<A>,
76 epsilon: f32,
77) -> BoxResult<TensorType>
78where
79 TensorType: From<ArrayD<A>>,
80 f32: F32IntoType<A>,
81{
82 let dims_x = x.ndim();
83 let dim_ones_generator = std::iter::repeat(1).take(dims_x - 2);
84 let sshape = [scale.len()]
85 .into_iter()
86 .chain(dim_ones_generator.clone())
87 .collect::<Vec<_>>();
88 let s = scale.to_shape(sshape.as_slice())?;
89 let bshape = [bias.len()]
90 .into_iter()
91 .chain(dim_ones_generator.clone())
92 .collect::<Vec<_>>();
93 let b = bias.to_shape(bshape.as_slice())?;
94 let mshape = [mean.len()]
95 .into_iter()
96 .chain(dim_ones_generator.clone())
97 .collect::<Vec<_>>();
98 let m = mean.to_shape(mshape.as_slice())?;
99 let vshape = [var.len()]
100 .into_iter()
101 .chain(dim_ones_generator.clone())
102 .collect::<Vec<_>>();
103 let mut v = var.to_shape(vshape.as_slice())?;
104 v.par_mapv_inplace(|x| (x + epsilon.as_()).sqrt());
105 let y = &s * (x - &m) / v + b;
106 Ok(y.into())
107}
108
109fn _batchnorm_training_mode<A: ArrayElement + num::Float>(
111 x: &ArrayD<A>,
112 scale: &ArrayD<A>,
113 bias: &ArrayD<A>,
114 mean: &ArrayD<A>,
115 var: &ArrayD<A>,
116 momentum: f32,
117 epsilon: f32,
118) -> BoxResult<Vec<TensorType>>
119where
120 TensorType: From<ArrayD<A>>,
121 f32: F32IntoType<A>,
122{
123 let momentum = momentum.as_();
124 let axis = (0..x.ndim()).skip_while(|&i| i == 1).collect::<Vec<_>>();
125 let mut saved_mean = x.clone();
126 for ax in axis.iter().rev() {
127 saved_mean = saved_mean
128 .mean_axis(ndarray::Axis(*ax))
129 .ok_or(anyhow::anyhow!("BatchNormalization: mean_axis failed"))?;
130 }
131 let saved_var_len = x.shape()[1];
132 let mut saved_var = ArrayD::<A>::zeros(vec![saved_var_len]);
133 for i in 0..saved_var_len {
134 let sliceinfo = [(0..).into(), i.into()]
135 .into_iter()
136 .chain(std::iter::repeat((0..).into()).take(x.ndim() - 2))
137 .collect::<Vec<SliceInfoElem>>();
138 let sliced = x.slice(sliceinfo.as_slice()).to_owned();
139 saved_var[i] = sliced.var((0.0).as_());
140 }
141 let (output_mean, output_var) = {
142 let mut saved_mean = saved_mean.clone();
143 let mut saved_var = saved_var.clone();
144 saved_mean.par_mapv_inplace(|x| x * (1f32.as_() - momentum));
145 saved_var.par_mapv_inplace(|x| x * (1f32.as_() - momentum));
146 let mut vmean = mean.clone();
147 let mut vvar = var.clone();
148 vmean.par_mapv_inplace(|x| x * momentum);
149 vvar.par_mapv_inplace(|x| x * momentum);
150 let output_mean = vmean + saved_mean;
151 let output_var = vvar + saved_var;
152 (output_mean, output_var)
153 };
154 let y = _batchnorm_test_mode(x, scale, bias, &output_mean, &output_var, epsilon)?;
155 Ok(vec![
156 y,
157 saved_mean.into(),
158 saved_var.into(),
159 output_mean.into(),
160 output_var.into(),
161 ])
162}
163
164fn batchnormalization_1_6<A: ArrayElement + num::Float>(
165 x: &ArrayD<A>,
166 scale: &ArrayD<A>,
167 bias: &ArrayD<A>,
168 mean: &ArrayD<A>,
169 var: &ArrayD<A>,
170 attrs: BatchNormalizationAttrs,
171) -> BoxResult<Vec<TensorType>>
172where
173 TensorType: From<ArrayD<A>>,
174 f32: F32IntoType<A>,
175{
176 if attrs.is_test {
177 Ok(vec![_batchnorm_test_mode(
178 x,
179 scale,
180 bias,
181 mean,
182 var,
183 attrs.epsilon,
184 )?])
185 } else {
186 _batchnorm_training_mode(
187 x,
188 scale,
189 bias,
190 mean,
191 var,
192 attrs.momentum.unwrap_or(0.9),
193 attrs.epsilon,
194 )
195 }
196}
197fn batchnormalization_7_9<A: ArrayElement + num::Float>(
198 x: &ArrayD<A>,
199 scale: &ArrayD<A>,
200 bias: &ArrayD<A>,
201 mean: &ArrayD<A>,
202 var: &ArrayD<A>,
203 attrs: BatchNormalizationAttrs,
204) -> BoxResult<Vec<TensorType>>
205where
206 TensorType: From<ArrayD<A>>,
207 f32: F32IntoType<A>,
208{
209 if let Some(momentum) = attrs.momentum {
210 let momentum = momentum.as_();
211 let axis = (0..x.ndim()).filter(|i| *i != 1).collect::<Vec<_>>();
212 let mut saved_mean = x.clone();
213 for ax in axis.iter().rev() {
214 saved_mean = saved_mean
215 .mean_axis(ndarray::Axis(*ax))
216 .ok_or(anyhow::anyhow!("BatchNormalization: mean_axis failed"))?;
217 }
218 let saved_var_len = x.shape()[1];
219 let mut saved_var = ArrayD::<A>::zeros(vec![saved_var_len]);
220 for i in 0..saved_var_len {
221 let sliceinfo = [(0..).into(), i.into()]
222 .into_iter()
223 .chain(std::iter::repeat((0..).into()).take(x.ndim() - 2))
224 .collect::<Vec<SliceInfoElem>>();
225 let sliced = x.slice(sliceinfo.as_slice()).to_owned();
226 saved_var[i] = sliced.var((0.0).as_());
227 }
228 saved_mean.par_mapv_inplace(|x| x * (1f32.as_() - momentum));
229 saved_var.par_mapv_inplace(|x| x * (1f32.as_() - momentum));
230 let mut vmean = mean.clone();
231 let mut vvar = var.clone();
232 vmean.par_mapv_inplace(|x| x * momentum);
233 vvar.par_mapv_inplace(|x| x * momentum);
234 let output_mean = vmean + saved_mean;
235 let output_var = vvar + saved_var;
236 let y = _batchnorm_test_mode(x, scale, bias, &output_mean, &output_var, attrs.epsilon)?;
237 Ok(vec![y])
238 } else {
239 Ok(vec![_batchnorm_test_mode(
240 x,
241 scale,
242 bias,
243 mean,
244 var,
245 attrs.epsilon,
246 )?])
247 }
248}
249fn batchnormalization_14_15<A: ArrayElement + num::Float>(
250 x: &ArrayD<A>,
251 scale: &ArrayD<A>,
252 bias: &ArrayD<A>,
253 mean: &ArrayD<A>,
254 var: &ArrayD<A>,
255 attrs: BatchNormalizationAttrs,
256) -> BoxResult<Vec<TensorType>>
257where
258 TensorType: From<ArrayD<A>>,
259 f32: F32IntoType<A>,
260{
261 if !attrs.training_mode {
262 let res = _batchnorm_test_mode(x, scale, bias, mean, var, attrs.epsilon)?;
263 Ok(vec![res])
264 } else {
265 let outputs = _batchnorm_training_mode(
266 x,
267 scale,
268 bias,
269 mean,
270 var,
271 attrs.momentum.unwrap_or(0.9),
272 attrs.epsilon,
273 )?;
274 Ok(outputs
275 .into_iter()
276 .enumerate()
277 .filter_map(|(i, v)| if i == 1 || i == 2 { None } else { Some(v) })
278 .collect::<Vec<_>>())
279 }
280}
281
282pub fn batchnormalization(
290 inputs: &[&TensorType],
291 node: &NodeProto,
292 opset_version: i64,
293 _output_len: usize,
294) -> BoxResult<OperatorResult> {
295 let attrs = BatchNormalizationAttrs::new(node);
296 let target_ver = pick_opset_version(opset_version, &OPSET_VERSIONS);
297 if inputs.len() < 5 {
298 return Err(anyhow::anyhow!(
299 "BatchNormalization: inputs must be at least 5"
300 ));
301 }
302 let x = if let TensorType::F32(x) = inputs[0] {
303 x
304 } else {
305 todo!("BatchNormalization for x type {}", inputs[0])
306 };
307 let scale = if let TensorType::F32(scale) = inputs[1] {
308 scale
309 } else {
310 todo!("BatchNormalization for scale type {}", inputs[1])
311 };
312 let bias = if let TensorType::F32(bias) = inputs[2] {
313 bias
314 } else {
315 todo!("BatchNormalization for bias type {}", inputs[2])
316 };
317 let mean = if let TensorType::F32(mean) = inputs[3] {
318 mean
319 } else {
320 todo!("BatchNormalization for mean type {}", inputs[3])
321 };
322 let var = if let TensorType::F32(var) = inputs[4] {
323 var
324 } else {
325 todo!("BatchNormalization for var type {}", inputs[4])
326 };
327
328 if target_ver < 7 {
329 Ok(batchnormalization_1_6(x, scale, bias, mean, var, attrs)?.into())
330 } else if target_ver < 14 {
331 Ok(batchnormalization_7_9(x, scale, bias, mean, var, attrs)?.into())
332 } else {
333 Ok(batchnormalization_14_15(x, scale, bias, mean, var, attrs)?.into())
334 }
335}