ruvector_attention/sparse/
flash.rs

1//! Flash attention - memory-efficient attention with tiled computation
2//!
3//! Memory: O(block_size) for attention matrix instead of O(n²)
4
5use crate::error::{AttentionError, AttentionResult};
6use crate::traits::Attention;
7
8/// Flash attention with block-wise computation
9///
10/// Computes attention in tiles to minimize memory usage while maintaining numerical stability.
11pub struct FlashAttention {
12    dim: usize,
13    block_size: usize,
14    scale: f32,
15    causal: bool,
16}
17
18impl FlashAttention {
19    /// Create new flash attention
20    pub fn new(dim: usize, block_size: usize) -> Self {
21        Self {
22            dim,
23            block_size,
24            scale: 1.0 / (dim as f32).sqrt(),
25            causal: false,
26        }
27    }
28
29    /// Create with causal masking
30    pub fn causal(dim: usize, block_size: usize) -> Self {
31        Self {
32            dim,
33            block_size,
34            scale: 1.0 / (dim as f32).sqrt(),
35            causal: true,
36        }
37    }
38
39    /// Compute attention scores for a block
40    fn compute_block_scores(&self, query: &[f32], keys: &[&[f32]], start_idx: usize) -> Vec<f32> {
41        keys.iter()
42            .enumerate()
43            .map(|(j, key)| {
44                if self.causal && start_idx + j > 0 {
45                    // Simplified causal: assuming query is at position 0
46                    f32::NEG_INFINITY
47                } else {
48                    query
49                        .iter()
50                        .zip(key.iter())
51                        .map(|(q, k)| q * k)
52                        .sum::<f32>()
53                        * self.scale
54                }
55            })
56            .collect()
57    }
58}
59
60impl Attention for FlashAttention {
61    fn compute(
62        &self,
63        query: &[f32],
64        keys: &[&[f32]],
65        values: &[&[f32]],
66    ) -> AttentionResult<Vec<f32>> {
67        if keys.is_empty() {
68            return Err(AttentionError::InvalidConfig("Empty keys".to_string()));
69        }
70        if keys.len() != values.len() {
71            return Err(AttentionError::DimensionMismatch {
72                expected: keys.len(),
73                actual: values.len(),
74            });
75        }
76        if query.len() != self.dim {
77            return Err(AttentionError::DimensionMismatch {
78                expected: self.dim,
79                actual: query.len(),
80            });
81        }
82
83        let n = keys.len();
84        let value_dim = values[0].len();
85
86        // Online softmax with tiled computation
87        let mut output = vec![0.0f32; value_dim];
88        let mut max_so_far = f32::NEG_INFINITY;
89        let mut sum_exp = 0.0f32;
90
91        // Process in blocks
92        for block_start in (0..n).step_by(self.block_size) {
93            let block_end = (block_start + self.block_size).min(n);
94            let block_keys: Vec<&[f32]> = keys[block_start..block_end].to_vec();
95
96            // Compute attention scores for this block
97            let block_scores = self.compute_block_scores(query, &block_keys, block_start);
98
99            // Find block maximum
100            let block_max = block_scores
101                .iter()
102                .copied()
103                .filter(|x| x.is_finite())
104                .fold(f32::NEG_INFINITY, f32::max);
105
106            if !block_max.is_finite() {
107                continue; // Skip fully masked blocks
108            }
109
110            // New maximum
111            let new_max = max_so_far.max(block_max);
112
113            // Rescale previous accumulations
114            if max_so_far.is_finite() {
115                let rescale = (max_so_far - new_max).exp();
116                sum_exp *= rescale;
117                output.iter_mut().for_each(|o| *o *= rescale);
118            }
119
120            // Add contribution from this block
121            for (local_idx, &score) in block_scores.iter().enumerate() {
122                if score.is_finite() {
123                    let exp_score = (score - new_max).exp();
124                    sum_exp += exp_score;
125
126                    let global_idx = block_start + local_idx;
127                    for (j, &vj) in values[global_idx].iter().enumerate() {
128                        output[j] += exp_score * vj;
129                    }
130                }
131            }
132
133            max_so_far = new_max;
134        }
135
136        // Final normalization
137        if sum_exp > 1e-8 {
138            output.iter_mut().for_each(|o| *o /= sum_exp);
139        }
140
141        Ok(output)
142    }
143
144    fn compute_with_mask(
145        &self,
146        query: &[f32],
147        keys: &[&[f32]],
148        values: &[&[f32]],
149        mask: Option<&[bool]>,
150    ) -> AttentionResult<Vec<f32>> {
151        if let Some(m) = mask {
152            let filtered: Vec<(usize, bool)> = m
153                .iter()
154                .copied()
155                .enumerate()
156                .filter(|(_, keep)| *keep)
157                .collect();
158            let filtered_keys: Vec<&[f32]> = filtered.iter().map(|(i, _)| keys[*i]).collect();
159            let filtered_values: Vec<&[f32]> = filtered.iter().map(|(i, _)| values[*i]).collect();
160            self.compute(query, &filtered_keys, &filtered_values)
161        } else {
162            self.compute(query, keys, values)
163        }
164    }
165
166    fn dim(&self) -> usize {
167        self.dim
168    }
169}
170
171#[cfg(test)]
172mod tests {
173    use super::*;
174    use crate::attention::ScaledDotProductAttention;
175
176    #[test]
177    fn test_flash_attention() {
178        let attention = FlashAttention::new(64, 16);
179
180        let query = vec![0.5; 64];
181        let keys: Vec<Vec<f32>> = (0..256).map(|_| vec![0.3; 64]).collect();
182        let values: Vec<Vec<f32>> = (0..256).map(|_| vec![1.0; 64]).collect();
183
184        let keys_refs: Vec<&[f32]> = keys.iter().map(|k| k.as_slice()).collect();
185        let values_refs: Vec<&[f32]> = values.iter().map(|v| v.as_slice()).collect();
186
187        let result = attention.compute(&query, &keys_refs, &values_refs).unwrap();
188        assert_eq!(result.len(), 64);
189    }
190
191    #[test]
192    fn test_flash_matches_standard() {
193        let dim = 32;
194        let flash = FlashAttention::new(dim, 8);
195        let standard = ScaledDotProductAttention::new(dim);
196
197        let query = vec![0.5; dim];
198        let keys: Vec<Vec<f32>> = (0..16).map(|i| vec![(i as f32) * 0.1; dim]).collect();
199        let values: Vec<Vec<f32>> = (0..16).map(|i| vec![(i as f32) * 0.2; dim]).collect();
200
201        let keys_refs: Vec<&[f32]> = keys.iter().map(|k| k.as_slice()).collect();
202        let values_refs: Vec<&[f32]> = values.iter().map(|v| v.as_slice()).collect();
203
204        let flash_result = flash.compute(&query, &keys_refs, &values_refs).unwrap();
205        let standard_result = standard.compute(&query, &keys_refs, &values_refs).unwrap();
206
207        // Results should be approximately equal
208        for (f, s) in flash_result.iter().zip(standard_result.iter()) {
209            assert!((f - s).abs() < 1e-4, "Flash: {}, Standard: {}", f, s);
210        }
211    }
212
213    #[test]
214    fn test_causal_flash() {
215        let attention = FlashAttention::causal(32, 8);
216
217        let query = vec![1.0; 32];
218        let keys = vec![vec![0.5; 32]; 20];
219        let values = vec![vec![1.0; 32]; 20];
220
221        let keys_refs: Vec<&[f32]> = keys.iter().map(|k| k.as_slice()).collect();
222        let values_refs: Vec<&[f32]> = values.iter().map(|v| v.as_slice()).collect();
223
224        let result = attention.compute(&query, &keys_refs, &values_refs).unwrap();
225        assert_eq!(result.len(), 32);
226    }
227}