1use serde::{Deserialize, Serialize};
20use tensorlogic_ir::{EinsumGraph, EinsumNode};
21
22use crate::{
23 config::AttentionConfig,
24 error::{Result, TrustformerError},
25};
26
27#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
29pub enum SparsePatternType {
30 Strided { stride: usize },
32 Local { window_size: usize },
34 GlobalLocal {
36 window_size: usize,
37 global_positions: Vec<usize>,
38 },
39 BlockSparse { block_size: usize },
41 Random { num_random: usize },
43}
44
45#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
47pub struct SparseAttentionGraphConfig {
48 pub base_attention: AttentionConfig,
50 pub pattern: SparsePatternType,
52 pub exact_sparse: bool,
54}
55
56impl SparseAttentionGraphConfig {
57 pub fn strided(base_attention: AttentionConfig, stride: usize) -> Result<Self> {
59 if stride == 0 {
60 return Err(TrustformerError::InvalidDimension {
61 expected: 1,
62 got: 0,
63 context: "stride must be positive".to_string(),
64 });
65 }
66
67 Ok(Self {
68 base_attention,
69 pattern: SparsePatternType::Strided { stride },
70 exact_sparse: true,
71 })
72 }
73
74 pub fn local(base_attention: AttentionConfig, window_size: usize) -> Result<Self> {
76 if window_size == 0 {
77 return Err(TrustformerError::InvalidDimension {
78 expected: 1,
79 got: 0,
80 context: "window_size must be positive".to_string(),
81 });
82 }
83
84 Ok(Self {
85 base_attention,
86 pattern: SparsePatternType::Local { window_size },
87 exact_sparse: true,
88 })
89 }
90
91 pub fn global_local(
93 base_attention: AttentionConfig,
94 window_size: usize,
95 global_positions: Vec<usize>,
96 ) -> Result<Self> {
97 if window_size == 0 {
98 return Err(TrustformerError::InvalidDimension {
99 expected: 1,
100 got: 0,
101 context: "window_size must be positive".to_string(),
102 });
103 }
104
105 Ok(Self {
106 base_attention,
107 pattern: SparsePatternType::GlobalLocal {
108 window_size,
109 global_positions,
110 },
111 exact_sparse: true,
112 })
113 }
114
115 pub fn block_sparse(base_attention: AttentionConfig, block_size: usize) -> Result<Self> {
117 if block_size == 0 {
118 return Err(TrustformerError::InvalidDimension {
119 expected: 1,
120 got: 0,
121 context: "block_size must be positive".to_string(),
122 });
123 }
124
125 Ok(Self {
126 base_attention,
127 pattern: SparsePatternType::BlockSparse { block_size },
128 exact_sparse: true,
129 })
130 }
131
132 pub fn with_exact_sparse(mut self, exact_sparse: bool) -> Self {
134 self.exact_sparse = exact_sparse;
135 self
136 }
137
138 pub fn validate(&self) -> Result<()> {
140 self.base_attention.validate()?;
141
142 match &self.pattern {
143 SparsePatternType::Strided { stride } => {
144 if *stride == 0 {
145 return Err(TrustformerError::InvalidDimension {
146 expected: 1,
147 got: 0,
148 context: "stride must be positive".to_string(),
149 });
150 }
151 }
152 SparsePatternType::Local { window_size } => {
153 if *window_size == 0 {
154 return Err(TrustformerError::InvalidDimension {
155 expected: 1,
156 got: 0,
157 context: "window_size must be positive".to_string(),
158 });
159 }
160 }
161 SparsePatternType::GlobalLocal {
162 window_size,
163 global_positions: _,
164 } => {
165 if *window_size == 0 {
166 return Err(TrustformerError::InvalidDimension {
167 expected: 1,
168 got: 0,
169 context: "window_size must be positive".to_string(),
170 });
171 }
172 }
173 SparsePatternType::BlockSparse { block_size } => {
174 if *block_size == 0 {
175 return Err(TrustformerError::InvalidDimension {
176 expected: 1,
177 got: 0,
178 context: "block_size must be positive".to_string(),
179 });
180 }
181 }
182 SparsePatternType::Random { num_random } => {
183 if *num_random == 0 {
184 return Err(TrustformerError::InvalidDimension {
185 expected: 1,
186 got: 0,
187 context: "num_random must be positive".to_string(),
188 });
189 }
190 }
191 }
192
193 Ok(())
194 }
195}
196
197#[derive(Clone, Debug)]
199pub struct SparseAttentionGraph {
200 pub config: SparseAttentionGraphConfig,
202}
203
204impl SparseAttentionGraph {
205 pub fn new(config: SparseAttentionGraphConfig) -> Result<Self> {
207 config.validate()?;
208 Ok(Self { config })
209 }
210
211 pub fn build_sparse_attention_graph(&self, graph: &mut EinsumGraph) -> Result<Vec<usize>> {
222 let scores_tensor = graph.add_tensor("sparse_attn_scores");
224 let scores_node = EinsumNode::new("bqd,bkd->bqk", vec![0, 1], vec![scores_tensor]);
225 graph.add_node(scores_node)?;
226
227 let scale_factor = (self.config.base_attention.d_k as f64).sqrt();
229 let scale_tensor = graph.add_tensor("sparse_scale");
230 let scaled_tensor = graph.add_tensor("sparse_scaled_scores");
231 let scale_node = EinsumNode::elem_binary(
232 format!("div_scalar_{}", scale_factor),
233 scores_tensor,
234 scale_tensor,
235 scaled_tensor,
236 );
237 graph.add_node(scale_node)?;
238
239 let masked_tensor = graph.add_tensor("sparse_masked_scores");
242 let mask_node = EinsumNode::elem_binary("mul", scaled_tensor, 3, masked_tensor);
243 graph.add_node(mask_node)?;
244
245 let softmax_tensor = graph.add_tensor("sparse_attention_weights");
247 let softmax_node =
248 EinsumNode::elem_unary("sparse_softmax_k", masked_tensor, softmax_tensor);
249 graph.add_node(softmax_node)?;
250
251 let output_tensor = graph.add_tensor("sparse_attn_output");
253 let output_node =
254 EinsumNode::new("bqk,bkv->bqv", vec![softmax_tensor, 2], vec![output_tensor]);
255 graph.add_node(output_node)?;
256
257 Ok(vec![output_tensor])
258 }
259
260 pub fn sparsity_factor(&self, seq_len: usize) -> f64 {
262 match &self.config.pattern {
263 SparsePatternType::Strided { stride } => 1.0 / (*stride as f64),
264 SparsePatternType::Local { window_size } => {
265 (*window_size as f64).min(seq_len as f64) / (seq_len as f64)
266 }
267 SparsePatternType::GlobalLocal {
268 window_size,
269 global_positions,
270 } => {
271 let local_fraction = (*window_size as f64) / (seq_len as f64);
272 let global_fraction = (global_positions.len() as f64) / (seq_len as f64);
273 (local_fraction + global_fraction).min(1.0)
274 }
275 SparsePatternType::BlockSparse { block_size } => {
276 (*block_size as f64) / (seq_len as f64)
277 }
278 SparsePatternType::Random { num_random } => (*num_random as f64) / (seq_len as f64),
279 }
280 }
281
282 pub fn pattern_description(&self) -> String {
284 match &self.config.pattern {
285 SparsePatternType::Strided { stride } => {
286 format!("Strided(stride={})", stride)
287 }
288 SparsePatternType::Local { window_size } => {
289 format!("Local(window={})", window_size)
290 }
291 SparsePatternType::GlobalLocal {
292 window_size,
293 global_positions,
294 } => {
295 format!(
296 "GlobalLocal(window={}, global_tokens={})",
297 window_size,
298 global_positions.len()
299 )
300 }
301 SparsePatternType::BlockSparse { block_size } => {
302 format!("BlockSparse(block={})", block_size)
303 }
304 SparsePatternType::Random { num_random } => {
305 format!("Random(k={})", num_random)
306 }
307 }
308 }
309}
310
311#[derive(Clone, Debug)]
316pub struct LocalAttention {
317 pub config: AttentionConfig,
319 pub window_size: usize,
321}
322
323impl LocalAttention {
324 pub fn new(config: AttentionConfig, window_size: usize) -> Result<Self> {
326 config.validate()?;
327
328 if window_size == 0 {
329 return Err(TrustformerError::InvalidDimension {
330 expected: 1,
331 got: 0,
332 context: "window_size must be positive".to_string(),
333 });
334 }
335
336 Ok(Self {
337 config,
338 window_size,
339 })
340 }
341
342 pub fn build_local_attention_graph(&self, graph: &mut EinsumGraph) -> Result<Vec<usize>> {
352 let scores_tensor = graph.add_tensor("local_attn_scores");
354 let scores_node = EinsumNode::new("bqd,bkd->bqk", vec![0, 1], vec![scores_tensor]);
355 graph.add_node(scores_node)?;
356
357 let scale_factor = (self.config.d_k as f64).sqrt();
359 let scale_tensor = graph.add_tensor("local_scale");
360 let scaled_tensor = graph.add_tensor("local_scaled_scores");
361 let scale_node = EinsumNode::elem_binary(
362 format!("div_scalar_{}", scale_factor),
363 scores_tensor,
364 scale_tensor,
365 scaled_tensor,
366 );
367 graph.add_node(scale_node)?;
368
369 let window_mask_tensor = graph.add_tensor("local_window_mask");
372 let masked_tensor = graph.add_tensor("local_masked_scores");
373 let mask_node =
374 EinsumNode::elem_binary("mul", scaled_tensor, window_mask_tensor, masked_tensor);
375 graph.add_node(mask_node)?;
376
377 let softmax_tensor = graph.add_tensor("local_attention_weights");
379 let softmax_node = EinsumNode::elem_unary("softmax_k", masked_tensor, softmax_tensor);
380 graph.add_node(softmax_node)?;
381
382 let output_tensor = graph.add_tensor("local_attn_output");
384 let output_node =
385 EinsumNode::new("bqk,bkv->bqv", vec![softmax_tensor, 2], vec![output_tensor]);
386 graph.add_node(output_node)?;
387
388 Ok(vec![output_tensor])
389 }
390
391 pub fn attention_span(&self) -> usize {
393 2 * self.window_size + 1
394 }
395
396 pub fn memory_savings(&self, seq_len: usize) -> f64 {
398 let full_memory = seq_len * seq_len;
399 let sparse_memory = seq_len * self.attention_span().min(seq_len);
400 1.0 - (sparse_memory as f64 / full_memory as f64)
401 }
402}
403
404#[cfg(test)]
405mod tests {
406 use super::*;
407
408 #[test]
409 fn test_strided_sparse_config() {
410 let base = AttentionConfig::new(512, 8).expect("unwrap");
411 let config = SparseAttentionGraphConfig::strided(base, 4).expect("unwrap");
412 assert!(matches!(
413 config.pattern,
414 SparsePatternType::Strided { stride: 4 }
415 ));
416 assert!(config.validate().is_ok());
417 }
418
419 #[test]
420 fn test_local_sparse_config() {
421 let base = AttentionConfig::new(512, 8).expect("unwrap");
422 let config = SparseAttentionGraphConfig::local(base, 128).expect("unwrap");
423 assert!(matches!(
424 config.pattern,
425 SparsePatternType::Local { window_size: 128 }
426 ));
427 assert!(config.validate().is_ok());
428 }
429
430 #[test]
431 fn test_global_local_sparse_config() {
432 let base = AttentionConfig::new(512, 8).expect("unwrap");
433 let global_positions = vec![0, 1, 2]; let config =
435 SparseAttentionGraphConfig::global_local(base, 64, global_positions).expect("unwrap");
436 assert!(config.validate().is_ok());
437 }
438
439 #[test]
440 fn test_block_sparse_config() {
441 let base = AttentionConfig::new(512, 8).expect("unwrap");
442 let config = SparseAttentionGraphConfig::block_sparse(base, 64).expect("unwrap");
443 assert!(matches!(
444 config.pattern,
445 SparsePatternType::BlockSparse { block_size: 64 }
446 ));
447 assert!(config.validate().is_ok());
448 }
449
450 #[test]
451 fn test_sparse_attention_creation() {
452 let base = AttentionConfig::new(512, 8).expect("unwrap");
453 let config = SparseAttentionGraphConfig::strided(base, 2).expect("unwrap");
454 let attn = SparseAttentionGraph::new(config).expect("unwrap");
455 assert_eq!(attn.sparsity_factor(1024), 0.5);
456 }
457
458 #[test]
459 fn test_sparse_attention_graph_building() {
460 let base = AttentionConfig::new(512, 8).expect("unwrap");
461 let config = SparseAttentionGraphConfig::local(base, 128).expect("unwrap");
462 let attn = SparseAttentionGraph::new(config).expect("unwrap");
463
464 let mut graph = EinsumGraph::new();
465 graph.add_tensor("Q");
466 graph.add_tensor("K");
467 graph.add_tensor("V");
468 graph.add_tensor("sparse_mask");
469
470 let outputs = attn
471 .build_sparse_attention_graph(&mut graph)
472 .expect("unwrap");
473 assert_eq!(outputs.len(), 1);
474 assert!(!graph.nodes.is_empty());
475 }
476
477 #[test]
478 fn test_local_attention_creation() {
479 let config = AttentionConfig::new(512, 8).expect("unwrap");
480 let local = LocalAttention::new(config, 64).expect("unwrap");
481 assert_eq!(local.window_size, 64);
482 assert_eq!(local.attention_span(), 129);
483 }
484
485 #[test]
486 fn test_local_attention_graph_building() {
487 let config = AttentionConfig::new(512, 8).expect("unwrap");
488 let local = LocalAttention::new(config, 64).expect("unwrap");
489
490 let mut graph = EinsumGraph::new();
491 graph.add_tensor("Q");
492 graph.add_tensor("K");
493 graph.add_tensor("V");
494
495 let outputs = local
496 .build_local_attention_graph(&mut graph)
497 .expect("unwrap");
498 assert_eq!(outputs.len(), 1);
499 assert!(!graph.nodes.is_empty());
500 }
501
502 #[test]
503 fn test_sparsity_factors() {
504 let base = AttentionConfig::new(512, 8).expect("unwrap");
505
506 let strided = SparseAttentionGraphConfig::strided(base.clone(), 4).expect("unwrap");
508 let attn = SparseAttentionGraph::new(strided).expect("unwrap");
509 assert!((attn.sparsity_factor(1024) - 0.25).abs() < 1e-10);
510
511 let local = SparseAttentionGraphConfig::local(base, 128).expect("unwrap");
513 let attn = SparseAttentionGraph::new(local).expect("unwrap");
514 assert!((attn.sparsity_factor(1024) - 0.125).abs() < 1e-10);
515 }
516
517 #[test]
518 fn test_memory_savings() {
519 let config = AttentionConfig::new(512, 8).expect("unwrap");
520 let local = LocalAttention::new(config, 64).expect("unwrap");
521
522 let savings = local.memory_savings(1024);
527 assert!(savings > 0.87 && savings < 0.88);
528 }
529
530 #[test]
531 fn test_pattern_descriptions() {
532 let base = AttentionConfig::new(512, 8).expect("unwrap");
533
534 let strided = SparseAttentionGraphConfig::strided(base.clone(), 4).expect("unwrap");
535 let attn = SparseAttentionGraph::new(strided).expect("unwrap");
536 assert_eq!(attn.pattern_description(), "Strided(stride=4)");
537
538 let local = SparseAttentionGraphConfig::local(base, 128).expect("unwrap");
539 let attn = SparseAttentionGraph::new(local).expect("unwrap");
540 assert_eq!(attn.pattern_description(), "Local(window=128)");
541 }
542
543 #[test]
544 fn test_invalid_configs() {
545 let base = AttentionConfig::new(512, 8).expect("unwrap");
546
547 let result = SparseAttentionGraphConfig::strided(base.clone(), 0);
549 assert!(result.is_err());
550
551 let result = SparseAttentionGraphConfig::local(base.clone(), 0);
553 assert!(result.is_err());
554
555 let result = SparseAttentionGraphConfig::block_sparse(base, 0);
557 assert!(result.is_err());
558 }
559}