tensorlogic_trustformers/
encoder.rs1use tensorlogic_ir::EinsumGraph;
24
25use crate::{
26 attention::MultiHeadAttention,
27 config::{AttentionConfig, FeedForwardConfig},
28 error::Result,
29 ffn::FeedForward,
30 normalization::{LayerNorm, LayerNormConfig},
31};
32
33#[derive(Clone, Debug)]
35pub struct EncoderConfig {
36 pub attention: AttentionConfig,
38 pub feed_forward: FeedForwardConfig,
40 pub layer_norm: LayerNormConfig,
42 pub pre_norm: bool,
44}
45
46impl EncoderConfig {
47 pub fn new(d_model: usize, n_heads: usize, d_ff: usize) -> Result<Self> {
49 Ok(Self {
50 attention: AttentionConfig::new(d_model, n_heads)?,
51 feed_forward: FeedForwardConfig::new(d_model, d_ff),
52 layer_norm: LayerNormConfig::new(d_model),
53 pre_norm: true,
54 })
55 }
56
57 pub fn with_pre_norm(mut self, pre_norm: bool) -> Self {
59 self.pre_norm = pre_norm;
60 self
61 }
62
63 pub fn with_causal(mut self, causal: bool) -> Self {
65 self.attention = self.attention.with_causal(causal);
66 self
67 }
68
69 pub fn with_dropout(mut self, dropout: f64) -> Self {
71 self.attention = self.attention.with_dropout(dropout);
72 self.feed_forward = self.feed_forward.with_dropout(dropout);
73 self
74 }
75
76 pub fn validate(&self) -> Result<()> {
78 self.attention.validate()?;
79 self.feed_forward.validate()?;
80 self.layer_norm.validate()?;
81
82 if self.attention.d_model != self.feed_forward.d_model {
84 return Err(crate::error::TrustformerError::InvalidDimension {
85 expected: self.attention.d_model,
86 got: self.feed_forward.d_model,
87 context: "d_model mismatch between attention and FFN".to_string(),
88 });
89 }
90
91 if self.attention.d_model != self.layer_norm.normalized_shape {
92 return Err(crate::error::TrustformerError::InvalidDimension {
93 expected: self.attention.d_model,
94 got: self.layer_norm.normalized_shape,
95 context: "d_model mismatch with layer norm".to_string(),
96 });
97 }
98
99 Ok(())
100 }
101}
102
103#[derive(Clone, Debug)]
105pub struct Encoder {
106 pub config: EncoderConfig,
108 pub attention: MultiHeadAttention,
110 pub ffn: FeedForward,
112 pub norm1: LayerNorm,
114 pub norm2: LayerNorm,
116}
117
118impl Encoder {
119 pub fn new(config: EncoderConfig) -> Result<Self> {
121 config.validate()?;
122
123 let attention = MultiHeadAttention::new(config.attention.clone())?;
124 let ffn = FeedForward::new(config.feed_forward.clone())?;
125 let norm1 = LayerNorm::new(config.layer_norm.clone())?;
126 let norm2 = LayerNorm::new(config.layer_norm.clone())?;
127
128 Ok(Self {
129 config,
130 attention,
131 ffn,
132 norm1,
133 norm2,
134 })
135 }
136
137 pub fn build_encoder_graph(&self, graph: &mut EinsumGraph) -> Result<Vec<usize>> {
146 let input_tensor = 0;
147
148 if self.config.pre_norm {
149 self.build_pre_norm_encoder(graph, input_tensor)
150 } else {
151 self.build_post_norm_encoder(graph, input_tensor)
152 }
153 }
154
155 fn build_pre_norm_encoder(
156 &self,
157 graph: &mut EinsumGraph,
158 input_tensor: usize,
159 ) -> Result<Vec<usize>> {
160 let normed1_outputs = self.norm1.build_layernorm_graph(graph)?;
162 let normed1 = normed1_outputs[0];
163
164 let q_tensor = graph.add_tensor("encoder_Q");
167 let k_tensor = graph.add_tensor("encoder_K");
168 let v_tensor = graph.add_tensor("encoder_V");
169
170 let _q_node = tensorlogic_ir::EinsumNode::elem_unary("identity", normed1, q_tensor);
171 let _k_node = tensorlogic_ir::EinsumNode::elem_unary("identity", normed1, k_tensor);
172 let _v_node = tensorlogic_ir::EinsumNode::elem_unary("identity", normed1, v_tensor);
173
174 let attn_outputs = self.attention.build_mha_graph(graph)?;
175 let attn_output = attn_outputs[0];
176
177 let residual1 = graph.add_tensor("encoder_residual1");
179 let res1_node =
180 tensorlogic_ir::EinsumNode::elem_binary("add", input_tensor, attn_output, residual1);
181 graph.add_node(res1_node)?;
182
183 let _normed2_outputs = self.norm2.build_layernorm_graph(graph)?;
185
186 let ffn_outputs = self.ffn.build_ffn_graph(graph)?;
188 let ffn_output = ffn_outputs[0];
189
190 let output = graph.add_tensor("encoder_output");
192 let res2_node =
193 tensorlogic_ir::EinsumNode::elem_binary("add", residual1, ffn_output, output);
194 graph.add_node(res2_node)?;
195
196 Ok(vec![output])
197 }
198
199 fn build_post_norm_encoder(
200 &self,
201 graph: &mut EinsumGraph,
202 input_tensor: usize,
203 ) -> Result<Vec<usize>> {
204 let attn_outputs = self.attention.build_mha_graph(graph)?;
206 let attn_output = attn_outputs[0];
207
208 let residual1 = graph.add_tensor("encoder_residual1");
210 let res1_node =
211 tensorlogic_ir::EinsumNode::elem_binary("add", input_tensor, attn_output, residual1);
212 graph.add_node(res1_node)?;
213
214 let normed1_outputs = self.norm1.build_layernorm_graph(graph)?;
216 let normed1 = normed1_outputs[0];
217
218 let ffn_outputs = self.ffn.build_ffn_graph(graph)?;
220 let ffn_output = ffn_outputs[0];
221
222 let residual2 = graph.add_tensor("encoder_residual2");
224 let res2_node =
225 tensorlogic_ir::EinsumNode::elem_binary("add", normed1, ffn_output, residual2);
226 graph.add_node(res2_node)?;
227
228 let normed2_outputs = self.norm2.build_layernorm_graph(graph)?;
230 let output = normed2_outputs[0];
231
232 Ok(vec![output])
233 }
234}
235
236#[cfg(test)]
237mod tests {
238 use super::*;
239
240 #[test]
241 fn test_encoder_config_creation() {
242 let config = EncoderConfig::new(512, 8, 2048).unwrap();
243 assert_eq!(config.attention.d_model, 512);
244 assert_eq!(config.attention.n_heads, 8);
245 assert_eq!(config.feed_forward.d_ff, 2048);
246 assert!(config.pre_norm);
247 assert!(config.validate().is_ok());
248 }
249
250 #[test]
251 fn test_encoder_config_with_dropout() {
252 let config = EncoderConfig::new(512, 8, 2048).unwrap().with_dropout(0.1);
253 assert!((config.attention.dropout - 0.1).abs() < 1e-10);
254 assert!((config.feed_forward.dropout - 0.1).abs() < 1e-10);
255 }
256
257 #[test]
258 fn test_encoder_config_pre_norm() {
259 let config = EncoderConfig::new(512, 8, 2048)
260 .unwrap()
261 .with_pre_norm(false);
262 assert!(!config.pre_norm);
263 }
264
265 #[test]
266 fn test_encoder_creation() {
267 let config = EncoderConfig::new(512, 8, 2048).unwrap();
268 let encoder = Encoder::new(config).unwrap();
269 assert_eq!(encoder.config.attention.d_model, 512);
270 }
271
272 #[test]
273 fn test_encoder_graph_building_pre_norm() {
274 let config = EncoderConfig::new(512, 8, 2048).unwrap();
275 let encoder = Encoder::new(config).unwrap();
276
277 let mut graph = EinsumGraph::new();
278 graph.add_tensor("x");
279
280 let outputs = encoder.build_encoder_graph(&mut graph).unwrap();
281 assert_eq!(outputs.len(), 1);
282 assert!(!graph.nodes.is_empty());
283 }
284
285 #[test]
286 fn test_encoder_graph_building_post_norm() {
287 let config = EncoderConfig::new(512, 8, 2048)
288 .unwrap()
289 .with_pre_norm(false);
290 let encoder = Encoder::new(config).unwrap();
291
292 let mut graph = EinsumGraph::new();
293 graph.add_tensor("x");
294
295 let outputs = encoder.build_encoder_graph(&mut graph).unwrap();
296 assert_eq!(outputs.len(), 1);
297 assert!(!graph.nodes.is_empty());
298 }
299
300 #[test]
301 fn test_encoder_config_validation() {
302 let config = EncoderConfig::new(512, 8, 2048).unwrap();
303 assert!(config.validate().is_ok());
304
305 let result = EncoderConfig::new(512, 7, 2048);
307 assert!(result.is_err());
308 }
309
310 #[test]
311 fn test_encoder_with_causal() {
312 let config = EncoderConfig::new(512, 8, 2048).unwrap().with_causal(true);
313 assert!(config.attention.causal);
314 }
315}