Skip to main content

scirs2_neural/attention/sparse/
types.rs

1//! Types for sparse attention patterns (BigBird / Longformer-style).
2//!
3//! This module defines the enumerations and configuration structs that govern
4//! which attention patterns are used when operating on long sequences.
5
6use std::fmt;
7
8/// Sparse-attention pattern variants.
9///
10/// Controls which query–key pairs are evaluated during attention.  All
11/// patterns give sub-quadratic complexity for long sequences.
12///
13/// ```rust
14/// use scirs2_neural::attention::sparse::SparsePattern;
15/// let p = SparsePattern::LocalWindow;
16/// assert_eq!(p, SparsePattern::LocalWindow);
17/// ```
18#[non_exhaustive]
19#[derive(Debug, Clone, Copy, PartialEq, Eq)]
20pub enum SparsePattern {
21    /// Sliding-window local attention: each token attends to the `window_size`
22    /// tokens immediately before and after it (Longformer local attention).
23    LocalWindow,
24
25    /// Global + local attention: a set of designated *global* tokens attend to
26    /// all positions and all positions attend to global tokens; non-global
27    /// tokens additionally use a local window (Longformer full model).
28    GlobalLocal,
29
30    /// Uniform-random attention pattern: each token attends to `n_random`
31    /// randomly chosen positions in addition to its local window.  Used in the
32    /// BigBird architecture.
33    Random,
34
35    /// Block-sparse attention (BigBird): the sequence is divided into blocks of
36    /// size `block_size` and each block attends to: itself, `n_random_blocks`
37    /// randomly chosen blocks, and any designated global tokens.
38    BlockSparse,
39
40    /// Alias for [`LocalWindow`](SparsePattern::LocalWindow) — a sliding window
41    /// with no global tokens.
42    Sliding,
43}
44
45impl fmt::Display for SparsePattern {
46    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
47        match self {
48            SparsePattern::LocalWindow => write!(f, "LocalWindow"),
49            SparsePattern::GlobalLocal => write!(f, "GlobalLocal"),
50            SparsePattern::Random => write!(f, "Random"),
51            SparsePattern::BlockSparse => write!(f, "BlockSparse"),
52            SparsePattern::Sliding => write!(f, "Sliding"),
53        }
54    }
55}
56
57/// Configuration for sparse attention.
58///
59/// # Defaults
60///
61/// ```rust
62/// use scirs2_neural::attention::sparse::{SparseAttentionConfig, SparsePattern};
63/// let cfg = SparseAttentionConfig::default();
64/// assert_eq!(cfg.pattern, SparsePattern::LocalWindow);
65/// assert_eq!(cfg.window_size, 64);
66/// assert_eq!(cfg.n_global_tokens, 0);
67/// assert_eq!(cfg.n_random, 3);
68/// assert_eq!(cfg.block_size, 64);
69/// assert_eq!(cfg.n_heads, 8);
70/// assert_eq!(cfg.head_dim, 64);
71/// ```
72#[non_exhaustive]
73#[derive(Debug, Clone)]
74pub struct SparseAttentionConfig {
75    /// Which attention pattern to use.
76    pub pattern: SparsePattern,
77
78    /// Half-width of the local attention window (tokens on each side).
79    ///
80    /// Token `i` attends to positions `max(0, i−w)..=min(n−1, i+w)`.
81    pub window_size: usize,
82
83    /// Number of *global* token slots prepended to every sequence.
84    ///
85    /// Global tokens attend to **all** positions and all positions attend back
86    /// to them (used for CLS tokens, task prefixes, etc.).
87    pub n_global_tokens: usize,
88
89    /// Number of uniformly-random extra positions each token attends to
90    /// (BigBird random component).
91    pub n_random: usize,
92
93    /// Block size for BigBird block-sparse patterns.
94    ///
95    /// The sequence is partitioned into non-overlapping blocks of this size.
96    pub block_size: usize,
97
98    /// Number of attention heads.
99    pub n_heads: usize,
100
101    /// Dimensionality of each attention head.
102    pub head_dim: usize,
103}
104
105impl Default for SparseAttentionConfig {
106    fn default() -> Self {
107        Self {
108            pattern: SparsePattern::LocalWindow,
109            window_size: 64,
110            n_global_tokens: 0,
111            n_random: 3,
112            block_size: 64,
113            n_heads: 8,
114            head_dim: 64,
115        }
116    }
117}
118
119impl fmt::Display for SparseAttentionConfig {
120    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
121        write!(
122            f,
123            "SparseAttentionConfig(pattern={}, window={}, n_global={}, n_random={}, block={}, heads={}, head_dim={})",
124            self.pattern,
125            self.window_size,
126            self.n_global_tokens,
127            self.n_random,
128            self.block_size,
129            self.n_heads,
130            self.head_dim,
131        )
132    }
133}
134
135/// Precomputed sparse-attention mask.
136///
137/// `attend_to[i]` is a sorted, deduplicated list of position indices that
138/// query `i` is allowed to attend to.
139#[derive(Debug, Clone)]
140pub struct SparseAttentionMask {
141    /// Length of the sequence this mask was built for.
142    pub seq_len: usize,
143
144    /// For each query position `i`, the set of key positions `i` attends to.
145    ///
146    /// Entries are sorted in ascending order and deduplicated.
147    pub attend_to: Vec<Vec<usize>>,
148}
149
150impl SparseAttentionMask {
151    /// Total number of attended pairs across all query positions.
152    pub fn n_pairs(&self) -> usize {
153        self.attend_to.iter().map(|v| v.len()).sum()
154    }
155
156    /// Fraction of attended pairs relative to the fully dense case `seq_len²`.
157    pub fn density(&self) -> f64 {
158        if self.seq_len == 0 {
159            return 0.0;
160        }
161        self.n_pairs() as f64 / (self.seq_len * self.seq_len) as f64
162    }
163}
164
165#[cfg(test)]
166mod tests {
167    use super::*;
168
169    #[test]
170    fn sparse_attention_config_default() {
171        let cfg = SparseAttentionConfig::default();
172        assert_eq!(cfg.pattern, SparsePattern::LocalWindow);
173        assert_eq!(cfg.window_size, 64);
174        assert_eq!(cfg.n_global_tokens, 0);
175        assert_eq!(cfg.n_random, 3);
176        assert_eq!(cfg.block_size, 64);
177        assert_eq!(cfg.n_heads, 8);
178        assert_eq!(cfg.head_dim, 64);
179    }
180
181    #[test]
182    fn sparse_pattern_display() {
183        assert_eq!(SparsePattern::LocalWindow.to_string(), "LocalWindow");
184        assert_eq!(SparsePattern::BlockSparse.to_string(), "BlockSparse");
185    }
186
187    #[test]
188    fn mask_density_empty() {
189        let mask = SparseAttentionMask {
190            seq_len: 0,
191            attend_to: Vec::new(),
192        };
193        assert!((mask.density() - 0.0).abs() < 1e-12);
194    }
195
196    #[test]
197    fn mask_density_full() {
198        // Every position attends to every other: density == 1.
199        let seq_len = 4;
200        let attend_to = vec![vec![0, 1, 2, 3]; seq_len];
201        let mask = SparseAttentionMask { seq_len, attend_to };
202        assert!((mask.density() - 1.0).abs() < 1e-10);
203    }
204}