tiny_recursive_rs/layers/
normalization.rs1use candle_core::{Result, Tensor, DType};
5
6pub fn rms_norm(hidden_states: &Tensor, variance_epsilon: f64) -> Result<Tensor> {
18 let input_dtype = hidden_states.dtype();
19
20 let hidden_states = if input_dtype != DType::F32 {
22 hidden_states.to_dtype(DType::F32)?
23 } else {
24 hidden_states.clone()
25 };
26
27 let variance = hidden_states.sqr()?.mean_keepdim(candle_core::D::Minus1)?;
29
30 let normalized = hidden_states.broadcast_div(
32 &(variance + variance_epsilon)?.sqrt()?
33 )?;
34
35 if input_dtype != DType::F32 {
37 normalized.to_dtype(input_dtype)
38 } else {
39 Ok(normalized)
40 }
41}
42
43pub struct RMSNorm {
45 weight: Tensor,
46 eps: f64,
47}
48
49impl RMSNorm {
50 pub fn new(hidden_size: usize, eps: f64, vb: candle_nn::VarBuilder) -> Result<Self> {
51 let weight = vb.get((hidden_size,), "weight")?;
53 Ok(Self { weight, eps })
54 }
55
56 pub fn new_no_weight(hidden_size: usize, eps: f64, device: &candle_core::Device) -> Result<Self> {
58 let weight = Tensor::ones((hidden_size,), DType::F32, device)?;
59 Ok(Self { weight, eps })
60 }
61
62 pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
63 let normalized = rms_norm(x, self.eps)?;
64 normalized.broadcast_mul(&self.weight)
66 }
67}
68
69#[cfg(test)]
70mod tests {
71 use super::*;
72 use candle_core::{Device, DType};
73
74 #[test]
75 fn test_rms_norm_basic() -> Result<()> {
76 let device = Device::Cpu;
77
78 let x = Tensor::new(&[1.0f32, 2.0, 3.0, 4.0], &device)?.reshape((1, 4))?;
80
81 let normalized = rms_norm(&x, 1e-6)?;
83
84 let rms = normalized.sqr()?.mean_all()?.to_scalar::<f32>()?;
86 assert!((rms - 1.0).abs() < 0.1, "RMS should be close to 1.0, got {}", rms);
87
88 Ok(())
89 }
90
91 #[test]
92 fn test_rms_norm_preserves_shape() -> Result<()> {
93 let device = Device::Cpu;
94
95 let x = Tensor::randn(0f32, 1.0, (2, 8, 64), &device)?;
97 let normalized = rms_norm(&x, 1e-6)?;
98
99 assert_eq!(x.dims(), normalized.dims());
100
101 Ok(())
102 }
103}