ruvector_attention/sparse/
mask.rs

1//! Sparse mask utilities for attention patterns
2
3use std::collections::HashSet;
4
5/// Sparse mask for attention patterns
6#[derive(Clone, Debug)]
7pub struct AttentionMask {
8    /// Sparse indices as (row, col) pairs
9    pub indices: Vec<(usize, usize)>,
10    /// Shape of the full attention matrix
11    pub shape: (usize, usize),
12    /// Set for O(1) lookup
13    lookup: HashSet<(usize, usize)>,
14}
15
16impl AttentionMask {
17    /// Create a new sparse mask from indices
18    pub fn new(indices: Vec<(usize, usize)>, shape: (usize, usize)) -> Self {
19        let lookup: HashSet<_> = indices.iter().copied().collect();
20        Self {
21            indices,
22            shape,
23            lookup,
24        }
25    }
26
27    /// Check if position is masked (should attend)
28    #[inline]
29    pub fn is_attended(&self, row: usize, col: usize) -> bool {
30        self.lookup.contains(&(row, col))
31    }
32
33    /// Apply mask to attention scores (set non-attended to -inf)
34    pub fn apply(&self, scores: &mut [f32], seq_len: usize) {
35        for i in 0..seq_len {
36            for j in 0..seq_len {
37                if !self.is_attended(i, j) {
38                    scores[i * seq_len + j] = f32::NEG_INFINITY;
39                }
40            }
41        }
42    }
43
44    /// Create a local window mask
45    pub fn local_window(n: usize, window_size: usize) -> Self {
46        let mut indices = Vec::new();
47        let half_window = window_size / 2;
48
49        for i in 0..n {
50            let start = i.saturating_sub(half_window);
51            let end = (i + half_window + 1).min(n);
52            for j in start..end {
53                indices.push((i, j));
54            }
55        }
56
57        Self::new(indices, (n, n))
58    }
59
60    /// Create a causal mask (lower triangular)
61    pub fn causal(n: usize) -> Self {
62        let mut indices = Vec::new();
63        for i in 0..n {
64            for j in 0..=i {
65                indices.push((i, j));
66            }
67        }
68        Self::new(indices, (n, n))
69    }
70
71    /// Create a strided mask
72    pub fn strided(n: usize, stride: usize) -> Self {
73        let mut indices = Vec::new();
74        for i in 0..n {
75            for j in (0..n).step_by(stride) {
76                indices.push((i, j));
77            }
78            // Always attend to self
79            indices.push((i, i));
80        }
81        let mut indices: Vec<_> = indices
82            .into_iter()
83            .collect::<HashSet<_>>()
84            .into_iter()
85            .collect();
86        indices.sort();
87        Self::new(indices, (n, n))
88    }
89
90    /// Number of non-zero entries
91    pub fn nnz(&self) -> usize {
92        self.indices.len()
93    }
94
95    /// Sparsity ratio (0 = all zeros, 1 = all ones)
96    pub fn density(&self) -> f32 {
97        self.nnz() as f32 / (self.shape.0 * self.shape.1) as f32
98    }
99}
100
101/// Builder for creating sparse masks
102pub struct SparseMaskBuilder {
103    n: usize,
104    indices: Vec<(usize, usize)>,
105}
106
107impl SparseMaskBuilder {
108    pub fn new(n: usize) -> Self {
109        Self {
110            n,
111            indices: Vec::new(),
112        }
113    }
114
115    /// Add local window pattern
116    pub fn with_local_window(mut self, window_size: usize) -> Self {
117        let half_window = window_size / 2;
118        for i in 0..self.n {
119            let start = i.saturating_sub(half_window);
120            let end = (i + half_window + 1).min(self.n);
121            for j in start..end {
122                self.indices.push((i, j));
123            }
124        }
125        self
126    }
127
128    /// Add global tokens (all positions attend to these)
129    pub fn with_global_tokens(mut self, global_indices: &[usize]) -> Self {
130        for i in 0..self.n {
131            for &g in global_indices {
132                if g < self.n {
133                    self.indices.push((i, g));
134                    self.indices.push((g, i));
135                }
136            }
137        }
138        self
139    }
140
141    /// Add causal masking
142    pub fn with_causal(mut self) -> Self {
143        for i in 0..self.n {
144            for j in 0..=i {
145                self.indices.push((i, j));
146            }
147        }
148        self
149    }
150
151    /// Build the mask
152    pub fn build(self) -> AttentionMask {
153        let mut indices: Vec<_> = self
154            .indices
155            .into_iter()
156            .collect::<HashSet<_>>()
157            .into_iter()
158            .collect();
159        indices.sort();
160        AttentionMask::new(indices, (self.n, self.n))
161    }
162}
163
164#[cfg(test)]
165mod tests {
166    use super::*;
167
168    #[test]
169    fn test_local_window_mask() {
170        let mask = AttentionMask::local_window(10, 3);
171
172        // Position 5 should attend to positions 4, 5, 6
173        assert!(mask.is_attended(5, 4));
174        assert!(mask.is_attended(5, 5));
175        assert!(mask.is_attended(5, 6));
176
177        // Position 5 should not attend to position 0
178        assert!(!mask.is_attended(5, 0));
179    }
180
181    #[test]
182    fn test_causal_mask() {
183        let mask = AttentionMask::causal(5);
184
185        // Lower triangle should be attended
186        assert!(mask.is_attended(2, 0));
187        assert!(mask.is_attended(2, 1));
188        assert!(mask.is_attended(2, 2));
189
190        // Upper triangle should not
191        assert!(!mask.is_attended(2, 3));
192        assert!(!mask.is_attended(2, 4));
193    }
194
195    #[test]
196    fn test_builder() {
197        let mask = SparseMaskBuilder::new(10)
198            .with_local_window(3)
199            .with_global_tokens(&[0])
200            .build();
201
202        // All positions should attend to global token 0
203        for i in 0..10 {
204            assert!(mask.is_attended(i, 0));
205        }
206    }
207}