Skip to main content

tensorlogic_trustformers/sparse_attention/
graph.rs

1//! Sparse attention patterns for efficient long-sequence processing.
2//!
3//! This module implements various sparse attention mechanisms that reduce
4//! the quadratic complexity of full attention by only attending to a subset
5//! of positions.
6//!
7//! ## Sparse Attention Types
8//!
9//! 1. **Fixed Pattern**: Pre-defined sparse patterns (strided, local, global)
10//! 2. **Random**: Randomly sample attention positions
11//! 3. **Learned**: Learn which positions to attend to
12//! 4. **Hybrid**: Combine multiple sparse patterns
13//!
14//! ## Complexity
15//!
16//! - Full attention: O(n²) memory and compute
17//! - Sparse attention: O(n·k) where k << n is the sparsity factor
18
19use serde::{Deserialize, Serialize};
20use tensorlogic_ir::{EinsumGraph, EinsumNode};
21
22use crate::{
23    config::AttentionConfig,
24    error::{Result, TrustformerError},
25};
26
27/// Type of sparse attention pattern
28#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
29pub enum SparsePatternType {
30    /// Strided pattern: attend every k-th position
31    Strided { stride: usize },
32    /// Local window: attend to nearby positions
33    Local { window_size: usize },
34    /// Global + local: some positions attend globally, others locally
35    GlobalLocal {
36        window_size: usize,
37        global_positions: Vec<usize>,
38    },
39    /// Block sparse: divide sequence into blocks
40    BlockSparse { block_size: usize },
41    /// Random: randomly sample k positions per query
42    Random { num_random: usize },
43}
44
45/// Configuration for sparse attention graph building
46#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
47pub struct SparseAttentionGraphConfig {
48    /// Base attention configuration
49    pub base_attention: AttentionConfig,
50    /// Sparse pattern type
51    pub pattern: SparsePatternType,
52    /// Whether to use exact sparse computation or approximation
53    pub exact_sparse: bool,
54}
55
56impl SparseAttentionGraphConfig {
57    /// Create a new strided sparse attention configuration
58    pub fn strided(base_attention: AttentionConfig, stride: usize) -> Result<Self> {
59        if stride == 0 {
60            return Err(TrustformerError::InvalidDimension {
61                expected: 1,
62                got: 0,
63                context: "stride must be positive".to_string(),
64            });
65        }
66
67        Ok(Self {
68            base_attention,
69            pattern: SparsePatternType::Strided { stride },
70            exact_sparse: true,
71        })
72    }
73
74    /// Create a new local window sparse attention configuration
75    pub fn local(base_attention: AttentionConfig, window_size: usize) -> Result<Self> {
76        if window_size == 0 {
77            return Err(TrustformerError::InvalidDimension {
78                expected: 1,
79                got: 0,
80                context: "window_size must be positive".to_string(),
81            });
82        }
83
84        Ok(Self {
85            base_attention,
86            pattern: SparsePatternType::Local { window_size },
87            exact_sparse: true,
88        })
89    }
90
91    /// Create a new global-local sparse attention configuration
92    pub fn global_local(
93        base_attention: AttentionConfig,
94        window_size: usize,
95        global_positions: Vec<usize>,
96    ) -> Result<Self> {
97        if window_size == 0 {
98            return Err(TrustformerError::InvalidDimension {
99                expected: 1,
100                got: 0,
101                context: "window_size must be positive".to_string(),
102            });
103        }
104
105        Ok(Self {
106            base_attention,
107            pattern: SparsePatternType::GlobalLocal {
108                window_size,
109                global_positions,
110            },
111            exact_sparse: true,
112        })
113    }
114
115    /// Create a new block sparse attention configuration
116    pub fn block_sparse(base_attention: AttentionConfig, block_size: usize) -> Result<Self> {
117        if block_size == 0 {
118            return Err(TrustformerError::InvalidDimension {
119                expected: 1,
120                got: 0,
121                context: "block_size must be positive".to_string(),
122            });
123        }
124
125        Ok(Self {
126            base_attention,
127            pattern: SparsePatternType::BlockSparse { block_size },
128            exact_sparse: true,
129        })
130    }
131
132    /// Set whether to use exact sparse computation
133    pub fn with_exact_sparse(mut self, exact_sparse: bool) -> Self {
134        self.exact_sparse = exact_sparse;
135        self
136    }
137
138    /// Validate configuration
139    pub fn validate(&self) -> Result<()> {
140        self.base_attention.validate()?;
141
142        match &self.pattern {
143            SparsePatternType::Strided { stride } => {
144                if *stride == 0 {
145                    return Err(TrustformerError::InvalidDimension {
146                        expected: 1,
147                        got: 0,
148                        context: "stride must be positive".to_string(),
149                    });
150                }
151            }
152            SparsePatternType::Local { window_size } => {
153                if *window_size == 0 {
154                    return Err(TrustformerError::InvalidDimension {
155                        expected: 1,
156                        got: 0,
157                        context: "window_size must be positive".to_string(),
158                    });
159                }
160            }
161            SparsePatternType::GlobalLocal {
162                window_size,
163                global_positions: _,
164            } => {
165                if *window_size == 0 {
166                    return Err(TrustformerError::InvalidDimension {
167                        expected: 1,
168                        got: 0,
169                        context: "window_size must be positive".to_string(),
170                    });
171                }
172            }
173            SparsePatternType::BlockSparse { block_size } => {
174                if *block_size == 0 {
175                    return Err(TrustformerError::InvalidDimension {
176                        expected: 1,
177                        got: 0,
178                        context: "block_size must be positive".to_string(),
179                    });
180                }
181            }
182            SparsePatternType::Random { num_random } => {
183                if *num_random == 0 {
184                    return Err(TrustformerError::InvalidDimension {
185                        expected: 1,
186                        got: 0,
187                        context: "num_random must be positive".to_string(),
188                    });
189                }
190            }
191        }
192
193        Ok(())
194    }
195}
196
197/// Sparse attention graph-building component
198#[derive(Clone, Debug)]
199pub struct SparseAttentionGraph {
200    /// Configuration
201    pub config: SparseAttentionGraphConfig,
202}
203
204impl SparseAttentionGraph {
205    /// Create a new sparse attention component
206    pub fn new(config: SparseAttentionGraphConfig) -> Result<Self> {
207        config.validate()?;
208        Ok(Self { config })
209    }
210
211    /// Build einsum graph for sparse attention
212    ///
213    /// Input tensors:
214    /// - 0: Q (query) [batch, seq_len, d_model]
215    /// - 1: K (key) [batch, seq_len, d_model]
216    /// - 2: V (value) [batch, seq_len, d_model]
217    /// - 3: sparse_mask [batch, seq_q, seq_k] (sparse pattern mask)
218    ///
219    /// Output tensors:
220    /// - output: [batch, seq_len, d_model]
221    pub fn build_sparse_attention_graph(&self, graph: &mut EinsumGraph) -> Result<Vec<usize>> {
222        // Step 1: Compute attention scores
223        let scores_tensor = graph.add_tensor("sparse_attn_scores");
224        let scores_node = EinsumNode::new("bqd,bkd->bqk", vec![0, 1], vec![scores_tensor]);
225        graph.add_node(scores_node)?;
226
227        // Step 2: Scale scores
228        let scale_factor = (self.config.base_attention.d_k as f64).sqrt();
229        let scale_tensor = graph.add_tensor("sparse_scale");
230        let scaled_tensor = graph.add_tensor("sparse_scaled_scores");
231        let scale_node = EinsumNode::elem_binary(
232            format!("div_scalar_{}", scale_factor),
233            scores_tensor,
234            scale_tensor,
235            scaled_tensor,
236        );
237        graph.add_node(scale_node)?;
238
239        // Step 3: Apply sparse mask
240        // Mask out positions not in the sparse pattern
241        let masked_tensor = graph.add_tensor("sparse_masked_scores");
242        let mask_node = EinsumNode::elem_binary("mul", scaled_tensor, 3, masked_tensor);
243        graph.add_node(mask_node)?;
244
245        // Step 4: Apply softmax (only over non-masked positions)
246        let softmax_tensor = graph.add_tensor("sparse_attention_weights");
247        let softmax_node =
248            EinsumNode::elem_unary("sparse_softmax_k", masked_tensor, softmax_tensor);
249        graph.add_node(softmax_node)?;
250
251        // Step 5: Apply attention to values
252        let output_tensor = graph.add_tensor("sparse_attn_output");
253        let output_node =
254            EinsumNode::new("bqk,bkv->bqv", vec![softmax_tensor, 2], vec![output_tensor]);
255        graph.add_node(output_node)?;
256
257        Ok(vec![output_tensor])
258    }
259
260    /// Get the sparsity factor (approximate percentage of attended positions)
261    pub fn sparsity_factor(&self, seq_len: usize) -> f64 {
262        match &self.config.pattern {
263            SparsePatternType::Strided { stride } => 1.0 / (*stride as f64),
264            SparsePatternType::Local { window_size } => {
265                (*window_size as f64).min(seq_len as f64) / (seq_len as f64)
266            }
267            SparsePatternType::GlobalLocal {
268                window_size,
269                global_positions,
270            } => {
271                let local_fraction = (*window_size as f64) / (seq_len as f64);
272                let global_fraction = (global_positions.len() as f64) / (seq_len as f64);
273                (local_fraction + global_fraction).min(1.0)
274            }
275            SparsePatternType::BlockSparse { block_size } => {
276                (*block_size as f64) / (seq_len as f64)
277            }
278            SparsePatternType::Random { num_random } => (*num_random as f64) / (seq_len as f64),
279        }
280    }
281
282    /// Get pattern description
283    pub fn pattern_description(&self) -> String {
284        match &self.config.pattern {
285            SparsePatternType::Strided { stride } => {
286                format!("Strided(stride={})", stride)
287            }
288            SparsePatternType::Local { window_size } => {
289                format!("Local(window={})", window_size)
290            }
291            SparsePatternType::GlobalLocal {
292                window_size,
293                global_positions,
294            } => {
295                format!(
296                    "GlobalLocal(window={}, global_tokens={})",
297                    window_size,
298                    global_positions.len()
299                )
300            }
301            SparsePatternType::BlockSparse { block_size } => {
302                format!("BlockSparse(block={})", block_size)
303            }
304            SparsePatternType::Random { num_random } => {
305                format!("Random(k={})", num_random)
306            }
307        }
308    }
309}
310
311/// Local attention (windowed attention)
312///
313/// Each query only attends to keys within a fixed window.
314/// This is a special case of sparse attention optimized for efficiency.
315#[derive(Clone, Debug)]
316pub struct LocalAttention {
317    /// Configuration
318    pub config: AttentionConfig,
319    /// Window size (attend to positions within ±window_size)
320    pub window_size: usize,
321}
322
323impl LocalAttention {
324    /// Create a new local attention component
325    pub fn new(config: AttentionConfig, window_size: usize) -> Result<Self> {
326        config.validate()?;
327
328        if window_size == 0 {
329            return Err(TrustformerError::InvalidDimension {
330                expected: 1,
331                got: 0,
332                context: "window_size must be positive".to_string(),
333            });
334        }
335
336        Ok(Self {
337            config,
338            window_size,
339        })
340    }
341
342    /// Build einsum graph for local attention
343    ///
344    /// Input tensors:
345    /// - 0: Q (query) [batch, seq_len, d_model]
346    /// - 1: K (key) [batch, seq_len, d_model]
347    /// - 2: V (value) [batch, seq_len, d_model]
348    ///
349    /// Output tensors:
350    /// - output: [batch, seq_len, d_model]
351    pub fn build_local_attention_graph(&self, graph: &mut EinsumGraph) -> Result<Vec<usize>> {
352        // Step 1: Compute attention scores
353        let scores_tensor = graph.add_tensor("local_attn_scores");
354        let scores_node = EinsumNode::new("bqd,bkd->bqk", vec![0, 1], vec![scores_tensor]);
355        graph.add_node(scores_node)?;
356
357        // Step 2: Scale scores
358        let scale_factor = (self.config.d_k as f64).sqrt();
359        let scale_tensor = graph.add_tensor("local_scale");
360        let scaled_tensor = graph.add_tensor("local_scaled_scores");
361        let scale_node = EinsumNode::elem_binary(
362            format!("div_scalar_{}", scale_factor),
363            scores_tensor,
364            scale_tensor,
365            scaled_tensor,
366        );
367        graph.add_node(scale_node)?;
368
369        // Step 3: Apply local window mask
370        // Mask is generated based on position distance: |i - j| <= window_size
371        let window_mask_tensor = graph.add_tensor("local_window_mask");
372        let masked_tensor = graph.add_tensor("local_masked_scores");
373        let mask_node =
374            EinsumNode::elem_binary("mul", scaled_tensor, window_mask_tensor, masked_tensor);
375        graph.add_node(mask_node)?;
376
377        // Step 4: Apply softmax
378        let softmax_tensor = graph.add_tensor("local_attention_weights");
379        let softmax_node = EinsumNode::elem_unary("softmax_k", masked_tensor, softmax_tensor);
380        graph.add_node(softmax_node)?;
381
382        // Step 5: Apply attention to values
383        let output_tensor = graph.add_tensor("local_attn_output");
384        let output_node =
385            EinsumNode::new("bqk,bkv->bqv", vec![softmax_tensor, 2], vec![output_tensor]);
386        graph.add_node(output_node)?;
387
388        Ok(vec![output_tensor])
389    }
390
391    /// Get effective attention span
392    pub fn attention_span(&self) -> usize {
393        2 * self.window_size + 1
394    }
395
396    /// Calculate memory savings compared to full attention
397    pub fn memory_savings(&self, seq_len: usize) -> f64 {
398        let full_memory = seq_len * seq_len;
399        let sparse_memory = seq_len * self.attention_span().min(seq_len);
400        1.0 - (sparse_memory as f64 / full_memory as f64)
401    }
402}
403
404#[cfg(test)]
405mod tests {
406    use super::*;
407
408    #[test]
409    fn test_strided_sparse_config() {
410        let base = AttentionConfig::new(512, 8).expect("unwrap");
411        let config = SparseAttentionGraphConfig::strided(base, 4).expect("unwrap");
412        assert!(matches!(
413            config.pattern,
414            SparsePatternType::Strided { stride: 4 }
415        ));
416        assert!(config.validate().is_ok());
417    }
418
419    #[test]
420    fn test_local_sparse_config() {
421        let base = AttentionConfig::new(512, 8).expect("unwrap");
422        let config = SparseAttentionGraphConfig::local(base, 128).expect("unwrap");
423        assert!(matches!(
424            config.pattern,
425            SparsePatternType::Local { window_size: 128 }
426        ));
427        assert!(config.validate().is_ok());
428    }
429
430    #[test]
431    fn test_global_local_sparse_config() {
432        let base = AttentionConfig::new(512, 8).expect("unwrap");
433        let global_positions = vec![0, 1, 2]; // First 3 tokens attend globally
434        let config =
435            SparseAttentionGraphConfig::global_local(base, 64, global_positions).expect("unwrap");
436        assert!(config.validate().is_ok());
437    }
438
439    #[test]
440    fn test_block_sparse_config() {
441        let base = AttentionConfig::new(512, 8).expect("unwrap");
442        let config = SparseAttentionGraphConfig::block_sparse(base, 64).expect("unwrap");
443        assert!(matches!(
444            config.pattern,
445            SparsePatternType::BlockSparse { block_size: 64 }
446        ));
447        assert!(config.validate().is_ok());
448    }
449
450    #[test]
451    fn test_sparse_attention_creation() {
452        let base = AttentionConfig::new(512, 8).expect("unwrap");
453        let config = SparseAttentionGraphConfig::strided(base, 2).expect("unwrap");
454        let attn = SparseAttentionGraph::new(config).expect("unwrap");
455        assert_eq!(attn.sparsity_factor(1024), 0.5);
456    }
457
458    #[test]
459    fn test_sparse_attention_graph_building() {
460        let base = AttentionConfig::new(512, 8).expect("unwrap");
461        let config = SparseAttentionGraphConfig::local(base, 128).expect("unwrap");
462        let attn = SparseAttentionGraph::new(config).expect("unwrap");
463
464        let mut graph = EinsumGraph::new();
465        graph.add_tensor("Q");
466        graph.add_tensor("K");
467        graph.add_tensor("V");
468        graph.add_tensor("sparse_mask");
469
470        let outputs = attn
471            .build_sparse_attention_graph(&mut graph)
472            .expect("unwrap");
473        assert_eq!(outputs.len(), 1);
474        assert!(!graph.nodes.is_empty());
475    }
476
477    #[test]
478    fn test_local_attention_creation() {
479        let config = AttentionConfig::new(512, 8).expect("unwrap");
480        let local = LocalAttention::new(config, 64).expect("unwrap");
481        assert_eq!(local.window_size, 64);
482        assert_eq!(local.attention_span(), 129);
483    }
484
485    #[test]
486    fn test_local_attention_graph_building() {
487        let config = AttentionConfig::new(512, 8).expect("unwrap");
488        let local = LocalAttention::new(config, 64).expect("unwrap");
489
490        let mut graph = EinsumGraph::new();
491        graph.add_tensor("Q");
492        graph.add_tensor("K");
493        graph.add_tensor("V");
494
495        let outputs = local
496            .build_local_attention_graph(&mut graph)
497            .expect("unwrap");
498        assert_eq!(outputs.len(), 1);
499        assert!(!graph.nodes.is_empty());
500    }
501
502    #[test]
503    fn test_sparsity_factors() {
504        let base = AttentionConfig::new(512, 8).expect("unwrap");
505
506        // Strided: 1/stride
507        let strided = SparseAttentionGraphConfig::strided(base.clone(), 4).expect("unwrap");
508        let attn = SparseAttentionGraph::new(strided).expect("unwrap");
509        assert!((attn.sparsity_factor(1024) - 0.25).abs() < 1e-10);
510
511        // Local: window/seq_len
512        let local = SparseAttentionGraphConfig::local(base, 128).expect("unwrap");
513        let attn = SparseAttentionGraph::new(local).expect("unwrap");
514        assert!((attn.sparsity_factor(1024) - 0.125).abs() < 1e-10);
515    }
516
517    #[test]
518    fn test_memory_savings() {
519        let config = AttentionConfig::new(512, 8).expect("unwrap");
520        let local = LocalAttention::new(config, 64).expect("unwrap");
521
522        // For seq_len=1024, window=64
523        // Full: 1024*1024 = 1,048,576
524        // Sparse: 1024*129 = 132,096
525        // Savings: ~87.4%
526        let savings = local.memory_savings(1024);
527        assert!(savings > 0.87 && savings < 0.88);
528    }
529
530    #[test]
531    fn test_pattern_descriptions() {
532        let base = AttentionConfig::new(512, 8).expect("unwrap");
533
534        let strided = SparseAttentionGraphConfig::strided(base.clone(), 4).expect("unwrap");
535        let attn = SparseAttentionGraph::new(strided).expect("unwrap");
536        assert_eq!(attn.pattern_description(), "Strided(stride=4)");
537
538        let local = SparseAttentionGraphConfig::local(base, 128).expect("unwrap");
539        let attn = SparseAttentionGraph::new(local).expect("unwrap");
540        assert_eq!(attn.pattern_description(), "Local(window=128)");
541    }
542
543    #[test]
544    fn test_invalid_configs() {
545        let base = AttentionConfig::new(512, 8).expect("unwrap");
546
547        // Zero stride
548        let result = SparseAttentionGraphConfig::strided(base.clone(), 0);
549        assert!(result.is_err());
550
551        // Zero window
552        let result = SparseAttentionGraphConfig::local(base.clone(), 0);
553        assert!(result.is_err());
554
555        // Zero block size
556        let result = SparseAttentionGraphConfig::block_sparse(base, 0);
557        assert!(result.is_err());
558    }
559}