tensorlogic_trustformers/
stacks.rs1use tensorlogic_ir::EinsumGraph;
25
26use crate::{
27 error::Result,
28 layers::{DecoderLayer, DecoderLayerConfig, EncoderLayer, EncoderLayerConfig},
29 normalization::{LayerNorm, LayerNormConfig},
30 position::{LearnedPositionEncoding, PositionEncodingConfig, SinusoidalPositionEncoding},
31};
32
33#[derive(Clone, Debug)]
35pub struct EncoderStackConfig {
36 pub num_layers: usize,
38 pub layer_config: EncoderLayerConfig,
40 pub position_encoding: PositionEncodingConfig,
42 pub final_layer_norm: bool,
44}
45
46impl EncoderStackConfig {
47 pub fn new(
49 num_layers: usize,
50 d_model: usize,
51 n_heads: usize,
52 d_ff: usize,
53 max_seq_len: usize,
54 ) -> Result<Self> {
55 Ok(Self {
56 num_layers,
57 layer_config: EncoderLayerConfig::new(d_model, n_heads, d_ff)?,
58 position_encoding: PositionEncodingConfig::sinusoidal(d_model, max_seq_len),
59 final_layer_norm: true,
60 })
61 }
62
63 pub fn with_learned_position_encoding(mut self) -> Self {
65 self.position_encoding = PositionEncodingConfig::learned(
66 self.position_encoding.d_model,
67 self.position_encoding.max_seq_len,
68 );
69 self
70 }
71
72 pub fn with_final_layer_norm(mut self, final_layer_norm: bool) -> Self {
74 self.final_layer_norm = final_layer_norm;
75 self
76 }
77
78 pub fn with_dropout(mut self, dropout: f64) -> Self {
80 self.layer_config = self.layer_config.with_dropout(dropout);
81 self.position_encoding = self.position_encoding.with_dropout(dropout);
82 self
83 }
84
85 pub fn validate(&self) -> Result<()> {
87 if self.num_layers == 0 {
88 return Err(crate::error::TrustformerError::InvalidDimension {
89 expected: 1,
90 got: 0,
91 context: "num_layers must be positive".to_string(),
92 });
93 }
94
95 self.layer_config.validate()?;
96 self.position_encoding.validate()?;
97
98 Ok(())
99 }
100}
101
102#[derive(Clone, Debug)]
104pub struct EncoderStack {
105 pub config: EncoderStackConfig,
107 pub layers: Vec<EncoderLayer>,
109 pub position_encoding_sin: Option<SinusoidalPositionEncoding>,
111 pub position_encoding_learned: Option<LearnedPositionEncoding>,
113 pub final_norm: Option<LayerNorm>,
115}
116
117impl EncoderStack {
118 pub fn new(config: EncoderStackConfig) -> Result<Self> {
120 config.validate()?;
121
122 let mut layers = Vec::with_capacity(config.num_layers);
123 for _ in 0..config.num_layers {
124 layers.push(EncoderLayer::new(config.layer_config.clone())?);
125 }
126
127 let position_encoding_sin = match config.position_encoding.encoding_type {
128 crate::position::PositionEncodingType::Sinusoidal { .. } => Some(
129 SinusoidalPositionEncoding::new(config.position_encoding.clone())?,
130 ),
131 _ => None,
132 };
133
134 let position_encoding_learned = match config.position_encoding.encoding_type {
135 crate::position::PositionEncodingType::Learned => Some(LearnedPositionEncoding::new(
136 config.position_encoding.clone(),
137 )?),
138 _ => None,
139 };
140
141 let final_norm = if config.final_layer_norm {
142 Some(LayerNorm::new(LayerNormConfig::new(
143 config.layer_config.attention.d_model,
144 ))?)
145 } else {
146 None
147 };
148
149 Ok(Self {
150 config,
151 layers,
152 position_encoding_sin,
153 position_encoding_learned,
154 final_norm,
155 })
156 }
157
158 pub fn build_encoder_stack_graph(&self, graph: &mut EinsumGraph) -> Result<Vec<usize>> {
167 let mut current_output = if let Some(ref pe_sin) = self.position_encoding_sin {
169 pe_sin.build_encoding_graph(graph)?[0]
170 } else if let Some(ref pe_learned) = self.position_encoding_learned {
171 pe_learned.build_encoding_graph(graph)?[0]
172 } else {
173 0 };
175
176 for (i, layer) in self.layers.iter().enumerate() {
178 let layer_outputs = layer.build_encoder_layer_graph(graph)?;
180 current_output = layer_outputs[0];
181
182 let layer_marker = graph.add_tensor(format!("encoder_layer_{}_output", i));
184 let marker_node =
185 tensorlogic_ir::EinsumNode::elem_unary("identity", current_output, layer_marker);
186 graph.add_node(marker_node)?;
187 current_output = layer_marker;
188 }
189
190 if let Some(ref final_norm) = self.final_norm {
192 let final_outputs = final_norm.build_layernorm_graph(graph)?;
193 current_output = final_outputs[0];
194 }
195
196 Ok(vec![current_output])
197 }
198
199 pub fn num_layers(&self) -> usize {
201 self.config.num_layers
202 }
203}
204
205#[derive(Clone, Debug)]
207pub struct DecoderStackConfig {
208 pub num_layers: usize,
210 pub layer_config: DecoderLayerConfig,
212 pub position_encoding: PositionEncodingConfig,
214 pub final_layer_norm: bool,
216}
217
218impl DecoderStackConfig {
219 pub fn new(
221 num_layers: usize,
222 d_model: usize,
223 n_heads: usize,
224 d_ff: usize,
225 max_seq_len: usize,
226 ) -> Result<Self> {
227 Ok(Self {
228 num_layers,
229 layer_config: DecoderLayerConfig::new(d_model, n_heads, d_ff)?,
230 position_encoding: PositionEncodingConfig::sinusoidal(d_model, max_seq_len),
231 final_layer_norm: true,
232 })
233 }
234
235 pub fn with_learned_position_encoding(mut self) -> Self {
237 self.position_encoding = PositionEncodingConfig::learned(
238 self.position_encoding.d_model,
239 self.position_encoding.max_seq_len,
240 );
241 self
242 }
243
244 pub fn with_final_layer_norm(mut self, final_layer_norm: bool) -> Self {
246 self.final_layer_norm = final_layer_norm;
247 self
248 }
249
250 pub fn with_dropout(mut self, dropout: f64) -> Self {
252 self.layer_config = self.layer_config.with_dropout(dropout);
253 self.position_encoding = self.position_encoding.with_dropout(dropout);
254 self
255 }
256
257 pub fn validate(&self) -> Result<()> {
259 if self.num_layers == 0 {
260 return Err(crate::error::TrustformerError::InvalidDimension {
261 expected: 1,
262 got: 0,
263 context: "num_layers must be positive".to_string(),
264 });
265 }
266
267 self.layer_config.validate()?;
268 self.position_encoding.validate()?;
269
270 Ok(())
271 }
272}
273
274#[derive(Clone, Debug)]
276pub struct DecoderStack {
277 pub config: DecoderStackConfig,
279 pub layers: Vec<DecoderLayer>,
281 pub position_encoding_sin: Option<SinusoidalPositionEncoding>,
283 pub position_encoding_learned: Option<LearnedPositionEncoding>,
285 pub final_norm: Option<LayerNorm>,
287}
288
289impl DecoderStack {
290 pub fn new(config: DecoderStackConfig) -> Result<Self> {
292 config.validate()?;
293
294 let mut layers = Vec::with_capacity(config.num_layers);
295 for _ in 0..config.num_layers {
296 layers.push(DecoderLayer::new(config.layer_config.clone())?);
297 }
298
299 let position_encoding_sin = match config.position_encoding.encoding_type {
300 crate::position::PositionEncodingType::Sinusoidal { .. } => Some(
301 SinusoidalPositionEncoding::new(config.position_encoding.clone())?,
302 ),
303 _ => None,
304 };
305
306 let position_encoding_learned = match config.position_encoding.encoding_type {
307 crate::position::PositionEncodingType::Learned => Some(LearnedPositionEncoding::new(
308 config.position_encoding.clone(),
309 )?),
310 _ => None,
311 };
312
313 let final_norm = if config.final_layer_norm {
314 Some(LayerNorm::new(LayerNormConfig::new(
315 config.layer_config.self_attention.d_model,
316 ))?)
317 } else {
318 None
319 };
320
321 Ok(Self {
322 config,
323 layers,
324 position_encoding_sin,
325 position_encoding_learned,
326 final_norm,
327 })
328 }
329
330 pub fn build_decoder_stack_graph(&self, graph: &mut EinsumGraph) -> Result<Vec<usize>> {
340 let mut current_output = if let Some(ref pe_sin) = self.position_encoding_sin {
342 pe_sin.build_encoding_graph(graph)?[0]
343 } else if let Some(ref pe_learned) = self.position_encoding_learned {
344 pe_learned.build_encoding_graph(graph)?[0]
345 } else {
346 0 };
348
349 for (i, layer) in self.layers.iter().enumerate() {
351 let layer_outputs = layer.build_decoder_layer_graph(graph)?;
352 current_output = layer_outputs[0];
353
354 let layer_marker = graph.add_tensor(format!("decoder_layer_{}_output", i));
356 let marker_node =
357 tensorlogic_ir::EinsumNode::elem_unary("identity", current_output, layer_marker);
358 graph.add_node(marker_node)?;
359 current_output = layer_marker;
360 }
361
362 if let Some(ref final_norm) = self.final_norm {
364 let final_outputs = final_norm.build_layernorm_graph(graph)?;
365 current_output = final_outputs[0];
366 }
367
368 Ok(vec![current_output])
369 }
370
371 pub fn num_layers(&self) -> usize {
373 self.config.num_layers
374 }
375}
376
377#[cfg(test)]
378mod tests {
379 use super::*;
380
381 #[test]
382 fn test_encoder_stack_config_creation() {
383 let config = EncoderStackConfig::new(6, 512, 8, 2048, 1024).unwrap();
384 assert_eq!(config.num_layers, 6);
385 assert_eq!(config.layer_config.attention.d_model, 512);
386 assert!(config.final_layer_norm);
387 assert!(config.validate().is_ok());
388 }
389
390 #[test]
391 fn test_encoder_stack_config_with_learned_pe() {
392 let config = EncoderStackConfig::new(6, 512, 8, 2048, 1024)
393 .unwrap()
394 .with_learned_position_encoding();
395 assert!(matches!(
396 config.position_encoding.encoding_type,
397 crate::position::PositionEncodingType::Learned
398 ));
399 }
400
401 #[test]
402 fn test_encoder_stack_creation() {
403 let config = EncoderStackConfig::new(6, 512, 8, 2048, 1024).unwrap();
404 let stack = EncoderStack::new(config).unwrap();
405 assert_eq!(stack.num_layers(), 6);
406 assert!(stack.position_encoding_sin.is_some());
407 assert!(stack.final_norm.is_some());
408 }
409
410 #[test]
411 fn test_encoder_stack_graph_building() {
412 let config = EncoderStackConfig::new(2, 512, 8, 2048, 1024).unwrap();
413 let stack = EncoderStack::new(config).unwrap();
414
415 let mut graph = EinsumGraph::new();
416 graph.add_tensor("x");
417
418 let outputs = stack.build_encoder_stack_graph(&mut graph).unwrap();
419 assert_eq!(outputs.len(), 1);
420 assert!(!graph.nodes.is_empty());
421 }
422
423 #[test]
424 fn test_decoder_stack_config_creation() {
425 let config = DecoderStackConfig::new(6, 512, 8, 2048, 1024).unwrap();
426 assert_eq!(config.num_layers, 6);
427 assert_eq!(config.layer_config.self_attention.d_model, 512);
428 assert!(config.layer_config.self_attention.causal);
429 assert!(config.validate().is_ok());
430 }
431
432 #[test]
433 fn test_decoder_stack_creation() {
434 let config = DecoderStackConfig::new(6, 512, 8, 2048, 1024).unwrap();
435 let stack = DecoderStack::new(config).unwrap();
436 assert_eq!(stack.num_layers(), 6);
437 assert!(stack.position_encoding_sin.is_some());
438 assert!(stack.final_norm.is_some());
439 }
440
441 #[test]
442 fn test_decoder_stack_graph_building() {
443 let config = DecoderStackConfig::new(2, 512, 8, 2048, 1024).unwrap();
444 let stack = DecoderStack::new(config).unwrap();
445
446 let mut graph = EinsumGraph::new();
447 graph.add_tensor("target");
448 graph.add_tensor("encoder_output");
449
450 let outputs = stack.build_decoder_stack_graph(&mut graph).unwrap();
451 assert_eq!(outputs.len(), 1);
452 assert!(!graph.nodes.is_empty());
453 }
454
455 #[test]
456 fn test_invalid_zero_layers() {
457 let result = EncoderStackConfig::new(0, 512, 8, 2048, 1024);
458 if let Ok(config) = result {
460 assert!(config.validate().is_err());
461 }
462 }
463}