ruvector_attention/sparse/
mask.rs1use std::collections::HashSet;
4
5#[derive(Clone, Debug)]
7pub struct AttentionMask {
8 pub indices: Vec<(usize, usize)>,
10 pub shape: (usize, usize),
12 lookup: HashSet<(usize, usize)>,
14}
15
16impl AttentionMask {
17 pub fn new(indices: Vec<(usize, usize)>, shape: (usize, usize)) -> Self {
19 let lookup: HashSet<_> = indices.iter().copied().collect();
20 Self {
21 indices,
22 shape,
23 lookup,
24 }
25 }
26
27 #[inline]
29 pub fn is_attended(&self, row: usize, col: usize) -> bool {
30 self.lookup.contains(&(row, col))
31 }
32
33 pub fn apply(&self, scores: &mut [f32], seq_len: usize) {
35 for i in 0..seq_len {
36 for j in 0..seq_len {
37 if !self.is_attended(i, j) {
38 scores[i * seq_len + j] = f32::NEG_INFINITY;
39 }
40 }
41 }
42 }
43
44 pub fn local_window(n: usize, window_size: usize) -> Self {
46 let mut indices = Vec::new();
47 let half_window = window_size / 2;
48
49 for i in 0..n {
50 let start = i.saturating_sub(half_window);
51 let end = (i + half_window + 1).min(n);
52 for j in start..end {
53 indices.push((i, j));
54 }
55 }
56
57 Self::new(indices, (n, n))
58 }
59
60 pub fn causal(n: usize) -> Self {
62 let mut indices = Vec::new();
63 for i in 0..n {
64 for j in 0..=i {
65 indices.push((i, j));
66 }
67 }
68 Self::new(indices, (n, n))
69 }
70
71 pub fn strided(n: usize, stride: usize) -> Self {
73 let mut indices = Vec::new();
74 for i in 0..n {
75 for j in (0..n).step_by(stride) {
76 indices.push((i, j));
77 }
78 indices.push((i, i));
80 }
81 let mut indices: Vec<_> = indices
82 .into_iter()
83 .collect::<HashSet<_>>()
84 .into_iter()
85 .collect();
86 indices.sort();
87 Self::new(indices, (n, n))
88 }
89
90 pub fn nnz(&self) -> usize {
92 self.indices.len()
93 }
94
95 pub fn density(&self) -> f32 {
97 self.nnz() as f32 / (self.shape.0 * self.shape.1) as f32
98 }
99}
100
101pub struct SparseMaskBuilder {
103 n: usize,
104 indices: Vec<(usize, usize)>,
105}
106
107impl SparseMaskBuilder {
108 pub fn new(n: usize) -> Self {
109 Self {
110 n,
111 indices: Vec::new(),
112 }
113 }
114
115 pub fn with_local_window(mut self, window_size: usize) -> Self {
117 let half_window = window_size / 2;
118 for i in 0..self.n {
119 let start = i.saturating_sub(half_window);
120 let end = (i + half_window + 1).min(self.n);
121 for j in start..end {
122 self.indices.push((i, j));
123 }
124 }
125 self
126 }
127
128 pub fn with_global_tokens(mut self, global_indices: &[usize]) -> Self {
130 for i in 0..self.n {
131 for &g in global_indices {
132 if g < self.n {
133 self.indices.push((i, g));
134 self.indices.push((g, i));
135 }
136 }
137 }
138 self
139 }
140
141 pub fn with_causal(mut self) -> Self {
143 for i in 0..self.n {
144 for j in 0..=i {
145 self.indices.push((i, j));
146 }
147 }
148 self
149 }
150
151 pub fn build(self) -> AttentionMask {
153 let mut indices: Vec<_> = self
154 .indices
155 .into_iter()
156 .collect::<HashSet<_>>()
157 .into_iter()
158 .collect();
159 indices.sort();
160 AttentionMask::new(indices, (self.n, self.n))
161 }
162}
163
164#[cfg(test)]
165mod tests {
166 use super::*;
167
168 #[test]
169 fn test_local_window_mask() {
170 let mask = AttentionMask::local_window(10, 3);
171
172 assert!(mask.is_attended(5, 4));
174 assert!(mask.is_attended(5, 5));
175 assert!(mask.is_attended(5, 6));
176
177 assert!(!mask.is_attended(5, 0));
179 }
180
181 #[test]
182 fn test_causal_mask() {
183 let mask = AttentionMask::causal(5);
184
185 assert!(mask.is_attended(2, 0));
187 assert!(mask.is_attended(2, 1));
188 assert!(mask.is_attended(2, 2));
189
190 assert!(!mask.is_attended(2, 3));
192 assert!(!mask.is_attended(2, 4));
193 }
194
195 #[test]
196 fn test_builder() {
197 let mask = SparseMaskBuilder::new(10)
198 .with_local_window(3)
199 .with_global_tokens(&[0])
200 .build();
201
202 for i in 0..10 {
204 assert!(mask.is_attended(i, 0));
205 }
206 }
207}