tch_plus/nn/
batch_norm.rs1use crate::Tensor;
3use std::borrow::Borrow;
4
5#[derive(Debug, Clone, Copy)]
7pub struct BatchNormConfig {
8 pub cudnn_enabled: bool,
9 pub eps: f64,
10 pub momentum: f64,
11 pub affine: bool,
12 pub ws_init: super::Init,
13 pub bs_init: super::Init,
14}
15
16impl Default for BatchNormConfig {
17 fn default() -> Self {
18 BatchNormConfig {
19 cudnn_enabled: true,
20 eps: 1e-5,
21 momentum: 0.1,
22 affine: true,
23 ws_init: super::Init::Uniform { lo: 0., up: 1. },
24 bs_init: super::Init::Const(0.),
25 }
26 }
27}
28
29#[derive(Debug)]
31pub struct BatchNorm {
32 config: BatchNormConfig,
33 pub running_mean: Tensor,
34 pub running_var: Tensor,
35 pub ws: Option<Tensor>,
36 pub bs: Option<Tensor>,
37 pub nd: usize,
38}
39
40fn batch_norm<'a, T: Borrow<super::Path<'a>>>(
41 vs: T,
42 nd: usize,
43 out_dim: i64,
44 config: BatchNormConfig,
45) -> BatchNorm {
46 let vs = vs.borrow();
47 let (ws, bs) = if config.affine {
48 let ws = vs.var("weight", &[out_dim], config.ws_init);
49 let bs = vs.var("bias", &[out_dim], config.bs_init);
50 (Some(ws), Some(bs))
51 } else {
52 (None, None)
53 };
54 BatchNorm {
55 config,
56 running_mean: vs.zeros_no_train("running_mean", &[out_dim]),
57 running_var: vs.ones_no_train("running_var", &[out_dim]),
58 ws,
59 bs,
60 nd,
61 }
62}
63
64pub fn batch_norm1d<'a, T: Borrow<super::Path<'a>>>(
69 vs: T,
70 out_dim: i64,
71 config: BatchNormConfig,
72) -> BatchNorm {
73 batch_norm(vs, 1, out_dim, config)
74}
75
76pub fn batch_norm2d<'a, T: Borrow<super::Path<'a>>>(
81 vs: T,
82 out_dim: i64,
83 config: BatchNormConfig,
84) -> BatchNorm {
85 batch_norm(vs, 2, out_dim, config)
86}
87
88pub fn batch_norm3d<'a, T: Borrow<super::Path<'a>>>(
93 vs: T,
94 out_dim: i64,
95 config: BatchNormConfig,
96) -> BatchNorm {
97 batch_norm(vs, 3, out_dim, config)
98}
99
100impl super::module::ModuleT for BatchNorm {
101 fn forward_t(&self, xs: &Tensor, train: bool) -> Tensor {
102 let dim = xs.dim();
103 if self.nd == 1 && dim != 2 && dim != 3 {
104 panic!(
105 "as nd={}, expected an input tensor with 2 or 3 dims, got {} ({:?})",
106 self.nd,
107 dim,
108 xs.size()
109 )
110 }
111 if self.nd > 1 && xs.dim() != self.nd + 2 {
112 panic!(
113 "as nd={}, expected an input tensor with {} dims, got {} ({:?})",
114 self.nd,
115 self.nd + 2,
116 dim,
117 xs.size()
118 )
119 };
120 Tensor::batch_norm(
121 xs,
122 self.ws.as_ref(),
123 self.bs.as_ref(),
124 Some(&self.running_mean),
125 Some(&self.running_var),
126 train,
127 self.config.momentum,
128 self.config.eps,
129 self.config.cudnn_enabled,
130 )
131 }
132}