Skip to main content

ruvector_attention/attention/
flash.rs

1//! FlashAttention-3 IO-aware tiled attention.
2//!
3//! Implements the FlashAttention algorithm which reduces HBM (High Bandwidth Memory)
4//! reads from O(N^2 d) to O(N^2 d^2 / M) where M is SRAM size, by tiling Q, K, V
5//! into blocks and fusing the softmax rescaling with the matmul accumulation.
6//!
7//! The key insight is that standard attention materializes the full N x N attention
8//! matrix in HBM, causing O(N^2) memory. FlashAttention never materializes this
9//! matrix, instead computing attention in tiles using an online softmax algorithm
10//! that maintains running statistics (row-max and log-sum-exp) to avoid the
11//! two-pass softmax.
12//!
13//! This module provides:
14//! - [`FlashConfig`]: Configuration for block sizes, causal masking, and dropout
15//! - [`FlashAttention3`]: IO-aware tiled forward pass returning output + LSE
16//! - [`IOStats`]: Tracking of FLOPs and memory transfer for IO analysis
17//! - [`RingAttention`]: Simplified ring-based distributed attention across devices
18
19use crate::error::{AttentionError, AttentionResult};
20
21/// Configuration for FlashAttention tiled computation.
22#[derive(Clone, Debug)]
23pub struct FlashConfig {
24    /// Block size along the query dimension (Br).
25    pub block_size_q: usize,
26    /// Block size along the key/value dimension (Bc).
27    pub block_size_kv: usize,
28    /// Whether to apply causal masking (upper-triangular mask).
29    pub causal: bool,
30    /// Dropout probability (0.0 = no dropout). Applied conceptually but not
31    /// stochastically in this CPU implementation.
32    pub dropout_p: f32,
33}
34
35impl Default for FlashConfig {
36    fn default() -> Self {
37        Self {
38            block_size_q: 64,
39            block_size_kv: 64,
40            causal: false,
41            dropout_p: 0.0,
42        }
43    }
44}
45
46impl FlashConfig {
47    /// Creates a config with custom block sizes.
48    pub fn new(block_size_q: usize, block_size_kv: usize) -> AttentionResult<Self> {
49        if block_size_q == 0 || block_size_kv == 0 {
50            return Err(AttentionError::InvalidConfig(
51                "Block sizes must be > 0".into(),
52            ));
53        }
54        Ok(Self {
55            block_size_q,
56            block_size_kv,
57            ..Default::default()
58        })
59    }
60
61    /// Returns a causal variant of this config.
62    pub fn with_causal(mut self) -> Self {
63        self.causal = true;
64        self
65    }
66
67    /// Sets the dropout probability.
68    pub fn with_dropout(mut self, p: f32) -> AttentionResult<Self> {
69        if !(0.0..=1.0).contains(&p) {
70            return Err(AttentionError::InvalidConfig(
71                "Dropout must be in [0, 1]".into(),
72            ));
73        }
74        self.dropout_p = p;
75        Ok(self)
76    }
77}
78
79/// IO statistics for comparing tiled vs naive attention.
80#[derive(Clone, Debug, Default)]
81pub struct IOStats {
82    /// Total floating-point operations performed.
83    pub total_flops: u64,
84    /// Total elements read from main memory.
85    pub memory_reads: u64,
86    /// Total elements written to main memory.
87    pub memory_writes: u64,
88    /// Sequence length used for the computation.
89    seq_len: usize,
90    /// Head dimension used for the computation.
91    head_dim: usize,
92    /// Block size Q used.
93    #[allow(dead_code)]
94    block_size_q: usize,
95    /// Block size KV used.
96    #[allow(dead_code)]
97    block_size_kv: usize,
98}
99
100impl IOStats {
101    /// Returns the ratio of naive FLOPs to tiled FLOPs (should be ~1.0 since
102    /// FLOPs are the same; the advantage is in memory IO).
103    pub fn flop_ratio(&self) -> f32 {
104        if self.total_flops == 0 {
105            return 1.0;
106        }
107        // Naive attention has same FLOPs but materializes N^2 attention matrix.
108        // The IO ratio compares memory transfers: naive reads/writes O(N^2 + Nd),
109        // tiled reads/writes O(N^2 d / M) where M ~ block_size.
110        let n = self.seq_len as f64;
111        let d = self.head_dim as f64;
112        let naive_io = n * n + n * d; // attention matrix + QKV
113        let tiled_io = self.memory_reads as f64 + self.memory_writes as f64;
114        if tiled_io < 1.0 {
115            return 1.0;
116        }
117        (naive_io / tiled_io) as f32
118    }
119
120    /// Returns the memory complexity class as a string.
121    /// Tiled: O(N) working memory. Naive: O(N^2).
122    pub fn memory_complexity(&self) -> &'static str {
123        "O(N)"
124    }
125
126    /// Returns the naive attention memory complexity for comparison.
127    pub fn naive_memory_complexity(&self) -> &'static str {
128        "O(N^2)"
129    }
130}
131
132/// FlashAttention-3: IO-aware tiled attention.
133///
134/// Processes Q in blocks of Br rows and K/V in blocks of Bc rows, never
135/// materializing the full N x N attention matrix. Uses online softmax with
136/// running max and log-sum-exp to maintain numerical stability.
137pub struct FlashAttention3;
138
139/// Output of a flash attention forward pass.
140#[derive(Clone, Debug)]
141pub struct FlashOutput {
142    /// The attention output matrix, shape [num_queries, dim].
143    pub output: Vec<Vec<f32>>,
144    /// Log-sum-exp per query row (m_i + ln(l_i)), used for backward pass.
145    pub lse: Vec<f32>,
146    /// IO statistics for this computation.
147    pub stats: IOStats,
148}
149
150impl FlashAttention3 {
151    /// Computes IO-aware tiled attention.
152    ///
153    /// # Algorithm
154    ///
155    /// 1. Split Q into Tr blocks of Br rows, K/V into Tc blocks of Bc rows.
156    /// 2. For each Q block i, iterate over all K/V blocks j:
157    ///    - Compute S_ij = Q_i @ K_j^T / sqrt(d)
158    ///    - Apply causal mask if configured
159    ///    - Update running max, sum-exp, and output using online softmax
160    /// 3. Return output and log-sum-exp for backward pass.
161    ///
162    /// # Arguments
163    ///
164    /// * `q` - Query matrix, shape [n_q, d]
165    /// * `k` - Key matrix, shape [n_kv, d]
166    /// * `v` - Value matrix, shape [n_kv, d]
167    /// * `config` - Flash attention configuration
168    pub fn forward(
169        q: &[Vec<f32>],
170        k: &[Vec<f32>],
171        v: &[Vec<f32>],
172        config: &FlashConfig,
173    ) -> AttentionResult<FlashOutput> {
174        if q.is_empty() {
175            return Err(AttentionError::EmptyInput("queries".into()));
176        }
177        if k.is_empty() || v.is_empty() {
178            return Err(AttentionError::EmptyInput("keys or values".into()));
179        }
180        if k.len() != v.len() {
181            return Err(AttentionError::DimensionMismatch {
182                expected: k.len(),
183                actual: v.len(),
184            });
185        }
186        let d = q[0].len();
187        if d == 0 {
188            return Err(AttentionError::InvalidConfig("Dimension must be > 0".into()));
189        }
190        let scale = 1.0 / (d as f32).sqrt();
191        let n_q = q.len();
192        let n_kv = k.len();
193        let br = config.block_size_q;
194        let bc = config.block_size_kv;
195
196        let mut output = vec![vec![0.0f32; d]; n_q];
197        let mut lse = vec![f32::NEG_INFINITY; n_q];
198        let mut row_max = vec![f32::NEG_INFINITY; n_q];
199        let mut row_sum = vec![0.0f32; n_q];
200
201        let mut stats = IOStats {
202            seq_len: n_q.max(n_kv),
203            head_dim: d,
204            block_size_q: br,
205            block_size_kv: bc,
206            ..Default::default()
207        };
208
209        // Outer loop: iterate over Q blocks
210        for qi_start in (0..n_q).step_by(br) {
211            let qi_end = (qi_start + br).min(n_q);
212
213            // Inner loop: iterate over K/V blocks
214            for kj_start in (0..n_kv).step_by(bc) {
215                let kj_end = (kj_start + bc).min(n_kv);
216
217                // Track memory reads: Q block + K block + V block
218                stats.memory_reads += ((qi_end - qi_start) * d
219                    + (kj_end - kj_start) * d * 2) as u64;
220
221                // For each query row in this Q block
222                for qi in qi_start..qi_end {
223                    // Compute S_ij = Q_i @ K_j^T / sqrt(d) for each key in block
224                    let mut block_scores = Vec::with_capacity(kj_end - kj_start);
225                    for kj in kj_start..kj_end {
226                        let mut dot = 0.0f32;
227                        for dd in 0..d {
228                            dot += q[qi][dd] * k[kj][dd];
229                        }
230                        let mut score = dot * scale;
231
232                        // Apply causal mask: mask out positions where kj > qi
233                        if config.causal && kj > qi {
234                            score = f32::NEG_INFINITY;
235                        }
236                        block_scores.push(score);
237                        stats.total_flops += (2 * d) as u64; // dot product
238                    }
239
240                    // Block row-max
241                    let m_ij = block_scores
242                        .iter()
243                        .copied()
244                        .fold(f32::NEG_INFINITY, f32::max);
245
246                    if !m_ij.is_finite() {
247                        continue; // Fully masked block
248                    }
249
250                    // Exponentiate and sum
251                    let exp_scores: Vec<f32> =
252                        block_scores.iter().map(|&s| (s - m_ij).exp()).collect();
253                    let l_ij: f32 = exp_scores
254                        .iter()
255                        .filter(|x| x.is_finite())
256                        .sum();
257
258                    // Online softmax rescaling
259                    let m_old = row_max[qi];
260                    let m_new = m_old.max(m_ij);
261
262                    let exp_old = if m_old.is_finite() {
263                        (m_old - m_new).exp()
264                    } else {
265                        0.0
266                    };
267                    let exp_new = (m_ij - m_new).exp();
268
269                    let l_new = exp_old * row_sum[qi] + exp_new * l_ij;
270
271                    // Rescale existing output and add new contribution
272                    // O_i = (exp(m_old - m_new) * l_old * O_i
273                    //      + exp(m_ij - m_new) * P_ij @ V_j) / l_new
274                    if l_new > 0.0 {
275                        let inv_l_new = 1.0 / l_new;
276                        let scale_old = exp_old * row_sum[qi] * inv_l_new;
277                        let scale_new = exp_new * inv_l_new;
278
279                        for dd in 0..d {
280                            let mut pv = 0.0f32;
281                            for (local_j, kj) in (kj_start..kj_end).enumerate() {
282                                if exp_scores[local_j].is_finite() {
283                                    pv += exp_scores[local_j] * v[kj][dd];
284                                }
285                            }
286                            output[qi][dd] =
287                                scale_old * output[qi][dd] + scale_new * pv;
288                            stats.total_flops += (2 * (kj_end - kj_start)) as u64;
289                        }
290                    }
291
292                    row_max[qi] = m_new;
293                    row_sum[qi] = l_new;
294                }
295            }
296
297            // Track memory writes: output block
298            stats.memory_writes += ((qi_end - qi_start) * d) as u64;
299        }
300
301        // Compute LSE = m + ln(l) for backward pass
302        for i in 0..n_q {
303            if row_sum[i] > 0.0 && row_max[i].is_finite() {
304                lse[i] = row_max[i] + row_sum[i].ln();
305            }
306        }
307
308        Ok(FlashOutput {
309            output,
310            lse,
311            stats,
312        })
313    }
314}
315
316/// Generates a causal mask for block (qi_start..qi_end) x (kj_start..kj_end)
317/// without materializing a full N x N mask.
318///
319/// Returns `true` for positions that should be attended to (kj <= qi).
320pub fn causal_block_mask(
321    qi_start: usize,
322    qi_end: usize,
323    kj_start: usize,
324    kj_end: usize,
325) -> Vec<Vec<bool>> {
326    let mut mask = Vec::with_capacity(qi_end - qi_start);
327    for qi in qi_start..qi_end {
328        let mut row = Vec::with_capacity(kj_end - kj_start);
329        for kj in kj_start..kj_end {
330            row.push(kj <= qi);
331        }
332        mask.push(row);
333    }
334    mask
335}
336
337/// Simplified ring attention for distributed sequence parallelism.
338///
339/// In ring attention, the sequence is sharded across devices. Each device holds
340/// a local Q shard and rotates K/V shards around a ring, accumulating partial
341/// attention using the same online softmax as FlashAttention.
342pub struct RingAttention;
343
344/// Result from a single device in ring attention.
345#[derive(Clone, Debug)]
346pub struct RingDeviceOutput {
347    /// Output for this device's Q shard.
348    pub output: Vec<Vec<f32>>,
349    /// LSE for this device's Q shard.
350    pub lse: Vec<f32>,
351    /// Number of simulated ring transfers.
352    pub transfers: usize,
353}
354
355impl RingAttention {
356    /// Runs ring attention across simulated devices.
357    ///
358    /// Each device holds a Q shard and processes all K/V shards by rotating
359    /// them around the ring. This simulates the communication pattern of
360    /// distributed ring attention.
361    ///
362    /// # Arguments
363    ///
364    /// * `q_shards` - Q shards, one per device
365    /// * `k_shards` - K shards, one per device
366    /// * `v_shards` - V shards, one per device
367    pub fn ring_forward(
368        q_shards: &[Vec<Vec<f32>>],
369        k_shards: &[Vec<Vec<f32>>],
370        v_shards: &[Vec<Vec<f32>>],
371    ) -> AttentionResult<Vec<RingDeviceOutput>> {
372        let num_devices = q_shards.len();
373        if num_devices == 0 {
374            return Err(AttentionError::EmptyInput("shards".into()));
375        }
376        if k_shards.len() != num_devices || v_shards.len() != num_devices {
377            return Err(AttentionError::DimensionMismatch {
378                expected: num_devices,
379                actual: k_shards.len().min(v_shards.len()),
380            });
381        }
382
383        let config = FlashConfig {
384            block_size_q: 32,
385            block_size_kv: 32,
386            causal: false,
387            dropout_p: 0.0,
388        };
389
390        let mut results = Vec::with_capacity(num_devices);
391
392        // Each device processes its local Q against all K/V shards
393        for device_id in 0..num_devices {
394            let local_q = &q_shards[device_id];
395            if local_q.is_empty() {
396                return Err(AttentionError::EmptyInput(
397                    format!("Q shard on device {device_id}"),
398                ));
399            }
400            let d = local_q[0].len();
401            let n_q = local_q.len();
402
403            let mut output = vec![vec![0.0f32; d]; n_q];
404            let mut row_max = vec![f32::NEG_INFINITY; n_q];
405            let mut row_sum = vec![0.0f32; n_q];
406            let mut lse = vec![f32::NEG_INFINITY; n_q];
407            let mut transfers = 0usize;
408
409            // Rotate through all K/V shards (ring communication)
410            for step in 0..num_devices {
411                let kv_idx = (device_id + step) % num_devices;
412                if step > 0 {
413                    transfers += 1; // Simulated device-to-device transfer
414                }
415
416                let partial = FlashAttention3::forward(
417                    local_q,
418                    &k_shards[kv_idx],
419                    &v_shards[kv_idx],
420                    &config,
421                )?;
422
423                // Merge partial results using online softmax
424                for qi in 0..n_q {
425                    let m_partial = if partial.lse[qi].is_finite() {
426                        // Recover max from lse: we stored lse = m + ln(l),
427                        // but for merging we use the partial output directly.
428                        partial.lse[qi]
429                    } else {
430                        continue;
431                    };
432
433                    let m_old = row_max[qi];
434                    let m_new = m_old.max(m_partial);
435
436                    let exp_old = if m_old.is_finite() {
437                        (m_old - m_new).exp()
438                    } else {
439                        0.0
440                    };
441                    let exp_partial = (m_partial - m_new).exp();
442
443                    // partial.output is already normalized, so we need to
444                    // un-normalize: partial_unnorm = partial.output * exp(partial.lse)
445                    // For simplicity, use the sum approach:
446                    let l_partial = if partial.lse[qi].is_finite() {
447                        partial.lse[qi].exp()
448                    } else {
449                        0.0
450                    };
451                    let l_old = row_sum[qi];
452
453                    let l_new = exp_old * l_old + exp_partial * l_partial;
454
455                    if l_new > 0.0 {
456                        let inv_l = 1.0 / l_new;
457                        for dd in 0..d {
458                            output[qi][dd] = (exp_old * l_old * output[qi][dd]
459                                + exp_partial * l_partial * partial.output[qi][dd])
460                                * inv_l;
461                        }
462                    }
463
464                    row_max[qi] = m_new;
465                    row_sum[qi] = l_new;
466                }
467            }
468
469            // Final LSE
470            for qi in 0..n_q {
471                if row_sum[qi] > 0.0 && row_max[qi].is_finite() {
472                    lse[qi] = row_max[qi] + row_sum[qi].ln();
473                }
474            }
475
476            results.push(RingDeviceOutput {
477                output,
478                lse,
479                transfers,
480            });
481        }
482
483        Ok(results)
484    }
485}
486
487/// Computes naive (standard) attention for correctness comparison.
488/// Returns (output, attention_weights) where output is [n_q, d].
489fn naive_attention(
490    q: &[Vec<f32>],
491    k: &[Vec<f32>],
492    v: &[Vec<f32>],
493    causal: bool,
494) -> Vec<Vec<f32>> {
495    let n_q = q.len();
496    let n_kv = k.len();
497    let d = q[0].len();
498    let scale = 1.0 / (d as f32).sqrt();
499
500    let mut output = vec![vec![0.0f32; d]; n_q];
501
502    for qi in 0..n_q {
503        // Compute scores
504        let mut scores = Vec::with_capacity(n_kv);
505        for kj in 0..n_kv {
506            let mut dot = 0.0f32;
507            for dd in 0..d {
508                dot += q[qi][dd] * k[kj][dd];
509            }
510            let mut s = dot * scale;
511            if causal && kj > qi {
512                s = f32::NEG_INFINITY;
513            }
514            scores.push(s);
515        }
516
517        // Softmax
518        let max_s = scores.iter().copied().fold(f32::NEG_INFINITY, f32::max);
519        let exp_s: Vec<f32> = scores.iter().map(|&s| (s - max_s).exp()).collect();
520        let sum_s: f32 = exp_s.iter().sum();
521
522        // Weighted sum
523        for dd in 0..d {
524            let mut val = 0.0f32;
525            for kj in 0..n_kv {
526                val += (exp_s[kj] / sum_s) * v[kj][dd];
527            }
528            output[qi][dd] = val;
529        }
530    }
531
532    output
533}
534
535#[cfg(test)]
536mod tests {
537    use super::*;
538
539    fn make_seq(n: usize, d: usize, seed: f32) -> Vec<Vec<f32>> {
540        (0..n)
541            .map(|i| {
542                (0..d)
543                    .map(|j| ((i as f32 + 1.0) * (j as f32 + 1.0) * seed).sin() * 0.5)
544                    .collect()
545            })
546            .collect()
547    }
548
549    #[test]
550    fn test_forward_matches_naive() {
551        let d = 16;
552        let n = 12;
553        let q = make_seq(n, d, 0.1);
554        let k = make_seq(n, d, 0.2);
555        let v = make_seq(n, d, 0.3);
556
557        let config = FlashConfig::new(4, 4).unwrap();
558        let flash = FlashAttention3::forward(&q, &k, &v, &config).unwrap();
559        let naive = naive_attention(&q, &k, &v, false);
560
561        for qi in 0..n {
562            for dd in 0..d {
563                let diff = (flash.output[qi][dd] - naive[qi][dd]).abs();
564                assert!(diff < 1e-4, "row={qi} col={dd} flash={} naive={} diff={diff}",
565                    flash.output[qi][dd], naive[qi][dd]);
566            }
567        }
568    }
569
570    #[test]
571    fn test_causal_masking() {
572        let d = 8;
573        let n = 6;
574        let q = make_seq(n, d, 0.4);
575        let k = make_seq(n, d, 0.5);
576        let v = make_seq(n, d, 0.6);
577
578        let config = FlashConfig::new(2, 2).unwrap().with_causal();
579        let flash = FlashAttention3::forward(&q, &k, &v, &config).unwrap();
580        let naive = naive_attention(&q, &k, &v, true);
581
582        for qi in 0..n {
583            for dd in 0..d {
584                let diff = (flash.output[qi][dd] - naive[qi][dd]).abs();
585                assert!(diff < 1e-4, "causal row={qi} col={dd} diff={diff}");
586            }
587        }
588    }
589
590    #[test]
591    fn test_numerical_stability_large_values() {
592        let d = 8;
593        let n = 4;
594        // Use large values that could cause overflow without stable softmax
595        let q: Vec<Vec<f32>> = (0..n)
596            .map(|i| vec![100.0 * (i as f32 + 1.0); d])
597            .collect();
598        let k = q.clone();
599        let v: Vec<Vec<f32>> = (0..n).map(|i| vec![i as f32; d]).collect();
600
601        let config = FlashConfig::new(2, 2).unwrap();
602        let result = FlashAttention3::forward(&q, &k, &v, &config).unwrap();
603
604        // Output should contain finite values (no NaN/Inf)
605        for row in &result.output {
606            for &val in row {
607                assert!(val.is_finite(), "Non-finite output: {val}");
608            }
609        }
610        for &l in &result.lse {
611            assert!(l.is_finite(), "Non-finite LSE: {l}");
612        }
613    }
614
615    #[test]
616    fn test_block_size_variations() {
617        let d = 8;
618        let n = 10;
619        let q = make_seq(n, d, 0.7);
620        let k = make_seq(n, d, 0.8);
621        let v = make_seq(n, d, 0.9);
622
623        let block_sizes = [(2, 2), (3, 5), (1, 1), (10, 10), (7, 3)];
624        let naive = naive_attention(&q, &k, &v, false);
625
626        for (bq, bk) in block_sizes {
627            let config = FlashConfig::new(bq, bk).unwrap();
628            let flash = FlashAttention3::forward(&q, &k, &v, &config).unwrap();
629
630            for qi in 0..n {
631                for dd in 0..d {
632                    let diff = (flash.output[qi][dd] - naive[qi][dd]).abs();
633                    assert!(
634                        diff < 1e-4,
635                        "blocks=({bq},{bk}) row={qi} col={dd} diff={diff}"
636                    );
637                }
638            }
639        }
640    }
641
642    #[test]
643    fn test_io_stats_tracking() {
644        let d = 8;
645        let n = 16;
646        let q = make_seq(n, d, 1.0);
647        let k = make_seq(n, d, 1.1);
648        let v = make_seq(n, d, 1.2);
649
650        let config = FlashConfig::new(4, 4).unwrap();
651        let result = FlashAttention3::forward(&q, &k, &v, &config).unwrap();
652
653        assert!(result.stats.total_flops > 0, "FLOPs should be tracked");
654        assert!(result.stats.memory_reads > 0, "Reads should be tracked");
655        assert!(result.stats.memory_writes > 0, "Writes should be tracked");
656        assert_eq!(result.stats.memory_complexity(), "O(N)");
657        assert_eq!(result.stats.naive_memory_complexity(), "O(N^2)");
658
659        let ratio = result.stats.flop_ratio();
660        assert!(ratio > 0.0, "IO ratio should be positive");
661    }
662
663    #[test]
664    fn test_ring_attention() {
665        let d = 8;
666        let shard_size = 4;
667        let num_devices = 3;
668
669        let q_shards: Vec<Vec<Vec<f32>>> = (0..num_devices)
670            .map(|dev| make_seq(shard_size, d, 0.1 * (dev as f32 + 1.0)))
671            .collect();
672        let k_shards: Vec<Vec<Vec<f32>>> = (0..num_devices)
673            .map(|dev| make_seq(shard_size, d, 0.2 * (dev as f32 + 1.0)))
674            .collect();
675        let v_shards: Vec<Vec<Vec<f32>>> = (0..num_devices)
676            .map(|dev| make_seq(shard_size, d, 0.3 * (dev as f32 + 1.0)))
677            .collect();
678
679        let results =
680            RingAttention::ring_forward(&q_shards, &k_shards, &v_shards).unwrap();
681
682        assert_eq!(results.len(), num_devices);
683        for (dev_id, res) in results.iter().enumerate() {
684            assert_eq!(res.output.len(), shard_size);
685            assert_eq!(res.output[0].len(), d);
686            // Each device except first does (num_devices - 1) transfers
687            assert_eq!(res.transfers, num_devices - 1,
688                "Device {dev_id} should have {} transfers", num_devices - 1);
689            for row in &res.output {
690                for &val in row {
691                    assert!(val.is_finite(), "Device {dev_id} has non-finite output");
692                }
693            }
694        }
695    }
696
697    #[test]
698    fn test_single_block() {
699        // When block size >= sequence length, should behave identically to naive
700        let d = 4;
701        let n = 3;
702        let q = make_seq(n, d, 1.5);
703        let k = make_seq(n, d, 1.6);
704        let v = make_seq(n, d, 1.7);
705
706        let config = FlashConfig::new(n, n).unwrap();
707        let flash = FlashAttention3::forward(&q, &k, &v, &config).unwrap();
708        let naive = naive_attention(&q, &k, &v, false);
709
710        for qi in 0..n {
711            for dd in 0..d {
712                let diff = (flash.output[qi][dd] - naive[qi][dd]).abs();
713                assert!(diff < 1e-5, "single block row={qi} col={dd} diff={diff}");
714            }
715        }
716    }
717
718    #[test]
719    fn test_large_sequence() {
720        let d = 16;
721        let n = 128;
722        let q = make_seq(n, d, 2.0);
723        let k = make_seq(n, d, 2.1);
724        let v = make_seq(n, d, 2.2);
725
726        let config = FlashConfig::new(16, 16).unwrap();
727        let flash = FlashAttention3::forward(&q, &k, &v, &config).unwrap();
728        let naive = naive_attention(&q, &k, &v, false);
729
730        let mut max_diff = 0.0f32;
731        for qi in 0..n {
732            for dd in 0..d {
733                max_diff = max_diff.max((flash.output[qi][dd] - naive[qi][dd]).abs());
734            }
735        }
736        assert!(max_diff < 1e-3, "Large seq max diff: {max_diff}");
737    }
738
739    #[test]
740    fn test_lse_correctness() {
741        let d = 8;
742        let n = 6;
743        let q = make_seq(n, d, 3.0);
744        let k = make_seq(n, d, 3.1);
745        let v = make_seq(n, d, 3.2);
746        let scale = 1.0 / (d as f32).sqrt();
747
748        let config = FlashConfig::new(2, 3).unwrap();
749        let result = FlashAttention3::forward(&q, &k, &v, &config).unwrap();
750
751        // Verify LSE: for each query, compute log(sum(exp(scores))) manually
752        for qi in 0..n {
753            let mut scores = Vec::with_capacity(n);
754            for kj in 0..n {
755                let dot: f32 = (0..d).map(|dd| q[qi][dd] * k[kj][dd]).sum();
756                scores.push(dot * scale);
757            }
758            let max_s = scores.iter().copied().fold(f32::NEG_INFINITY, f32::max);
759            let sum_exp: f32 = scores.iter().map(|&s| (s - max_s).exp()).sum();
760            let expected_lse = max_s + sum_exp.ln();
761
762            let diff = (result.lse[qi] - expected_lse).abs();
763            assert!(diff < 1e-3, "LSE row={qi} flash={} expected={expected_lse} diff={diff}",
764                result.lse[qi]);
765        }
766    }
767
768    #[test]
769    fn test_causal_block_mask_utility() {
770        let mask = causal_block_mask(2, 5, 0, 4);
771        // qi=2: kj 0,1,2 allowed, 3 not
772        assert_eq!(mask[0], vec![true, true, true, false]);
773        // qi=3: kj 0,1,2,3 allowed
774        assert_eq!(mask[1], vec![true, true, true, true]);
775        // qi=4: all allowed
776        assert_eq!(mask[2], vec![true, true, true, true]);
777    }
778
779    #[test]
780    fn test_empty_input_errors() {
781        let config = FlashConfig::default();
782        let empty: Vec<Vec<f32>> = vec![];
783        let q = vec![vec![1.0; 4]];
784
785        assert!(FlashAttention3::forward(&empty, &q, &q, &config).is_err());
786        assert!(FlashAttention3::forward(&q, &empty, &q, &config).is_err());
787        assert!(FlashAttention3::forward(&q, &q, &empty, &config).is_err());
788    }
789
790    #[test]
791    fn test_config_validation() {
792        assert!(FlashConfig::new(0, 4).is_err());
793        assert!(FlashConfig::new(4, 0).is_err());
794        assert!(FlashConfig::new(4, 4).is_ok());
795
796        assert!(FlashConfig::default().with_dropout(1.5).is_err());
797        assert!(FlashConfig::default().with_dropout(-0.1).is_err());
798        assert!(FlashConfig::default().with_dropout(0.5).is_ok());
799    }
800}