tensorlogic_trustformers/
patterns.rs1use tensorlogic_ir::{EinsumGraph, EinsumNode};
8
9use crate::error::{Result, TrustformerError};
10
11pub trait AttentionMask {
13 fn build_mask(&self, graph: &mut EinsumGraph, seq_len: usize) -> Result<usize>;
18
19 fn mask_type(&self) -> &str;
21}
22
23#[derive(Clone, Debug)]
28pub struct CausalMask {
29 pub batch_size: usize,
31}
32
33impl CausalMask {
34 pub fn new(batch_size: usize) -> Self {
36 Self { batch_size }
37 }
38}
39
40impl AttentionMask for CausalMask {
41 fn build_mask(&self, graph: &mut EinsumGraph, seq_len: usize) -> Result<usize> {
42 let mask_tensor = graph.add_tensor("causal_mask");
44 let mask_node = EinsumNode::elem_unary(
45 format!("causal_mask_{}x{}", seq_len, seq_len),
46 0, mask_tensor,
48 );
49 graph.add_node(mask_node)?;
50 Ok(mask_tensor)
51 }
52
53 fn mask_type(&self) -> &str {
54 "causal"
55 }
56}
57
58#[derive(Clone, Debug)]
62pub struct LocalMask {
63 pub batch_size: usize,
65 pub window_size: usize,
67}
68
69impl LocalMask {
70 pub fn new(batch_size: usize, window_size: usize) -> Self {
72 Self {
73 batch_size,
74 window_size,
75 }
76 }
77}
78
79impl AttentionMask for LocalMask {
80 fn build_mask(&self, graph: &mut EinsumGraph, seq_len: usize) -> Result<usize> {
81 let mask_tensor = graph.add_tensor("local_mask");
82 let mask_node = EinsumNode::elem_unary(
83 format!("local_mask_w{}_{}x{}", self.window_size, seq_len, seq_len),
84 0,
85 mask_tensor,
86 );
87 graph.add_node(mask_node)?;
88 Ok(mask_tensor)
89 }
90
91 fn mask_type(&self) -> &str {
92 "local"
93 }
94}
95
96#[derive(Clone, Debug)]
100pub struct StridedMask {
101 pub batch_size: usize,
103 pub stride: usize,
105}
106
107impl StridedMask {
108 pub fn new(batch_size: usize, stride: usize) -> Result<Self> {
110 if stride == 0 {
111 return Err(TrustformerError::InvalidDimension {
112 expected: 1,
113 got: 0,
114 context: "stride must be positive".to_string(),
115 });
116 }
117 Ok(Self { batch_size, stride })
118 }
119}
120
121impl AttentionMask for StridedMask {
122 fn build_mask(&self, graph: &mut EinsumGraph, seq_len: usize) -> Result<usize> {
123 let mask_tensor = graph.add_tensor("strided_mask");
124 let mask_node = EinsumNode::elem_unary(
125 format!("strided_mask_s{}_{}x{}", self.stride, seq_len, seq_len),
126 0,
127 mask_tensor,
128 );
129 graph.add_node(mask_node)?;
130 Ok(mask_tensor)
131 }
132
133 fn mask_type(&self) -> &str {
134 "strided"
135 }
136}
137
138#[derive(Clone, Debug)]
142pub struct BlockSparseMask {
143 pub batch_size: usize,
145 pub block_size: usize,
147}
148
149impl BlockSparseMask {
150 pub fn new(batch_size: usize, block_size: usize) -> Result<Self> {
152 if block_size == 0 {
153 return Err(TrustformerError::InvalidDimension {
154 expected: 1,
155 got: 0,
156 context: "block_size must be positive".to_string(),
157 });
158 }
159 Ok(Self {
160 batch_size,
161 block_size,
162 })
163 }
164}
165
166impl AttentionMask for BlockSparseMask {
167 fn build_mask(&self, graph: &mut EinsumGraph, seq_len: usize) -> Result<usize> {
168 let mask_tensor = graph.add_tensor("block_sparse_mask");
169 let mask_node = EinsumNode::elem_unary(
170 format!(
171 "block_sparse_mask_b{}_{}x{}",
172 self.block_size, seq_len, seq_len
173 ),
174 0,
175 mask_tensor,
176 );
177 graph.add_node(mask_node)?;
178 Ok(mask_tensor)
179 }
180
181 fn mask_type(&self) -> &str {
182 "block_sparse"
183 }
184}
185
186#[derive(Clone, Debug)]
190pub struct GlobalLocalMask {
191 pub batch_size: usize,
193 pub num_global_tokens: usize,
195 pub local_window: usize,
197}
198
199impl GlobalLocalMask {
200 pub fn new(batch_size: usize, num_global_tokens: usize, local_window: usize) -> Self {
202 Self {
203 batch_size,
204 num_global_tokens,
205 local_window,
206 }
207 }
208}
209
210impl AttentionMask for GlobalLocalMask {
211 fn build_mask(&self, graph: &mut EinsumGraph, seq_len: usize) -> Result<usize> {
212 let mask_tensor = graph.add_tensor("global_local_mask");
213 let mask_node = EinsumNode::elem_unary(
214 format!(
215 "global_local_mask_g{}_w{}_{}x{}",
216 self.num_global_tokens, self.local_window, seq_len, seq_len
217 ),
218 0,
219 mask_tensor,
220 );
221 graph.add_node(mask_node)?;
222 Ok(mask_tensor)
223 }
224
225 fn mask_type(&self) -> &str {
226 "global_local"
227 }
228}
229
230#[derive(Clone, Debug)]
234pub enum RulePattern {
235 Hard,
237 Soft,
239 Gated,
241}
242
243#[derive(Clone, Debug)]
245pub struct RuleBasedMask {
246 pub batch_size: usize,
248 pub pattern: RulePattern,
250 pub rule_spec: String,
252}
253
254impl RuleBasedMask {
255 pub fn new(batch_size: usize, pattern: RulePattern, rule_spec: String) -> Self {
257 Self {
258 batch_size,
259 pattern,
260 rule_spec,
261 }
262 }
263
264 pub fn hard(batch_size: usize, rule_spec: String) -> Self {
266 Self::new(batch_size, RulePattern::Hard, rule_spec)
267 }
268
269 pub fn soft(batch_size: usize, rule_spec: String) -> Self {
271 Self::new(batch_size, RulePattern::Soft, rule_spec)
272 }
273
274 pub fn gated(batch_size: usize, rule_spec: String) -> Self {
276 Self::new(batch_size, RulePattern::Gated, rule_spec)
277 }
278}
279
280impl AttentionMask for RuleBasedMask {
281 fn build_mask(&self, graph: &mut EinsumGraph, seq_len: usize) -> Result<usize> {
282 let pattern_name = match self.pattern {
283 RulePattern::Hard => "hard",
284 RulePattern::Soft => "soft",
285 RulePattern::Gated => "gated",
286 };
287
288 let mask_tensor = graph.add_tensor(format!("rule_mask_{}", pattern_name));
289 let mask_node = EinsumNode::elem_unary(
290 format!("rule_mask_{}_{}x{}", pattern_name, seq_len, seq_len),
291 0,
292 mask_tensor,
293 );
294 graph.add_node(mask_node)?;
295 Ok(mask_tensor)
296 }
297
298 fn mask_type(&self) -> &str {
299 match self.pattern {
300 RulePattern::Hard => "rule_hard",
301 RulePattern::Soft => "rule_soft",
302 RulePattern::Gated => "rule_gated",
303 }
304 }
305}
306
307#[cfg(test)]
308mod tests {
309 use super::*;
310
311 #[test]
312 fn test_causal_mask_creation() {
313 let mask = CausalMask::new(4);
314 assert_eq!(mask.batch_size, 4);
315 assert_eq!(mask.mask_type(), "causal");
316 }
317
318 #[test]
319 fn test_causal_mask_build() {
320 let mask = CausalMask::new(4);
321 let mut graph = EinsumGraph::new();
322 let result = mask.build_mask(&mut graph, 10);
323 assert!(result.is_ok());
324 }
325
326 #[test]
327 fn test_local_mask_creation() {
328 let mask = LocalMask::new(4, 3);
329 assert_eq!(mask.batch_size, 4);
330 assert_eq!(mask.window_size, 3);
331 assert_eq!(mask.mask_type(), "local");
332 }
333
334 #[test]
335 fn test_local_mask_build() {
336 let mask = LocalMask::new(4, 5);
337 let mut graph = EinsumGraph::new();
338 let result = mask.build_mask(&mut graph, 20);
339 assert!(result.is_ok());
340 }
341
342 #[test]
343 fn test_strided_mask_creation() {
344 let mask = StridedMask::new(4, 2).unwrap();
345 assert_eq!(mask.batch_size, 4);
346 assert_eq!(mask.stride, 2);
347 assert_eq!(mask.mask_type(), "strided");
348 }
349
350 #[test]
351 fn test_strided_mask_invalid_stride() {
352 let result = StridedMask::new(4, 0);
353 assert!(result.is_err());
354 }
355
356 #[test]
357 fn test_strided_mask_build() {
358 let mask = StridedMask::new(4, 3).unwrap();
359 let mut graph = EinsumGraph::new();
360 let result = mask.build_mask(&mut graph, 15);
361 assert!(result.is_ok());
362 }
363
364 #[test]
365 fn test_block_sparse_mask_creation() {
366 let mask = BlockSparseMask::new(4, 8).unwrap();
367 assert_eq!(mask.batch_size, 4);
368 assert_eq!(mask.block_size, 8);
369 assert_eq!(mask.mask_type(), "block_sparse");
370 }
371
372 #[test]
373 fn test_block_sparse_mask_invalid_size() {
374 let result = BlockSparseMask::new(4, 0);
375 assert!(result.is_err());
376 }
377
378 #[test]
379 fn test_block_sparse_mask_build() {
380 let mask = BlockSparseMask::new(4, 16).unwrap();
381 let mut graph = EinsumGraph::new();
382 let result = mask.build_mask(&mut graph, 64);
383 assert!(result.is_ok());
384 }
385
386 #[test]
387 fn test_global_local_mask_creation() {
388 let mask = GlobalLocalMask::new(4, 2, 5);
389 assert_eq!(mask.batch_size, 4);
390 assert_eq!(mask.num_global_tokens, 2);
391 assert_eq!(mask.local_window, 5);
392 assert_eq!(mask.mask_type(), "global_local");
393 }
394
395 #[test]
396 fn test_global_local_mask_build() {
397 let mask = GlobalLocalMask::new(4, 3, 7);
398 let mut graph = EinsumGraph::new();
399 let result = mask.build_mask(&mut graph, 50);
400 assert!(result.is_ok());
401 }
402
403 #[test]
404 fn test_rule_based_mask_hard() {
405 let mask = RuleBasedMask::hard(4, "entity_type=person".to_string());
406 assert_eq!(mask.batch_size, 4);
407 assert!(matches!(mask.pattern, RulePattern::Hard));
408 assert_eq!(mask.mask_type(), "rule_hard");
409 }
410
411 #[test]
412 fn test_rule_based_mask_soft() {
413 let mask = RuleBasedMask::soft(4, "similarity>0.5".to_string());
414 assert!(matches!(mask.pattern, RulePattern::Soft));
415 assert_eq!(mask.mask_type(), "rule_soft");
416 }
417
418 #[test]
419 fn test_rule_based_mask_gated() {
420 let mask = RuleBasedMask::gated(4, "weighted_rule".to_string());
421 assert!(matches!(mask.pattern, RulePattern::Gated));
422 assert_eq!(mask.mask_type(), "rule_gated");
423 }
424
425 #[test]
426 fn test_rule_based_mask_build() {
427 let mask = RuleBasedMask::hard(4, "test_rule".to_string());
428 let mut graph = EinsumGraph::new();
429 let result = mask.build_mask(&mut graph, 32);
430 assert!(result.is_ok());
431 }
432}