scirs2_neural/attention/sparse/types.rs
1//! Types for sparse attention patterns (BigBird / Longformer-style).
2//!
3//! This module defines the enumerations and configuration structs that govern
4//! which attention patterns are used when operating on long sequences.
5
6use std::fmt;
7
8/// Sparse-attention pattern variants.
9///
10/// Controls which query–key pairs are evaluated during attention. All
11/// patterns give sub-quadratic complexity for long sequences.
12///
13/// ```rust
14/// use scirs2_neural::attention::sparse::SparsePattern;
15/// let p = SparsePattern::LocalWindow;
16/// assert_eq!(p, SparsePattern::LocalWindow);
17/// ```
18#[non_exhaustive]
19#[derive(Debug, Clone, Copy, PartialEq, Eq)]
20pub enum SparsePattern {
21 /// Sliding-window local attention: each token attends to the `window_size`
22 /// tokens immediately before and after it (Longformer local attention).
23 LocalWindow,
24
25 /// Global + local attention: a set of designated *global* tokens attend to
26 /// all positions and all positions attend to global tokens; non-global
27 /// tokens additionally use a local window (Longformer full model).
28 GlobalLocal,
29
30 /// Uniform-random attention pattern: each token attends to `n_random`
31 /// randomly chosen positions in addition to its local window. Used in the
32 /// BigBird architecture.
33 Random,
34
35 /// Block-sparse attention (BigBird): the sequence is divided into blocks of
36 /// size `block_size` and each block attends to: itself, `n_random_blocks`
37 /// randomly chosen blocks, and any designated global tokens.
38 BlockSparse,
39
40 /// Alias for [`LocalWindow`](SparsePattern::LocalWindow) — a sliding window
41 /// with no global tokens.
42 Sliding,
43}
44
45impl fmt::Display for SparsePattern {
46 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
47 match self {
48 SparsePattern::LocalWindow => write!(f, "LocalWindow"),
49 SparsePattern::GlobalLocal => write!(f, "GlobalLocal"),
50 SparsePattern::Random => write!(f, "Random"),
51 SparsePattern::BlockSparse => write!(f, "BlockSparse"),
52 SparsePattern::Sliding => write!(f, "Sliding"),
53 }
54 }
55}
56
57/// Configuration for sparse attention.
58///
59/// # Defaults
60///
61/// ```rust
62/// use scirs2_neural::attention::sparse::{SparseAttentionConfig, SparsePattern};
63/// let cfg = SparseAttentionConfig::default();
64/// assert_eq!(cfg.pattern, SparsePattern::LocalWindow);
65/// assert_eq!(cfg.window_size, 64);
66/// assert_eq!(cfg.n_global_tokens, 0);
67/// assert_eq!(cfg.n_random, 3);
68/// assert_eq!(cfg.block_size, 64);
69/// assert_eq!(cfg.n_heads, 8);
70/// assert_eq!(cfg.head_dim, 64);
71/// ```
72#[non_exhaustive]
73#[derive(Debug, Clone)]
74pub struct SparseAttentionConfig {
75 /// Which attention pattern to use.
76 pub pattern: SparsePattern,
77
78 /// Half-width of the local attention window (tokens on each side).
79 ///
80 /// Token `i` attends to positions `max(0, i−w)..=min(n−1, i+w)`.
81 pub window_size: usize,
82
83 /// Number of *global* token slots prepended to every sequence.
84 ///
85 /// Global tokens attend to **all** positions and all positions attend back
86 /// to them (used for CLS tokens, task prefixes, etc.).
87 pub n_global_tokens: usize,
88
89 /// Number of uniformly-random extra positions each token attends to
90 /// (BigBird random component).
91 pub n_random: usize,
92
93 /// Block size for BigBird block-sparse patterns.
94 ///
95 /// The sequence is partitioned into non-overlapping blocks of this size.
96 pub block_size: usize,
97
98 /// Number of attention heads.
99 pub n_heads: usize,
100
101 /// Dimensionality of each attention head.
102 pub head_dim: usize,
103}
104
105impl Default for SparseAttentionConfig {
106 fn default() -> Self {
107 Self {
108 pattern: SparsePattern::LocalWindow,
109 window_size: 64,
110 n_global_tokens: 0,
111 n_random: 3,
112 block_size: 64,
113 n_heads: 8,
114 head_dim: 64,
115 }
116 }
117}
118
119impl fmt::Display for SparseAttentionConfig {
120 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
121 write!(
122 f,
123 "SparseAttentionConfig(pattern={}, window={}, n_global={}, n_random={}, block={}, heads={}, head_dim={})",
124 self.pattern,
125 self.window_size,
126 self.n_global_tokens,
127 self.n_random,
128 self.block_size,
129 self.n_heads,
130 self.head_dim,
131 )
132 }
133}
134
135/// Precomputed sparse-attention mask.
136///
137/// `attend_to[i]` is a sorted, deduplicated list of position indices that
138/// query `i` is allowed to attend to.
139#[derive(Debug, Clone)]
140pub struct SparseAttentionMask {
141 /// Length of the sequence this mask was built for.
142 pub seq_len: usize,
143
144 /// For each query position `i`, the set of key positions `i` attends to.
145 ///
146 /// Entries are sorted in ascending order and deduplicated.
147 pub attend_to: Vec<Vec<usize>>,
148}
149
150impl SparseAttentionMask {
151 /// Total number of attended pairs across all query positions.
152 pub fn n_pairs(&self) -> usize {
153 self.attend_to.iter().map(|v| v.len()).sum()
154 }
155
156 /// Fraction of attended pairs relative to the fully dense case `seq_len²`.
157 pub fn density(&self) -> f64 {
158 if self.seq_len == 0 {
159 return 0.0;
160 }
161 self.n_pairs() as f64 / (self.seq_len * self.seq_len) as f64
162 }
163}
164
165#[cfg(test)]
166mod tests {
167 use super::*;
168
169 #[test]
170 fn sparse_attention_config_default() {
171 let cfg = SparseAttentionConfig::default();
172 assert_eq!(cfg.pattern, SparsePattern::LocalWindow);
173 assert_eq!(cfg.window_size, 64);
174 assert_eq!(cfg.n_global_tokens, 0);
175 assert_eq!(cfg.n_random, 3);
176 assert_eq!(cfg.block_size, 64);
177 assert_eq!(cfg.n_heads, 8);
178 assert_eq!(cfg.head_dim, 64);
179 }
180
181 #[test]
182 fn sparse_pattern_display() {
183 assert_eq!(SparsePattern::LocalWindow.to_string(), "LocalWindow");
184 assert_eq!(SparsePattern::BlockSparse.to_string(), "BlockSparse");
185 }
186
187 #[test]
188 fn mask_density_empty() {
189 let mask = SparseAttentionMask {
190 seq_len: 0,
191 attend_to: Vec::new(),
192 };
193 assert!((mask.density() - 0.0).abs() < 1e-12);
194 }
195
196 #[test]
197 fn mask_density_full() {
198 // Every position attends to every other: density == 1.
199 let seq_len = 4;
200 let attend_to = vec![vec![0, 1, 2, 3]; seq_len];
201 let mask = SparseAttentionMask { seq_len, attend_to };
202 assert!((mask.density() - 1.0).abs() < 1e-10);
203 }
204}