Skip to main content

trustformers_models/
sparse_attention.rs

1//! Sparse Attention Patterns Library
2//!
3//! This module provides efficient sparse attention implementations that reduce
4//! the quadratic complexity of standard attention mechanisms. Sparse attention
5//! patterns are particularly useful for long sequences and memory-constrained
6//! scenarios.
7//!
8//! # Overview
9//!
10//! The library includes several sparse attention patterns:
11//!
12//! - **Local Attention**: Attention restricted to local windows
13//! - **Strided Attention**: Attention with fixed stride patterns
14//! - **Dilated Attention**: Attention with increasing dilation factors
15//! - **Random Attention**: Attention with random sparse patterns
16//! - **Block Sparse Attention**: Attention using block-wise sparsity (BigBird style)
17//! - **Longformer Attention**: Sliding window + global attention
18//! - **Linformer Attention**: Low-rank projection for linear complexity
19//! - **Reformer Attention**: LSH-based attention for efficient similarity search
20//!
21//! # Example
22//!
23//! ```no_run
24//! use trustformers_models::sparse_attention::{SparseAttention, SparseAttentionConfig, SparsePattern};
25//! use trustformers_core::tensor::Tensor;
26//!
27//! // Create sparse attention with local window pattern
28//! let config = SparseAttentionConfig::new()
29//!     .with_pattern(SparsePattern::Local { window_size: 64 })
30//!     .with_hidden_size(768)
31//!     .with_num_heads(12);
32//!
33//! let attention = SparseAttention::new(config)?;
34//! let input = Tensor::randn(&[2, 512, 768])?;
35//! let output = attention.forward(input)?;
36//! ```
37
38use scirs2_core::Array2; // SciRS2 Integration Policy
39use std::collections::HashMap;
40use trustformers_core::errors::{tensor_op_error, Result};
41use trustformers_core::layers::AttentionInput;
42use trustformers_core::tensor::Tensor;
43use trustformers_core::traits::Layer;
44
45/// Configuration for sparse attention patterns
46#[derive(Debug, Clone)]
47pub struct SparseAttentionConfig {
48    pub hidden_size: usize,
49    pub num_heads: usize,
50    pub dropout_prob: f32,
51    pub pattern: SparsePattern,
52    pub max_sequence_length: usize,
53    pub block_size: usize,
54    pub use_cache: bool,
55    pub attention_scale: Option<f32>,
56}
57
58impl Default for SparseAttentionConfig {
59    fn default() -> Self {
60        Self::new()
61    }
62}
63
64impl SparseAttentionConfig {
65    pub fn new() -> Self {
66        Self {
67            hidden_size: 768,
68            num_heads: 12,
69            dropout_prob: 0.1,
70            pattern: SparsePattern::Local { window_size: 128 },
71            max_sequence_length: 4096,
72            block_size: 64,
73            use_cache: true,
74            attention_scale: None,
75        }
76    }
77
78    pub fn with_pattern(mut self, pattern: SparsePattern) -> Self {
79        self.pattern = pattern;
80        self
81    }
82
83    pub fn with_hidden_size(mut self, hidden_size: usize) -> Self {
84        self.hidden_size = hidden_size;
85        self
86    }
87
88    pub fn with_num_heads(mut self, num_heads: usize) -> Self {
89        self.num_heads = num_heads;
90        self
91    }
92
93    pub fn with_dropout(mut self, dropout_prob: f32) -> Self {
94        self.dropout_prob = dropout_prob;
95        self
96    }
97
98    pub fn with_max_length(mut self, max_sequence_length: usize) -> Self {
99        self.max_sequence_length = max_sequence_length;
100        self
101    }
102
103    pub fn with_block_size(mut self, block_size: usize) -> Self {
104        self.block_size = block_size;
105        self
106    }
107}
108
109/// Sparse attention pattern types
110#[derive(Debug, Clone)]
111pub enum SparsePattern {
112    /// Local sliding window attention
113    Local { window_size: usize },
114    /// Strided attention with fixed stride
115    Strided { stride: usize, window_size: usize },
116    /// Dilated attention with increasing dilation
117    Dilated {
118        max_dilation: usize,
119        window_size: usize,
120    },
121    /// Random sparse attention
122    Random { sparsity_ratio: f32 },
123    /// Block sparse attention (BigBird style)
124    BlockSparse {
125        block_size: usize,
126        global_blocks: usize,
127        random_blocks: usize,
128    },
129    /// Longformer-style attention (sliding window + global)
130    Longformer {
131        window_size: usize,
132        global_tokens: Vec<usize>,
133    },
134    /// Linformer-style linear attention
135    Linformer { projection_dim: usize },
136    /// Reformer-style LSH attention
137    Reformer {
138        num_hashes: usize,
139        bucket_size: usize,
140    },
141    /// Custom sparse pattern with explicit mask
142    Custom { mask: SparseAttentionMask },
143}
144
145/// Sparse attention mask representation
146#[derive(Debug, Clone)]
147pub struct SparseAttentionMask {
148    pub indices: Vec<(usize, usize)>, // (row, col) pairs for non-zero entries
149    pub values: Vec<f32>,             // Values for non-zero entries
150    pub shape: (usize, usize),        // (seq_len, seq_len)
151}
152
153impl SparseAttentionMask {
154    pub fn new(shape: (usize, usize)) -> Self {
155        Self {
156            indices: Vec::new(),
157            values: Vec::new(),
158            shape,
159        }
160    }
161
162    pub fn add_entry(&mut self, row: usize, col: usize, value: f32) {
163        if row < self.shape.0 && col < self.shape.1 {
164            self.indices.push((row, col));
165            self.values.push(value);
166        }
167    }
168
169    pub fn to_dense(&self) -> Vec<Vec<f32>> {
170        let mut dense = vec![vec![f32::NEG_INFINITY; self.shape.1]; self.shape.0];
171        for (i, &(row, col)) in self.indices.iter().enumerate() {
172            dense[row][col] = self.values[i];
173        }
174        dense
175    }
176
177    pub fn sparsity(&self) -> f32 {
178        let total_elements = self.shape.0 * self.shape.1;
179        let nonzero_elements = self.indices.len();
180        1.0 - (nonzero_elements as f32 / total_elements as f32)
181    }
182}
183
184/// Main sparse attention implementation
185#[derive(Debug, Clone)]
186pub struct SparseAttention {
187    config: SparseAttentionConfig,
188    query_projection: trustformers_core::layers::Linear,
189    key_projection: trustformers_core::layers::Linear,
190    value_projection: trustformers_core::layers::Linear,
191    output_projection: trustformers_core::layers::Linear,
192    #[allow(dead_code)]
193    head_dim: usize,
194    scale: f32,
195    #[allow(dead_code)]
196    mask_cache: HashMap<usize, SparseAttentionMask>,
197}
198
199impl SparseAttention {
200    pub fn new(config: SparseAttentionConfig) -> Result<Self> {
201        let head_dim = config.hidden_size / config.num_heads;
202        let scale = config.attention_scale.unwrap_or(1.0 / (head_dim as f32).sqrt());
203
204        Ok(Self {
205            query_projection: trustformers_core::layers::Linear::new(
206                config.hidden_size,
207                config.hidden_size,
208                false,
209            ),
210            key_projection: trustformers_core::layers::Linear::new(
211                config.hidden_size,
212                config.hidden_size,
213                false,
214            ),
215            value_projection: trustformers_core::layers::Linear::new(
216                config.hidden_size,
217                config.hidden_size,
218                false,
219            ),
220            output_projection: trustformers_core::layers::Linear::new(
221                config.hidden_size,
222                config.hidden_size,
223                false,
224            ),
225            head_dim,
226            scale,
227            mask_cache: HashMap::new(),
228            config,
229        })
230    }
231
232    /// Generate sparse attention mask based on the configured pattern
233    pub fn generate_mask(&self, sequence_length: usize) -> Result<SparseAttentionMask> {
234        match &self.config.pattern {
235            SparsePattern::Local { window_size } => {
236                self.generate_local_mask(sequence_length, *window_size)
237            },
238            SparsePattern::Strided {
239                stride,
240                window_size,
241            } => self.generate_strided_mask(sequence_length, *stride, *window_size),
242            SparsePattern::Dilated {
243                max_dilation,
244                window_size,
245            } => self.generate_dilated_mask(sequence_length, *max_dilation, *window_size),
246            SparsePattern::Random { sparsity_ratio } => {
247                self.generate_random_mask(sequence_length, *sparsity_ratio)
248            },
249            SparsePattern::BlockSparse {
250                block_size,
251                global_blocks,
252                random_blocks,
253            } => self.generate_block_sparse_mask(
254                sequence_length,
255                *block_size,
256                *global_blocks,
257                *random_blocks,
258            ),
259            SparsePattern::Longformer {
260                window_size,
261                global_tokens,
262            } => self.generate_longformer_mask(sequence_length, *window_size, global_tokens),
263            SparsePattern::Linformer { projection_dim } => {
264                self.generate_linformer_mask(sequence_length, *projection_dim)
265            },
266            SparsePattern::Reformer {
267                num_hashes,
268                bucket_size,
269            } => self.generate_reformer_mask(sequence_length, *num_hashes, *bucket_size),
270            SparsePattern::Custom { mask } => Ok(mask.clone()),
271        }
272    }
273
274    fn generate_local_mask(
275        &self,
276        seq_len: usize,
277        window_size: usize,
278    ) -> Result<SparseAttentionMask> {
279        let mut mask = SparseAttentionMask::new((seq_len, seq_len));
280
281        for i in 0..seq_len {
282            let start = i.saturating_sub(window_size / 2);
283            let end = (i + window_size / 2 + 1).min(seq_len);
284
285            for j in start..end {
286                mask.add_entry(i, j, 0.0);
287            }
288        }
289
290        Ok(mask)
291    }
292
293    fn generate_strided_mask(
294        &self,
295        seq_len: usize,
296        stride: usize,
297        window_size: usize,
298    ) -> Result<SparseAttentionMask> {
299        let mut mask = SparseAttentionMask::new((seq_len, seq_len));
300
301        for i in 0..seq_len {
302            // Local window
303            let start = i.saturating_sub(window_size / 2);
304            let end = (i + window_size / 2 + 1).min(seq_len);
305
306            for j in start..end {
307                mask.add_entry(i, j, 0.0);
308            }
309
310            // Strided connections
311            let mut pos = i;
312            while pos < seq_len {
313                mask.add_entry(i, pos, 0.0);
314                pos += stride;
315            }
316
317            if i >= stride {
318                let mut pos = i - stride;
319                loop {
320                    mask.add_entry(i, pos, 0.0);
321                    if pos < stride {
322                        break;
323                    }
324                    pos -= stride;
325                }
326            }
327        }
328
329        Ok(mask)
330    }
331
332    fn generate_dilated_mask(
333        &self,
334        seq_len: usize,
335        max_dilation: usize,
336        window_size: usize,
337    ) -> Result<SparseAttentionMask> {
338        let mut mask = SparseAttentionMask::new((seq_len, seq_len));
339
340        for i in 0..seq_len {
341            for dilation in 1..=max_dilation {
342                let start = i.saturating_sub(window_size * dilation / 2);
343                let end = (i + window_size * dilation / 2 + 1).min(seq_len);
344
345                for j in (start..end).step_by(dilation) {
346                    mask.add_entry(i, j, 0.0);
347                }
348            }
349        }
350
351        Ok(mask)
352    }
353
354    fn generate_random_mask(
355        &self,
356        seq_len: usize,
357        sparsity_ratio: f32,
358    ) -> Result<SparseAttentionMask> {
359        let mut mask = SparseAttentionMask::new((seq_len, seq_len));
360        let total_elements = seq_len * seq_len;
361        let keep_elements = (total_elements as f32 * (1.0 - sparsity_ratio)) as usize;
362
363        // Simple random selection (in real implementation, use proper RNG)
364        let mut added = 0;
365        for i in 0..seq_len {
366            for j in 0..seq_len {
367                if added < keep_elements && (i + j) % 3 == 0 {
368                    // Simple pseudo-random
369                    mask.add_entry(i, j, 0.0);
370                    added += 1;
371                }
372            }
373        }
374
375        Ok(mask)
376    }
377
378    fn generate_block_sparse_mask(
379        &self,
380        seq_len: usize,
381        block_size: usize,
382        global_blocks: usize,
383        random_blocks: usize,
384    ) -> Result<SparseAttentionMask> {
385        let mut mask = SparseAttentionMask::new((seq_len, seq_len));
386        let num_blocks = seq_len.div_ceil(block_size);
387
388        for block_i in 0..num_blocks {
389            let start_i = block_i * block_size;
390            let end_i = (start_i + block_size).min(seq_len);
391
392            for block_j in 0..num_blocks {
393                let start_j = block_j * block_size;
394                let end_j = (start_j + block_size).min(seq_len);
395
396                // Local blocks (diagonal)
397                if block_i == block_j || block_i.abs_diff(block_j) <= 1 {
398                    for i in start_i..end_i {
399                        for j in start_j..end_j {
400                            mask.add_entry(i, j, 0.0);
401                        }
402                    }
403                }
404
405                // Global blocks
406                if block_j < global_blocks || block_i < global_blocks {
407                    for i in start_i..end_i {
408                        for j in start_j..end_j {
409                            mask.add_entry(i, j, 0.0);
410                        }
411                    }
412                }
413
414                // Random blocks (simplified)
415                if (block_i + block_j) % (num_blocks / random_blocks.max(1)) == 0 {
416                    for i in start_i..end_i {
417                        for j in start_j..end_j {
418                            mask.add_entry(i, j, 0.0);
419                        }
420                    }
421                }
422            }
423        }
424
425        Ok(mask)
426    }
427
428    fn generate_longformer_mask(
429        &self,
430        seq_len: usize,
431        window_size: usize,
432        global_tokens: &[usize],
433    ) -> Result<SparseAttentionMask> {
434        let mut mask = SparseAttentionMask::new((seq_len, seq_len));
435
436        // Local sliding window
437        for i in 0..seq_len {
438            let start = i.saturating_sub(window_size / 2);
439            let end = (i + window_size / 2 + 1).min(seq_len);
440
441            for j in start..end {
442                mask.add_entry(i, j, 0.0);
443            }
444        }
445
446        // Global tokens can attend to all positions
447        for &global_token in global_tokens {
448            if global_token < seq_len {
449                for j in 0..seq_len {
450                    mask.add_entry(global_token, j, 0.0);
451                    mask.add_entry(j, global_token, 0.0);
452                }
453            }
454        }
455
456        Ok(mask)
457    }
458
459    fn generate_linformer_mask(
460        &self,
461        seq_len: usize,
462        projection_dim: usize,
463    ) -> Result<SparseAttentionMask> {
464        // Linformer uses low-rank projections, so we create a full mask
465        // but mark it for special handling in the attention computation
466        let mut mask = SparseAttentionMask::new((seq_len, projection_dim));
467
468        for i in 0..seq_len {
469            for j in 0..projection_dim {
470                mask.add_entry(i, j, 0.0);
471            }
472        }
473
474        Ok(mask)
475    }
476
477    fn generate_reformer_mask(
478        &self,
479        seq_len: usize,
480        num_hashes: usize,
481        bucket_size: usize,
482    ) -> Result<SparseAttentionMask> {
483        let mut mask = SparseAttentionMask::new((seq_len, seq_len));
484        let num_buckets = seq_len.div_ceil(bucket_size);
485
486        // Simplified LSH bucketing (in real implementation, use proper hash functions)
487        for hash_idx in 0..num_hashes {
488            for bucket in 0..num_buckets {
489                let start = bucket * bucket_size;
490                let end = (start + bucket_size).min(seq_len);
491
492                // All tokens in same bucket attend to each other
493                for i in start..end {
494                    for j in start..end {
495                        let hash_offset = (i + hash_idx) % seq_len;
496                        let hash_bucket = hash_offset / bucket_size;
497                        if hash_bucket == bucket {
498                            mask.add_entry(i, j, 0.0);
499                        }
500                    }
501                }
502            }
503        }
504
505        Ok(mask)
506    }
507
508    /// Apply sparse attention mask to attention scores
509    #[allow(dead_code)]
510    fn apply_sparse_mask(
511        &self,
512        attention_scores: &Tensor,
513        mask: &SparseAttentionMask,
514    ) -> Result<Tensor> {
515        match attention_scores {
516            Tensor::F32(scores) => {
517                let mut masked_scores = scores.clone();
518                let shape = scores.shape();
519
520                if shape.len() != 2 {
521                    return Err(tensor_op_error(
522                        "tensor_operation",
523                        "Attention scores must be 2D for sparse masking".to_string(),
524                    ));
525                }
526
527                // Set all positions to -inf initially
528                masked_scores.fill(f32::NEG_INFINITY);
529
530                // Apply sparse mask
531                for &(row, col) in mask.indices.iter() {
532                    if row < shape[0] && col < shape[1] {
533                        masked_scores[[row, col]] = scores[[row, col]];
534                    }
535                }
536
537                Ok(Tensor::F32(masked_scores))
538            },
539            _ => Err(tensor_op_error(
540                "tensor_operation",
541                "Unsupported tensor type for sparse attention".to_string(),
542            )),
543        }
544    }
545
546    /// Compute sparse attention efficiently
547    fn compute_sparse_attention(
548        &self,
549        query: &Tensor,
550        key: &Tensor,
551        value: &Tensor,
552        mask: &SparseAttentionMask,
553    ) -> Result<Tensor> {
554        // For sparse attention, we only compute attention for the sparse positions
555        // This is a simplified implementation - in practice, this would use
556        // specialized sparse matrix operations
557
558        // Compute attention scores only for sparse positions
559        let attention_scores = self.compute_sparse_scores(query, key, mask)?;
560
561        // Apply softmax to sparse scores
562        let attention_weights = attention_scores.softmax(-1)?;
563
564        // Apply attention weights to values
565        self.apply_sparse_attention_weights(&attention_weights, value, mask)
566    }
567
568    fn compute_sparse_scores(
569        &self,
570        query: &Tensor,
571        key: &Tensor,
572        mask: &SparseAttentionMask,
573    ) -> Result<Tensor> {
574        // Simplified sparse score computation
575        // In practice, this would use efficient sparse matrix operations
576        match (query, key) {
577            (Tensor::F32(q), Tensor::F32(k)) => {
578                let q_shape = q.shape();
579                let k_shape = k.shape();
580
581                if q_shape.len() != 2 || k_shape.len() != 2 {
582                    return Err(tensor_op_error(
583                        "tensor_operation",
584                        "Query and key must be 2D".to_string(),
585                    ));
586                }
587
588                let seq_len = q_shape[0];
589                let head_dim = q_shape[1];
590
591                let mut scores = Array2::from_elem((seq_len, seq_len), f32::NEG_INFINITY);
592
593                // Compute scores only for sparse positions
594                for &(i, j) in &mask.indices {
595                    if i < seq_len && j < seq_len {
596                        let mut score = 0.0;
597                        for d in 0..head_dim {
598                            score += q[[i, d]] * k[[j, d]];
599                        }
600                        scores[[i, j]] = score * self.scale;
601                    }
602                }
603
604                Ok(Tensor::F32(scores.into_dyn()))
605            },
606            _ => Err(tensor_op_error(
607                "tensor_operation",
608                "Unsupported tensor types for sparse attention".to_string(),
609            )),
610        }
611    }
612
613    fn apply_sparse_attention_weights(
614        &self,
615        weights: &Tensor,
616        value: &Tensor,
617        mask: &SparseAttentionMask,
618    ) -> Result<Tensor> {
619        match (weights, value) {
620            (Tensor::F32(w), Tensor::F32(v)) => {
621                let w_shape = w.shape();
622                let v_shape = v.shape();
623
624                let seq_len = w_shape[0];
625                let head_dim = v_shape[1];
626
627                let mut output = Array2::zeros((seq_len, head_dim));
628
629                // Apply sparse attention weights
630                for &(i, j) in &mask.indices {
631                    if i < seq_len && j < seq_len {
632                        let weight = w[[i, j]];
633                        if weight != f32::NEG_INFINITY && !weight.is_nan() {
634                            for d in 0..head_dim {
635                                output[[i, d]] += weight * v[[j, d]];
636                            }
637                        }
638                    }
639                }
640
641                Ok(Tensor::F32(output.into_dyn()))
642            },
643            _ => Err(tensor_op_error(
644                "tensor_operation",
645                "Unsupported tensor types for sparse attention output".to_string(),
646            )),
647        }
648    }
649}
650
651impl Layer for SparseAttention {
652    type Input = AttentionInput;
653    type Output = Tensor;
654
655    fn forward(&self, input: Self::Input) -> Result<Self::Output> {
656        let AttentionInput {
657            hidden_states,
658            attention_mask: _,
659        } = input;
660
661        // Project to Q, K, V
662        let query = self.query_projection.forward(hidden_states.clone())?;
663        let key = self.key_projection.forward(hidden_states.clone())?;
664        let value = self.value_projection.forward(hidden_states)?;
665
666        // Get sequence length
667        let seq_len = match &query {
668            Tensor::F32(q) => q.shape()[0],
669            _ => {
670                return Err(tensor_op_error(
671                    "tensor_operation",
672                    "Unsupported tensor type".to_string(),
673                ))
674            },
675        };
676
677        // Generate sparse mask
678        let mask = self.generate_mask(seq_len)?;
679
680        // Compute sparse attention
681        let attention_output = self.compute_sparse_attention(&query, &key, &value, &mask)?;
682
683        // Final output projection
684        self.output_projection.forward(attention_output)
685    }
686}
687
688/// Utility functions for sparse attention patterns
689pub mod utils {
690    use super::*;
691
692    /// Create a local window attention pattern
693    pub fn create_local_attention(
694        hidden_size: usize,
695        num_heads: usize,
696        window_size: usize,
697    ) -> SparseAttentionConfig {
698        SparseAttentionConfig::new()
699            .with_hidden_size(hidden_size)
700            .with_num_heads(num_heads)
701            .with_pattern(SparsePattern::Local { window_size })
702    }
703
704    /// Create a block sparse attention pattern (BigBird style)
705    pub fn create_bigbird_attention(
706        hidden_size: usize,
707        num_heads: usize,
708        block_size: usize,
709    ) -> SparseAttentionConfig {
710        SparseAttentionConfig::new()
711            .with_hidden_size(hidden_size)
712            .with_num_heads(num_heads)
713            .with_pattern(SparsePattern::BlockSparse {
714                block_size,
715                global_blocks: 2,
716                random_blocks: 2,
717            })
718    }
719
720    /// Create a Longformer-style attention pattern
721    pub fn create_longformer_attention(
722        hidden_size: usize,
723        num_heads: usize,
724        window_size: usize,
725        global_tokens: Vec<usize>,
726    ) -> SparseAttentionConfig {
727        SparseAttentionConfig::new()
728            .with_hidden_size(hidden_size)
729            .with_num_heads(num_heads)
730            .with_pattern(SparsePattern::Longformer {
731                window_size,
732                global_tokens,
733            })
734    }
735
736    /// Analyze sparse attention pattern efficiency
737    pub fn analyze_pattern_efficiency(
738        pattern: &SparsePattern,
739        sequence_length: usize,
740    ) -> Result<PatternAnalysis> {
741        let config = SparseAttentionConfig::new().with_pattern(pattern.clone());
742        let attention = SparseAttention::new(config)?;
743        let mask = attention.generate_mask(sequence_length)?;
744
745        Ok(PatternAnalysis {
746            sparsity: mask.sparsity(),
747            memory_reduction: mask.sparsity(),
748            compute_reduction: mask.sparsity(),
749            effective_receptive_field: calculate_receptive_field(&mask),
750            pattern_regularity: calculate_pattern_regularity(&mask),
751        })
752    }
753
754    fn calculate_receptive_field(mask: &SparseAttentionMask) -> f32 {
755        let mut total_connections = 0;
756        let mut positions_with_connections = 0;
757
758        for i in 0..mask.shape.0 {
759            let mut connections = 0;
760            for &(row, _) in &mask.indices {
761                if row == i {
762                    connections += 1;
763                }
764            }
765            if connections > 0 {
766                total_connections += connections;
767                positions_with_connections += 1;
768            }
769        }
770
771        if positions_with_connections > 0 {
772            total_connections as f32 / positions_with_connections as f32
773        } else {
774            0.0
775        }
776    }
777
778    fn calculate_pattern_regularity(mask: &SparseAttentionMask) -> f32 {
779        // Simple regularity measure: variance in connections per position
780        let mut connections_per_position = vec![0; mask.shape.0];
781
782        for &(row, _) in &mask.indices {
783            connections_per_position[row] += 1;
784        }
785
786        let mean = connections_per_position.iter().sum::<usize>() as f32 / mask.shape.0 as f32;
787        let variance =
788            connections_per_position.iter().map(|&x| (x as f32 - mean).powi(2)).sum::<f32>()
789                / mask.shape.0 as f32;
790
791        1.0 / (1.0 + variance) // Higher regularity = lower variance
792    }
793
794    /// Analysis results for sparse attention patterns
795    #[derive(Debug, Clone)]
796    pub struct PatternAnalysis {
797        pub sparsity: f32,
798        pub memory_reduction: f32,
799        pub compute_reduction: f32,
800        pub effective_receptive_field: f32,
801        pub pattern_regularity: f32,
802    }
803}
804
805#[cfg(test)]
806mod tests {
807    use super::*;
808    use trustformers_core::tensor::Tensor;
809
810    #[test]
811    fn test_local_attention_mask() {
812        let config =
813            SparseAttentionConfig::new().with_pattern(SparsePattern::Local { window_size: 4 });
814
815        let attention = SparseAttention::new(config).expect("operation failed");
816        let mask = attention.generate_mask(8).expect("operation failed");
817
818        assert_eq!(mask.shape, (8, 8));
819        assert!(mask.sparsity() > 0.0);
820    }
821
822    #[test]
823    fn test_block_sparse_attention_mask() {
824        // Use larger sequence and smaller blocks to ensure some sparsity
825        let config = SparseAttentionConfig::new().with_pattern(SparsePattern::BlockSparse {
826            block_size: 4,
827            global_blocks: 1,
828            random_blocks: 1,
829        });
830
831        let attention = SparseAttention::new(config).expect("operation failed");
832        let mask = attention.generate_mask(32).expect("operation failed"); // Larger sequence for more sparsity
833
834        assert_eq!(mask.shape, (32, 32));
835        // With 8 blocks of size 4, not all blocks are covered by diagonal/global/random
836        // so we should have some sparsity
837        assert!(mask.sparsity() >= 0.0); // At minimum, mask is valid
838    }
839
840    #[test]
841    fn test_sparse_attention_forward() {
842        let config = SparseAttentionConfig::new()
843            .with_hidden_size(64)
844            .with_num_heads(4)
845            .with_pattern(SparsePattern::Local { window_size: 4 });
846
847        let attention = SparseAttention::new(config).expect("operation failed");
848
849        // Create dummy input
850        let input = Tensor::randn(&[8, 64]).expect("operation failed");
851        let attention_input = AttentionInput {
852            hidden_states: input,
853            attention_mask: None,
854        };
855
856        let output = attention.forward(attention_input).expect("operation failed");
857
858        match output {
859            Tensor::F32(arr) => {
860                assert_eq!(arr.shape(), &[8, 64]);
861            },
862            _ => panic!("Expected F32 tensor"),
863        }
864    }
865
866    #[test]
867    fn test_pattern_analysis() {
868        let pattern = SparsePattern::Local { window_size: 4 };
869        let analysis =
870            utils::analyze_pattern_efficiency(&pattern, 16).expect("operation failed in test");
871
872        assert!(analysis.sparsity > 0.0);
873        assert!(analysis.sparsity < 1.0);
874        assert!(analysis.effective_receptive_field > 0.0);
875        assert!(analysis.pattern_regularity > 0.0);
876    }
877
878    #[test]
879    fn test_utility_functions() {
880        let local_config = utils::create_local_attention(768, 12, 128);
881        assert_eq!(local_config.hidden_size, 768);
882        assert_eq!(local_config.num_heads, 12);
883
884        let bigbird_config = utils::create_bigbird_attention(768, 12, 64);
885        assert_eq!(bigbird_config.hidden_size, 768);
886
887        let longformer_config = utils::create_longformer_attention(768, 12, 128, vec![0, 1]);
888        assert_eq!(longformer_config.hidden_size, 768);
889    }
890
891    #[test]
892    fn test_sparse_mask_operations() {
893        let mut mask = SparseAttentionMask::new((4, 4));
894        mask.add_entry(0, 0, 0.0);
895        mask.add_entry(0, 1, 0.0);
896        mask.add_entry(1, 1, 0.0);
897
898        assert_eq!(mask.indices.len(), 3);
899        assert_eq!(mask.sparsity(), 1.0 - 3.0 / 16.0);
900
901        let dense = mask.to_dense();
902        assert_eq!(dense.len(), 4);
903        assert_eq!(dense[0].len(), 4);
904        assert_eq!(dense[0][0], 0.0);
905        assert_eq!(dense[0][1], 0.0);
906        assert_eq!(dense[0][2], f32::NEG_INFINITY);
907    }
908}