ruvector_attention/sparse/
mask.rs1use std::collections::HashSet;
4
5#[derive(Clone, Debug)]
7pub struct AttentionMask {
8 pub indices: Vec<(usize, usize)>,
10 pub shape: (usize, usize),
12 lookup: HashSet<(usize, usize)>,
14}
15
16impl AttentionMask {
17 pub fn new(indices: Vec<(usize, usize)>, shape: (usize, usize)) -> Self {
19 let lookup: HashSet<_> = indices.iter().copied().collect();
20 Self { indices, shape, lookup }
21 }
22
23 #[inline]
25 pub fn is_attended(&self, row: usize, col: usize) -> bool {
26 self.lookup.contains(&(row, col))
27 }
28
29 pub fn apply(&self, scores: &mut [f32], seq_len: usize) {
31 for i in 0..seq_len {
32 for j in 0..seq_len {
33 if !self.is_attended(i, j) {
34 scores[i * seq_len + j] = f32::NEG_INFINITY;
35 }
36 }
37 }
38 }
39
40 pub fn local_window(n: usize, window_size: usize) -> Self {
42 let mut indices = Vec::new();
43 let half_window = window_size / 2;
44
45 for i in 0..n {
46 let start = i.saturating_sub(half_window);
47 let end = (i + half_window + 1).min(n);
48 for j in start..end {
49 indices.push((i, j));
50 }
51 }
52
53 Self::new(indices, (n, n))
54 }
55
56 pub fn causal(n: usize) -> Self {
58 let mut indices = Vec::new();
59 for i in 0..n {
60 for j in 0..=i {
61 indices.push((i, j));
62 }
63 }
64 Self::new(indices, (n, n))
65 }
66
67 pub fn strided(n: usize, stride: usize) -> Self {
69 let mut indices = Vec::new();
70 for i in 0..n {
71 for j in (0..n).step_by(stride) {
72 indices.push((i, j));
73 }
74 indices.push((i, i));
76 }
77 let mut indices: Vec<_> = indices.into_iter().collect::<HashSet<_>>().into_iter().collect();
78 indices.sort();
79 Self::new(indices, (n, n))
80 }
81
82 pub fn nnz(&self) -> usize {
84 self.indices.len()
85 }
86
87 pub fn density(&self) -> f32 {
89 self.nnz() as f32 / (self.shape.0 * self.shape.1) as f32
90 }
91}
92
93pub struct SparseMaskBuilder {
95 n: usize,
96 indices: Vec<(usize, usize)>,
97}
98
99impl SparseMaskBuilder {
100 pub fn new(n: usize) -> Self {
101 Self { n, indices: Vec::new() }
102 }
103
104 pub fn with_local_window(mut self, window_size: usize) -> Self {
106 let half_window = window_size / 2;
107 for i in 0..self.n {
108 let start = i.saturating_sub(half_window);
109 let end = (i + half_window + 1).min(self.n);
110 for j in start..end {
111 self.indices.push((i, j));
112 }
113 }
114 self
115 }
116
117 pub fn with_global_tokens(mut self, global_indices: &[usize]) -> Self {
119 for i in 0..self.n {
120 for &g in global_indices {
121 if g < self.n {
122 self.indices.push((i, g));
123 self.indices.push((g, i));
124 }
125 }
126 }
127 self
128 }
129
130 pub fn with_causal(mut self) -> Self {
132 for i in 0..self.n {
133 for j in 0..=i {
134 self.indices.push((i, j));
135 }
136 }
137 self
138 }
139
140 pub fn build(self) -> AttentionMask {
142 let mut indices: Vec<_> = self.indices.into_iter().collect::<HashSet<_>>().into_iter().collect();
143 indices.sort();
144 AttentionMask::new(indices, (self.n, self.n))
145 }
146}
147
148#[cfg(test)]
149mod tests {
150 use super::*;
151
152 #[test]
153 fn test_local_window_mask() {
154 let mask = AttentionMask::local_window(10, 3);
155
156 assert!(mask.is_attended(5, 4));
158 assert!(mask.is_attended(5, 5));
159 assert!(mask.is_attended(5, 6));
160
161 assert!(!mask.is_attended(5, 0));
163 }
164
165 #[test]
166 fn test_causal_mask() {
167 let mask = AttentionMask::causal(5);
168
169 assert!(mask.is_attended(2, 0));
171 assert!(mask.is_attended(2, 1));
172 assert!(mask.is_attended(2, 2));
173
174 assert!(!mask.is_attended(2, 3));
176 assert!(!mask.is_attended(2, 4));
177 }
178
179 #[test]
180 fn test_builder() {
181 let mask = SparseMaskBuilder::new(10)
182 .with_local_window(3)
183 .with_global_tokens(&[0])
184 .build();
185
186 for i in 0..10 {
188 assert!(mask.is_attended(i, 0));
189 }
190 }
191}