Skip to main content

tensorlogic_trustformers/
stacks.rs

1//! Transformer encoder and decoder stacks.
2//!
3//! This module implements complete transformer stacks by composing multiple
4//! encoder or decoder layers.
5//!
6//! ## Transformer Encoder Stack
7//!
8//! ```text
9//! x = input + position_encoding
10//! for layer in encoder_layers:
11//!     x = layer(x)
12//! output = final_layer_norm(x)
13//! ```
14//!
15//! ## Transformer Decoder Stack
16//!
17//! ```text
18//! x = target + position_encoding
19//! for layer in decoder_layers:
20//!     x = layer(x, encoder_output)
21//! output = final_layer_norm(x)
22//! ```
23
24use tensorlogic_ir::EinsumGraph;
25
26use crate::{
27    error::Result,
28    layers::{DecoderLayer, DecoderLayerConfig, EncoderLayer, EncoderLayerConfig},
29    normalization::{LayerNorm, LayerNormConfig},
30    position::{LearnedPositionEncoding, PositionEncodingConfig, SinusoidalPositionEncoding},
31};
32
33/// Configuration for transformer encoder stack
34#[derive(Clone, Debug)]
35pub struct EncoderStackConfig {
36    /// Number of encoder layers
37    pub num_layers: usize,
38    /// Configuration for each encoder layer
39    pub layer_config: EncoderLayerConfig,
40    /// Position encoding configuration
41    pub position_encoding: PositionEncodingConfig,
42    /// Whether to apply final layer normalization
43    pub final_layer_norm: bool,
44}
45
46impl EncoderStackConfig {
47    /// Create a new encoder stack configuration
48    pub fn new(
49        num_layers: usize,
50        d_model: usize,
51        n_heads: usize,
52        d_ff: usize,
53        max_seq_len: usize,
54    ) -> Result<Self> {
55        Ok(Self {
56            num_layers,
57            layer_config: EncoderLayerConfig::new(d_model, n_heads, d_ff)?,
58            position_encoding: PositionEncodingConfig::sinusoidal(d_model, max_seq_len),
59            final_layer_norm: true,
60        })
61    }
62
63    /// Set position encoding type to learned
64    pub fn with_learned_position_encoding(mut self) -> Self {
65        self.position_encoding = PositionEncodingConfig::learned(
66            self.position_encoding.d_model,
67            self.position_encoding.max_seq_len,
68        );
69        self
70    }
71
72    /// Set whether to use final layer normalization
73    pub fn with_final_layer_norm(mut self, final_layer_norm: bool) -> Self {
74        self.final_layer_norm = final_layer_norm;
75        self
76    }
77
78    /// Set dropout
79    pub fn with_dropout(mut self, dropout: f64) -> Self {
80        self.layer_config = self.layer_config.with_dropout(dropout);
81        self.position_encoding = self.position_encoding.with_dropout(dropout);
82        self
83    }
84
85    /// Validate configuration
86    pub fn validate(&self) -> Result<()> {
87        if self.num_layers == 0 {
88            return Err(crate::error::TrustformerError::InvalidDimension {
89                expected: 1,
90                got: 0,
91                context: "num_layers must be positive".to_string(),
92            });
93        }
94
95        self.layer_config.validate()?;
96        self.position_encoding.validate()?;
97
98        Ok(())
99    }
100}
101
102/// Transformer encoder stack
103#[derive(Clone, Debug)]
104pub struct EncoderStack {
105    /// Configuration
106    pub config: EncoderStackConfig,
107    /// Encoder layers
108    pub layers: Vec<EncoderLayer>,
109    /// Position encoding (if sinusoidal)
110    pub position_encoding_sin: Option<SinusoidalPositionEncoding>,
111    /// Position encoding (if learned)
112    pub position_encoding_learned: Option<LearnedPositionEncoding>,
113    /// Final layer normalization
114    pub final_norm: Option<LayerNorm>,
115}
116
117impl EncoderStack {
118    /// Create a new encoder stack
119    pub fn new(config: EncoderStackConfig) -> Result<Self> {
120        config.validate()?;
121
122        let mut layers = Vec::with_capacity(config.num_layers);
123        for _ in 0..config.num_layers {
124            layers.push(EncoderLayer::new(config.layer_config.clone())?);
125        }
126
127        let position_encoding_sin = match config.position_encoding.encoding_type {
128            crate::position::PositionEncodingType::Sinusoidal { .. } => Some(
129                SinusoidalPositionEncoding::new(config.position_encoding.clone())?,
130            ),
131            _ => None,
132        };
133
134        let position_encoding_learned = match config.position_encoding.encoding_type {
135            crate::position::PositionEncodingType::Learned => Some(LearnedPositionEncoding::new(
136                config.position_encoding.clone(),
137            )?),
138            _ => None,
139        };
140
141        let final_norm = if config.final_layer_norm {
142            Some(LayerNorm::new(LayerNormConfig::new(
143                config.layer_config.attention.d_model,
144            ))?)
145        } else {
146            None
147        };
148
149        Ok(Self {
150            config,
151            layers,
152            position_encoding_sin,
153            position_encoding_learned,
154            final_norm,
155        })
156    }
157
158    /// Build einsum graph for encoder stack
159    ///
160    /// Input tensors:
161    /// - 0: x (input) [batch, seq_len, d_model]
162    /// - 1-N: all parameters for position encoding, layers, and final norm
163    ///
164    /// Output tensors:
165    /// - output: [batch, seq_len, d_model]
166    pub fn build_encoder_stack_graph(&self, graph: &mut EinsumGraph) -> Result<Vec<usize>> {
167        // Step 1: Add position encoding
168        let mut current_output = if let Some(ref pe_sin) = self.position_encoding_sin {
169            pe_sin.build_encoding_graph(graph)?[0]
170        } else if let Some(ref pe_learned) = self.position_encoding_learned {
171            pe_learned.build_encoding_graph(graph)?[0]
172        } else {
173            0 // No position encoding
174        };
175
176        // Step 2: Apply each encoder layer sequentially
177        for (i, layer) in self.layers.iter().enumerate() {
178            // Update the input tensor reference for this layer
179            let layer_outputs = layer.build_encoder_layer_graph(graph)?;
180            current_output = layer_outputs[0];
181
182            // Add a marker for layer boundary
183            let layer_marker = graph.add_tensor(format!("encoder_layer_{}_output", i));
184            let marker_node =
185                tensorlogic_ir::EinsumNode::elem_unary("identity", current_output, layer_marker);
186            graph.add_node(marker_node)?;
187            current_output = layer_marker;
188        }
189
190        // Step 3: Apply final layer normalization if configured
191        if let Some(ref final_norm) = self.final_norm {
192            let final_outputs = final_norm.build_layernorm_graph(graph)?;
193            current_output = final_outputs[0];
194        }
195
196        Ok(vec![current_output])
197    }
198
199    /// Get number of layers
200    pub fn num_layers(&self) -> usize {
201        self.config.num_layers
202    }
203}
204
205/// Configuration for transformer decoder stack
206#[derive(Clone, Debug)]
207pub struct DecoderStackConfig {
208    /// Number of decoder layers
209    pub num_layers: usize,
210    /// Configuration for each decoder layer
211    pub layer_config: DecoderLayerConfig,
212    /// Position encoding configuration
213    pub position_encoding: PositionEncodingConfig,
214    /// Whether to apply final layer normalization
215    pub final_layer_norm: bool,
216}
217
218impl DecoderStackConfig {
219    /// Create a new decoder stack configuration
220    pub fn new(
221        num_layers: usize,
222        d_model: usize,
223        n_heads: usize,
224        d_ff: usize,
225        max_seq_len: usize,
226    ) -> Result<Self> {
227        Ok(Self {
228            num_layers,
229            layer_config: DecoderLayerConfig::new(d_model, n_heads, d_ff)?,
230            position_encoding: PositionEncodingConfig::sinusoidal(d_model, max_seq_len),
231            final_layer_norm: true,
232        })
233    }
234
235    /// Set position encoding type to learned
236    pub fn with_learned_position_encoding(mut self) -> Self {
237        self.position_encoding = PositionEncodingConfig::learned(
238            self.position_encoding.d_model,
239            self.position_encoding.max_seq_len,
240        );
241        self
242    }
243
244    /// Set whether to use final layer normalization
245    pub fn with_final_layer_norm(mut self, final_layer_norm: bool) -> Self {
246        self.final_layer_norm = final_layer_norm;
247        self
248    }
249
250    /// Set dropout
251    pub fn with_dropout(mut self, dropout: f64) -> Self {
252        self.layer_config = self.layer_config.with_dropout(dropout);
253        self.position_encoding = self.position_encoding.with_dropout(dropout);
254        self
255    }
256
257    /// Validate configuration
258    pub fn validate(&self) -> Result<()> {
259        if self.num_layers == 0 {
260            return Err(crate::error::TrustformerError::InvalidDimension {
261                expected: 1,
262                got: 0,
263                context: "num_layers must be positive".to_string(),
264            });
265        }
266
267        self.layer_config.validate()?;
268        self.position_encoding.validate()?;
269
270        Ok(())
271    }
272}
273
274/// Transformer decoder stack
275#[derive(Clone, Debug)]
276pub struct DecoderStack {
277    /// Configuration
278    pub config: DecoderStackConfig,
279    /// Decoder layers
280    pub layers: Vec<DecoderLayer>,
281    /// Position encoding (if sinusoidal)
282    pub position_encoding_sin: Option<SinusoidalPositionEncoding>,
283    /// Position encoding (if learned)
284    pub position_encoding_learned: Option<LearnedPositionEncoding>,
285    /// Final layer normalization
286    pub final_norm: Option<LayerNorm>,
287}
288
289impl DecoderStack {
290    /// Create a new decoder stack
291    pub fn new(config: DecoderStackConfig) -> Result<Self> {
292        config.validate()?;
293
294        let mut layers = Vec::with_capacity(config.num_layers);
295        for _ in 0..config.num_layers {
296            layers.push(DecoderLayer::new(config.layer_config.clone())?);
297        }
298
299        let position_encoding_sin = match config.position_encoding.encoding_type {
300            crate::position::PositionEncodingType::Sinusoidal { .. } => Some(
301                SinusoidalPositionEncoding::new(config.position_encoding.clone())?,
302            ),
303            _ => None,
304        };
305
306        let position_encoding_learned = match config.position_encoding.encoding_type {
307            crate::position::PositionEncodingType::Learned => Some(LearnedPositionEncoding::new(
308                config.position_encoding.clone(),
309            )?),
310            _ => None,
311        };
312
313        let final_norm = if config.final_layer_norm {
314            Some(LayerNorm::new(LayerNormConfig::new(
315                config.layer_config.self_attention.d_model,
316            ))?)
317        } else {
318            None
319        };
320
321        Ok(Self {
322            config,
323            layers,
324            position_encoding_sin,
325            position_encoding_learned,
326            final_norm,
327        })
328    }
329
330    /// Build einsum graph for decoder stack
331    ///
332    /// Input tensors:
333    /// - 0: x (target input) [batch, tgt_len, d_model]
334    /// - 1: encoder_output [batch, src_len, d_model]
335    /// - 2-N: all parameters
336    ///
337    /// Output tensors:
338    /// - output: [batch, tgt_len, d_model]
339    pub fn build_decoder_stack_graph(&self, graph: &mut EinsumGraph) -> Result<Vec<usize>> {
340        // Step 1: Add position encoding to target
341        let mut current_output = if let Some(ref pe_sin) = self.position_encoding_sin {
342            pe_sin.build_encoding_graph(graph)?[0]
343        } else if let Some(ref pe_learned) = self.position_encoding_learned {
344            pe_learned.build_encoding_graph(graph)?[0]
345        } else {
346            0 // No position encoding
347        };
348
349        // Step 2: Apply each decoder layer sequentially
350        for (i, layer) in self.layers.iter().enumerate() {
351            let layer_outputs = layer.build_decoder_layer_graph(graph)?;
352            current_output = layer_outputs[0];
353
354            // Add a marker for layer boundary
355            let layer_marker = graph.add_tensor(format!("decoder_layer_{}_output", i));
356            let marker_node =
357                tensorlogic_ir::EinsumNode::elem_unary("identity", current_output, layer_marker);
358            graph.add_node(marker_node)?;
359            current_output = layer_marker;
360        }
361
362        // Step 3: Apply final layer normalization if configured
363        if let Some(ref final_norm) = self.final_norm {
364            let final_outputs = final_norm.build_layernorm_graph(graph)?;
365            current_output = final_outputs[0];
366        }
367
368        Ok(vec![current_output])
369    }
370
371    /// Get number of layers
372    pub fn num_layers(&self) -> usize {
373        self.config.num_layers
374    }
375}
376
377#[cfg(test)]
378mod tests {
379    use super::*;
380
381    #[test]
382    fn test_encoder_stack_config_creation() {
383        let config = EncoderStackConfig::new(6, 512, 8, 2048, 1024).unwrap();
384        assert_eq!(config.num_layers, 6);
385        assert_eq!(config.layer_config.attention.d_model, 512);
386        assert!(config.final_layer_norm);
387        assert!(config.validate().is_ok());
388    }
389
390    #[test]
391    fn test_encoder_stack_config_with_learned_pe() {
392        let config = EncoderStackConfig::new(6, 512, 8, 2048, 1024)
393            .unwrap()
394            .with_learned_position_encoding();
395        assert!(matches!(
396            config.position_encoding.encoding_type,
397            crate::position::PositionEncodingType::Learned
398        ));
399    }
400
401    #[test]
402    fn test_encoder_stack_creation() {
403        let config = EncoderStackConfig::new(6, 512, 8, 2048, 1024).unwrap();
404        let stack = EncoderStack::new(config).unwrap();
405        assert_eq!(stack.num_layers(), 6);
406        assert!(stack.position_encoding_sin.is_some());
407        assert!(stack.final_norm.is_some());
408    }
409
410    #[test]
411    fn test_encoder_stack_graph_building() {
412        let config = EncoderStackConfig::new(2, 512, 8, 2048, 1024).unwrap();
413        let stack = EncoderStack::new(config).unwrap();
414
415        let mut graph = EinsumGraph::new();
416        graph.add_tensor("x");
417
418        let outputs = stack.build_encoder_stack_graph(&mut graph).unwrap();
419        assert_eq!(outputs.len(), 1);
420        assert!(!graph.nodes.is_empty());
421    }
422
423    #[test]
424    fn test_decoder_stack_config_creation() {
425        let config = DecoderStackConfig::new(6, 512, 8, 2048, 1024).unwrap();
426        assert_eq!(config.num_layers, 6);
427        assert_eq!(config.layer_config.self_attention.d_model, 512);
428        assert!(config.layer_config.self_attention.causal);
429        assert!(config.validate().is_ok());
430    }
431
432    #[test]
433    fn test_decoder_stack_creation() {
434        let config = DecoderStackConfig::new(6, 512, 8, 2048, 1024).unwrap();
435        let stack = DecoderStack::new(config).unwrap();
436        assert_eq!(stack.num_layers(), 6);
437        assert!(stack.position_encoding_sin.is_some());
438        assert!(stack.final_norm.is_some());
439    }
440
441    #[test]
442    fn test_decoder_stack_graph_building() {
443        let config = DecoderStackConfig::new(2, 512, 8, 2048, 1024).unwrap();
444        let stack = DecoderStack::new(config).unwrap();
445
446        let mut graph = EinsumGraph::new();
447        graph.add_tensor("target");
448        graph.add_tensor("encoder_output");
449
450        let outputs = stack.build_decoder_stack_graph(&mut graph).unwrap();
451        assert_eq!(outputs.len(), 1);
452        assert!(!graph.nodes.is_empty());
453    }
454
455    #[test]
456    fn test_invalid_zero_layers() {
457        let result = EncoderStackConfig::new(0, 512, 8, 2048, 1024);
458        // Should fail validation when creating EncoderStack
459        if let Ok(config) = result {
460            assert!(config.validate().is_err());
461        }
462    }
463}