tch_plus/nn/
batch_norm.rs

1//! A batch-normalization layer.
2use crate::Tensor;
3use std::borrow::Borrow;
4
5/// Batch-normalization config.
6#[derive(Debug, Clone, Copy)]
7pub struct BatchNormConfig {
8    pub cudnn_enabled: bool,
9    pub eps: f64,
10    pub momentum: f64,
11    pub affine: bool,
12    pub ws_init: super::Init,
13    pub bs_init: super::Init,
14}
15
16impl Default for BatchNormConfig {
17    fn default() -> Self {
18        BatchNormConfig {
19            cudnn_enabled: true,
20            eps: 1e-5,
21            momentum: 0.1,
22            affine: true,
23            ws_init: super::Init::Uniform { lo: 0., up: 1. },
24            bs_init: super::Init::Const(0.),
25        }
26    }
27}
28
29/// A batch-normalization layer.
30#[derive(Debug)]
31pub struct BatchNorm {
32    config: BatchNormConfig,
33    pub running_mean: Tensor,
34    pub running_var: Tensor,
35    pub ws: Option<Tensor>,
36    pub bs: Option<Tensor>,
37    pub nd: usize,
38}
39
40fn batch_norm<'a, T: Borrow<super::Path<'a>>>(
41    vs: T,
42    nd: usize,
43    out_dim: i64,
44    config: BatchNormConfig,
45) -> BatchNorm {
46    let vs = vs.borrow();
47    let (ws, bs) = if config.affine {
48        let ws = vs.var("weight", &[out_dim], config.ws_init);
49        let bs = vs.var("bias", &[out_dim], config.bs_init);
50        (Some(ws), Some(bs))
51    } else {
52        (None, None)
53    };
54    BatchNorm {
55        config,
56        running_mean: vs.zeros_no_train("running_mean", &[out_dim]),
57        running_var: vs.ones_no_train("running_var", &[out_dim]),
58        ws,
59        bs,
60        nd,
61    }
62}
63
64/// Applies Batch Normalization over a three dimension input.
65///
66/// The input shape is assumed to be (N, C, L). Normalization
67/// is performed over the first batch dimension N.
68pub fn batch_norm1d<'a, T: Borrow<super::Path<'a>>>(
69    vs: T,
70    out_dim: i64,
71    config: BatchNormConfig,
72) -> BatchNorm {
73    batch_norm(vs, 1, out_dim, config)
74}
75
76/// Applies Batch Normalization over a four dimension input.
77///
78/// The input shape is assumed to be (N, C, H, W). Normalization
79/// is performed over the first batch dimension N.
80pub fn batch_norm2d<'a, T: Borrow<super::Path<'a>>>(
81    vs: T,
82    out_dim: i64,
83    config: BatchNormConfig,
84) -> BatchNorm {
85    batch_norm(vs, 2, out_dim, config)
86}
87
88/// Applies Batch Normalization over a five dimension input.
89///
90/// The input shape is assumed to be (N, C, D, H, W). Normalization
91/// is performed over the first batch dimension N.
92pub fn batch_norm3d<'a, T: Borrow<super::Path<'a>>>(
93    vs: T,
94    out_dim: i64,
95    config: BatchNormConfig,
96) -> BatchNorm {
97    batch_norm(vs, 3, out_dim, config)
98}
99
100impl super::module::ModuleT for BatchNorm {
101    fn forward_t(&self, xs: &Tensor, train: bool) -> Tensor {
102        let dim = xs.dim();
103        if self.nd == 1 && dim != 2 && dim != 3 {
104            panic!(
105                "as nd={}, expected an input tensor with 2 or 3 dims, got {} ({:?})",
106                self.nd,
107                dim,
108                xs.size()
109            )
110        }
111        if self.nd > 1 && xs.dim() != self.nd + 2 {
112            panic!(
113                "as nd={}, expected an input tensor with {} dims, got {} ({:?})",
114                self.nd,
115                self.nd + 2,
116                dim,
117                xs.size()
118            )
119        };
120        Tensor::batch_norm(
121            xs,
122            self.ws.as_ref(),
123            self.bs.as_ref(),
124            Some(&self.running_mean),
125            Some(&self.running_var),
126            train,
127            self.config.momentum,
128            self.config.eps,
129            self.config.cudnn_enabled,
130        )
131    }
132}