tensorlogic_trustformers/
decoder.rs1use tensorlogic_ir::EinsumGraph;
27
28use crate::{
29 attention::MultiHeadAttention,
30 config::{AttentionConfig, FeedForwardConfig},
31 error::Result,
32 ffn::FeedForward,
33 normalization::{LayerNorm, LayerNormConfig},
34};
35
36#[derive(Clone, Debug)]
38pub struct DecoderConfig {
39 pub self_attention: AttentionConfig,
41 pub cross_attention: AttentionConfig,
43 pub feed_forward: FeedForwardConfig,
45 pub layer_norm: LayerNormConfig,
47 pub pre_norm: bool,
49}
50
51impl DecoderConfig {
52 pub fn new(d_model: usize, n_heads: usize, d_ff: usize) -> Result<Self> {
54 Ok(Self {
55 self_attention: AttentionConfig::new(d_model, n_heads)?.with_causal(true),
56 cross_attention: AttentionConfig::new(d_model, n_heads)?,
57 feed_forward: FeedForwardConfig::new(d_model, d_ff),
58 layer_norm: LayerNormConfig::new(d_model),
59 pre_norm: true,
60 })
61 }
62
63 pub fn with_pre_norm(mut self, pre_norm: bool) -> Self {
65 self.pre_norm = pre_norm;
66 self
67 }
68
69 pub fn with_dropout(mut self, dropout: f64) -> Self {
71 self.self_attention = self.self_attention.with_dropout(dropout);
72 self.cross_attention = self.cross_attention.with_dropout(dropout);
73 self.feed_forward = self.feed_forward.with_dropout(dropout);
74 self
75 }
76
77 pub fn validate(&self) -> Result<()> {
79 self.self_attention.validate()?;
80 self.cross_attention.validate()?;
81 self.feed_forward.validate()?;
82 self.layer_norm.validate()?;
83
84 if !self.self_attention.causal {
86 return Err(crate::error::TrustformerError::InvalidDimension {
87 expected: 1,
88 got: 0,
89 context: "Decoder self-attention must use causal masking".to_string(),
90 });
91 }
92
93 if self.self_attention.d_model != self.cross_attention.d_model
95 || self.self_attention.d_model != self.feed_forward.d_model
96 || self.self_attention.d_model != self.layer_norm.normalized_shape
97 {
98 return Err(crate::error::TrustformerError::InvalidDimension {
99 expected: self.self_attention.d_model,
100 got: 0,
101 context: "d_model mismatch between components".to_string(),
102 });
103 }
104
105 Ok(())
106 }
107}
108
109#[derive(Clone, Debug)]
111pub struct Decoder {
112 pub config: DecoderConfig,
114 pub self_attention: MultiHeadAttention,
116 pub cross_attention: MultiHeadAttention,
118 pub ffn: FeedForward,
120 pub norm1: LayerNorm,
122 pub norm2: LayerNorm,
124 pub norm3: LayerNorm,
126}
127
128impl Decoder {
129 pub fn new(config: DecoderConfig) -> Result<Self> {
131 config.validate()?;
132
133 let self_attention = MultiHeadAttention::new(config.self_attention.clone())?;
134 let cross_attention = MultiHeadAttention::new(config.cross_attention.clone())?;
135 let ffn = FeedForward::new(config.feed_forward.clone())?;
136 let norm1 = LayerNorm::new(config.layer_norm.clone())?;
137 let norm2 = LayerNorm::new(config.layer_norm.clone())?;
138 let norm3 = LayerNorm::new(config.layer_norm.clone())?;
139
140 Ok(Self {
141 config,
142 self_attention,
143 cross_attention,
144 ffn,
145 norm1,
146 norm2,
147 norm3,
148 })
149 }
150
151 pub fn build_decoder_graph(&self, graph: &mut EinsumGraph) -> Result<Vec<usize>> {
161 let decoder_input = 0;
162 let encoder_output = 1;
163
164 if self.config.pre_norm {
165 self.build_pre_norm_decoder(graph, decoder_input, encoder_output)
166 } else {
167 self.build_post_norm_decoder(graph, decoder_input, encoder_output)
168 }
169 }
170
171 fn build_pre_norm_decoder(
172 &self,
173 graph: &mut EinsumGraph,
174 decoder_input: usize,
175 _encoder_output: usize,
176 ) -> Result<Vec<usize>> {
177 let normed1_outputs = self.norm1.build_layernorm_graph(graph)?;
179 let _normed1 = normed1_outputs[0];
180
181 let self_attn_outputs = self.self_attention.build_mha_graph(graph)?;
183 let self_attn_output = self_attn_outputs[0];
184
185 let residual1 = graph.add_tensor("decoder_residual1");
187 let res1_node = tensorlogic_ir::EinsumNode::elem_binary(
188 "add",
189 decoder_input,
190 self_attn_output,
191 residual1,
192 );
193 graph.add_node(res1_node)?;
194
195 let normed2_outputs = self.norm2.build_layernorm_graph(graph)?;
197 let _normed2 = normed2_outputs[0];
198
199 let cross_attn_outputs = self.cross_attention.build_mha_graph(graph)?;
201 let cross_attn_output = cross_attn_outputs[0];
202
203 let residual2 = graph.add_tensor("decoder_residual2");
205 let res2_node =
206 tensorlogic_ir::EinsumNode::elem_binary("add", residual1, cross_attn_output, residual2);
207 graph.add_node(res2_node)?;
208
209 let normed3_outputs = self.norm3.build_layernorm_graph(graph)?;
211 let _normed3 = normed3_outputs[0];
212
213 let ffn_outputs = self.ffn.build_ffn_graph(graph)?;
215 let ffn_output = ffn_outputs[0];
216
217 let output = graph.add_tensor("decoder_output");
219 let res3_node =
220 tensorlogic_ir::EinsumNode::elem_binary("add", residual2, ffn_output, output);
221 graph.add_node(res3_node)?;
222
223 Ok(vec![output])
224 }
225
226 fn build_post_norm_decoder(
227 &self,
228 graph: &mut EinsumGraph,
229 decoder_input: usize,
230 _encoder_output: usize,
231 ) -> Result<Vec<usize>> {
232 let self_attn_outputs = self.self_attention.build_mha_graph(graph)?;
234 let self_attn_output = self_attn_outputs[0];
235
236 let residual1 = graph.add_tensor("decoder_residual1");
238 let res1_node = tensorlogic_ir::EinsumNode::elem_binary(
239 "add",
240 decoder_input,
241 self_attn_output,
242 residual1,
243 );
244 graph.add_node(res1_node)?;
245
246 let normed1_outputs = self.norm1.build_layernorm_graph(graph)?;
247 let normed1 = normed1_outputs[0];
248
249 let cross_attn_outputs = self.cross_attention.build_mha_graph(graph)?;
251 let cross_attn_output = cross_attn_outputs[0];
252
253 let residual2 = graph.add_tensor("decoder_residual2");
255 let res2_node =
256 tensorlogic_ir::EinsumNode::elem_binary("add", normed1, cross_attn_output, residual2);
257 graph.add_node(res2_node)?;
258
259 let normed2_outputs = self.norm2.build_layernorm_graph(graph)?;
260 let normed2 = normed2_outputs[0];
261
262 let ffn_outputs = self.ffn.build_ffn_graph(graph)?;
264 let ffn_output = ffn_outputs[0];
265
266 let residual3 = graph.add_tensor("decoder_residual3");
268 let res3_node =
269 tensorlogic_ir::EinsumNode::elem_binary("add", normed2, ffn_output, residual3);
270 graph.add_node(res3_node)?;
271
272 let normed3_outputs = self.norm3.build_layernorm_graph(graph)?;
273 let output = normed3_outputs[0];
274
275 Ok(vec![output])
276 }
277}
278
279#[cfg(test)]
280mod tests {
281 use super::*;
282
283 #[test]
284 fn test_decoder_config_creation() {
285 let config = DecoderConfig::new(512, 8, 2048).unwrap();
286 assert_eq!(config.self_attention.d_model, 512);
287 assert_eq!(config.cross_attention.d_model, 512);
288 assert!(config.self_attention.causal);
289 assert!(!config.cross_attention.causal);
290 assert!(config.validate().is_ok());
291 }
292
293 #[test]
294 fn test_decoder_config_with_dropout() {
295 let config = DecoderConfig::new(512, 8, 2048).unwrap().with_dropout(0.1);
296 assert!((config.self_attention.dropout - 0.1).abs() < 1e-10);
297 assert!((config.cross_attention.dropout - 0.1).abs() < 1e-10);
298 assert!((config.feed_forward.dropout - 0.1).abs() < 1e-10);
299 }
300
301 #[test]
302 fn test_decoder_config_pre_norm() {
303 let config = DecoderConfig::new(512, 8, 2048)
304 .unwrap()
305 .with_pre_norm(false);
306 assert!(!config.pre_norm);
307 }
308
309 #[test]
310 fn test_decoder_creation() {
311 let config = DecoderConfig::new(512, 8, 2048).unwrap();
312 let decoder = Decoder::new(config).unwrap();
313 assert_eq!(decoder.config.self_attention.d_model, 512);
314 }
315
316 #[test]
317 fn test_decoder_graph_building_pre_norm() {
318 let config = DecoderConfig::new(512, 8, 2048).unwrap();
319 let decoder = Decoder::new(config).unwrap();
320
321 let mut graph = EinsumGraph::new();
322 graph.add_tensor("decoder_input");
323 graph.add_tensor("encoder_output");
324
325 let outputs = decoder.build_decoder_graph(&mut graph).unwrap();
326 assert_eq!(outputs.len(), 1);
327 assert!(!graph.nodes.is_empty());
328 }
329
330 #[test]
331 fn test_decoder_graph_building_post_norm() {
332 let config = DecoderConfig::new(512, 8, 2048)
333 .unwrap()
334 .with_pre_norm(false);
335 let decoder = Decoder::new(config).unwrap();
336
337 let mut graph = EinsumGraph::new();
338 graph.add_tensor("decoder_input");
339 graph.add_tensor("encoder_output");
340
341 let outputs = decoder.build_decoder_graph(&mut graph).unwrap();
342 assert_eq!(outputs.len(), 1);
343 assert!(!graph.nodes.is_empty());
344 }
345
346 #[test]
347 fn test_decoder_config_validation() {
348 let config = DecoderConfig::new(512, 8, 2048).unwrap();
349 assert!(config.validate().is_ok());
350
351 let result = DecoderConfig::new(512, 7, 2048);
353 assert!(result.is_err());
354 }
355
356 #[test]
357 fn test_decoder_requires_causal_masking() {
358 let config = DecoderConfig::new(512, 8, 2048).unwrap();
359 assert!(config.self_attention.causal);
360 assert!(!config.cross_attention.causal);
361 }
362}