ruvector_attention/sparse/
flash.rs1use crate::error::{AttentionError, AttentionResult};
6use crate::traits::Attention;
7
8pub struct FlashAttention {
12 dim: usize,
13 block_size: usize,
14 scale: f32,
15 causal: bool,
16}
17
18impl FlashAttention {
19 pub fn new(dim: usize, block_size: usize) -> Self {
21 Self {
22 dim,
23 block_size,
24 scale: 1.0 / (dim as f32).sqrt(),
25 causal: false,
26 }
27 }
28
29 pub fn causal(dim: usize, block_size: usize) -> Self {
31 Self {
32 dim,
33 block_size,
34 scale: 1.0 / (dim as f32).sqrt(),
35 causal: true,
36 }
37 }
38
39 fn compute_block_scores(&self, query: &[f32], keys: &[&[f32]], start_idx: usize) -> Vec<f32> {
41 keys.iter()
42 .enumerate()
43 .map(|(j, key)| {
44 if self.causal && start_idx + j > 0 {
45 f32::NEG_INFINITY
47 } else {
48 query
49 .iter()
50 .zip(key.iter())
51 .map(|(q, k)| q * k)
52 .sum::<f32>()
53 * self.scale
54 }
55 })
56 .collect()
57 }
58}
59
60impl Attention for FlashAttention {
61 fn compute(
62 &self,
63 query: &[f32],
64 keys: &[&[f32]],
65 values: &[&[f32]],
66 ) -> AttentionResult<Vec<f32>> {
67 if keys.is_empty() {
68 return Err(AttentionError::InvalidConfig("Empty keys".to_string()));
69 }
70 if keys.len() != values.len() {
71 return Err(AttentionError::DimensionMismatch {
72 expected: keys.len(),
73 actual: values.len(),
74 });
75 }
76 if query.len() != self.dim {
77 return Err(AttentionError::DimensionMismatch {
78 expected: self.dim,
79 actual: query.len(),
80 });
81 }
82
83 let n = keys.len();
84 let value_dim = values[0].len();
85
86 let mut output = vec![0.0f32; value_dim];
88 let mut max_so_far = f32::NEG_INFINITY;
89 let mut sum_exp = 0.0f32;
90
91 for block_start in (0..n).step_by(self.block_size) {
93 let block_end = (block_start + self.block_size).min(n);
94 let block_keys: Vec<&[f32]> = keys[block_start..block_end].to_vec();
95
96 let block_scores = self.compute_block_scores(query, &block_keys, block_start);
98
99 let block_max = block_scores
101 .iter()
102 .copied()
103 .filter(|x| x.is_finite())
104 .fold(f32::NEG_INFINITY, f32::max);
105
106 if !block_max.is_finite() {
107 continue; }
109
110 let new_max = max_so_far.max(block_max);
112
113 if max_so_far.is_finite() {
115 let rescale = (max_so_far - new_max).exp();
116 sum_exp *= rescale;
117 output.iter_mut().for_each(|o| *o *= rescale);
118 }
119
120 for (local_idx, &score) in block_scores.iter().enumerate() {
122 if score.is_finite() {
123 let exp_score = (score - new_max).exp();
124 sum_exp += exp_score;
125
126 let global_idx = block_start + local_idx;
127 for (j, &vj) in values[global_idx].iter().enumerate() {
128 output[j] += exp_score * vj;
129 }
130 }
131 }
132
133 max_so_far = new_max;
134 }
135
136 if sum_exp > 1e-8 {
138 output.iter_mut().for_each(|o| *o /= sum_exp);
139 }
140
141 Ok(output)
142 }
143
144 fn compute_with_mask(
145 &self,
146 query: &[f32],
147 keys: &[&[f32]],
148 values: &[&[f32]],
149 mask: Option<&[bool]>,
150 ) -> AttentionResult<Vec<f32>> {
151 if let Some(m) = mask {
152 let filtered: Vec<(usize, bool)> = m
153 .iter()
154 .copied()
155 .enumerate()
156 .filter(|(_, keep)| *keep)
157 .collect();
158 let filtered_keys: Vec<&[f32]> = filtered.iter().map(|(i, _)| keys[*i]).collect();
159 let filtered_values: Vec<&[f32]> = filtered.iter().map(|(i, _)| values[*i]).collect();
160 self.compute(query, &filtered_keys, &filtered_values)
161 } else {
162 self.compute(query, keys, values)
163 }
164 }
165
166 fn dim(&self) -> usize {
167 self.dim
168 }
169}
170
171#[cfg(test)]
172mod tests {
173 use super::*;
174 use crate::attention::ScaledDotProductAttention;
175
176 #[test]
177 fn test_flash_attention() {
178 let attention = FlashAttention::new(64, 16);
179
180 let query = vec![0.5; 64];
181 let keys: Vec<Vec<f32>> = (0..256).map(|_| vec![0.3; 64]).collect();
182 let values: Vec<Vec<f32>> = (0..256).map(|_| vec![1.0; 64]).collect();
183
184 let keys_refs: Vec<&[f32]> = keys.iter().map(|k| k.as_slice()).collect();
185 let values_refs: Vec<&[f32]> = values.iter().map(|v| v.as_slice()).collect();
186
187 let result = attention.compute(&query, &keys_refs, &values_refs).unwrap();
188 assert_eq!(result.len(), 64);
189 }
190
191 #[test]
192 fn test_flash_matches_standard() {
193 let dim = 32;
194 let flash = FlashAttention::new(dim, 8);
195 let standard = ScaledDotProductAttention::new(dim);
196
197 let query = vec![0.5; dim];
198 let keys: Vec<Vec<f32>> = (0..16).map(|i| vec![(i as f32) * 0.1; dim]).collect();
199 let values: Vec<Vec<f32>> = (0..16).map(|i| vec![(i as f32) * 0.2; dim]).collect();
200
201 let keys_refs: Vec<&[f32]> = keys.iter().map(|k| k.as_slice()).collect();
202 let values_refs: Vec<&[f32]> = values.iter().map(|v| v.as_slice()).collect();
203
204 let flash_result = flash.compute(&query, &keys_refs, &values_refs).unwrap();
205 let standard_result = standard.compute(&query, &keys_refs, &values_refs).unwrap();
206
207 for (f, s) in flash_result.iter().zip(standard_result.iter()) {
209 assert!((f - s).abs() < 1e-4, "Flash: {}, Standard: {}", f, s);
210 }
211 }
212
213 #[test]
214 fn test_causal_flash() {
215 let attention = FlashAttention::causal(32, 8);
216
217 let query = vec![1.0; 32];
218 let keys = vec![vec![0.5; 32]; 20];
219 let values = vec![vec![1.0; 32]; 20];
220
221 let keys_refs: Vec<&[f32]> = keys.iter().map(|k| k.as_slice()).collect();
222 let values_refs: Vec<&[f32]> = values.iter().map(|v| v.as_slice()).collect();
223
224 let result = attention.compute(&query, &keys_refs, &values_refs).unwrap();
225 assert_eq!(result.len(), 32);
226 }
227}