tch_plus/nn/
layer_norm.rs1use crate::Tensor;
3use std::borrow::Borrow;
4
5#[derive(Debug, Clone, Copy)]
7pub struct LayerNormConfig {
8 pub cudnn_enabled: bool,
9 pub eps: f64,
10 pub elementwise_affine: bool,
11 pub ws_init: super::Init,
12 pub bs_init: super::Init,
13}
14
15impl Default for LayerNormConfig {
16 fn default() -> Self {
17 LayerNormConfig {
18 cudnn_enabled: true,
19 eps: 1e-5,
20 elementwise_affine: true,
21 ws_init: super::Init::Const(1.),
22 bs_init: super::Init::Const(0.),
23 }
24 }
25}
26
27#[derive(Debug)]
29pub struct LayerNorm {
30 config: LayerNormConfig,
31 pub ws: Option<Tensor>,
32 pub bs: Option<Tensor>,
33 pub normalized_shape: Vec<i64>,
34}
35
36pub fn layer_norm<'a, T: Borrow<super::Path<'a>>>(
37 vs: T,
38 normalized_shape: Vec<i64>,
39 config: LayerNormConfig,
40) -> LayerNorm {
41 let vs = vs.borrow();
42
43 let (ws, bs) = if config.elementwise_affine {
44 let ws = vs.var("weight", normalized_shape.as_slice(), config.ws_init);
45 let bs = vs.var("bias", normalized_shape.as_slice(), config.bs_init);
46 (Some(ws), Some(bs))
47 } else {
48 (None, None)
49 };
50
51 LayerNorm { config, ws, bs, normalized_shape }
52}
53
54impl super::module::Module for LayerNorm {
55 fn forward(&self, xs: &Tensor) -> Tensor {
56 Tensor::layer_norm(
57 xs,
58 self.normalized_shape.as_slice(),
59 self.ws.as_ref(),
60 self.bs.as_ref(),
61 self.config.eps,
62 self.config.cudnn_enabled,
63 )
64 }
65}