tiny_recursive_rs/layers/
normalization.rs

1/// RMS Layer Normalization
2///
3/// Based on the Python implementation: `rms_norm` function in layers.py
4use candle_core::{Result, Tensor, DType};
5
6/// RMS Normalization function
7///
8/// Normalizes by root mean square with a small epsilon for numerical stability.
9/// The computation is done in f32 for precision, then cast back to original dtype.
10///
11/// # Arguments
12/// * `hidden_states` - Input tensor
13/// * `variance_epsilon` - Small constant for numerical stability (typically 1e-6)
14///
15/// # Returns
16/// Normalized tensor with same shape and dtype as input
17pub fn rms_norm(hidden_states: &Tensor, variance_epsilon: f64) -> Result<Tensor> {
18    let input_dtype = hidden_states.dtype();
19
20    // Convert to f32 for precision
21    let hidden_states = if input_dtype != DType::F32 {
22        hidden_states.to_dtype(DType::F32)?
23    } else {
24        hidden_states.clone()
25    };
26
27    // Compute variance: mean of squares along last dimension
28    let variance = hidden_states.sqr()?.mean_keepdim(candle_core::D::Minus1)?;
29
30    // Normalize: x * rsqrt(variance + epsilon)
31    let normalized = hidden_states.broadcast_div(
32        &(variance + variance_epsilon)?.sqrt()?
33    )?;
34
35    // Convert back to original dtype
36    if input_dtype != DType::F32 {
37        normalized.to_dtype(input_dtype)
38    } else {
39        Ok(normalized)
40    }
41}
42
43/// RMS Normalization layer with learnable scale
44pub struct RMSNorm {
45    weight: Tensor,
46    eps: f64,
47}
48
49impl RMSNorm {
50    pub fn new(hidden_size: usize, eps: f64, vb: candle_nn::VarBuilder) -> Result<Self> {
51        // Initialize weight to ones
52        let weight = vb.get((hidden_size,), "weight")?;
53        Ok(Self { weight, eps })
54    }
55
56    /// Create RMSNorm without learnable parameters (just the function)
57    pub fn new_no_weight(hidden_size: usize, eps: f64, device: &candle_core::Device) -> Result<Self> {
58        let weight = Tensor::ones((hidden_size,), DType::F32, device)?;
59        Ok(Self { weight, eps })
60    }
61
62    pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
63        let normalized = rms_norm(x, self.eps)?;
64        // Apply learnable scale
65        normalized.broadcast_mul(&self.weight)
66    }
67}
68
69#[cfg(test)]
70mod tests {
71    use super::*;
72    use candle_core::{Device, DType};
73
74    #[test]
75    fn test_rms_norm_basic() -> Result<()> {
76        let device = Device::Cpu;
77
78        // Create a simple tensor [1, 2, 3, 4]
79        let x = Tensor::new(&[1.0f32, 2.0, 3.0, 4.0], &device)?.reshape((1, 4))?;
80
81        // Apply RMS norm
82        let normalized = rms_norm(&x, 1e-6)?;
83
84        // The RMS should be approximately 1.0 after normalization
85        let rms = normalized.sqr()?.mean_all()?.to_scalar::<f32>()?;
86        assert!((rms - 1.0).abs() < 0.1, "RMS should be close to 1.0, got {}", rms);
87
88        Ok(())
89    }
90
91    #[test]
92    fn test_rms_norm_preserves_shape() -> Result<()> {
93        let device = Device::Cpu;
94
95        // Test with multiple shapes
96        let x = Tensor::randn(0f32, 1.0, (2, 8, 64), &device)?;
97        let normalized = rms_norm(&x, 1e-6)?;
98
99        assert_eq!(x.dims(), normalized.dims());
100
101        Ok(())
102    }
103}