Skip to main content

ruvector_attention/attention/
flash.rs

1//! FlashAttention-3 IO-aware tiled attention.
2//!
3//! Implements the FlashAttention algorithm which reduces HBM (High Bandwidth Memory)
4//! reads from O(N^2 d) to O(N^2 d^2 / M) where M is SRAM size, by tiling Q, K, V
5//! into blocks and fusing the softmax rescaling with the matmul accumulation.
6//!
7//! The key insight is that standard attention materializes the full N x N attention
8//! matrix in HBM, causing O(N^2) memory. FlashAttention never materializes this
9//! matrix, instead computing attention in tiles using an online softmax algorithm
10//! that maintains running statistics (row-max and log-sum-exp) to avoid the
11//! two-pass softmax.
12//!
13//! This module provides:
14//! - [`FlashConfig`]: Configuration for block sizes, causal masking, and dropout
15//! - [`FlashAttention3`]: IO-aware tiled forward pass returning output + LSE
16//! - [`IOStats`]: Tracking of FLOPs and memory transfer for IO analysis
17//! - [`RingAttention`]: Simplified ring-based distributed attention across devices
18
19use crate::error::{AttentionError, AttentionResult};
20
21/// Configuration for FlashAttention tiled computation.
22#[derive(Clone, Debug)]
23pub struct FlashConfig {
24    /// Block size along the query dimension (Br).
25    pub block_size_q: usize,
26    /// Block size along the key/value dimension (Bc).
27    pub block_size_kv: usize,
28    /// Whether to apply causal masking (upper-triangular mask).
29    pub causal: bool,
30    /// Dropout probability (0.0 = no dropout). Applied conceptually but not
31    /// stochastically in this CPU implementation.
32    pub dropout_p: f32,
33}
34
35impl Default for FlashConfig {
36    fn default() -> Self {
37        Self {
38            block_size_q: 64,
39            block_size_kv: 64,
40            causal: false,
41            dropout_p: 0.0,
42        }
43    }
44}
45
46impl FlashConfig {
47    /// Creates a config with custom block sizes.
48    pub fn new(block_size_q: usize, block_size_kv: usize) -> AttentionResult<Self> {
49        if block_size_q == 0 || block_size_kv == 0 {
50            return Err(AttentionError::InvalidConfig(
51                "Block sizes must be > 0".into(),
52            ));
53        }
54        Ok(Self {
55            block_size_q,
56            block_size_kv,
57            ..Default::default()
58        })
59    }
60
61    /// Returns a causal variant of this config.
62    pub fn with_causal(mut self) -> Self {
63        self.causal = true;
64        self
65    }
66
67    /// Sets the dropout probability.
68    pub fn with_dropout(mut self, p: f32) -> AttentionResult<Self> {
69        if !(0.0..=1.0).contains(&p) {
70            return Err(AttentionError::InvalidConfig(
71                "Dropout must be in [0, 1]".into(),
72            ));
73        }
74        self.dropout_p = p;
75        Ok(self)
76    }
77}
78
79/// IO statistics for comparing tiled vs naive attention.
80#[derive(Clone, Debug, Default)]
81pub struct IOStats {
82    /// Total floating-point operations performed.
83    pub total_flops: u64,
84    /// Total elements read from main memory.
85    pub memory_reads: u64,
86    /// Total elements written to main memory.
87    pub memory_writes: u64,
88    /// Sequence length used for the computation.
89    seq_len: usize,
90    /// Head dimension used for the computation.
91    head_dim: usize,
92    /// Block size Q used.
93    #[allow(dead_code)]
94    block_size_q: usize,
95    /// Block size KV used.
96    #[allow(dead_code)]
97    block_size_kv: usize,
98}
99
100impl IOStats {
101    /// Returns the ratio of naive FLOPs to tiled FLOPs (should be ~1.0 since
102    /// FLOPs are the same; the advantage is in memory IO).
103    pub fn flop_ratio(&self) -> f32 {
104        if self.total_flops == 0 {
105            return 1.0;
106        }
107        // Naive attention has same FLOPs but materializes N^2 attention matrix.
108        // The IO ratio compares memory transfers: naive reads/writes O(N^2 + Nd),
109        // tiled reads/writes O(N^2 d / M) where M ~ block_size.
110        let n = self.seq_len as f64;
111        let d = self.head_dim as f64;
112        let naive_io = n * n + n * d; // attention matrix + QKV
113        let tiled_io = self.memory_reads as f64 + self.memory_writes as f64;
114        if tiled_io < 1.0 {
115            return 1.0;
116        }
117        (naive_io / tiled_io) as f32
118    }
119
120    /// Returns the memory complexity class as a string.
121    /// Tiled: O(N) working memory. Naive: O(N^2).
122    pub fn memory_complexity(&self) -> &'static str {
123        "O(N)"
124    }
125
126    /// Returns the naive attention memory complexity for comparison.
127    pub fn naive_memory_complexity(&self) -> &'static str {
128        "O(N^2)"
129    }
130}
131
132/// FlashAttention-3: IO-aware tiled attention.
133///
134/// Processes Q in blocks of Br rows and K/V in blocks of Bc rows, never
135/// materializing the full N x N attention matrix. Uses online softmax with
136/// running max and log-sum-exp to maintain numerical stability.
137pub struct FlashAttention3;
138
139/// Output of a flash attention forward pass.
140#[derive(Clone, Debug)]
141pub struct FlashOutput {
142    /// The attention output matrix, shape [num_queries, dim].
143    pub output: Vec<Vec<f32>>,
144    /// Log-sum-exp per query row (m_i + ln(l_i)), used for backward pass.
145    pub lse: Vec<f32>,
146    /// IO statistics for this computation.
147    pub stats: IOStats,
148}
149
150impl FlashAttention3 {
151    /// Computes IO-aware tiled attention.
152    ///
153    /// # Algorithm
154    ///
155    /// 1. Split Q into Tr blocks of Br rows, K/V into Tc blocks of Bc rows.
156    /// 2. For each Q block i, iterate over all K/V blocks j:
157    ///    - Compute S_ij = Q_i @ K_j^T / sqrt(d)
158    ///    - Apply causal mask if configured
159    ///    - Update running max, sum-exp, and output using online softmax
160    /// 3. Return output and log-sum-exp for backward pass.
161    ///
162    /// # Arguments
163    ///
164    /// * `q` - Query matrix, shape [n_q, d]
165    /// * `k` - Key matrix, shape [n_kv, d]
166    /// * `v` - Value matrix, shape [n_kv, d]
167    /// * `config` - Flash attention configuration
168    pub fn forward(
169        q: &[Vec<f32>],
170        k: &[Vec<f32>],
171        v: &[Vec<f32>],
172        config: &FlashConfig,
173    ) -> AttentionResult<FlashOutput> {
174        if q.is_empty() {
175            return Err(AttentionError::EmptyInput("queries".into()));
176        }
177        if k.is_empty() || v.is_empty() {
178            return Err(AttentionError::EmptyInput("keys or values".into()));
179        }
180        if k.len() != v.len() {
181            return Err(AttentionError::DimensionMismatch {
182                expected: k.len(),
183                actual: v.len(),
184            });
185        }
186        let d = q[0].len();
187        if d == 0 {
188            return Err(AttentionError::InvalidConfig(
189                "Dimension must be > 0".into(),
190            ));
191        }
192        let scale = 1.0 / (d as f32).sqrt();
193        let n_q = q.len();
194        let n_kv = k.len();
195        let br = config.block_size_q;
196        let bc = config.block_size_kv;
197
198        let mut output = vec![vec![0.0f32; d]; n_q];
199        let mut lse = vec![f32::NEG_INFINITY; n_q];
200        let mut row_max = vec![f32::NEG_INFINITY; n_q];
201        let mut row_sum = vec![0.0f32; n_q];
202
203        let mut stats = IOStats {
204            seq_len: n_q.max(n_kv),
205            head_dim: d,
206            block_size_q: br,
207            block_size_kv: bc,
208            ..Default::default()
209        };
210
211        // Outer loop: iterate over Q blocks
212        for qi_start in (0..n_q).step_by(br) {
213            let qi_end = (qi_start + br).min(n_q);
214
215            // Inner loop: iterate over K/V blocks
216            for kj_start in (0..n_kv).step_by(bc) {
217                let kj_end = (kj_start + bc).min(n_kv);
218
219                // Track memory reads: Q block + K block + V block
220                stats.memory_reads +=
221                    ((qi_end - qi_start) * d + (kj_end - kj_start) * d * 2) as u64;
222
223                // For each query row in this Q block
224                for qi in qi_start..qi_end {
225                    // Compute S_ij = Q_i @ K_j^T / sqrt(d) for each key in block
226                    let mut block_scores = Vec::with_capacity(kj_end - kj_start);
227                    for kj in kj_start..kj_end {
228                        let mut dot = 0.0f32;
229                        for dd in 0..d {
230                            dot += q[qi][dd] * k[kj][dd];
231                        }
232                        let mut score = dot * scale;
233
234                        // Apply causal mask: mask out positions where kj > qi
235                        if config.causal && kj > qi {
236                            score = f32::NEG_INFINITY;
237                        }
238                        block_scores.push(score);
239                        stats.total_flops += (2 * d) as u64; // dot product
240                    }
241
242                    // Block row-max
243                    let m_ij = block_scores
244                        .iter()
245                        .copied()
246                        .fold(f32::NEG_INFINITY, f32::max);
247
248                    if !m_ij.is_finite() {
249                        continue; // Fully masked block
250                    }
251
252                    // Exponentiate and sum
253                    let exp_scores: Vec<f32> =
254                        block_scores.iter().map(|&s| (s - m_ij).exp()).collect();
255                    let l_ij: f32 = exp_scores.iter().filter(|x| x.is_finite()).sum();
256
257                    // Online softmax rescaling
258                    let m_old = row_max[qi];
259                    let m_new = m_old.max(m_ij);
260
261                    let exp_old = if m_old.is_finite() {
262                        (m_old - m_new).exp()
263                    } else {
264                        0.0
265                    };
266                    let exp_new = (m_ij - m_new).exp();
267
268                    let l_new = exp_old * row_sum[qi] + exp_new * l_ij;
269
270                    // Rescale existing output and add new contribution
271                    // O_i = (exp(m_old - m_new) * l_old * O_i
272                    //      + exp(m_ij - m_new) * P_ij @ V_j) / l_new
273                    if l_new > 0.0 {
274                        let inv_l_new = 1.0 / l_new;
275                        let scale_old = exp_old * row_sum[qi] * inv_l_new;
276                        let scale_new = exp_new * inv_l_new;
277
278                        for dd in 0..d {
279                            let mut pv = 0.0f32;
280                            for (local_j, kj) in (kj_start..kj_end).enumerate() {
281                                if exp_scores[local_j].is_finite() {
282                                    pv += exp_scores[local_j] * v[kj][dd];
283                                }
284                            }
285                            output[qi][dd] = scale_old * output[qi][dd] + scale_new * pv;
286                            stats.total_flops += (2 * (kj_end - kj_start)) as u64;
287                        }
288                    }
289
290                    row_max[qi] = m_new;
291                    row_sum[qi] = l_new;
292                }
293            }
294
295            // Track memory writes: output block
296            stats.memory_writes += ((qi_end - qi_start) * d) as u64;
297        }
298
299        // Compute LSE = m + ln(l) for backward pass
300        for i in 0..n_q {
301            if row_sum[i] > 0.0 && row_max[i].is_finite() {
302                lse[i] = row_max[i] + row_sum[i].ln();
303            }
304        }
305
306        Ok(FlashOutput { output, lse, stats })
307    }
308}
309
310/// Generates a causal mask for block (qi_start..qi_end) x (kj_start..kj_end)
311/// without materializing a full N x N mask.
312///
313/// Returns `true` for positions that should be attended to (kj <= qi).
314pub fn causal_block_mask(
315    qi_start: usize,
316    qi_end: usize,
317    kj_start: usize,
318    kj_end: usize,
319) -> Vec<Vec<bool>> {
320    let mut mask = Vec::with_capacity(qi_end - qi_start);
321    for qi in qi_start..qi_end {
322        let mut row = Vec::with_capacity(kj_end - kj_start);
323        for kj in kj_start..kj_end {
324            row.push(kj <= qi);
325        }
326        mask.push(row);
327    }
328    mask
329}
330
331/// Simplified ring attention for distributed sequence parallelism.
332///
333/// In ring attention, the sequence is sharded across devices. Each device holds
334/// a local Q shard and rotates K/V shards around a ring, accumulating partial
335/// attention using the same online softmax as FlashAttention.
336pub struct RingAttention;
337
338/// Result from a single device in ring attention.
339#[derive(Clone, Debug)]
340pub struct RingDeviceOutput {
341    /// Output for this device's Q shard.
342    pub output: Vec<Vec<f32>>,
343    /// LSE for this device's Q shard.
344    pub lse: Vec<f32>,
345    /// Number of simulated ring transfers.
346    pub transfers: usize,
347}
348
349impl RingAttention {
350    /// Runs ring attention across simulated devices.
351    ///
352    /// Each device holds a Q shard and processes all K/V shards by rotating
353    /// them around the ring. This simulates the communication pattern of
354    /// distributed ring attention.
355    ///
356    /// # Arguments
357    ///
358    /// * `q_shards` - Q shards, one per device
359    /// * `k_shards` - K shards, one per device
360    /// * `v_shards` - V shards, one per device
361    pub fn ring_forward(
362        q_shards: &[Vec<Vec<f32>>],
363        k_shards: &[Vec<Vec<f32>>],
364        v_shards: &[Vec<Vec<f32>>],
365    ) -> AttentionResult<Vec<RingDeviceOutput>> {
366        let num_devices = q_shards.len();
367        if num_devices == 0 {
368            return Err(AttentionError::EmptyInput("shards".into()));
369        }
370        if k_shards.len() != num_devices || v_shards.len() != num_devices {
371            return Err(AttentionError::DimensionMismatch {
372                expected: num_devices,
373                actual: k_shards.len().min(v_shards.len()),
374            });
375        }
376
377        let config = FlashConfig {
378            block_size_q: 32,
379            block_size_kv: 32,
380            causal: false,
381            dropout_p: 0.0,
382        };
383
384        let mut results = Vec::with_capacity(num_devices);
385
386        // Each device processes its local Q against all K/V shards
387        for device_id in 0..num_devices {
388            let local_q = &q_shards[device_id];
389            if local_q.is_empty() {
390                return Err(AttentionError::EmptyInput(format!(
391                    "Q shard on device {device_id}"
392                )));
393            }
394            let d = local_q[0].len();
395            let n_q = local_q.len();
396
397            let mut output = vec![vec![0.0f32; d]; n_q];
398            let mut row_max = vec![f32::NEG_INFINITY; n_q];
399            let mut row_sum = vec![0.0f32; n_q];
400            let mut lse = vec![f32::NEG_INFINITY; n_q];
401            let mut transfers = 0usize;
402
403            // Rotate through all K/V shards (ring communication)
404            for step in 0..num_devices {
405                let kv_idx = (device_id + step) % num_devices;
406                if step > 0 {
407                    transfers += 1; // Simulated device-to-device transfer
408                }
409
410                let partial = FlashAttention3::forward(
411                    local_q,
412                    &k_shards[kv_idx],
413                    &v_shards[kv_idx],
414                    &config,
415                )?;
416
417                // Merge partial results using online softmax
418                for qi in 0..n_q {
419                    let m_partial = if partial.lse[qi].is_finite() {
420                        // Recover max from lse: we stored lse = m + ln(l),
421                        // but for merging we use the partial output directly.
422                        partial.lse[qi]
423                    } else {
424                        continue;
425                    };
426
427                    let m_old = row_max[qi];
428                    let m_new = m_old.max(m_partial);
429
430                    let exp_old = if m_old.is_finite() {
431                        (m_old - m_new).exp()
432                    } else {
433                        0.0
434                    };
435                    let exp_partial = (m_partial - m_new).exp();
436
437                    // partial.output is already normalized, so we need to
438                    // un-normalize: partial_unnorm = partial.output * exp(partial.lse)
439                    // For simplicity, use the sum approach:
440                    let l_partial = if partial.lse[qi].is_finite() {
441                        partial.lse[qi].exp()
442                    } else {
443                        0.0
444                    };
445                    let l_old = row_sum[qi];
446
447                    let l_new = exp_old * l_old + exp_partial * l_partial;
448
449                    if l_new > 0.0 {
450                        let inv_l = 1.0 / l_new;
451                        for dd in 0..d {
452                            output[qi][dd] = (exp_old * l_old * output[qi][dd]
453                                + exp_partial * l_partial * partial.output[qi][dd])
454                                * inv_l;
455                        }
456                    }
457
458                    row_max[qi] = m_new;
459                    row_sum[qi] = l_new;
460                }
461            }
462
463            // Final LSE
464            for qi in 0..n_q {
465                if row_sum[qi] > 0.0 && row_max[qi].is_finite() {
466                    lse[qi] = row_max[qi] + row_sum[qi].ln();
467                }
468            }
469
470            results.push(RingDeviceOutput {
471                output,
472                lse,
473                transfers,
474            });
475        }
476
477        Ok(results)
478    }
479}
480
481/// Computes naive (standard) attention for correctness comparison.
482/// Returns (output, attention_weights) where output is [n_q, d].
483fn naive_attention(q: &[Vec<f32>], k: &[Vec<f32>], v: &[Vec<f32>], causal: bool) -> Vec<Vec<f32>> {
484    let n_q = q.len();
485    let n_kv = k.len();
486    let d = q[0].len();
487    let scale = 1.0 / (d as f32).sqrt();
488
489    let mut output = vec![vec![0.0f32; d]; n_q];
490
491    for qi in 0..n_q {
492        // Compute scores
493        let mut scores = Vec::with_capacity(n_kv);
494        for kj in 0..n_kv {
495            let mut dot = 0.0f32;
496            for dd in 0..d {
497                dot += q[qi][dd] * k[kj][dd];
498            }
499            let mut s = dot * scale;
500            if causal && kj > qi {
501                s = f32::NEG_INFINITY;
502            }
503            scores.push(s);
504        }
505
506        // Softmax
507        let max_s = scores.iter().copied().fold(f32::NEG_INFINITY, f32::max);
508        let exp_s: Vec<f32> = scores.iter().map(|&s| (s - max_s).exp()).collect();
509        let sum_s: f32 = exp_s.iter().sum();
510
511        // Weighted sum
512        for dd in 0..d {
513            let mut val = 0.0f32;
514            for kj in 0..n_kv {
515                val += (exp_s[kj] / sum_s) * v[kj][dd];
516            }
517            output[qi][dd] = val;
518        }
519    }
520
521    output
522}
523
524#[cfg(test)]
525mod tests {
526    use super::*;
527
528    fn make_seq(n: usize, d: usize, seed: f32) -> Vec<Vec<f32>> {
529        (0..n)
530            .map(|i| {
531                (0..d)
532                    .map(|j| ((i as f32 + 1.0) * (j as f32 + 1.0) * seed).sin() * 0.5)
533                    .collect()
534            })
535            .collect()
536    }
537
538    #[test]
539    fn test_forward_matches_naive() {
540        let d = 16;
541        let n = 12;
542        let q = make_seq(n, d, 0.1);
543        let k = make_seq(n, d, 0.2);
544        let v = make_seq(n, d, 0.3);
545
546        let config = FlashConfig::new(4, 4).unwrap();
547        let flash = FlashAttention3::forward(&q, &k, &v, &config).unwrap();
548        let naive = naive_attention(&q, &k, &v, false);
549
550        for qi in 0..n {
551            for dd in 0..d {
552                let diff = (flash.output[qi][dd] - naive[qi][dd]).abs();
553                assert!(
554                    diff < 1e-4,
555                    "row={qi} col={dd} flash={} naive={} diff={diff}",
556                    flash.output[qi][dd],
557                    naive[qi][dd]
558                );
559            }
560        }
561    }
562
563    #[test]
564    fn test_causal_masking() {
565        let d = 8;
566        let n = 6;
567        let q = make_seq(n, d, 0.4);
568        let k = make_seq(n, d, 0.5);
569        let v = make_seq(n, d, 0.6);
570
571        let config = FlashConfig::new(2, 2).unwrap().with_causal();
572        let flash = FlashAttention3::forward(&q, &k, &v, &config).unwrap();
573        let naive = naive_attention(&q, &k, &v, true);
574
575        for qi in 0..n {
576            for dd in 0..d {
577                let diff = (flash.output[qi][dd] - naive[qi][dd]).abs();
578                assert!(diff < 1e-4, "causal row={qi} col={dd} diff={diff}");
579            }
580        }
581    }
582
583    #[test]
584    fn test_numerical_stability_large_values() {
585        let d = 8;
586        let n = 4;
587        // Use large values that could cause overflow without stable softmax
588        let q: Vec<Vec<f32>> = (0..n).map(|i| vec![100.0 * (i as f32 + 1.0); d]).collect();
589        let k = q.clone();
590        let v: Vec<Vec<f32>> = (0..n).map(|i| vec![i as f32; d]).collect();
591
592        let config = FlashConfig::new(2, 2).unwrap();
593        let result = FlashAttention3::forward(&q, &k, &v, &config).unwrap();
594
595        // Output should contain finite values (no NaN/Inf)
596        for row in &result.output {
597            for &val in row {
598                assert!(val.is_finite(), "Non-finite output: {val}");
599            }
600        }
601        for &l in &result.lse {
602            assert!(l.is_finite(), "Non-finite LSE: {l}");
603        }
604    }
605
606    #[test]
607    fn test_block_size_variations() {
608        let d = 8;
609        let n = 10;
610        let q = make_seq(n, d, 0.7);
611        let k = make_seq(n, d, 0.8);
612        let v = make_seq(n, d, 0.9);
613
614        let block_sizes = [(2, 2), (3, 5), (1, 1), (10, 10), (7, 3)];
615        let naive = naive_attention(&q, &k, &v, false);
616
617        for (bq, bk) in block_sizes {
618            let config = FlashConfig::new(bq, bk).unwrap();
619            let flash = FlashAttention3::forward(&q, &k, &v, &config).unwrap();
620
621            for qi in 0..n {
622                for dd in 0..d {
623                    let diff = (flash.output[qi][dd] - naive[qi][dd]).abs();
624                    assert!(
625                        diff < 1e-4,
626                        "blocks=({bq},{bk}) row={qi} col={dd} diff={diff}"
627                    );
628                }
629            }
630        }
631    }
632
633    #[test]
634    fn test_io_stats_tracking() {
635        let d = 8;
636        let n = 16;
637        let q = make_seq(n, d, 1.0);
638        let k = make_seq(n, d, 1.1);
639        let v = make_seq(n, d, 1.2);
640
641        let config = FlashConfig::new(4, 4).unwrap();
642        let result = FlashAttention3::forward(&q, &k, &v, &config).unwrap();
643
644        assert!(result.stats.total_flops > 0, "FLOPs should be tracked");
645        assert!(result.stats.memory_reads > 0, "Reads should be tracked");
646        assert!(result.stats.memory_writes > 0, "Writes should be tracked");
647        assert_eq!(result.stats.memory_complexity(), "O(N)");
648        assert_eq!(result.stats.naive_memory_complexity(), "O(N^2)");
649
650        let ratio = result.stats.flop_ratio();
651        assert!(ratio > 0.0, "IO ratio should be positive");
652    }
653
654    #[test]
655    fn test_ring_attention() {
656        let d = 8;
657        let shard_size = 4;
658        let num_devices = 3;
659
660        let q_shards: Vec<Vec<Vec<f32>>> = (0..num_devices)
661            .map(|dev| make_seq(shard_size, d, 0.1 * (dev as f32 + 1.0)))
662            .collect();
663        let k_shards: Vec<Vec<Vec<f32>>> = (0..num_devices)
664            .map(|dev| make_seq(shard_size, d, 0.2 * (dev as f32 + 1.0)))
665            .collect();
666        let v_shards: Vec<Vec<Vec<f32>>> = (0..num_devices)
667            .map(|dev| make_seq(shard_size, d, 0.3 * (dev as f32 + 1.0)))
668            .collect();
669
670        let results = RingAttention::ring_forward(&q_shards, &k_shards, &v_shards).unwrap();
671
672        assert_eq!(results.len(), num_devices);
673        for (dev_id, res) in results.iter().enumerate() {
674            assert_eq!(res.output.len(), shard_size);
675            assert_eq!(res.output[0].len(), d);
676            // Each device except first does (num_devices - 1) transfers
677            assert_eq!(
678                res.transfers,
679                num_devices - 1,
680                "Device {dev_id} should have {} transfers",
681                num_devices - 1
682            );
683            for row in &res.output {
684                for &val in row {
685                    assert!(val.is_finite(), "Device {dev_id} has non-finite output");
686                }
687            }
688        }
689    }
690
691    #[test]
692    fn test_single_block() {
693        // When block size >= sequence length, should behave identically to naive
694        let d = 4;
695        let n = 3;
696        let q = make_seq(n, d, 1.5);
697        let k = make_seq(n, d, 1.6);
698        let v = make_seq(n, d, 1.7);
699
700        let config = FlashConfig::new(n, n).unwrap();
701        let flash = FlashAttention3::forward(&q, &k, &v, &config).unwrap();
702        let naive = naive_attention(&q, &k, &v, false);
703
704        for qi in 0..n {
705            for dd in 0..d {
706                let diff = (flash.output[qi][dd] - naive[qi][dd]).abs();
707                assert!(diff < 1e-5, "single block row={qi} col={dd} diff={diff}");
708            }
709        }
710    }
711
712    #[test]
713    fn test_large_sequence() {
714        let d = 16;
715        let n = 128;
716        let q = make_seq(n, d, 2.0);
717        let k = make_seq(n, d, 2.1);
718        let v = make_seq(n, d, 2.2);
719
720        let config = FlashConfig::new(16, 16).unwrap();
721        let flash = FlashAttention3::forward(&q, &k, &v, &config).unwrap();
722        let naive = naive_attention(&q, &k, &v, false);
723
724        let mut max_diff = 0.0f32;
725        for qi in 0..n {
726            for dd in 0..d {
727                max_diff = max_diff.max((flash.output[qi][dd] - naive[qi][dd]).abs());
728            }
729        }
730        assert!(max_diff < 1e-3, "Large seq max diff: {max_diff}");
731    }
732
733    #[test]
734    fn test_lse_correctness() {
735        let d = 8;
736        let n = 6;
737        let q = make_seq(n, d, 3.0);
738        let k = make_seq(n, d, 3.1);
739        let v = make_seq(n, d, 3.2);
740        let scale = 1.0 / (d as f32).sqrt();
741
742        let config = FlashConfig::new(2, 3).unwrap();
743        let result = FlashAttention3::forward(&q, &k, &v, &config).unwrap();
744
745        // Verify LSE: for each query, compute log(sum(exp(scores))) manually
746        for qi in 0..n {
747            let mut scores = Vec::with_capacity(n);
748            for kj in 0..n {
749                let dot: f32 = (0..d).map(|dd| q[qi][dd] * k[kj][dd]).sum();
750                scores.push(dot * scale);
751            }
752            let max_s = scores.iter().copied().fold(f32::NEG_INFINITY, f32::max);
753            let sum_exp: f32 = scores.iter().map(|&s| (s - max_s).exp()).sum();
754            let expected_lse = max_s + sum_exp.ln();
755
756            let diff = (result.lse[qi] - expected_lse).abs();
757            assert!(
758                diff < 1e-3,
759                "LSE row={qi} flash={} expected={expected_lse} diff={diff}",
760                result.lse[qi]
761            );
762        }
763    }
764
765    #[test]
766    fn test_causal_block_mask_utility() {
767        let mask = causal_block_mask(2, 5, 0, 4);
768        // qi=2: kj 0,1,2 allowed, 3 not
769        assert_eq!(mask[0], vec![true, true, true, false]);
770        // qi=3: kj 0,1,2,3 allowed
771        assert_eq!(mask[1], vec![true, true, true, true]);
772        // qi=4: all allowed
773        assert_eq!(mask[2], vec![true, true, true, true]);
774    }
775
776    #[test]
777    fn test_empty_input_errors() {
778        let config = FlashConfig::default();
779        let empty: Vec<Vec<f32>> = vec![];
780        let q = vec![vec![1.0; 4]];
781
782        assert!(FlashAttention3::forward(&empty, &q, &q, &config).is_err());
783        assert!(FlashAttention3::forward(&q, &empty, &q, &config).is_err());
784        assert!(FlashAttention3::forward(&q, &q, &empty, &config).is_err());
785    }
786
787    #[test]
788    fn test_config_validation() {
789        assert!(FlashConfig::new(0, 4).is_err());
790        assert!(FlashConfig::new(4, 0).is_err());
791        assert!(FlashConfig::new(4, 4).is_ok());
792
793        assert!(FlashConfig::default().with_dropout(1.5).is_err());
794        assert!(FlashConfig::default().with_dropout(-0.1).is_err());
795        assert!(FlashConfig::default().with_dropout(0.5).is_ok());
796    }
797}