Skip to main content

Module normalization

Module normalization 

Source
Expand description

Layer normalization for transformer models.

This module implements layer normalization, a critical component of transformer architectures for stabilizing training.

§Layer Normalization Formula

LN(x) = γ ⊙ (x - μ) / √(σ² + ε) + β

Where:

  • μ = mean over the feature dimension
  • σ² = variance over the feature dimension
  • γ = learnable scale parameter
  • β = learnable shift parameter
  • ε = small constant for numerical stability (default: 1e-5)

§Einsum Representation

Layer norm can be expressed as a series of reductions and element-wise ops:

  1. Mean: reduce_mean(x, axis=-1) -> einsum("bsd->bs", x) / d
  2. Variance: reduce_mean((x - μ)², axis=-1)
  3. Normalize: (x - μ) / √(σ² + ε)
  4. Affine: γ ⊙ normalized + β

Structs§

LayerNorm
Layer normalization component
LayerNormConfig
Configuration for layer normalization
RMSNorm
RMS (Root Mean Square) normalization