Skip to main content

tensorlogic_trustformers/
encoder.rs

1//! Transformer encoder layers.
2//!
3//! This module implements transformer encoder layers that combine:
4//! - Multi-head self-attention
5//! - Feed-forward networks
6//! - Layer normalization
7//! - Residual connections
8//!
9//! ## Transformer Encoder Layer
10//!
11//! Pre-normalization variant:
12//! ```text
13//! x' = x + MultiHeadAttention(LayerNorm(x))
14//! output = x' + FFN(LayerNorm(x'))
15//! ```
16//!
17//! Post-normalization variant:
18//! ```text
19//! x' = LayerNorm(x + MultiHeadAttention(x))
20//! output = LayerNorm(x' + FFN(x'))
21//! ```
22
23use tensorlogic_ir::EinsumGraph;
24
25use crate::{
26    attention::MultiHeadAttention,
27    config::{AttentionConfig, FeedForwardConfig},
28    error::Result,
29    ffn::FeedForward,
30    normalization::{LayerNorm, LayerNormConfig},
31};
32
33/// Configuration for transformer encoder layer
34#[derive(Clone, Debug)]
35pub struct EncoderConfig {
36    /// Attention configuration
37    pub attention: AttentionConfig,
38    /// Feed-forward configuration
39    pub feed_forward: FeedForwardConfig,
40    /// Layer normalization configuration
41    pub layer_norm: LayerNormConfig,
42    /// Whether to use pre-layer normalization (vs post)
43    pub pre_norm: bool,
44}
45
46impl EncoderConfig {
47    /// Create a new encoder configuration
48    pub fn new(d_model: usize, n_heads: usize, d_ff: usize) -> Result<Self> {
49        Ok(Self {
50            attention: AttentionConfig::new(d_model, n_heads)?,
51            feed_forward: FeedForwardConfig::new(d_model, d_ff),
52            layer_norm: LayerNormConfig::new(d_model),
53            pre_norm: true,
54        })
55    }
56
57    /// Set pre-normalization vs post-normalization
58    pub fn with_pre_norm(mut self, pre_norm: bool) -> Self {
59        self.pre_norm = pre_norm;
60        self
61    }
62
63    /// Set causal masking
64    pub fn with_causal(mut self, causal: bool) -> Self {
65        self.attention = self.attention.with_causal(causal);
66        self
67    }
68
69    /// Set dropout
70    pub fn with_dropout(mut self, dropout: f64) -> Self {
71        self.attention = self.attention.with_dropout(dropout);
72        self.feed_forward = self.feed_forward.with_dropout(dropout);
73        self
74    }
75
76    /// Validate configuration
77    pub fn validate(&self) -> Result<()> {
78        self.attention.validate()?;
79        self.feed_forward.validate()?;
80        self.layer_norm.validate()?;
81
82        // Check dimension consistency
83        if self.attention.d_model != self.feed_forward.d_model {
84            return Err(crate::error::TrustformerError::InvalidDimension {
85                expected: self.attention.d_model,
86                got: self.feed_forward.d_model,
87                context: "d_model mismatch between attention and FFN".to_string(),
88            });
89        }
90
91        if self.attention.d_model != self.layer_norm.normalized_shape {
92            return Err(crate::error::TrustformerError::InvalidDimension {
93                expected: self.attention.d_model,
94                got: self.layer_norm.normalized_shape,
95                context: "d_model mismatch with layer norm".to_string(),
96            });
97        }
98
99        Ok(())
100    }
101}
102
103/// Transformer encoder layer
104#[derive(Clone, Debug)]
105pub struct Encoder {
106    /// Configuration
107    pub config: EncoderConfig,
108    /// Multi-head attention
109    pub attention: MultiHeadAttention,
110    /// Feed-forward network
111    pub ffn: FeedForward,
112    /// First layer normalization
113    pub norm1: LayerNorm,
114    /// Second layer normalization
115    pub norm2: LayerNorm,
116}
117
118impl Encoder {
119    /// Create a new encoder layer
120    pub fn new(config: EncoderConfig) -> Result<Self> {
121        config.validate()?;
122
123        let attention = MultiHeadAttention::new(config.attention.clone())?;
124        let ffn = FeedForward::new(config.feed_forward.clone())?;
125        let norm1 = LayerNorm::new(config.layer_norm.clone())?;
126        let norm2 = LayerNorm::new(config.layer_norm.clone())?;
127
128        Ok(Self {
129            config,
130            attention,
131            ffn,
132            norm1,
133            norm2,
134        })
135    }
136
137    /// Build einsum graph for encoder layer
138    ///
139    /// Input tensors:
140    /// - 0: x (input) [batch, seq_len, d_model]
141    /// - 1-N: weight matrices and parameters for attention, FFN, and layer norms
142    ///
143    /// Output tensors:
144    /// - output: [batch, seq_len, d_model]
145    pub fn build_encoder_graph(&self, graph: &mut EinsumGraph) -> Result<Vec<usize>> {
146        let input_tensor = 0;
147
148        if self.config.pre_norm {
149            self.build_pre_norm_encoder(graph, input_tensor)
150        } else {
151            self.build_post_norm_encoder(graph, input_tensor)
152        }
153    }
154
155    fn build_pre_norm_encoder(
156        &self,
157        graph: &mut EinsumGraph,
158        input_tensor: usize,
159    ) -> Result<Vec<usize>> {
160        // Step 1: First layer norm
161        let normed1_outputs = self.norm1.build_layernorm_graph(graph)?;
162        let normed1 = normed1_outputs[0];
163
164        // Step 2: Multi-head attention (Q, K, V all from normed input)
165        // Create copies for Q, K, V
166        let q_tensor = graph.add_tensor("encoder_Q");
167        let k_tensor = graph.add_tensor("encoder_K");
168        let v_tensor = graph.add_tensor("encoder_V");
169
170        let _q_node = tensorlogic_ir::EinsumNode::elem_unary("identity", normed1, q_tensor);
171        let _k_node = tensorlogic_ir::EinsumNode::elem_unary("identity", normed1, k_tensor);
172        let _v_node = tensorlogic_ir::EinsumNode::elem_unary("identity", normed1, v_tensor);
173
174        let attn_outputs = self.attention.build_mha_graph(graph)?;
175        let attn_output = attn_outputs[0];
176
177        // Step 3: Residual connection: x + attention_output
178        let residual1 = graph.add_tensor("encoder_residual1");
179        let res1_node =
180            tensorlogic_ir::EinsumNode::elem_binary("add", input_tensor, attn_output, residual1);
181        graph.add_node(res1_node)?;
182
183        // Step 4: Second layer norm
184        let _normed2_outputs = self.norm2.build_layernorm_graph(graph)?;
185
186        // Step 5: Feed-forward network
187        let ffn_outputs = self.ffn.build_ffn_graph(graph)?;
188        let ffn_output = ffn_outputs[0];
189
190        // Step 6: Second residual connection: residual1 + ffn_output
191        let output = graph.add_tensor("encoder_output");
192        let res2_node =
193            tensorlogic_ir::EinsumNode::elem_binary("add", residual1, ffn_output, output);
194        graph.add_node(res2_node)?;
195
196        Ok(vec![output])
197    }
198
199    fn build_post_norm_encoder(
200        &self,
201        graph: &mut EinsumGraph,
202        input_tensor: usize,
203    ) -> Result<Vec<usize>> {
204        // Step 1: Multi-head attention
205        let attn_outputs = self.attention.build_mha_graph(graph)?;
206        let attn_output = attn_outputs[0];
207
208        // Step 2: Residual connection
209        let residual1 = graph.add_tensor("encoder_residual1");
210        let res1_node =
211            tensorlogic_ir::EinsumNode::elem_binary("add", input_tensor, attn_output, residual1);
212        graph.add_node(res1_node)?;
213
214        // Step 3: First layer norm
215        let normed1_outputs = self.norm1.build_layernorm_graph(graph)?;
216        let normed1 = normed1_outputs[0];
217
218        // Step 4: Feed-forward network
219        let ffn_outputs = self.ffn.build_ffn_graph(graph)?;
220        let ffn_output = ffn_outputs[0];
221
222        // Step 5: Second residual connection
223        let residual2 = graph.add_tensor("encoder_residual2");
224        let res2_node =
225            tensorlogic_ir::EinsumNode::elem_binary("add", normed1, ffn_output, residual2);
226        graph.add_node(res2_node)?;
227
228        // Step 6: Second layer norm
229        let normed2_outputs = self.norm2.build_layernorm_graph(graph)?;
230        let output = normed2_outputs[0];
231
232        Ok(vec![output])
233    }
234}
235
236#[cfg(test)]
237mod tests {
238    use super::*;
239
240    #[test]
241    fn test_encoder_config_creation() {
242        let config = EncoderConfig::new(512, 8, 2048).unwrap();
243        assert_eq!(config.attention.d_model, 512);
244        assert_eq!(config.attention.n_heads, 8);
245        assert_eq!(config.feed_forward.d_ff, 2048);
246        assert!(config.pre_norm);
247        assert!(config.validate().is_ok());
248    }
249
250    #[test]
251    fn test_encoder_config_with_dropout() {
252        let config = EncoderConfig::new(512, 8, 2048).unwrap().with_dropout(0.1);
253        assert!((config.attention.dropout - 0.1).abs() < 1e-10);
254        assert!((config.feed_forward.dropout - 0.1).abs() < 1e-10);
255    }
256
257    #[test]
258    fn test_encoder_config_pre_norm() {
259        let config = EncoderConfig::new(512, 8, 2048)
260            .unwrap()
261            .with_pre_norm(false);
262        assert!(!config.pre_norm);
263    }
264
265    #[test]
266    fn test_encoder_creation() {
267        let config = EncoderConfig::new(512, 8, 2048).unwrap();
268        let encoder = Encoder::new(config).unwrap();
269        assert_eq!(encoder.config.attention.d_model, 512);
270    }
271
272    #[test]
273    fn test_encoder_graph_building_pre_norm() {
274        let config = EncoderConfig::new(512, 8, 2048).unwrap();
275        let encoder = Encoder::new(config).unwrap();
276
277        let mut graph = EinsumGraph::new();
278        graph.add_tensor("x");
279
280        let outputs = encoder.build_encoder_graph(&mut graph).unwrap();
281        assert_eq!(outputs.len(), 1);
282        assert!(!graph.nodes.is_empty());
283    }
284
285    #[test]
286    fn test_encoder_graph_building_post_norm() {
287        let config = EncoderConfig::new(512, 8, 2048)
288            .unwrap()
289            .with_pre_norm(false);
290        let encoder = Encoder::new(config).unwrap();
291
292        let mut graph = EinsumGraph::new();
293        graph.add_tensor("x");
294
295        let outputs = encoder.build_encoder_graph(&mut graph).unwrap();
296        assert_eq!(outputs.len(), 1);
297        assert!(!graph.nodes.is_empty());
298    }
299
300    #[test]
301    fn test_encoder_config_validation() {
302        let config = EncoderConfig::new(512, 8, 2048).unwrap();
303        assert!(config.validate().is_ok());
304
305        // Invalid head count
306        let result = EncoderConfig::new(512, 7, 2048);
307        assert!(result.is_err());
308    }
309
310    #[test]
311    fn test_encoder_with_causal() {
312        let config = EncoderConfig::new(512, 8, 2048).unwrap().with_causal(true);
313        assert!(config.attention.causal);
314    }
315}