tch_plus/nn/
layer_norm.rs

1//! A layer-normalization layer.
2use crate::Tensor;
3use std::borrow::Borrow;
4
5/// Layer-normalization config.
6#[derive(Debug, Clone, Copy)]
7pub struct LayerNormConfig {
8    pub cudnn_enabled: bool,
9    pub eps: f64,
10    pub elementwise_affine: bool,
11    pub ws_init: super::Init,
12    pub bs_init: super::Init,
13}
14
15impl Default for LayerNormConfig {
16    fn default() -> Self {
17        LayerNormConfig {
18            cudnn_enabled: true,
19            eps: 1e-5,
20            elementwise_affine: true,
21            ws_init: super::Init::Const(1.),
22            bs_init: super::Init::Const(0.),
23        }
24    }
25}
26
27/// A layer-normalization layer.
28#[derive(Debug)]
29pub struct LayerNorm {
30    config: LayerNormConfig,
31    pub ws: Option<Tensor>,
32    pub bs: Option<Tensor>,
33    pub normalized_shape: Vec<i64>,
34}
35
36pub fn layer_norm<'a, T: Borrow<super::Path<'a>>>(
37    vs: T,
38    normalized_shape: Vec<i64>,
39    config: LayerNormConfig,
40) -> LayerNorm {
41    let vs = vs.borrow();
42
43    let (ws, bs) = if config.elementwise_affine {
44        let ws = vs.var("weight", normalized_shape.as_slice(), config.ws_init);
45        let bs = vs.var("bias", normalized_shape.as_slice(), config.bs_init);
46        (Some(ws), Some(bs))
47    } else {
48        (None, None)
49    };
50
51    LayerNorm { config, ws, bs, normalized_shape }
52}
53
54impl super::module::Module for LayerNorm {
55    fn forward(&self, xs: &Tensor) -> Tensor {
56        Tensor::layer_norm(
57            xs,
58            self.normalized_shape.as_slice(),
59            self.ws.as_ref(),
60            self.bs.as_ref(),
61            self.config.eps,
62            self.config.cudnn_enabled,
63        )
64    }
65}