Expand description
Layer normalization for transformer models.
This module implements layer normalization, a critical component of transformer architectures for stabilizing training.
§Layer Normalization Formula
LN(x) = γ ⊙ (x - μ) / √(σ² + ε) + βWhere:
μ= mean over the feature dimensionσ²= variance over the feature dimensionγ= learnable scale parameterβ= learnable shift parameterε= small constant for numerical stability (default: 1e-5)
§Einsum Representation
Layer norm can be expressed as a series of reductions and element-wise ops:
- Mean:
reduce_mean(x, axis=-1)->einsum("bsd->bs", x) / d - Variance:
reduce_mean((x - μ)², axis=-1) - Normalize:
(x - μ) / √(σ² + ε) - Affine:
γ ⊙ normalized + β
Structs§
- Layer
Norm - Layer normalization component
- Layer
Norm Config - Configuration for layer normalization
- RMSNorm
- RMS (Root Mean Square) normalization