Skip to main content

tensorlogic_trustformers/
patterns.rs

1//! Rule-based and sparse attention patterns.
2//!
3//! This module provides various attention masking patterns used in transformer
4//! architectures, including causal masks, local attention, strided patterns,
5//! and rule-based attention patterns.
6
7use tensorlogic_ir::{EinsumGraph, EinsumNode};
8
9use crate::error::{Result, TrustformerError};
10
11/// Trait for attention mask patterns
12pub trait AttentionMask {
13    /// Build a mask tensor in the einsum graph
14    ///
15    /// Returns the tensor ID of the mask with shape [batch, seq_len, seq_len]
16    /// where 0.0 = masked (no attention), 1.0 = unmasked (attend)
17    fn build_mask(&self, graph: &mut EinsumGraph, seq_len: usize) -> Result<usize>;
18
19    /// Get mask type name for documentation
20    fn mask_type(&self) -> &str;
21}
22
23/// Causal (autoregressive) attention mask
24///
25/// Prevents positions from attending to subsequent positions.
26/// mask[i, j] = 1 if i >= j, else 0
27#[derive(Clone, Debug)]
28pub struct CausalMask {
29    /// Batch size
30    pub batch_size: usize,
31}
32
33impl CausalMask {
34    /// Create a new causal mask
35    pub fn new(batch_size: usize) -> Self {
36        Self { batch_size }
37    }
38}
39
40impl AttentionMask for CausalMask {
41    fn build_mask(&self, graph: &mut EinsumGraph, seq_len: usize) -> Result<usize> {
42        // Create causal mask: lower triangular matrix
43        let mask_tensor = graph.add_tensor("causal_mask");
44        let mask_node = EinsumNode::elem_unary(
45            format!("causal_mask_{}x{}", seq_len, seq_len),
46            0, // Placeholder for mask generation
47            mask_tensor,
48        );
49        graph.add_node(mask_node)?;
50        Ok(mask_tensor)
51    }
52
53    fn mask_type(&self) -> &str {
54        "causal"
55    }
56}
57
58/// Local (windowed) attention mask
59///
60/// Each position attends only to a fixed window of nearby positions.
61#[derive(Clone, Debug)]
62pub struct LocalMask {
63    /// Batch size
64    pub batch_size: usize,
65    /// Window size (attends to ±window_size positions)
66    pub window_size: usize,
67}
68
69impl LocalMask {
70    /// Create a new local attention mask
71    pub fn new(batch_size: usize, window_size: usize) -> Self {
72        Self {
73            batch_size,
74            window_size,
75        }
76    }
77}
78
79impl AttentionMask for LocalMask {
80    fn build_mask(&self, graph: &mut EinsumGraph, seq_len: usize) -> Result<usize> {
81        let mask_tensor = graph.add_tensor("local_mask");
82        let mask_node = EinsumNode::elem_unary(
83            format!("local_mask_w{}_{}x{}", self.window_size, seq_len, seq_len),
84            0,
85            mask_tensor,
86        );
87        graph.add_node(mask_node)?;
88        Ok(mask_tensor)
89    }
90
91    fn mask_type(&self) -> &str {
92        "local"
93    }
94}
95
96/// Strided attention mask
97///
98/// Attends to every k-th position (used in Sparse Transformers).
99#[derive(Clone, Debug)]
100pub struct StridedMask {
101    /// Batch size
102    pub batch_size: usize,
103    /// Stride length
104    pub stride: usize,
105}
106
107impl StridedMask {
108    /// Create a new strided attention mask
109    pub fn new(batch_size: usize, stride: usize) -> Result<Self> {
110        if stride == 0 {
111            return Err(TrustformerError::InvalidDimension {
112                expected: 1,
113                got: 0,
114                context: "stride must be positive".to_string(),
115            });
116        }
117        Ok(Self { batch_size, stride })
118    }
119}
120
121impl AttentionMask for StridedMask {
122    fn build_mask(&self, graph: &mut EinsumGraph, seq_len: usize) -> Result<usize> {
123        let mask_tensor = graph.add_tensor("strided_mask");
124        let mask_node = EinsumNode::elem_unary(
125            format!("strided_mask_s{}_{}x{}", self.stride, seq_len, seq_len),
126            0,
127            mask_tensor,
128        );
129        graph.add_node(mask_node)?;
130        Ok(mask_tensor)
131    }
132
133    fn mask_type(&self) -> &str {
134        "strided"
135    }
136}
137
138/// Block-sparse attention mask
139///
140/// Divides attention into fixed-size blocks (used in BigBird, Longformer).
141#[derive(Clone, Debug)]
142pub struct BlockSparseMask {
143    /// Batch size
144    pub batch_size: usize,
145    /// Block size
146    pub block_size: usize,
147}
148
149impl BlockSparseMask {
150    /// Create a new block-sparse attention mask
151    pub fn new(batch_size: usize, block_size: usize) -> Result<Self> {
152        if block_size == 0 {
153            return Err(TrustformerError::InvalidDimension {
154                expected: 1,
155                got: 0,
156                context: "block_size must be positive".to_string(),
157            });
158        }
159        Ok(Self {
160            batch_size,
161            block_size,
162        })
163    }
164}
165
166impl AttentionMask for BlockSparseMask {
167    fn build_mask(&self, graph: &mut EinsumGraph, seq_len: usize) -> Result<usize> {
168        let mask_tensor = graph.add_tensor("block_sparse_mask");
169        let mask_node = EinsumNode::elem_unary(
170            format!(
171                "block_sparse_mask_b{}_{}x{}",
172                self.block_size, seq_len, seq_len
173            ),
174            0,
175            mask_tensor,
176        );
177        graph.add_node(mask_node)?;
178        Ok(mask_tensor)
179    }
180
181    fn mask_type(&self) -> &str {
182        "block_sparse"
183    }
184}
185
186/// Global + Local attention mask
187///
188/// Combines global tokens (attend to all) with local windows (Longformer-style).
189#[derive(Clone, Debug)]
190pub struct GlobalLocalMask {
191    /// Batch size
192    pub batch_size: usize,
193    /// Number of global tokens at the start
194    pub num_global_tokens: usize,
195    /// Window size for local attention
196    pub local_window: usize,
197}
198
199impl GlobalLocalMask {
200    /// Create a new global+local attention mask
201    pub fn new(batch_size: usize, num_global_tokens: usize, local_window: usize) -> Self {
202        Self {
203            batch_size,
204            num_global_tokens,
205            local_window,
206        }
207    }
208}
209
210impl AttentionMask for GlobalLocalMask {
211    fn build_mask(&self, graph: &mut EinsumGraph, seq_len: usize) -> Result<usize> {
212        let mask_tensor = graph.add_tensor("global_local_mask");
213        let mask_node = EinsumNode::elem_unary(
214            format!(
215                "global_local_mask_g{}_w{}_{}x{}",
216                self.num_global_tokens, self.local_window, seq_len, seq_len
217            ),
218            0,
219            mask_tensor,
220        );
221        graph.add_node(mask_node)?;
222        Ok(mask_tensor)
223    }
224
225    fn mask_type(&self) -> &str {
226        "global_local"
227    }
228}
229
230/// Rule-based attention pattern
231///
232/// Applies attention based on logical rules (hard, soft, or gated).
233#[derive(Clone, Debug)]
234pub enum RulePattern {
235    /// Hard masking: 0 or 1 based on rule satisfaction
236    Hard,
237    /// Soft masking: continuous values based on rule confidence
238    Soft,
239    /// Gated masking: learnable combination of rule and data-driven attention
240    Gated,
241}
242
243/// Rule-based attention mask
244#[derive(Clone, Debug)]
245pub struct RuleBasedMask {
246    /// Batch size
247    pub batch_size: usize,
248    /// Pattern type
249    pub pattern: RulePattern,
250    /// Rule specification (opaque for now)
251    pub rule_spec: String,
252}
253
254impl RuleBasedMask {
255    /// Create a new rule-based mask
256    pub fn new(batch_size: usize, pattern: RulePattern, rule_spec: String) -> Self {
257        Self {
258            batch_size,
259            pattern,
260            rule_spec,
261        }
262    }
263
264    /// Create a hard rule-based mask
265    pub fn hard(batch_size: usize, rule_spec: String) -> Self {
266        Self::new(batch_size, RulePattern::Hard, rule_spec)
267    }
268
269    /// Create a soft rule-based mask
270    pub fn soft(batch_size: usize, rule_spec: String) -> Self {
271        Self::new(batch_size, RulePattern::Soft, rule_spec)
272    }
273
274    /// Create a gated rule-based mask
275    pub fn gated(batch_size: usize, rule_spec: String) -> Self {
276        Self::new(batch_size, RulePattern::Gated, rule_spec)
277    }
278}
279
280impl AttentionMask for RuleBasedMask {
281    fn build_mask(&self, graph: &mut EinsumGraph, seq_len: usize) -> Result<usize> {
282        let pattern_name = match self.pattern {
283            RulePattern::Hard => "hard",
284            RulePattern::Soft => "soft",
285            RulePattern::Gated => "gated",
286        };
287
288        let mask_tensor = graph.add_tensor(format!("rule_mask_{}", pattern_name));
289        let mask_node = EinsumNode::elem_unary(
290            format!("rule_mask_{}_{}x{}", pattern_name, seq_len, seq_len),
291            0,
292            mask_tensor,
293        );
294        graph.add_node(mask_node)?;
295        Ok(mask_tensor)
296    }
297
298    fn mask_type(&self) -> &str {
299        match self.pattern {
300            RulePattern::Hard => "rule_hard",
301            RulePattern::Soft => "rule_soft",
302            RulePattern::Gated => "rule_gated",
303        }
304    }
305}
306
307#[cfg(test)]
308mod tests {
309    use super::*;
310
311    #[test]
312    fn test_causal_mask_creation() {
313        let mask = CausalMask::new(4);
314        assert_eq!(mask.batch_size, 4);
315        assert_eq!(mask.mask_type(), "causal");
316    }
317
318    #[test]
319    fn test_causal_mask_build() {
320        let mask = CausalMask::new(4);
321        let mut graph = EinsumGraph::new();
322        let result = mask.build_mask(&mut graph, 10);
323        assert!(result.is_ok());
324    }
325
326    #[test]
327    fn test_local_mask_creation() {
328        let mask = LocalMask::new(4, 3);
329        assert_eq!(mask.batch_size, 4);
330        assert_eq!(mask.window_size, 3);
331        assert_eq!(mask.mask_type(), "local");
332    }
333
334    #[test]
335    fn test_local_mask_build() {
336        let mask = LocalMask::new(4, 5);
337        let mut graph = EinsumGraph::new();
338        let result = mask.build_mask(&mut graph, 20);
339        assert!(result.is_ok());
340    }
341
342    #[test]
343    fn test_strided_mask_creation() {
344        let mask = StridedMask::new(4, 2).unwrap();
345        assert_eq!(mask.batch_size, 4);
346        assert_eq!(mask.stride, 2);
347        assert_eq!(mask.mask_type(), "strided");
348    }
349
350    #[test]
351    fn test_strided_mask_invalid_stride() {
352        let result = StridedMask::new(4, 0);
353        assert!(result.is_err());
354    }
355
356    #[test]
357    fn test_strided_mask_build() {
358        let mask = StridedMask::new(4, 3).unwrap();
359        let mut graph = EinsumGraph::new();
360        let result = mask.build_mask(&mut graph, 15);
361        assert!(result.is_ok());
362    }
363
364    #[test]
365    fn test_block_sparse_mask_creation() {
366        let mask = BlockSparseMask::new(4, 8).unwrap();
367        assert_eq!(mask.batch_size, 4);
368        assert_eq!(mask.block_size, 8);
369        assert_eq!(mask.mask_type(), "block_sparse");
370    }
371
372    #[test]
373    fn test_block_sparse_mask_invalid_size() {
374        let result = BlockSparseMask::new(4, 0);
375        assert!(result.is_err());
376    }
377
378    #[test]
379    fn test_block_sparse_mask_build() {
380        let mask = BlockSparseMask::new(4, 16).unwrap();
381        let mut graph = EinsumGraph::new();
382        let result = mask.build_mask(&mut graph, 64);
383        assert!(result.is_ok());
384    }
385
386    #[test]
387    fn test_global_local_mask_creation() {
388        let mask = GlobalLocalMask::new(4, 2, 5);
389        assert_eq!(mask.batch_size, 4);
390        assert_eq!(mask.num_global_tokens, 2);
391        assert_eq!(mask.local_window, 5);
392        assert_eq!(mask.mask_type(), "global_local");
393    }
394
395    #[test]
396    fn test_global_local_mask_build() {
397        let mask = GlobalLocalMask::new(4, 3, 7);
398        let mut graph = EinsumGraph::new();
399        let result = mask.build_mask(&mut graph, 50);
400        assert!(result.is_ok());
401    }
402
403    #[test]
404    fn test_rule_based_mask_hard() {
405        let mask = RuleBasedMask::hard(4, "entity_type=person".to_string());
406        assert_eq!(mask.batch_size, 4);
407        assert!(matches!(mask.pattern, RulePattern::Hard));
408        assert_eq!(mask.mask_type(), "rule_hard");
409    }
410
411    #[test]
412    fn test_rule_based_mask_soft() {
413        let mask = RuleBasedMask::soft(4, "similarity>0.5".to_string());
414        assert!(matches!(mask.pattern, RulePattern::Soft));
415        assert_eq!(mask.mask_type(), "rule_soft");
416    }
417
418    #[test]
419    fn test_rule_based_mask_gated() {
420        let mask = RuleBasedMask::gated(4, "weighted_rule".to_string());
421        assert!(matches!(mask.pattern, RulePattern::Gated));
422        assert_eq!(mask.mask_type(), "rule_gated");
423    }
424
425    #[test]
426    fn test_rule_based_mask_build() {
427        let mask = RuleBasedMask::hard(4, "test_rule".to_string());
428        let mut graph = EinsumGraph::new();
429        let result = mask.build_mask(&mut graph, 32);
430        assert!(result.is_ok());
431    }
432}