tensorlogic_trustformers/
normalization.rs1use serde::{Deserialize, Serialize};
28use tensorlogic_ir::{EinsumGraph, EinsumNode};
29
30use crate::error::{Result, TrustformerError};
31
32#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
34pub struct LayerNormConfig {
35 pub normalized_shape: usize,
37 pub eps: f64,
39 pub elementwise_affine: bool,
41}
42
43impl LayerNormConfig {
44 pub fn new(normalized_shape: usize) -> Self {
46 Self {
47 normalized_shape,
48 eps: 1e-5,
49 elementwise_affine: true,
50 }
51 }
52
53 pub fn with_eps(mut self, eps: f64) -> Self {
55 self.eps = eps;
56 self
57 }
58
59 pub fn with_elementwise_affine(mut self, elementwise_affine: bool) -> Self {
61 self.elementwise_affine = elementwise_affine;
62 self
63 }
64
65 pub fn validate(&self) -> Result<()> {
67 if self.normalized_shape == 0 {
68 return Err(TrustformerError::InvalidDimension {
69 expected: 1,
70 got: 0,
71 context: "normalized_shape must be positive".to_string(),
72 });
73 }
74
75 if self.eps <= 0.0 {
76 return Err(TrustformerError::InvalidDimension {
77 expected: 1,
78 got: 0,
79 context: format!("eps must be positive, got {}", self.eps),
80 });
81 }
82
83 Ok(())
84 }
85}
86
87#[derive(Clone, Debug)]
89pub struct LayerNorm {
90 pub config: LayerNormConfig,
92}
93
94impl LayerNorm {
95 pub fn new(config: LayerNormConfig) -> Result<Self> {
97 config.validate()?;
98 Ok(Self { config })
99 }
100
101 pub fn build_layernorm_graph(&self, graph: &mut EinsumGraph) -> Result<Vec<usize>> {
111 let mean_tensor = graph.add_tensor("ln_mean");
114 let mean_node = EinsumNode::reduce("mean", vec![2], 0, mean_tensor); graph.add_node(mean_node)?;
116
117 let centered_tensor = graph.add_tensor("ln_centered");
119 let center_node = EinsumNode::elem_binary("sub", 0, mean_tensor, centered_tensor);
120 graph.add_node(center_node)?;
121
122 let squared_tensor = graph.add_tensor("ln_squared");
125 let square_node =
126 EinsumNode::elem_binary("mul", centered_tensor, centered_tensor, squared_tensor);
127 graph.add_node(square_node)?;
128
129 let var_tensor = graph.add_tensor("ln_var");
130 let var_node = EinsumNode::reduce("mean", vec![2], squared_tensor, var_tensor);
131 graph.add_node(var_node)?;
132
133 let var_eps_tensor = graph.add_tensor("ln_var_eps");
135 let eps_const_tensor = graph.add_tensor("eps_const");
136 let eps_node = EinsumNode::elem_binary("add", var_tensor, eps_const_tensor, var_eps_tensor);
137 graph.add_node(eps_node)?;
138
139 let std_tensor = graph.add_tensor("ln_std");
141 let sqrt_node = EinsumNode::elem_unary("sqrt", var_eps_tensor, std_tensor);
142 graph.add_node(sqrt_node)?;
143
144 let normalized_tensor = graph.add_tensor("ln_normalized");
146 let norm_node =
147 EinsumNode::elem_binary("div", centered_tensor, std_tensor, normalized_tensor);
148 graph.add_node(norm_node)?;
149
150 if self.config.elementwise_affine {
152 let scaled_tensor = graph.add_tensor("ln_scaled");
154 let scale_node = EinsumNode::elem_binary("mul", normalized_tensor, 1, scaled_tensor);
155 graph.add_node(scale_node)?;
156
157 let output_tensor = graph.add_tensor("ln_output");
159 let shift_node = EinsumNode::elem_binary("add", scaled_tensor, 2, output_tensor);
160 graph.add_node(shift_node)?;
161
162 Ok(vec![output_tensor])
163 } else {
164 Ok(vec![normalized_tensor])
165 }
166 }
167
168 pub fn eps(&self) -> f64 {
170 self.config.eps
171 }
172
173 pub fn has_elementwise_affine(&self) -> bool {
175 self.config.elementwise_affine
176 }
177}
178
179#[derive(Clone, Debug)]
187pub struct RMSNorm {
188 pub config: LayerNormConfig,
190}
191
192impl RMSNorm {
193 pub fn new(config: LayerNormConfig) -> Result<Self> {
195 config.validate()?;
196 Ok(Self { config })
197 }
198
199 pub fn build_rmsnorm_graph(&self, graph: &mut EinsumGraph) -> Result<Vec<usize>> {
208 let squared_tensor = graph.add_tensor("rms_squared");
210 let square_node = EinsumNode::elem_binary("mul", 0, 0, squared_tensor);
211 graph.add_node(square_node)?;
212
213 let mean_sq_tensor = graph.add_tensor("rms_mean_sq");
215 let mean_node = EinsumNode::reduce("mean", vec![2], squared_tensor, mean_sq_tensor);
216 graph.add_node(mean_node)?;
217
218 let mean_sq_eps_tensor = graph.add_tensor("rms_mean_sq_eps");
220 let eps_const_tensor = graph.add_tensor("eps_const");
221 let eps_node =
222 EinsumNode::elem_binary("add", mean_sq_tensor, eps_const_tensor, mean_sq_eps_tensor);
223 graph.add_node(eps_node)?;
224
225 let rms_tensor = graph.add_tensor("rms");
227 let sqrt_node = EinsumNode::elem_unary("sqrt", mean_sq_eps_tensor, rms_tensor);
228 graph.add_node(sqrt_node)?;
229
230 let normalized_tensor = graph.add_tensor("rms_normalized");
232 let norm_node = EinsumNode::elem_binary("div", 0, rms_tensor, normalized_tensor);
233 graph.add_node(norm_node)?;
234
235 if self.config.elementwise_affine {
237 let output_tensor = graph.add_tensor("rms_output");
238 let scale_node = EinsumNode::elem_binary("mul", normalized_tensor, 1, output_tensor);
239 graph.add_node(scale_node)?;
240 Ok(vec![output_tensor])
241 } else {
242 Ok(vec![normalized_tensor])
243 }
244 }
245
246 pub fn eps(&self) -> f64 {
248 self.config.eps
249 }
250}
251
252#[cfg(test)]
253mod tests {
254 use super::*;
255
256 #[test]
257 fn test_layernorm_config_creation() {
258 let config = LayerNormConfig::new(512);
259 assert_eq!(config.normalized_shape, 512);
260 assert!((config.eps - 1e-5).abs() < 1e-10);
261 assert!(config.elementwise_affine);
262 assert!(config.validate().is_ok());
263 }
264
265 #[test]
266 fn test_layernorm_config_with_eps() {
267 let config = LayerNormConfig::new(512).with_eps(1e-6);
268 assert!((config.eps - 1e-6).abs() < 1e-10);
269 assert!(config.validate().is_ok());
270 }
271
272 #[test]
273 fn test_layernorm_config_without_affine() {
274 let config = LayerNormConfig::new(512).with_elementwise_affine(false);
275 assert!(!config.elementwise_affine);
276 }
277
278 #[test]
279 fn test_layernorm_creation() {
280 let config = LayerNormConfig::new(512);
281 let ln = LayerNorm::new(config).unwrap();
282 assert_eq!(ln.config.normalized_shape, 512);
283 assert!(ln.has_elementwise_affine());
284 }
285
286 #[test]
287 fn test_layernorm_graph_building_with_affine() {
288 let config = LayerNormConfig::new(512);
289 let ln = LayerNorm::new(config).unwrap();
290
291 let mut graph = EinsumGraph::new();
292 graph.add_tensor("x");
293 graph.add_tensor("gamma");
294 graph.add_tensor("beta");
295
296 let outputs = ln.build_layernorm_graph(&mut graph).unwrap();
297 assert_eq!(outputs.len(), 1);
298 assert!(!graph.nodes.is_empty());
299 }
300
301 #[test]
302 fn test_layernorm_graph_building_without_affine() {
303 let config = LayerNormConfig::new(512).with_elementwise_affine(false);
304 let ln = LayerNorm::new(config).unwrap();
305
306 let mut graph = EinsumGraph::new();
307 graph.add_tensor("x");
308
309 let outputs = ln.build_layernorm_graph(&mut graph).unwrap();
310 assert_eq!(outputs.len(), 1);
311 assert!(!graph.nodes.is_empty());
312 }
313
314 #[test]
315 fn test_rmsnorm_creation() {
316 let config = LayerNormConfig::new(512);
317 let rms = RMSNorm::new(config).unwrap();
318 assert_eq!(rms.config.normalized_shape, 512);
319 }
320
321 #[test]
322 fn test_rmsnorm_graph_building() {
323 let config = LayerNormConfig::new(512);
324 let rms = RMSNorm::new(config).unwrap();
325
326 let mut graph = EinsumGraph::new();
327 graph.add_tensor("x");
328 graph.add_tensor("gamma");
329
330 let outputs = rms.build_rmsnorm_graph(&mut graph).unwrap();
331 assert_eq!(outputs.len(), 1);
332 assert!(!graph.nodes.is_empty());
333 }
334
335 #[test]
336 fn test_invalid_config_zero_shape() {
337 let config = LayerNormConfig::new(0);
338 assert!(config.validate().is_err());
339 }
340
341 #[test]
342 fn test_invalid_config_negative_eps() {
343 let config = LayerNormConfig::new(512).with_eps(-1e-5);
344 assert!(config.validate().is_err());
345 }
346
347 #[test]
348 fn test_layernorm_eps() {
349 let config = LayerNormConfig::new(512).with_eps(1e-6);
350 let ln = LayerNorm::new(config).unwrap();
351 assert!((ln.eps() - 1e-6).abs() < 1e-10);
352 }
353}