Skip to main content

tensorlogic_trustformers/
normalization.rs

1//! Layer normalization for transformer models.
2//!
3//! This module implements layer normalization, a critical component of
4//! transformer architectures for stabilizing training.
5//!
6//! ## Layer Normalization Formula
7//!
8//! ```text
9//! LN(x) = γ ⊙ (x - μ) / √(σ² + ε) + β
10//! ```
11//!
12//! Where:
13//! - `μ` = mean over the feature dimension
14//! - `σ²` = variance over the feature dimension
15//! - `γ` = learnable scale parameter
16//! - `β` = learnable shift parameter
17//! - `ε` = small constant for numerical stability (default: 1e-5)
18//!
19//! ## Einsum Representation
20//!
21//! Layer norm can be expressed as a series of reductions and element-wise ops:
22//! 1. Mean: `reduce_mean(x, axis=-1)` -> `einsum("bsd->bs", x) / d`
23//! 2. Variance: `reduce_mean((x - μ)², axis=-1)`
24//! 3. Normalize: `(x - μ) / √(σ² + ε)`
25//! 4. Affine: `γ ⊙ normalized + β`
26
27use serde::{Deserialize, Serialize};
28use tensorlogic_ir::{EinsumGraph, EinsumNode};
29
30use crate::error::{Result, TrustformerError};
31
32/// Configuration for layer normalization
33#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
34pub struct LayerNormConfig {
35    /// Normalized dimension (typically d_model)
36    pub normalized_shape: usize,
37    /// Small constant for numerical stability
38    pub eps: f64,
39    /// Whether to include learnable scale parameter (γ)
40    pub elementwise_affine: bool,
41}
42
43impl LayerNormConfig {
44    /// Create a new layer normalization configuration
45    pub fn new(normalized_shape: usize) -> Self {
46        Self {
47            normalized_shape,
48            eps: 1e-5,
49            elementwise_affine: true,
50        }
51    }
52
53    /// Set epsilon for numerical stability
54    pub fn with_eps(mut self, eps: f64) -> Self {
55        self.eps = eps;
56        self
57    }
58
59    /// Set whether to use elementwise affine transformation
60    pub fn with_elementwise_affine(mut self, elementwise_affine: bool) -> Self {
61        self.elementwise_affine = elementwise_affine;
62        self
63    }
64
65    /// Validate configuration
66    pub fn validate(&self) -> Result<()> {
67        if self.normalized_shape == 0 {
68            return Err(TrustformerError::InvalidDimension {
69                expected: 1,
70                got: 0,
71                context: "normalized_shape must be positive".to_string(),
72            });
73        }
74
75        if self.eps <= 0.0 {
76            return Err(TrustformerError::InvalidDimension {
77                expected: 1,
78                got: 0,
79                context: format!("eps must be positive, got {}", self.eps),
80            });
81        }
82
83        Ok(())
84    }
85}
86
87/// Layer normalization component
88#[derive(Clone, Debug)]
89pub struct LayerNorm {
90    /// Configuration
91    pub config: LayerNormConfig,
92}
93
94impl LayerNorm {
95    /// Create a new layer normalization component
96    pub fn new(config: LayerNormConfig) -> Result<Self> {
97        config.validate()?;
98        Ok(Self { config })
99    }
100
101    /// Build einsum graph for layer normalization
102    ///
103    /// Input tensors:
104    /// - 0: x (input) `[batch, seq_len, d_model]`
105    /// - 1: gamma (scale) `[d_model]` (if elementwise_affine)
106    /// - 2: beta (shift) `[d_model]` (if elementwise_affine)
107    ///
108    /// Output tensors:
109    /// - output: `[batch, seq_len, d_model]` (normalized)
110    pub fn build_layernorm_graph(&self, graph: &mut EinsumGraph) -> Result<Vec<usize>> {
111        // Step 1: Compute mean over feature dimension
112        // mean = reduce_mean(x, axis=-1, keepdims=True)
113        let mean_tensor = graph.add_tensor("ln_mean");
114        let mean_node = EinsumNode::reduce("mean", vec![2], 0, mean_tensor); // axis=-1 (d_model dimension)
115        graph.add_node(mean_node)?;
116
117        // Step 2: Center the input (x - mean)
118        let centered_tensor = graph.add_tensor("ln_centered");
119        let center_node = EinsumNode::elem_binary("sub", 0, mean_tensor, centered_tensor);
120        graph.add_node(center_node)?;
121
122        // Step 3: Compute variance
123        // var = reduce_mean((x - mean)^2, axis=-1, keepdims=True)
124        let squared_tensor = graph.add_tensor("ln_squared");
125        let square_node =
126            EinsumNode::elem_binary("mul", centered_tensor, centered_tensor, squared_tensor);
127        graph.add_node(square_node)?;
128
129        let var_tensor = graph.add_tensor("ln_var");
130        let var_node = EinsumNode::reduce("mean", vec![2], squared_tensor, var_tensor);
131        graph.add_node(var_node)?;
132
133        // Step 4: Add epsilon for numerical stability
134        let var_eps_tensor = graph.add_tensor("ln_var_eps");
135        let eps_const_tensor = graph.add_tensor("eps_const");
136        let eps_node = EinsumNode::elem_binary("add", var_tensor, eps_const_tensor, var_eps_tensor);
137        graph.add_node(eps_node)?;
138
139        // Step 5: Compute standard deviation (sqrt(var + eps))
140        let std_tensor = graph.add_tensor("ln_std");
141        let sqrt_node = EinsumNode::elem_unary("sqrt", var_eps_tensor, std_tensor);
142        graph.add_node(sqrt_node)?;
143
144        // Step 6: Normalize (x - mean) / std
145        let normalized_tensor = graph.add_tensor("ln_normalized");
146        let norm_node =
147            EinsumNode::elem_binary("div", centered_tensor, std_tensor, normalized_tensor);
148        graph.add_node(norm_node)?;
149
150        // Step 7: Apply affine transformation if configured
151        if self.config.elementwise_affine {
152            // Scale: gamma * normalized
153            let scaled_tensor = graph.add_tensor("ln_scaled");
154            let scale_node = EinsumNode::elem_binary("mul", normalized_tensor, 1, scaled_tensor);
155            graph.add_node(scale_node)?;
156
157            // Shift: scaled + beta
158            let output_tensor = graph.add_tensor("ln_output");
159            let shift_node = EinsumNode::elem_binary("add", scaled_tensor, 2, output_tensor);
160            graph.add_node(shift_node)?;
161
162            Ok(vec![output_tensor])
163        } else {
164            Ok(vec![normalized_tensor])
165        }
166    }
167
168    /// Get epsilon value
169    pub fn eps(&self) -> f64 {
170        self.config.eps
171    }
172
173    /// Check if using elementwise affine
174    pub fn has_elementwise_affine(&self) -> bool {
175        self.config.elementwise_affine
176    }
177}
178
179/// RMS (Root Mean Square) normalization
180///
181/// A simplified variant of layer normalization that only computes RMS:
182/// ```text
183/// RMSNorm(x) = x / RMS(x) * γ
184/// where RMS(x) = √(mean(x²) + ε)
185/// ```
186#[derive(Clone, Debug)]
187pub struct RMSNorm {
188    /// Configuration
189    pub config: LayerNormConfig,
190}
191
192impl RMSNorm {
193    /// Create a new RMS normalization component
194    pub fn new(config: LayerNormConfig) -> Result<Self> {
195        config.validate()?;
196        Ok(Self { config })
197    }
198
199    /// Build einsum graph for RMS normalization
200    ///
201    /// Input tensors:
202    /// - 0: x (input) `[batch, seq_len, d_model]`
203    /// - 1: gamma (scale) `[d_model]` (if elementwise_affine)
204    ///
205    /// Output tensors:
206    /// - output: `[batch, seq_len, d_model]` (normalized)
207    pub fn build_rmsnorm_graph(&self, graph: &mut EinsumGraph) -> Result<Vec<usize>> {
208        // Step 1: Compute x^2
209        let squared_tensor = graph.add_tensor("rms_squared");
210        let square_node = EinsumNode::elem_binary("mul", 0, 0, squared_tensor);
211        graph.add_node(square_node)?;
212
213        // Step 2: Compute mean(x^2)
214        let mean_sq_tensor = graph.add_tensor("rms_mean_sq");
215        let mean_node = EinsumNode::reduce("mean", vec![2], squared_tensor, mean_sq_tensor);
216        graph.add_node(mean_node)?;
217
218        // Step 3: Add epsilon
219        let mean_sq_eps_tensor = graph.add_tensor("rms_mean_sq_eps");
220        let eps_const_tensor = graph.add_tensor("eps_const");
221        let eps_node =
222            EinsumNode::elem_binary("add", mean_sq_tensor, eps_const_tensor, mean_sq_eps_tensor);
223        graph.add_node(eps_node)?;
224
225        // Step 4: Compute RMS = sqrt(mean(x^2) + eps)
226        let rms_tensor = graph.add_tensor("rms");
227        let sqrt_node = EinsumNode::elem_unary("sqrt", mean_sq_eps_tensor, rms_tensor);
228        graph.add_node(sqrt_node)?;
229
230        // Step 5: Normalize x / RMS
231        let normalized_tensor = graph.add_tensor("rms_normalized");
232        let norm_node = EinsumNode::elem_binary("div", 0, rms_tensor, normalized_tensor);
233        graph.add_node(norm_node)?;
234
235        // Step 6: Apply scale if configured
236        if self.config.elementwise_affine {
237            let output_tensor = graph.add_tensor("rms_output");
238            let scale_node = EinsumNode::elem_binary("mul", normalized_tensor, 1, output_tensor);
239            graph.add_node(scale_node)?;
240            Ok(vec![output_tensor])
241        } else {
242            Ok(vec![normalized_tensor])
243        }
244    }
245
246    /// Get epsilon value
247    pub fn eps(&self) -> f64 {
248        self.config.eps
249    }
250}
251
252#[cfg(test)]
253mod tests {
254    use super::*;
255
256    #[test]
257    fn test_layernorm_config_creation() {
258        let config = LayerNormConfig::new(512);
259        assert_eq!(config.normalized_shape, 512);
260        assert!((config.eps - 1e-5).abs() < 1e-10);
261        assert!(config.elementwise_affine);
262        assert!(config.validate().is_ok());
263    }
264
265    #[test]
266    fn test_layernorm_config_with_eps() {
267        let config = LayerNormConfig::new(512).with_eps(1e-6);
268        assert!((config.eps - 1e-6).abs() < 1e-10);
269        assert!(config.validate().is_ok());
270    }
271
272    #[test]
273    fn test_layernorm_config_without_affine() {
274        let config = LayerNormConfig::new(512).with_elementwise_affine(false);
275        assert!(!config.elementwise_affine);
276    }
277
278    #[test]
279    fn test_layernorm_creation() {
280        let config = LayerNormConfig::new(512);
281        let ln = LayerNorm::new(config).unwrap();
282        assert_eq!(ln.config.normalized_shape, 512);
283        assert!(ln.has_elementwise_affine());
284    }
285
286    #[test]
287    fn test_layernorm_graph_building_with_affine() {
288        let config = LayerNormConfig::new(512);
289        let ln = LayerNorm::new(config).unwrap();
290
291        let mut graph = EinsumGraph::new();
292        graph.add_tensor("x");
293        graph.add_tensor("gamma");
294        graph.add_tensor("beta");
295
296        let outputs = ln.build_layernorm_graph(&mut graph).unwrap();
297        assert_eq!(outputs.len(), 1);
298        assert!(!graph.nodes.is_empty());
299    }
300
301    #[test]
302    fn test_layernorm_graph_building_without_affine() {
303        let config = LayerNormConfig::new(512).with_elementwise_affine(false);
304        let ln = LayerNorm::new(config).unwrap();
305
306        let mut graph = EinsumGraph::new();
307        graph.add_tensor("x");
308
309        let outputs = ln.build_layernorm_graph(&mut graph).unwrap();
310        assert_eq!(outputs.len(), 1);
311        assert!(!graph.nodes.is_empty());
312    }
313
314    #[test]
315    fn test_rmsnorm_creation() {
316        let config = LayerNormConfig::new(512);
317        let rms = RMSNorm::new(config).unwrap();
318        assert_eq!(rms.config.normalized_shape, 512);
319    }
320
321    #[test]
322    fn test_rmsnorm_graph_building() {
323        let config = LayerNormConfig::new(512);
324        let rms = RMSNorm::new(config).unwrap();
325
326        let mut graph = EinsumGraph::new();
327        graph.add_tensor("x");
328        graph.add_tensor("gamma");
329
330        let outputs = rms.build_rmsnorm_graph(&mut graph).unwrap();
331        assert_eq!(outputs.len(), 1);
332        assert!(!graph.nodes.is_empty());
333    }
334
335    #[test]
336    fn test_invalid_config_zero_shape() {
337        let config = LayerNormConfig::new(0);
338        assert!(config.validate().is_err());
339    }
340
341    #[test]
342    fn test_invalid_config_negative_eps() {
343        let config = LayerNormConfig::new(512).with_eps(-1e-5);
344        assert!(config.validate().is_err());
345    }
346
347    #[test]
348    fn test_layernorm_eps() {
349        let config = LayerNormConfig::new(512).with_eps(1e-6);
350        let ln = LayerNorm::new(config).unwrap();
351        assert!((ln.eps() - 1e-6).abs() < 1e-10);
352    }
353}