1use tensorlogic_ir::EinsumGraph;
25
26use crate::{
27 attention::MultiHeadAttention,
28 config::{AttentionConfig, FeedForwardConfig},
29 error::Result,
30 ffn::FeedForward,
31 normalization::{LayerNorm, LayerNormConfig},
32};
33
34#[derive(Clone, Debug)]
36pub struct EncoderLayerConfig {
37 pub attention: AttentionConfig,
39 pub feed_forward: FeedForwardConfig,
41 pub layer_norm: LayerNormConfig,
43 pub pre_norm: bool,
45}
46
47impl EncoderLayerConfig {
48 pub fn new(d_model: usize, n_heads: usize, d_ff: usize) -> Result<Self> {
50 Ok(Self {
51 attention: AttentionConfig::new(d_model, n_heads)?,
52 feed_forward: FeedForwardConfig::new(d_model, d_ff),
53 layer_norm: LayerNormConfig::new(d_model),
54 pre_norm: true,
55 })
56 }
57
58 pub fn with_pre_norm(mut self, pre_norm: bool) -> Self {
60 self.pre_norm = pre_norm;
61 self
62 }
63
64 pub fn with_causal(mut self, causal: bool) -> Self {
66 self.attention = self.attention.with_causal(causal);
67 self
68 }
69
70 pub fn with_dropout(mut self, dropout: f64) -> Self {
72 self.attention = self.attention.with_dropout(dropout);
73 self.feed_forward = self.feed_forward.with_dropout(dropout);
74 self
75 }
76
77 pub fn validate(&self) -> Result<()> {
79 self.attention.validate()?;
80 self.feed_forward.validate()?;
81 self.layer_norm.validate()?;
82
83 if self.attention.d_model != self.feed_forward.d_model {
85 return Err(crate::error::TrustformerError::InvalidDimension {
86 expected: self.attention.d_model,
87 got: self.feed_forward.d_model,
88 context: "d_model mismatch between attention and FFN".to_string(),
89 });
90 }
91
92 if self.attention.d_model != self.layer_norm.normalized_shape {
93 return Err(crate::error::TrustformerError::InvalidDimension {
94 expected: self.attention.d_model,
95 got: self.layer_norm.normalized_shape,
96 context: "d_model mismatch with layer norm".to_string(),
97 });
98 }
99
100 Ok(())
101 }
102}
103
104#[derive(Clone, Debug)]
106pub struct EncoderLayer {
107 pub config: EncoderLayerConfig,
109 pub attention: MultiHeadAttention,
111 pub ffn: FeedForward,
113 pub norm1: LayerNorm,
115 pub norm2: LayerNorm,
117}
118
119impl EncoderLayer {
120 pub fn new(config: EncoderLayerConfig) -> Result<Self> {
122 config.validate()?;
123
124 let attention = MultiHeadAttention::new(config.attention.clone())?;
125 let ffn = FeedForward::new(config.feed_forward.clone())?;
126 let norm1 = LayerNorm::new(config.layer_norm.clone())?;
127 let norm2 = LayerNorm::new(config.layer_norm.clone())?;
128
129 Ok(Self {
130 config,
131 attention,
132 ffn,
133 norm1,
134 norm2,
135 })
136 }
137
138 pub fn build_encoder_layer_graph(&self, graph: &mut EinsumGraph) -> Result<Vec<usize>> {
147 let input_tensor = 0;
148
149 if self.config.pre_norm {
150 self.build_pre_norm_encoder(graph, input_tensor)
152 } else {
153 self.build_post_norm_encoder(graph, input_tensor)
155 }
156 }
157
158 fn build_pre_norm_encoder(
159 &self,
160 graph: &mut EinsumGraph,
161 input_tensor: usize,
162 ) -> Result<Vec<usize>> {
163 let normed1_outputs = self.norm1.build_layernorm_graph(graph)?;
165 let normed1 = normed1_outputs[0];
166
167 let q_tensor = graph.add_tensor("encoder_Q");
170 let k_tensor = graph.add_tensor("encoder_K");
171 let v_tensor = graph.add_tensor("encoder_V");
172
173 let _q_node = tensorlogic_ir::EinsumNode::elem_unary("identity", normed1, q_tensor);
175 let _k_node = tensorlogic_ir::EinsumNode::elem_unary("identity", normed1, k_tensor);
176 let _v_node = tensorlogic_ir::EinsumNode::elem_unary("identity", normed1, v_tensor);
177
178 let attn_outputs = self.attention.build_mha_graph(graph)?;
179 let attn_output = attn_outputs[0];
180
181 let residual1 = graph.add_tensor("encoder_residual1");
183 let res1_node =
184 tensorlogic_ir::EinsumNode::elem_binary("add", input_tensor, attn_output, residual1);
185 graph.add_node(res1_node)?;
186
187 let _normed2_outputs = self.norm2.build_layernorm_graph(graph)?;
189
190 let ffn_outputs = self.ffn.build_ffn_graph(graph)?;
192 let ffn_output = ffn_outputs[0];
193
194 let output = graph.add_tensor("encoder_output");
196 let res2_node =
197 tensorlogic_ir::EinsumNode::elem_binary("add", residual1, ffn_output, output);
198 graph.add_node(res2_node)?;
199
200 Ok(vec![output])
201 }
202
203 fn build_post_norm_encoder(
204 &self,
205 graph: &mut EinsumGraph,
206 input_tensor: usize,
207 ) -> Result<Vec<usize>> {
208 let attn_outputs = self.attention.build_mha_graph(graph)?;
210 let attn_output = attn_outputs[0];
211
212 let residual1 = graph.add_tensor("encoder_residual1");
214 let res1_node =
215 tensorlogic_ir::EinsumNode::elem_binary("add", input_tensor, attn_output, residual1);
216 graph.add_node(res1_node)?;
217
218 let normed1_outputs = self.norm1.build_layernorm_graph(graph)?;
220 let _normed1 = normed1_outputs[0];
221
222 let ffn_outputs = self.ffn.build_ffn_graph(graph)?;
224 let ffn_output = ffn_outputs[0];
225
226 let residual2 = graph.add_tensor("encoder_residual2");
228 let res2_node =
229 tensorlogic_ir::EinsumNode::elem_binary("add", _normed1, ffn_output, residual2);
230 graph.add_node(res2_node)?;
231
232 let normed2_outputs = self.norm2.build_layernorm_graph(graph)?;
234 let output = normed2_outputs[0];
235
236 Ok(vec![output])
237 }
238}
239
240#[derive(Clone, Debug)]
242pub struct DecoderLayerConfig {
243 pub self_attention: AttentionConfig,
245 pub cross_attention: AttentionConfig,
247 pub feed_forward: FeedForwardConfig,
249 pub layer_norm: LayerNormConfig,
251 pub pre_norm: bool,
253}
254
255impl DecoderLayerConfig {
256 pub fn new(d_model: usize, n_heads: usize, d_ff: usize) -> Result<Self> {
258 Ok(Self {
259 self_attention: AttentionConfig::new(d_model, n_heads)?.with_causal(true),
260 cross_attention: AttentionConfig::new(d_model, n_heads)?,
261 feed_forward: FeedForwardConfig::new(d_model, d_ff),
262 layer_norm: LayerNormConfig::new(d_model),
263 pre_norm: true,
264 })
265 }
266
267 pub fn with_pre_norm(mut self, pre_norm: bool) -> Self {
269 self.pre_norm = pre_norm;
270 self
271 }
272
273 pub fn with_dropout(mut self, dropout: f64) -> Self {
275 self.self_attention = self.self_attention.with_dropout(dropout);
276 self.cross_attention = self.cross_attention.with_dropout(dropout);
277 self.feed_forward = self.feed_forward.with_dropout(dropout);
278 self
279 }
280
281 pub fn validate(&self) -> Result<()> {
283 self.self_attention.validate()?;
284 self.cross_attention.validate()?;
285 self.feed_forward.validate()?;
286 self.layer_norm.validate()?;
287
288 if !self.self_attention.causal {
290 return Err(crate::error::TrustformerError::InvalidDimension {
291 expected: 1,
292 got: 0,
293 context: "Decoder self-attention must use causal masking".to_string(),
294 });
295 }
296
297 if self.self_attention.d_model != self.cross_attention.d_model
299 || self.self_attention.d_model != self.feed_forward.d_model
300 || self.self_attention.d_model != self.layer_norm.normalized_shape
301 {
302 return Err(crate::error::TrustformerError::InvalidDimension {
303 expected: self.self_attention.d_model,
304 got: 0,
305 context: "d_model mismatch between components".to_string(),
306 });
307 }
308
309 Ok(())
310 }
311}
312
313#[derive(Clone, Debug)]
315pub struct DecoderLayer {
316 pub config: DecoderLayerConfig,
318 pub self_attention: MultiHeadAttention,
320 pub cross_attention: MultiHeadAttention,
322 pub ffn: FeedForward,
324 pub norm1: LayerNorm,
326 pub norm2: LayerNorm,
328 pub norm3: LayerNorm,
330}
331
332impl DecoderLayer {
333 pub fn new(config: DecoderLayerConfig) -> Result<Self> {
335 config.validate()?;
336
337 let self_attention = MultiHeadAttention::new(config.self_attention.clone())?;
338 let cross_attention = MultiHeadAttention::new(config.cross_attention.clone())?;
339 let ffn = FeedForward::new(config.feed_forward.clone())?;
340 let norm1 = LayerNorm::new(config.layer_norm.clone())?;
341 let norm2 = LayerNorm::new(config.layer_norm.clone())?;
342 let norm3 = LayerNorm::new(config.layer_norm.clone())?;
343
344 Ok(Self {
345 config,
346 self_attention,
347 cross_attention,
348 ffn,
349 norm1,
350 norm2,
351 norm3,
352 })
353 }
354
355 pub fn build_decoder_layer_graph(&self, graph: &mut EinsumGraph) -> Result<Vec<usize>> {
365 let decoder_input = 0;
366 let encoder_output = 1;
367
368 if self.config.pre_norm {
369 self.build_pre_norm_decoder(graph, decoder_input, encoder_output)
370 } else {
371 self.build_post_norm_decoder(graph, decoder_input, encoder_output)
372 }
373 }
374
375 fn build_pre_norm_decoder(
376 &self,
377 graph: &mut EinsumGraph,
378 decoder_input: usize,
379 _encoder_output: usize, ) -> Result<Vec<usize>> {
381 let normed1_outputs = self.norm1.build_layernorm_graph(graph)?;
383 let _normed1 = normed1_outputs[0];
384
385 let self_attn_outputs = self.self_attention.build_mha_graph(graph)?;
387 let self_attn_output = self_attn_outputs[0];
388
389 let residual1 = graph.add_tensor("decoder_residual1");
391 let res1_node = tensorlogic_ir::EinsumNode::elem_binary(
392 "add",
393 decoder_input,
394 self_attn_output,
395 residual1,
396 );
397 graph.add_node(res1_node)?;
398
399 let normed2_outputs = self.norm2.build_layernorm_graph(graph)?;
401 let _normed2 = normed2_outputs[0];
402
403 let cross_attn_outputs = self.cross_attention.build_mha_graph(graph)?;
405 let cross_attn_output = cross_attn_outputs[0];
406
407 let residual2 = graph.add_tensor("decoder_residual2");
409 let res2_node =
410 tensorlogic_ir::EinsumNode::elem_binary("add", residual1, cross_attn_output, residual2);
411 graph.add_node(res2_node)?;
412
413 let normed3_outputs = self.norm3.build_layernorm_graph(graph)?;
415 let _normed3 = normed3_outputs[0];
416
417 let ffn_outputs = self.ffn.build_ffn_graph(graph)?;
419 let ffn_output = ffn_outputs[0];
420
421 let output = graph.add_tensor("decoder_output");
423 let res3_node =
424 tensorlogic_ir::EinsumNode::elem_binary("add", residual2, ffn_output, output);
425 graph.add_node(res3_node)?;
426
427 Ok(vec![output])
428 }
429
430 fn build_post_norm_decoder(
431 &self,
432 graph: &mut EinsumGraph,
433 decoder_input: usize,
434 _encoder_output: usize, ) -> Result<Vec<usize>> {
436 let self_attn_outputs = self.self_attention.build_mha_graph(graph)?;
438 let self_attn_output = self_attn_outputs[0];
439
440 let residual1 = graph.add_tensor("decoder_residual1");
442 let res1_node = tensorlogic_ir::EinsumNode::elem_binary(
443 "add",
444 decoder_input,
445 self_attn_output,
446 residual1,
447 );
448 graph.add_node(res1_node)?;
449
450 let normed1_outputs = self.norm1.build_layernorm_graph(graph)?;
451 let _normed1 = normed1_outputs[0];
452
453 let cross_attn_outputs = self.cross_attention.build_mha_graph(graph)?;
455 let cross_attn_output = cross_attn_outputs[0];
456
457 let residual2 = graph.add_tensor("decoder_residual2");
459 let res2_node =
460 tensorlogic_ir::EinsumNode::elem_binary("add", _normed1, cross_attn_output, residual2);
461 graph.add_node(res2_node)?;
462
463 let normed2_outputs = self.norm2.build_layernorm_graph(graph)?;
464 let _normed2 = normed2_outputs[0];
465
466 let ffn_outputs = self.ffn.build_ffn_graph(graph)?;
468 let ffn_output = ffn_outputs[0];
469
470 let residual3 = graph.add_tensor("decoder_residual3");
472 let res3_node =
473 tensorlogic_ir::EinsumNode::elem_binary("add", _normed2, ffn_output, residual3);
474 graph.add_node(res3_node)?;
475
476 let normed3_outputs = self.norm3.build_layernorm_graph(graph)?;
477 let output = normed3_outputs[0];
478
479 Ok(vec![output])
480 }
481}
482
483#[cfg(test)]
484mod tests {
485 use super::*;
486
487 #[test]
488 fn test_encoder_layer_config_creation() {
489 let config = EncoderLayerConfig::new(512, 8, 2048).unwrap();
490 assert_eq!(config.attention.d_model, 512);
491 assert_eq!(config.attention.n_heads, 8);
492 assert_eq!(config.feed_forward.d_ff, 2048);
493 assert!(config.pre_norm);
494 assert!(config.validate().is_ok());
495 }
496
497 #[test]
498 fn test_encoder_layer_config_with_dropout() {
499 let config = EncoderLayerConfig::new(512, 8, 2048)
500 .unwrap()
501 .with_dropout(0.1);
502 assert!((config.attention.dropout - 0.1).abs() < 1e-10);
503 assert!((config.feed_forward.dropout - 0.1).abs() < 1e-10);
504 }
505
506 #[test]
507 fn test_encoder_layer_creation() {
508 let config = EncoderLayerConfig::new(512, 8, 2048).unwrap();
509 let layer = EncoderLayer::new(config).unwrap();
510 assert_eq!(layer.config.attention.d_model, 512);
511 }
512
513 #[test]
514 fn test_encoder_layer_graph_building() {
515 let config = EncoderLayerConfig::new(512, 8, 2048).unwrap();
516 let layer = EncoderLayer::new(config).unwrap();
517
518 let mut graph = EinsumGraph::new();
519 graph.add_tensor("x");
520
521 let outputs = layer.build_encoder_layer_graph(&mut graph).unwrap();
522 assert_eq!(outputs.len(), 1);
523 assert!(!graph.nodes.is_empty());
524 }
525
526 #[test]
527 fn test_decoder_layer_config_creation() {
528 let config = DecoderLayerConfig::new(512, 8, 2048).unwrap();
529 assert_eq!(config.self_attention.d_model, 512);
530 assert_eq!(config.cross_attention.d_model, 512);
531 assert!(config.self_attention.causal);
532 assert!(!config.cross_attention.causal);
533 assert!(config.validate().is_ok());
534 }
535
536 #[test]
537 fn test_decoder_layer_creation() {
538 let config = DecoderLayerConfig::new(512, 8, 2048).unwrap();
539 let layer = DecoderLayer::new(config).unwrap();
540 assert_eq!(layer.config.self_attention.d_model, 512);
541 }
542
543 #[test]
544 fn test_decoder_layer_graph_building() {
545 let config = DecoderLayerConfig::new(512, 8, 2048).unwrap();
546 let layer = DecoderLayer::new(config).unwrap();
547
548 let mut graph = EinsumGraph::new();
549 graph.add_tensor("decoder_input");
550 graph.add_tensor("encoder_output");
551
552 let outputs = layer.build_decoder_layer_graph(&mut graph).unwrap();
553 assert_eq!(outputs.len(), 1);
554 assert!(!graph.nodes.is_empty());
555 }
556}