Skip to main content

tensorlogic_trustformers/
layers.rs

1//! Complete transformer encoder and decoder layers.
2//!
3//! This module implements full transformer layers that combine:
4//! - Multi-head attention
5//! - Feed-forward networks
6//! - Layer normalization
7//! - Residual connections
8//!
9//! ## Transformer Encoder Layer
10//!
11//! ```text
12//! x' = LayerNorm(x + MultiHeadAttention(x, x, x))
13//! output = LayerNorm(x' + FFN(x'))
14//! ```
15//!
16//! ## Transformer Decoder Layer
17//!
18//! ```text
19//! x' = LayerNorm(x + MaskedMultiHeadAttention(x, x, x))
20//! x'' = LayerNorm(x' + CrossAttention(x', enc_output, enc_output))
21//! output = LayerNorm(x'' + FFN(x''))
22//! ```
23
24use tensorlogic_ir::EinsumGraph;
25
26use crate::{
27    attention::MultiHeadAttention,
28    config::{AttentionConfig, FeedForwardConfig},
29    error::Result,
30    ffn::FeedForward,
31    normalization::{LayerNorm, LayerNormConfig},
32};
33
34/// Configuration for a complete transformer encoder layer
35#[derive(Clone, Debug)]
36pub struct EncoderLayerConfig {
37    /// Attention configuration
38    pub attention: AttentionConfig,
39    /// Feed-forward configuration
40    pub feed_forward: FeedForwardConfig,
41    /// Layer normalization configuration
42    pub layer_norm: LayerNormConfig,
43    /// Whether to use pre-layer normalization (vs post)
44    pub pre_norm: bool,
45}
46
47impl EncoderLayerConfig {
48    /// Create a new encoder layer configuration
49    pub fn new(d_model: usize, n_heads: usize, d_ff: usize) -> Result<Self> {
50        Ok(Self {
51            attention: AttentionConfig::new(d_model, n_heads)?,
52            feed_forward: FeedForwardConfig::new(d_model, d_ff),
53            layer_norm: LayerNormConfig::new(d_model),
54            pre_norm: true,
55        })
56    }
57
58    /// Set pre-normalization vs post-normalization
59    pub fn with_pre_norm(mut self, pre_norm: bool) -> Self {
60        self.pre_norm = pre_norm;
61        self
62    }
63
64    /// Set causal masking
65    pub fn with_causal(mut self, causal: bool) -> Self {
66        self.attention = self.attention.with_causal(causal);
67        self
68    }
69
70    /// Set dropout
71    pub fn with_dropout(mut self, dropout: f64) -> Self {
72        self.attention = self.attention.with_dropout(dropout);
73        self.feed_forward = self.feed_forward.with_dropout(dropout);
74        self
75    }
76
77    /// Validate configuration
78    pub fn validate(&self) -> Result<()> {
79        self.attention.validate()?;
80        self.feed_forward.validate()?;
81        self.layer_norm.validate()?;
82
83        // Check dimension consistency
84        if self.attention.d_model != self.feed_forward.d_model {
85            return Err(crate::error::TrustformerError::InvalidDimension {
86                expected: self.attention.d_model,
87                got: self.feed_forward.d_model,
88                context: "d_model mismatch between attention and FFN".to_string(),
89            });
90        }
91
92        if self.attention.d_model != self.layer_norm.normalized_shape {
93            return Err(crate::error::TrustformerError::InvalidDimension {
94                expected: self.attention.d_model,
95                got: self.layer_norm.normalized_shape,
96                context: "d_model mismatch with layer norm".to_string(),
97            });
98        }
99
100        Ok(())
101    }
102}
103
104/// Transformer encoder layer
105#[derive(Clone, Debug)]
106pub struct EncoderLayer {
107    /// Configuration
108    pub config: EncoderLayerConfig,
109    /// Multi-head attention
110    pub attention: MultiHeadAttention,
111    /// Feed-forward network
112    pub ffn: FeedForward,
113    /// First layer normalization
114    pub norm1: LayerNorm,
115    /// Second layer normalization
116    pub norm2: LayerNorm,
117}
118
119impl EncoderLayer {
120    /// Create a new encoder layer
121    pub fn new(config: EncoderLayerConfig) -> Result<Self> {
122        config.validate()?;
123
124        let attention = MultiHeadAttention::new(config.attention.clone())?;
125        let ffn = FeedForward::new(config.feed_forward.clone())?;
126        let norm1 = LayerNorm::new(config.layer_norm.clone())?;
127        let norm2 = LayerNorm::new(config.layer_norm.clone())?;
128
129        Ok(Self {
130            config,
131            attention,
132            ffn,
133            norm1,
134            norm2,
135        })
136    }
137
138    /// Build einsum graph for encoder layer
139    ///
140    /// Input tensors:
141    /// - 0: x (input) [batch, seq_len, d_model]
142    /// - 1-N: weight matrices and parameters for attention, FFN, and layer norms
143    ///
144    /// Output tensors:
145    /// - output: [batch, seq_len, d_model]
146    pub fn build_encoder_layer_graph(&self, graph: &mut EinsumGraph) -> Result<Vec<usize>> {
147        let input_tensor = 0;
148
149        if self.config.pre_norm {
150            // Pre-LN: LN(x) -> Attention -> Add -> LN -> FFN -> Add
151            self.build_pre_norm_encoder(graph, input_tensor)
152        } else {
153            // Post-LN: Attention -> Add -> LN -> FFN -> Add -> LN
154            self.build_post_norm_encoder(graph, input_tensor)
155        }
156    }
157
158    fn build_pre_norm_encoder(
159        &self,
160        graph: &mut EinsumGraph,
161        input_tensor: usize,
162    ) -> Result<Vec<usize>> {
163        // Step 1: First layer norm
164        let normed1_outputs = self.norm1.build_layernorm_graph(graph)?;
165        let normed1 = normed1_outputs[0];
166
167        // Step 2: Multi-head attention (Q, K, V all from normed input)
168        // Replace input references with normed1
169        let q_tensor = graph.add_tensor("encoder_Q");
170        let k_tensor = graph.add_tensor("encoder_K");
171        let v_tensor = graph.add_tensor("encoder_V");
172
173        // Create copies for Q, K, V
174        let _q_node = tensorlogic_ir::EinsumNode::elem_unary("identity", normed1, q_tensor);
175        let _k_node = tensorlogic_ir::EinsumNode::elem_unary("identity", normed1, k_tensor);
176        let _v_node = tensorlogic_ir::EinsumNode::elem_unary("identity", normed1, v_tensor);
177
178        let attn_outputs = self.attention.build_mha_graph(graph)?;
179        let attn_output = attn_outputs[0];
180
181        // Step 3: Residual connection: x + attention_output
182        let residual1 = graph.add_tensor("encoder_residual1");
183        let res1_node =
184            tensorlogic_ir::EinsumNode::elem_binary("add", input_tensor, attn_output, residual1);
185        graph.add_node(res1_node)?;
186
187        // Step 4: Second layer norm
188        let _normed2_outputs = self.norm2.build_layernorm_graph(graph)?;
189
190        // Step 5: Feed-forward network
191        let ffn_outputs = self.ffn.build_ffn_graph(graph)?;
192        let ffn_output = ffn_outputs[0];
193
194        // Step 6: Second residual connection: residual1 + ffn_output
195        let output = graph.add_tensor("encoder_output");
196        let res2_node =
197            tensorlogic_ir::EinsumNode::elem_binary("add", residual1, ffn_output, output);
198        graph.add_node(res2_node)?;
199
200        Ok(vec![output])
201    }
202
203    fn build_post_norm_encoder(
204        &self,
205        graph: &mut EinsumGraph,
206        input_tensor: usize,
207    ) -> Result<Vec<usize>> {
208        // Step 1: Multi-head attention
209        let attn_outputs = self.attention.build_mha_graph(graph)?;
210        let attn_output = attn_outputs[0];
211
212        // Step 2: Residual connection
213        let residual1 = graph.add_tensor("encoder_residual1");
214        let res1_node =
215            tensorlogic_ir::EinsumNode::elem_binary("add", input_tensor, attn_output, residual1);
216        graph.add_node(res1_node)?;
217
218        // Step 3: First layer norm
219        let normed1_outputs = self.norm1.build_layernorm_graph(graph)?;
220        let _normed1 = normed1_outputs[0];
221
222        // Step 4: Feed-forward network
223        let ffn_outputs = self.ffn.build_ffn_graph(graph)?;
224        let ffn_output = ffn_outputs[0];
225
226        // Step 5: Second residual connection
227        let residual2 = graph.add_tensor("encoder_residual2");
228        let res2_node =
229            tensorlogic_ir::EinsumNode::elem_binary("add", _normed1, ffn_output, residual2);
230        graph.add_node(res2_node)?;
231
232        // Step 6: Second layer norm
233        let normed2_outputs = self.norm2.build_layernorm_graph(graph)?;
234        let output = normed2_outputs[0];
235
236        Ok(vec![output])
237    }
238}
239
240/// Configuration for a complete transformer decoder layer
241#[derive(Clone, Debug)]
242pub struct DecoderLayerConfig {
243    /// Self-attention configuration (with causal masking)
244    pub self_attention: AttentionConfig,
245    /// Cross-attention configuration
246    pub cross_attention: AttentionConfig,
247    /// Feed-forward configuration
248    pub feed_forward: FeedForwardConfig,
249    /// Layer normalization configuration
250    pub layer_norm: LayerNormConfig,
251    /// Whether to use pre-layer normalization
252    pub pre_norm: bool,
253}
254
255impl DecoderLayerConfig {
256    /// Create a new decoder layer configuration
257    pub fn new(d_model: usize, n_heads: usize, d_ff: usize) -> Result<Self> {
258        Ok(Self {
259            self_attention: AttentionConfig::new(d_model, n_heads)?.with_causal(true),
260            cross_attention: AttentionConfig::new(d_model, n_heads)?,
261            feed_forward: FeedForwardConfig::new(d_model, d_ff),
262            layer_norm: LayerNormConfig::new(d_model),
263            pre_norm: true,
264        })
265    }
266
267    /// Set pre-normalization vs post-normalization
268    pub fn with_pre_norm(mut self, pre_norm: bool) -> Self {
269        self.pre_norm = pre_norm;
270        self
271    }
272
273    /// Set dropout
274    pub fn with_dropout(mut self, dropout: f64) -> Self {
275        self.self_attention = self.self_attention.with_dropout(dropout);
276        self.cross_attention = self.cross_attention.with_dropout(dropout);
277        self.feed_forward = self.feed_forward.with_dropout(dropout);
278        self
279    }
280
281    /// Validate configuration
282    pub fn validate(&self) -> Result<()> {
283        self.self_attention.validate()?;
284        self.cross_attention.validate()?;
285        self.feed_forward.validate()?;
286        self.layer_norm.validate()?;
287
288        // Verify causal masking is enabled for self-attention
289        if !self.self_attention.causal {
290            return Err(crate::error::TrustformerError::InvalidDimension {
291                expected: 1,
292                got: 0,
293                context: "Decoder self-attention must use causal masking".to_string(),
294            });
295        }
296
297        // Check dimension consistency
298        if self.self_attention.d_model != self.cross_attention.d_model
299            || self.self_attention.d_model != self.feed_forward.d_model
300            || self.self_attention.d_model != self.layer_norm.normalized_shape
301        {
302            return Err(crate::error::TrustformerError::InvalidDimension {
303                expected: self.self_attention.d_model,
304                got: 0,
305                context: "d_model mismatch between components".to_string(),
306            });
307        }
308
309        Ok(())
310    }
311}
312
313/// Transformer decoder layer
314#[derive(Clone, Debug)]
315pub struct DecoderLayer {
316    /// Configuration
317    pub config: DecoderLayerConfig,
318    /// Masked self-attention
319    pub self_attention: MultiHeadAttention,
320    /// Cross-attention
321    pub cross_attention: MultiHeadAttention,
322    /// Feed-forward network
323    pub ffn: FeedForward,
324    /// First layer normalization (self-attention)
325    pub norm1: LayerNorm,
326    /// Second layer normalization (cross-attention)
327    pub norm2: LayerNorm,
328    /// Third layer normalization (FFN)
329    pub norm3: LayerNorm,
330}
331
332impl DecoderLayer {
333    /// Create a new decoder layer
334    pub fn new(config: DecoderLayerConfig) -> Result<Self> {
335        config.validate()?;
336
337        let self_attention = MultiHeadAttention::new(config.self_attention.clone())?;
338        let cross_attention = MultiHeadAttention::new(config.cross_attention.clone())?;
339        let ffn = FeedForward::new(config.feed_forward.clone())?;
340        let norm1 = LayerNorm::new(config.layer_norm.clone())?;
341        let norm2 = LayerNorm::new(config.layer_norm.clone())?;
342        let norm3 = LayerNorm::new(config.layer_norm.clone())?;
343
344        Ok(Self {
345            config,
346            self_attention,
347            cross_attention,
348            ffn,
349            norm1,
350            norm2,
351            norm3,
352        })
353    }
354
355    /// Build einsum graph for decoder layer
356    ///
357    /// Input tensors:
358    /// - 0: x (decoder input) [batch, tgt_len, d_model]
359    /// - 1: encoder_output [batch, src_len, d_model]
360    /// - 2-N: weight matrices and parameters
361    ///
362    /// Output tensors:
363    /// - output: [batch, tgt_len, d_model]
364    pub fn build_decoder_layer_graph(&self, graph: &mut EinsumGraph) -> Result<Vec<usize>> {
365        let decoder_input = 0;
366        let encoder_output = 1;
367
368        if self.config.pre_norm {
369            self.build_pre_norm_decoder(graph, decoder_input, encoder_output)
370        } else {
371            self.build_post_norm_decoder(graph, decoder_input, encoder_output)
372        }
373    }
374
375    fn build_pre_norm_decoder(
376        &self,
377        graph: &mut EinsumGraph,
378        decoder_input: usize,
379        _encoder_output: usize, // Used implicitly in cross-attention
380    ) -> Result<Vec<usize>> {
381        // Step 1: First layer norm
382        let normed1_outputs = self.norm1.build_layernorm_graph(graph)?;
383        let _normed1 = normed1_outputs[0];
384
385        // Step 2: Masked self-attention
386        let self_attn_outputs = self.self_attention.build_mha_graph(graph)?;
387        let self_attn_output = self_attn_outputs[0];
388
389        // Step 3: First residual
390        let residual1 = graph.add_tensor("decoder_residual1");
391        let res1_node = tensorlogic_ir::EinsumNode::elem_binary(
392            "add",
393            decoder_input,
394            self_attn_output,
395            residual1,
396        );
397        graph.add_node(res1_node)?;
398
399        // Step 4: Second layer norm
400        let normed2_outputs = self.norm2.build_layernorm_graph(graph)?;
401        let _normed2 = normed2_outputs[0];
402
403        // Step 5: Cross-attention (Q from decoder, K,V from encoder)
404        let cross_attn_outputs = self.cross_attention.build_mha_graph(graph)?;
405        let cross_attn_output = cross_attn_outputs[0];
406
407        // Step 6: Second residual
408        let residual2 = graph.add_tensor("decoder_residual2");
409        let res2_node =
410            tensorlogic_ir::EinsumNode::elem_binary("add", residual1, cross_attn_output, residual2);
411        graph.add_node(res2_node)?;
412
413        // Step 7: Third layer norm
414        let normed3_outputs = self.norm3.build_layernorm_graph(graph)?;
415        let _normed3 = normed3_outputs[0];
416
417        // Step 8: Feed-forward network
418        let ffn_outputs = self.ffn.build_ffn_graph(graph)?;
419        let ffn_output = ffn_outputs[0];
420
421        // Step 9: Third residual
422        let output = graph.add_tensor("decoder_output");
423        let res3_node =
424            tensorlogic_ir::EinsumNode::elem_binary("add", residual2, ffn_output, output);
425        graph.add_node(res3_node)?;
426
427        Ok(vec![output])
428    }
429
430    fn build_post_norm_decoder(
431        &self,
432        graph: &mut EinsumGraph,
433        decoder_input: usize,
434        _encoder_output: usize, // Used implicitly in cross-attention
435    ) -> Result<Vec<usize>> {
436        // Step 1: Masked self-attention
437        let self_attn_outputs = self.self_attention.build_mha_graph(graph)?;
438        let self_attn_output = self_attn_outputs[0];
439
440        // Step 2: First residual + norm
441        let residual1 = graph.add_tensor("decoder_residual1");
442        let res1_node = tensorlogic_ir::EinsumNode::elem_binary(
443            "add",
444            decoder_input,
445            self_attn_output,
446            residual1,
447        );
448        graph.add_node(res1_node)?;
449
450        let normed1_outputs = self.norm1.build_layernorm_graph(graph)?;
451        let _normed1 = normed1_outputs[0];
452
453        // Step 3: Cross-attention
454        let cross_attn_outputs = self.cross_attention.build_mha_graph(graph)?;
455        let cross_attn_output = cross_attn_outputs[0];
456
457        // Step 4: Second residual + norm
458        let residual2 = graph.add_tensor("decoder_residual2");
459        let res2_node =
460            tensorlogic_ir::EinsumNode::elem_binary("add", _normed1, cross_attn_output, residual2);
461        graph.add_node(res2_node)?;
462
463        let normed2_outputs = self.norm2.build_layernorm_graph(graph)?;
464        let _normed2 = normed2_outputs[0];
465
466        // Step 5: Feed-forward network
467        let ffn_outputs = self.ffn.build_ffn_graph(graph)?;
468        let ffn_output = ffn_outputs[0];
469
470        // Step 6: Third residual + norm
471        let residual3 = graph.add_tensor("decoder_residual3");
472        let res3_node =
473            tensorlogic_ir::EinsumNode::elem_binary("add", _normed2, ffn_output, residual3);
474        graph.add_node(res3_node)?;
475
476        let normed3_outputs = self.norm3.build_layernorm_graph(graph)?;
477        let output = normed3_outputs[0];
478
479        Ok(vec![output])
480    }
481}
482
483#[cfg(test)]
484mod tests {
485    use super::*;
486
487    #[test]
488    fn test_encoder_layer_config_creation() {
489        let config = EncoderLayerConfig::new(512, 8, 2048).unwrap();
490        assert_eq!(config.attention.d_model, 512);
491        assert_eq!(config.attention.n_heads, 8);
492        assert_eq!(config.feed_forward.d_ff, 2048);
493        assert!(config.pre_norm);
494        assert!(config.validate().is_ok());
495    }
496
497    #[test]
498    fn test_encoder_layer_config_with_dropout() {
499        let config = EncoderLayerConfig::new(512, 8, 2048)
500            .unwrap()
501            .with_dropout(0.1);
502        assert!((config.attention.dropout - 0.1).abs() < 1e-10);
503        assert!((config.feed_forward.dropout - 0.1).abs() < 1e-10);
504    }
505
506    #[test]
507    fn test_encoder_layer_creation() {
508        let config = EncoderLayerConfig::new(512, 8, 2048).unwrap();
509        let layer = EncoderLayer::new(config).unwrap();
510        assert_eq!(layer.config.attention.d_model, 512);
511    }
512
513    #[test]
514    fn test_encoder_layer_graph_building() {
515        let config = EncoderLayerConfig::new(512, 8, 2048).unwrap();
516        let layer = EncoderLayer::new(config).unwrap();
517
518        let mut graph = EinsumGraph::new();
519        graph.add_tensor("x");
520
521        let outputs = layer.build_encoder_layer_graph(&mut graph).unwrap();
522        assert_eq!(outputs.len(), 1);
523        assert!(!graph.nodes.is_empty());
524    }
525
526    #[test]
527    fn test_decoder_layer_config_creation() {
528        let config = DecoderLayerConfig::new(512, 8, 2048).unwrap();
529        assert_eq!(config.self_attention.d_model, 512);
530        assert_eq!(config.cross_attention.d_model, 512);
531        assert!(config.self_attention.causal);
532        assert!(!config.cross_attention.causal);
533        assert!(config.validate().is_ok());
534    }
535
536    #[test]
537    fn test_decoder_layer_creation() {
538        let config = DecoderLayerConfig::new(512, 8, 2048).unwrap();
539        let layer = DecoderLayer::new(config).unwrap();
540        assert_eq!(layer.config.self_attention.d_model, 512);
541    }
542
543    #[test]
544    fn test_decoder_layer_graph_building() {
545        let config = DecoderLayerConfig::new(512, 8, 2048).unwrap();
546        let layer = DecoderLayer::new(config).unwrap();
547
548        let mut graph = EinsumGraph::new();
549        graph.add_tensor("decoder_input");
550        graph.add_tensor("encoder_output");
551
552        let outputs = layer.build_decoder_layer_graph(&mut graph).unwrap();
553        assert_eq!(outputs.len(), 1);
554        assert!(!graph.nodes.is_empty());
555    }
556}