Skip to main content

tensorlogic_trustformers/
decoder.rs

1//! Transformer decoder layers.
2//!
3//! This module implements transformer decoder layers that combine:
4//! - Masked multi-head self-attention
5//! - Cross-attention to encoder outputs
6//! - Feed-forward networks
7//! - Layer normalization
8//! - Residual connections
9//!
10//! ## Transformer Decoder Layer
11//!
12//! Pre-normalization variant:
13//! ```text
14//! x' = x + MaskedSelfAttention(LayerNorm(x))
15//! x'' = x' + CrossAttention(LayerNorm(x'), encoder_output)
16//! output = x'' + FFN(LayerNorm(x''))
17//! ```
18//!
19//! Post-normalization variant:
20//! ```text
21//! x' = LayerNorm(x + MaskedSelfAttention(x))
22//! x'' = LayerNorm(x' + CrossAttention(x', encoder_output))
23//! output = LayerNorm(x'' + FFN(x''))
24//! ```
25
26use tensorlogic_ir::EinsumGraph;
27
28use crate::{
29    attention::MultiHeadAttention,
30    config::{AttentionConfig, FeedForwardConfig},
31    error::Result,
32    ffn::FeedForward,
33    normalization::{LayerNorm, LayerNormConfig},
34};
35
36/// Configuration for transformer decoder layer
37#[derive(Clone, Debug)]
38pub struct DecoderConfig {
39    /// Self-attention configuration (with causal masking)
40    pub self_attention: AttentionConfig,
41    /// Cross-attention configuration
42    pub cross_attention: AttentionConfig,
43    /// Feed-forward configuration
44    pub feed_forward: FeedForwardConfig,
45    /// Layer normalization configuration
46    pub layer_norm: LayerNormConfig,
47    /// Whether to use pre-layer normalization
48    pub pre_norm: bool,
49}
50
51impl DecoderConfig {
52    /// Create a new decoder configuration
53    pub fn new(d_model: usize, n_heads: usize, d_ff: usize) -> Result<Self> {
54        Ok(Self {
55            self_attention: AttentionConfig::new(d_model, n_heads)?.with_causal(true),
56            cross_attention: AttentionConfig::new(d_model, n_heads)?,
57            feed_forward: FeedForwardConfig::new(d_model, d_ff),
58            layer_norm: LayerNormConfig::new(d_model),
59            pre_norm: true,
60        })
61    }
62
63    /// Set pre-normalization vs post-normalization
64    pub fn with_pre_norm(mut self, pre_norm: bool) -> Self {
65        self.pre_norm = pre_norm;
66        self
67    }
68
69    /// Set dropout
70    pub fn with_dropout(mut self, dropout: f64) -> Self {
71        self.self_attention = self.self_attention.with_dropout(dropout);
72        self.cross_attention = self.cross_attention.with_dropout(dropout);
73        self.feed_forward = self.feed_forward.with_dropout(dropout);
74        self
75    }
76
77    /// Validate configuration
78    pub fn validate(&self) -> Result<()> {
79        self.self_attention.validate()?;
80        self.cross_attention.validate()?;
81        self.feed_forward.validate()?;
82        self.layer_norm.validate()?;
83
84        // Verify causal masking is enabled for self-attention
85        if !self.self_attention.causal {
86            return Err(crate::error::TrustformerError::InvalidDimension {
87                expected: 1,
88                got: 0,
89                context: "Decoder self-attention must use causal masking".to_string(),
90            });
91        }
92
93        // Check dimension consistency
94        if self.self_attention.d_model != self.cross_attention.d_model
95            || self.self_attention.d_model != self.feed_forward.d_model
96            || self.self_attention.d_model != self.layer_norm.normalized_shape
97        {
98            return Err(crate::error::TrustformerError::InvalidDimension {
99                expected: self.self_attention.d_model,
100                got: 0,
101                context: "d_model mismatch between components".to_string(),
102            });
103        }
104
105        Ok(())
106    }
107}
108
109/// Transformer decoder layer
110#[derive(Clone, Debug)]
111pub struct Decoder {
112    /// Configuration
113    pub config: DecoderConfig,
114    /// Masked self-attention
115    pub self_attention: MultiHeadAttention,
116    /// Cross-attention
117    pub cross_attention: MultiHeadAttention,
118    /// Feed-forward network
119    pub ffn: FeedForward,
120    /// First layer normalization (self-attention)
121    pub norm1: LayerNorm,
122    /// Second layer normalization (cross-attention)
123    pub norm2: LayerNorm,
124    /// Third layer normalization (FFN)
125    pub norm3: LayerNorm,
126}
127
128impl Decoder {
129    /// Create a new decoder layer
130    pub fn new(config: DecoderConfig) -> Result<Self> {
131        config.validate()?;
132
133        let self_attention = MultiHeadAttention::new(config.self_attention.clone())?;
134        let cross_attention = MultiHeadAttention::new(config.cross_attention.clone())?;
135        let ffn = FeedForward::new(config.feed_forward.clone())?;
136        let norm1 = LayerNorm::new(config.layer_norm.clone())?;
137        let norm2 = LayerNorm::new(config.layer_norm.clone())?;
138        let norm3 = LayerNorm::new(config.layer_norm.clone())?;
139
140        Ok(Self {
141            config,
142            self_attention,
143            cross_attention,
144            ffn,
145            norm1,
146            norm2,
147            norm3,
148        })
149    }
150
151    /// Build einsum graph for decoder layer
152    ///
153    /// Input tensors:
154    /// - 0: x (decoder input) [batch, tgt_len, d_model]
155    /// - 1: encoder_output [batch, src_len, d_model]
156    /// - 2-N: weight matrices and parameters
157    ///
158    /// Output tensors:
159    /// - output: [batch, tgt_len, d_model]
160    pub fn build_decoder_graph(&self, graph: &mut EinsumGraph) -> Result<Vec<usize>> {
161        let decoder_input = 0;
162        let encoder_output = 1;
163
164        if self.config.pre_norm {
165            self.build_pre_norm_decoder(graph, decoder_input, encoder_output)
166        } else {
167            self.build_post_norm_decoder(graph, decoder_input, encoder_output)
168        }
169    }
170
171    fn build_pre_norm_decoder(
172        &self,
173        graph: &mut EinsumGraph,
174        decoder_input: usize,
175        _encoder_output: usize,
176    ) -> Result<Vec<usize>> {
177        // Step 1: First layer norm
178        let normed1_outputs = self.norm1.build_layernorm_graph(graph)?;
179        let _normed1 = normed1_outputs[0];
180
181        // Step 2: Masked self-attention
182        let self_attn_outputs = self.self_attention.build_mha_graph(graph)?;
183        let self_attn_output = self_attn_outputs[0];
184
185        // Step 3: First residual
186        let residual1 = graph.add_tensor("decoder_residual1");
187        let res1_node = tensorlogic_ir::EinsumNode::elem_binary(
188            "add",
189            decoder_input,
190            self_attn_output,
191            residual1,
192        );
193        graph.add_node(res1_node)?;
194
195        // Step 4: Second layer norm
196        let normed2_outputs = self.norm2.build_layernorm_graph(graph)?;
197        let _normed2 = normed2_outputs[0];
198
199        // Step 5: Cross-attention (Q from decoder, K,V from encoder)
200        let cross_attn_outputs = self.cross_attention.build_mha_graph(graph)?;
201        let cross_attn_output = cross_attn_outputs[0];
202
203        // Step 6: Second residual
204        let residual2 = graph.add_tensor("decoder_residual2");
205        let res2_node =
206            tensorlogic_ir::EinsumNode::elem_binary("add", residual1, cross_attn_output, residual2);
207        graph.add_node(res2_node)?;
208
209        // Step 7: Third layer norm
210        let normed3_outputs = self.norm3.build_layernorm_graph(graph)?;
211        let _normed3 = normed3_outputs[0];
212
213        // Step 8: Feed-forward network
214        let ffn_outputs = self.ffn.build_ffn_graph(graph)?;
215        let ffn_output = ffn_outputs[0];
216
217        // Step 9: Third residual
218        let output = graph.add_tensor("decoder_output");
219        let res3_node =
220            tensorlogic_ir::EinsumNode::elem_binary("add", residual2, ffn_output, output);
221        graph.add_node(res3_node)?;
222
223        Ok(vec![output])
224    }
225
226    fn build_post_norm_decoder(
227        &self,
228        graph: &mut EinsumGraph,
229        decoder_input: usize,
230        _encoder_output: usize,
231    ) -> Result<Vec<usize>> {
232        // Step 1: Masked self-attention
233        let self_attn_outputs = self.self_attention.build_mha_graph(graph)?;
234        let self_attn_output = self_attn_outputs[0];
235
236        // Step 2: First residual + norm
237        let residual1 = graph.add_tensor("decoder_residual1");
238        let res1_node = tensorlogic_ir::EinsumNode::elem_binary(
239            "add",
240            decoder_input,
241            self_attn_output,
242            residual1,
243        );
244        graph.add_node(res1_node)?;
245
246        let normed1_outputs = self.norm1.build_layernorm_graph(graph)?;
247        let normed1 = normed1_outputs[0];
248
249        // Step 3: Cross-attention
250        let cross_attn_outputs = self.cross_attention.build_mha_graph(graph)?;
251        let cross_attn_output = cross_attn_outputs[0];
252
253        // Step 4: Second residual + norm
254        let residual2 = graph.add_tensor("decoder_residual2");
255        let res2_node =
256            tensorlogic_ir::EinsumNode::elem_binary("add", normed1, cross_attn_output, residual2);
257        graph.add_node(res2_node)?;
258
259        let normed2_outputs = self.norm2.build_layernorm_graph(graph)?;
260        let normed2 = normed2_outputs[0];
261
262        // Step 5: Feed-forward network
263        let ffn_outputs = self.ffn.build_ffn_graph(graph)?;
264        let ffn_output = ffn_outputs[0];
265
266        // Step 6: Third residual + norm
267        let residual3 = graph.add_tensor("decoder_residual3");
268        let res3_node =
269            tensorlogic_ir::EinsumNode::elem_binary("add", normed2, ffn_output, residual3);
270        graph.add_node(res3_node)?;
271
272        let normed3_outputs = self.norm3.build_layernorm_graph(graph)?;
273        let output = normed3_outputs[0];
274
275        Ok(vec![output])
276    }
277}
278
279#[cfg(test)]
280mod tests {
281    use super::*;
282
283    #[test]
284    fn test_decoder_config_creation() {
285        let config = DecoderConfig::new(512, 8, 2048).unwrap();
286        assert_eq!(config.self_attention.d_model, 512);
287        assert_eq!(config.cross_attention.d_model, 512);
288        assert!(config.self_attention.causal);
289        assert!(!config.cross_attention.causal);
290        assert!(config.validate().is_ok());
291    }
292
293    #[test]
294    fn test_decoder_config_with_dropout() {
295        let config = DecoderConfig::new(512, 8, 2048).unwrap().with_dropout(0.1);
296        assert!((config.self_attention.dropout - 0.1).abs() < 1e-10);
297        assert!((config.cross_attention.dropout - 0.1).abs() < 1e-10);
298        assert!((config.feed_forward.dropout - 0.1).abs() < 1e-10);
299    }
300
301    #[test]
302    fn test_decoder_config_pre_norm() {
303        let config = DecoderConfig::new(512, 8, 2048)
304            .unwrap()
305            .with_pre_norm(false);
306        assert!(!config.pre_norm);
307    }
308
309    #[test]
310    fn test_decoder_creation() {
311        let config = DecoderConfig::new(512, 8, 2048).unwrap();
312        let decoder = Decoder::new(config).unwrap();
313        assert_eq!(decoder.config.self_attention.d_model, 512);
314    }
315
316    #[test]
317    fn test_decoder_graph_building_pre_norm() {
318        let config = DecoderConfig::new(512, 8, 2048).unwrap();
319        let decoder = Decoder::new(config).unwrap();
320
321        let mut graph = EinsumGraph::new();
322        graph.add_tensor("decoder_input");
323        graph.add_tensor("encoder_output");
324
325        let outputs = decoder.build_decoder_graph(&mut graph).unwrap();
326        assert_eq!(outputs.len(), 1);
327        assert!(!graph.nodes.is_empty());
328    }
329
330    #[test]
331    fn test_decoder_graph_building_post_norm() {
332        let config = DecoderConfig::new(512, 8, 2048)
333            .unwrap()
334            .with_pre_norm(false);
335        let decoder = Decoder::new(config).unwrap();
336
337        let mut graph = EinsumGraph::new();
338        graph.add_tensor("decoder_input");
339        graph.add_tensor("encoder_output");
340
341        let outputs = decoder.build_decoder_graph(&mut graph).unwrap();
342        assert_eq!(outputs.len(), 1);
343        assert!(!graph.nodes.is_empty());
344    }
345
346    #[test]
347    fn test_decoder_config_validation() {
348        let config = DecoderConfig::new(512, 8, 2048).unwrap();
349        assert!(config.validate().is_ok());
350
351        // Invalid head count
352        let result = DecoderConfig::new(512, 7, 2048);
353        assert!(result.is_err());
354    }
355
356    #[test]
357    fn test_decoder_requires_causal_masking() {
358        let config = DecoderConfig::new(512, 8, 2048).unwrap();
359        assert!(config.self_attention.causal);
360        assert!(!config.cross_attention.causal);
361    }
362}