ruvector_attention/sparse/
mask.rs

1//! Sparse mask utilities for attention patterns
2
3use std::collections::HashSet;
4
5/// Sparse mask for attention patterns
6#[derive(Clone, Debug)]
7pub struct AttentionMask {
8    /// Sparse indices as (row, col) pairs
9    pub indices: Vec<(usize, usize)>,
10    /// Shape of the full attention matrix
11    pub shape: (usize, usize),
12    /// Set for O(1) lookup
13    lookup: HashSet<(usize, usize)>,
14}
15
16impl AttentionMask {
17    /// Create a new sparse mask from indices
18    pub fn new(indices: Vec<(usize, usize)>, shape: (usize, usize)) -> Self {
19        let lookup: HashSet<_> = indices.iter().copied().collect();
20        Self { indices, shape, lookup }
21    }
22
23    /// Check if position is masked (should attend)
24    #[inline]
25    pub fn is_attended(&self, row: usize, col: usize) -> bool {
26        self.lookup.contains(&(row, col))
27    }
28
29    /// Apply mask to attention scores (set non-attended to -inf)
30    pub fn apply(&self, scores: &mut [f32], seq_len: usize) {
31        for i in 0..seq_len {
32            for j in 0..seq_len {
33                if !self.is_attended(i, j) {
34                    scores[i * seq_len + j] = f32::NEG_INFINITY;
35                }
36            }
37        }
38    }
39
40    /// Create a local window mask
41    pub fn local_window(n: usize, window_size: usize) -> Self {
42        let mut indices = Vec::new();
43        let half_window = window_size / 2;
44
45        for i in 0..n {
46            let start = i.saturating_sub(half_window);
47            let end = (i + half_window + 1).min(n);
48            for j in start..end {
49                indices.push((i, j));
50            }
51        }
52
53        Self::new(indices, (n, n))
54    }
55
56    /// Create a causal mask (lower triangular)
57    pub fn causal(n: usize) -> Self {
58        let mut indices = Vec::new();
59        for i in 0..n {
60            for j in 0..=i {
61                indices.push((i, j));
62            }
63        }
64        Self::new(indices, (n, n))
65    }
66
67    /// Create a strided mask
68    pub fn strided(n: usize, stride: usize) -> Self {
69        let mut indices = Vec::new();
70        for i in 0..n {
71            for j in (0..n).step_by(stride) {
72                indices.push((i, j));
73            }
74            // Always attend to self
75            indices.push((i, i));
76        }
77        let mut indices: Vec<_> = indices.into_iter().collect::<HashSet<_>>().into_iter().collect();
78        indices.sort();
79        Self::new(indices, (n, n))
80    }
81
82    /// Number of non-zero entries
83    pub fn nnz(&self) -> usize {
84        self.indices.len()
85    }
86
87    /// Sparsity ratio (0 = all zeros, 1 = all ones)
88    pub fn density(&self) -> f32 {
89        self.nnz() as f32 / (self.shape.0 * self.shape.1) as f32
90    }
91}
92
93/// Builder for creating sparse masks
94pub struct SparseMaskBuilder {
95    n: usize,
96    indices: Vec<(usize, usize)>,
97}
98
99impl SparseMaskBuilder {
100    pub fn new(n: usize) -> Self {
101        Self { n, indices: Vec::new() }
102    }
103
104    /// Add local window pattern
105    pub fn with_local_window(mut self, window_size: usize) -> Self {
106        let half_window = window_size / 2;
107        for i in 0..self.n {
108            let start = i.saturating_sub(half_window);
109            let end = (i + half_window + 1).min(self.n);
110            for j in start..end {
111                self.indices.push((i, j));
112            }
113        }
114        self
115    }
116
117    /// Add global tokens (all positions attend to these)
118    pub fn with_global_tokens(mut self, global_indices: &[usize]) -> Self {
119        for i in 0..self.n {
120            for &g in global_indices {
121                if g < self.n {
122                    self.indices.push((i, g));
123                    self.indices.push((g, i));
124                }
125            }
126        }
127        self
128    }
129
130    /// Add causal masking
131    pub fn with_causal(mut self) -> Self {
132        for i in 0..self.n {
133            for j in 0..=i {
134                self.indices.push((i, j));
135            }
136        }
137        self
138    }
139
140    /// Build the mask
141    pub fn build(self) -> AttentionMask {
142        let mut indices: Vec<_> = self.indices.into_iter().collect::<HashSet<_>>().into_iter().collect();
143        indices.sort();
144        AttentionMask::new(indices, (self.n, self.n))
145    }
146}
147
148#[cfg(test)]
149mod tests {
150    use super::*;
151
152    #[test]
153    fn test_local_window_mask() {
154        let mask = AttentionMask::local_window(10, 3);
155
156        // Position 5 should attend to positions 4, 5, 6
157        assert!(mask.is_attended(5, 4));
158        assert!(mask.is_attended(5, 5));
159        assert!(mask.is_attended(5, 6));
160
161        // Position 5 should not attend to position 0
162        assert!(!mask.is_attended(5, 0));
163    }
164
165    #[test]
166    fn test_causal_mask() {
167        let mask = AttentionMask::causal(5);
168
169        // Lower triangle should be attended
170        assert!(mask.is_attended(2, 0));
171        assert!(mask.is_attended(2, 1));
172        assert!(mask.is_attended(2, 2));
173
174        // Upper triangle should not
175        assert!(!mask.is_attended(2, 3));
176        assert!(!mask.is_attended(2, 4));
177    }
178
179    #[test]
180    fn test_builder() {
181        let mask = SparseMaskBuilder::new(10)
182            .with_local_window(3)
183            .with_global_tokens(&[0])
184            .build();
185
186        // All positions should attend to global token 0
187        for i in 0..10 {
188            assert!(mask.is_attended(i, 0));
189        }
190    }
191}