1use crate::error::{AttentionError, AttentionResult};
20
21#[derive(Clone, Debug)]
23pub struct FlashConfig {
24 pub block_size_q: usize,
26 pub block_size_kv: usize,
28 pub causal: bool,
30 pub dropout_p: f32,
33}
34
35impl Default for FlashConfig {
36 fn default() -> Self {
37 Self {
38 block_size_q: 64,
39 block_size_kv: 64,
40 causal: false,
41 dropout_p: 0.0,
42 }
43 }
44}
45
46impl FlashConfig {
47 pub fn new(block_size_q: usize, block_size_kv: usize) -> AttentionResult<Self> {
49 if block_size_q == 0 || block_size_kv == 0 {
50 return Err(AttentionError::InvalidConfig(
51 "Block sizes must be > 0".into(),
52 ));
53 }
54 Ok(Self {
55 block_size_q,
56 block_size_kv,
57 ..Default::default()
58 })
59 }
60
61 pub fn with_causal(mut self) -> Self {
63 self.causal = true;
64 self
65 }
66
67 pub fn with_dropout(mut self, p: f32) -> AttentionResult<Self> {
69 if !(0.0..=1.0).contains(&p) {
70 return Err(AttentionError::InvalidConfig(
71 "Dropout must be in [0, 1]".into(),
72 ));
73 }
74 self.dropout_p = p;
75 Ok(self)
76 }
77}
78
79#[derive(Clone, Debug, Default)]
81pub struct IOStats {
82 pub total_flops: u64,
84 pub memory_reads: u64,
86 pub memory_writes: u64,
88 seq_len: usize,
90 head_dim: usize,
92 #[allow(dead_code)]
94 block_size_q: usize,
95 #[allow(dead_code)]
97 block_size_kv: usize,
98}
99
100impl IOStats {
101 pub fn flop_ratio(&self) -> f32 {
104 if self.total_flops == 0 {
105 return 1.0;
106 }
107 let n = self.seq_len as f64;
111 let d = self.head_dim as f64;
112 let naive_io = n * n + n * d; let tiled_io = self.memory_reads as f64 + self.memory_writes as f64;
114 if tiled_io < 1.0 {
115 return 1.0;
116 }
117 (naive_io / tiled_io) as f32
118 }
119
120 pub fn memory_complexity(&self) -> &'static str {
123 "O(N)"
124 }
125
126 pub fn naive_memory_complexity(&self) -> &'static str {
128 "O(N^2)"
129 }
130}
131
132pub struct FlashAttention3;
138
139#[derive(Clone, Debug)]
141pub struct FlashOutput {
142 pub output: Vec<Vec<f32>>,
144 pub lse: Vec<f32>,
146 pub stats: IOStats,
148}
149
150impl FlashAttention3 {
151 pub fn forward(
169 q: &[Vec<f32>],
170 k: &[Vec<f32>],
171 v: &[Vec<f32>],
172 config: &FlashConfig,
173 ) -> AttentionResult<FlashOutput> {
174 if q.is_empty() {
175 return Err(AttentionError::EmptyInput("queries".into()));
176 }
177 if k.is_empty() || v.is_empty() {
178 return Err(AttentionError::EmptyInput("keys or values".into()));
179 }
180 if k.len() != v.len() {
181 return Err(AttentionError::DimensionMismatch {
182 expected: k.len(),
183 actual: v.len(),
184 });
185 }
186 let d = q[0].len();
187 if d == 0 {
188 return Err(AttentionError::InvalidConfig(
189 "Dimension must be > 0".into(),
190 ));
191 }
192 let scale = 1.0 / (d as f32).sqrt();
193 let n_q = q.len();
194 let n_kv = k.len();
195 let br = config.block_size_q;
196 let bc = config.block_size_kv;
197
198 let mut output = vec![vec![0.0f32; d]; n_q];
199 let mut lse = vec![f32::NEG_INFINITY; n_q];
200 let mut row_max = vec![f32::NEG_INFINITY; n_q];
201 let mut row_sum = vec![0.0f32; n_q];
202
203 let mut stats = IOStats {
204 seq_len: n_q.max(n_kv),
205 head_dim: d,
206 block_size_q: br,
207 block_size_kv: bc,
208 ..Default::default()
209 };
210
211 for qi_start in (0..n_q).step_by(br) {
213 let qi_end = (qi_start + br).min(n_q);
214
215 for kj_start in (0..n_kv).step_by(bc) {
217 let kj_end = (kj_start + bc).min(n_kv);
218
219 stats.memory_reads +=
221 ((qi_end - qi_start) * d + (kj_end - kj_start) * d * 2) as u64;
222
223 for qi in qi_start..qi_end {
225 let mut block_scores = Vec::with_capacity(kj_end - kj_start);
227 for kj in kj_start..kj_end {
228 let mut dot = 0.0f32;
229 for dd in 0..d {
230 dot += q[qi][dd] * k[kj][dd];
231 }
232 let mut score = dot * scale;
233
234 if config.causal && kj > qi {
236 score = f32::NEG_INFINITY;
237 }
238 block_scores.push(score);
239 stats.total_flops += (2 * d) as u64; }
241
242 let m_ij = block_scores
244 .iter()
245 .copied()
246 .fold(f32::NEG_INFINITY, f32::max);
247
248 if !m_ij.is_finite() {
249 continue; }
251
252 let exp_scores: Vec<f32> =
254 block_scores.iter().map(|&s| (s - m_ij).exp()).collect();
255 let l_ij: f32 = exp_scores.iter().filter(|x| x.is_finite()).sum();
256
257 let m_old = row_max[qi];
259 let m_new = m_old.max(m_ij);
260
261 let exp_old = if m_old.is_finite() {
262 (m_old - m_new).exp()
263 } else {
264 0.0
265 };
266 let exp_new = (m_ij - m_new).exp();
267
268 let l_new = exp_old * row_sum[qi] + exp_new * l_ij;
269
270 if l_new > 0.0 {
274 let inv_l_new = 1.0 / l_new;
275 let scale_old = exp_old * row_sum[qi] * inv_l_new;
276 let scale_new = exp_new * inv_l_new;
277
278 for dd in 0..d {
279 let mut pv = 0.0f32;
280 for (local_j, kj) in (kj_start..kj_end).enumerate() {
281 if exp_scores[local_j].is_finite() {
282 pv += exp_scores[local_j] * v[kj][dd];
283 }
284 }
285 output[qi][dd] = scale_old * output[qi][dd] + scale_new * pv;
286 stats.total_flops += (2 * (kj_end - kj_start)) as u64;
287 }
288 }
289
290 row_max[qi] = m_new;
291 row_sum[qi] = l_new;
292 }
293 }
294
295 stats.memory_writes += ((qi_end - qi_start) * d) as u64;
297 }
298
299 for i in 0..n_q {
301 if row_sum[i] > 0.0 && row_max[i].is_finite() {
302 lse[i] = row_max[i] + row_sum[i].ln();
303 }
304 }
305
306 Ok(FlashOutput { output, lse, stats })
307 }
308}
309
310pub fn causal_block_mask(
315 qi_start: usize,
316 qi_end: usize,
317 kj_start: usize,
318 kj_end: usize,
319) -> Vec<Vec<bool>> {
320 let mut mask = Vec::with_capacity(qi_end - qi_start);
321 for qi in qi_start..qi_end {
322 let mut row = Vec::with_capacity(kj_end - kj_start);
323 for kj in kj_start..kj_end {
324 row.push(kj <= qi);
325 }
326 mask.push(row);
327 }
328 mask
329}
330
331pub struct RingAttention;
337
338#[derive(Clone, Debug)]
340pub struct RingDeviceOutput {
341 pub output: Vec<Vec<f32>>,
343 pub lse: Vec<f32>,
345 pub transfers: usize,
347}
348
349impl RingAttention {
350 pub fn ring_forward(
362 q_shards: &[Vec<Vec<f32>>],
363 k_shards: &[Vec<Vec<f32>>],
364 v_shards: &[Vec<Vec<f32>>],
365 ) -> AttentionResult<Vec<RingDeviceOutput>> {
366 let num_devices = q_shards.len();
367 if num_devices == 0 {
368 return Err(AttentionError::EmptyInput("shards".into()));
369 }
370 if k_shards.len() != num_devices || v_shards.len() != num_devices {
371 return Err(AttentionError::DimensionMismatch {
372 expected: num_devices,
373 actual: k_shards.len().min(v_shards.len()),
374 });
375 }
376
377 let config = FlashConfig {
378 block_size_q: 32,
379 block_size_kv: 32,
380 causal: false,
381 dropout_p: 0.0,
382 };
383
384 let mut results = Vec::with_capacity(num_devices);
385
386 for device_id in 0..num_devices {
388 let local_q = &q_shards[device_id];
389 if local_q.is_empty() {
390 return Err(AttentionError::EmptyInput(format!(
391 "Q shard on device {device_id}"
392 )));
393 }
394 let d = local_q[0].len();
395 let n_q = local_q.len();
396
397 let mut output = vec![vec![0.0f32; d]; n_q];
398 let mut row_max = vec![f32::NEG_INFINITY; n_q];
399 let mut row_sum = vec![0.0f32; n_q];
400 let mut lse = vec![f32::NEG_INFINITY; n_q];
401 let mut transfers = 0usize;
402
403 for step in 0..num_devices {
405 let kv_idx = (device_id + step) % num_devices;
406 if step > 0 {
407 transfers += 1; }
409
410 let partial = FlashAttention3::forward(
411 local_q,
412 &k_shards[kv_idx],
413 &v_shards[kv_idx],
414 &config,
415 )?;
416
417 for qi in 0..n_q {
419 let m_partial = if partial.lse[qi].is_finite() {
420 partial.lse[qi]
423 } else {
424 continue;
425 };
426
427 let m_old = row_max[qi];
428 let m_new = m_old.max(m_partial);
429
430 let exp_old = if m_old.is_finite() {
431 (m_old - m_new).exp()
432 } else {
433 0.0
434 };
435 let exp_partial = (m_partial - m_new).exp();
436
437 let l_partial = if partial.lse[qi].is_finite() {
441 partial.lse[qi].exp()
442 } else {
443 0.0
444 };
445 let l_old = row_sum[qi];
446
447 let l_new = exp_old * l_old + exp_partial * l_partial;
448
449 if l_new > 0.0 {
450 let inv_l = 1.0 / l_new;
451 for dd in 0..d {
452 output[qi][dd] = (exp_old * l_old * output[qi][dd]
453 + exp_partial * l_partial * partial.output[qi][dd])
454 * inv_l;
455 }
456 }
457
458 row_max[qi] = m_new;
459 row_sum[qi] = l_new;
460 }
461 }
462
463 for qi in 0..n_q {
465 if row_sum[qi] > 0.0 && row_max[qi].is_finite() {
466 lse[qi] = row_max[qi] + row_sum[qi].ln();
467 }
468 }
469
470 results.push(RingDeviceOutput {
471 output,
472 lse,
473 transfers,
474 });
475 }
476
477 Ok(results)
478 }
479}
480
481fn naive_attention(q: &[Vec<f32>], k: &[Vec<f32>], v: &[Vec<f32>], causal: bool) -> Vec<Vec<f32>> {
484 let n_q = q.len();
485 let n_kv = k.len();
486 let d = q[0].len();
487 let scale = 1.0 / (d as f32).sqrt();
488
489 let mut output = vec![vec![0.0f32; d]; n_q];
490
491 for qi in 0..n_q {
492 let mut scores = Vec::with_capacity(n_kv);
494 for kj in 0..n_kv {
495 let mut dot = 0.0f32;
496 for dd in 0..d {
497 dot += q[qi][dd] * k[kj][dd];
498 }
499 let mut s = dot * scale;
500 if causal && kj > qi {
501 s = f32::NEG_INFINITY;
502 }
503 scores.push(s);
504 }
505
506 let max_s = scores.iter().copied().fold(f32::NEG_INFINITY, f32::max);
508 let exp_s: Vec<f32> = scores.iter().map(|&s| (s - max_s).exp()).collect();
509 let sum_s: f32 = exp_s.iter().sum();
510
511 for dd in 0..d {
513 let mut val = 0.0f32;
514 for kj in 0..n_kv {
515 val += (exp_s[kj] / sum_s) * v[kj][dd];
516 }
517 output[qi][dd] = val;
518 }
519 }
520
521 output
522}
523
524#[cfg(test)]
525mod tests {
526 use super::*;
527
528 fn make_seq(n: usize, d: usize, seed: f32) -> Vec<Vec<f32>> {
529 (0..n)
530 .map(|i| {
531 (0..d)
532 .map(|j| ((i as f32 + 1.0) * (j as f32 + 1.0) * seed).sin() * 0.5)
533 .collect()
534 })
535 .collect()
536 }
537
538 #[test]
539 fn test_forward_matches_naive() {
540 let d = 16;
541 let n = 12;
542 let q = make_seq(n, d, 0.1);
543 let k = make_seq(n, d, 0.2);
544 let v = make_seq(n, d, 0.3);
545
546 let config = FlashConfig::new(4, 4).unwrap();
547 let flash = FlashAttention3::forward(&q, &k, &v, &config).unwrap();
548 let naive = naive_attention(&q, &k, &v, false);
549
550 for qi in 0..n {
551 for dd in 0..d {
552 let diff = (flash.output[qi][dd] - naive[qi][dd]).abs();
553 assert!(
554 diff < 1e-4,
555 "row={qi} col={dd} flash={} naive={} diff={diff}",
556 flash.output[qi][dd],
557 naive[qi][dd]
558 );
559 }
560 }
561 }
562
563 #[test]
564 fn test_causal_masking() {
565 let d = 8;
566 let n = 6;
567 let q = make_seq(n, d, 0.4);
568 let k = make_seq(n, d, 0.5);
569 let v = make_seq(n, d, 0.6);
570
571 let config = FlashConfig::new(2, 2).unwrap().with_causal();
572 let flash = FlashAttention3::forward(&q, &k, &v, &config).unwrap();
573 let naive = naive_attention(&q, &k, &v, true);
574
575 for qi in 0..n {
576 for dd in 0..d {
577 let diff = (flash.output[qi][dd] - naive[qi][dd]).abs();
578 assert!(diff < 1e-4, "causal row={qi} col={dd} diff={diff}");
579 }
580 }
581 }
582
583 #[test]
584 fn test_numerical_stability_large_values() {
585 let d = 8;
586 let n = 4;
587 let q: Vec<Vec<f32>> = (0..n).map(|i| vec![100.0 * (i as f32 + 1.0); d]).collect();
589 let k = q.clone();
590 let v: Vec<Vec<f32>> = (0..n).map(|i| vec![i as f32; d]).collect();
591
592 let config = FlashConfig::new(2, 2).unwrap();
593 let result = FlashAttention3::forward(&q, &k, &v, &config).unwrap();
594
595 for row in &result.output {
597 for &val in row {
598 assert!(val.is_finite(), "Non-finite output: {val}");
599 }
600 }
601 for &l in &result.lse {
602 assert!(l.is_finite(), "Non-finite LSE: {l}");
603 }
604 }
605
606 #[test]
607 fn test_block_size_variations() {
608 let d = 8;
609 let n = 10;
610 let q = make_seq(n, d, 0.7);
611 let k = make_seq(n, d, 0.8);
612 let v = make_seq(n, d, 0.9);
613
614 let block_sizes = [(2, 2), (3, 5), (1, 1), (10, 10), (7, 3)];
615 let naive = naive_attention(&q, &k, &v, false);
616
617 for (bq, bk) in block_sizes {
618 let config = FlashConfig::new(bq, bk).unwrap();
619 let flash = FlashAttention3::forward(&q, &k, &v, &config).unwrap();
620
621 for qi in 0..n {
622 for dd in 0..d {
623 let diff = (flash.output[qi][dd] - naive[qi][dd]).abs();
624 assert!(
625 diff < 1e-4,
626 "blocks=({bq},{bk}) row={qi} col={dd} diff={diff}"
627 );
628 }
629 }
630 }
631 }
632
633 #[test]
634 fn test_io_stats_tracking() {
635 let d = 8;
636 let n = 16;
637 let q = make_seq(n, d, 1.0);
638 let k = make_seq(n, d, 1.1);
639 let v = make_seq(n, d, 1.2);
640
641 let config = FlashConfig::new(4, 4).unwrap();
642 let result = FlashAttention3::forward(&q, &k, &v, &config).unwrap();
643
644 assert!(result.stats.total_flops > 0, "FLOPs should be tracked");
645 assert!(result.stats.memory_reads > 0, "Reads should be tracked");
646 assert!(result.stats.memory_writes > 0, "Writes should be tracked");
647 assert_eq!(result.stats.memory_complexity(), "O(N)");
648 assert_eq!(result.stats.naive_memory_complexity(), "O(N^2)");
649
650 let ratio = result.stats.flop_ratio();
651 assert!(ratio > 0.0, "IO ratio should be positive");
652 }
653
654 #[test]
655 fn test_ring_attention() {
656 let d = 8;
657 let shard_size = 4;
658 let num_devices = 3;
659
660 let q_shards: Vec<Vec<Vec<f32>>> = (0..num_devices)
661 .map(|dev| make_seq(shard_size, d, 0.1 * (dev as f32 + 1.0)))
662 .collect();
663 let k_shards: Vec<Vec<Vec<f32>>> = (0..num_devices)
664 .map(|dev| make_seq(shard_size, d, 0.2 * (dev as f32 + 1.0)))
665 .collect();
666 let v_shards: Vec<Vec<Vec<f32>>> = (0..num_devices)
667 .map(|dev| make_seq(shard_size, d, 0.3 * (dev as f32 + 1.0)))
668 .collect();
669
670 let results = RingAttention::ring_forward(&q_shards, &k_shards, &v_shards).unwrap();
671
672 assert_eq!(results.len(), num_devices);
673 for (dev_id, res) in results.iter().enumerate() {
674 assert_eq!(res.output.len(), shard_size);
675 assert_eq!(res.output[0].len(), d);
676 assert_eq!(
678 res.transfers,
679 num_devices - 1,
680 "Device {dev_id} should have {} transfers",
681 num_devices - 1
682 );
683 for row in &res.output {
684 for &val in row {
685 assert!(val.is_finite(), "Device {dev_id} has non-finite output");
686 }
687 }
688 }
689 }
690
691 #[test]
692 fn test_single_block() {
693 let d = 4;
695 let n = 3;
696 let q = make_seq(n, d, 1.5);
697 let k = make_seq(n, d, 1.6);
698 let v = make_seq(n, d, 1.7);
699
700 let config = FlashConfig::new(n, n).unwrap();
701 let flash = FlashAttention3::forward(&q, &k, &v, &config).unwrap();
702 let naive = naive_attention(&q, &k, &v, false);
703
704 for qi in 0..n {
705 for dd in 0..d {
706 let diff = (flash.output[qi][dd] - naive[qi][dd]).abs();
707 assert!(diff < 1e-5, "single block row={qi} col={dd} diff={diff}");
708 }
709 }
710 }
711
712 #[test]
713 fn test_large_sequence() {
714 let d = 16;
715 let n = 128;
716 let q = make_seq(n, d, 2.0);
717 let k = make_seq(n, d, 2.1);
718 let v = make_seq(n, d, 2.2);
719
720 let config = FlashConfig::new(16, 16).unwrap();
721 let flash = FlashAttention3::forward(&q, &k, &v, &config).unwrap();
722 let naive = naive_attention(&q, &k, &v, false);
723
724 let mut max_diff = 0.0f32;
725 for qi in 0..n {
726 for dd in 0..d {
727 max_diff = max_diff.max((flash.output[qi][dd] - naive[qi][dd]).abs());
728 }
729 }
730 assert!(max_diff < 1e-3, "Large seq max diff: {max_diff}");
731 }
732
733 #[test]
734 fn test_lse_correctness() {
735 let d = 8;
736 let n = 6;
737 let q = make_seq(n, d, 3.0);
738 let k = make_seq(n, d, 3.1);
739 let v = make_seq(n, d, 3.2);
740 let scale = 1.0 / (d as f32).sqrt();
741
742 let config = FlashConfig::new(2, 3).unwrap();
743 let result = FlashAttention3::forward(&q, &k, &v, &config).unwrap();
744
745 for qi in 0..n {
747 let mut scores = Vec::with_capacity(n);
748 for kj in 0..n {
749 let dot: f32 = (0..d).map(|dd| q[qi][dd] * k[kj][dd]).sum();
750 scores.push(dot * scale);
751 }
752 let max_s = scores.iter().copied().fold(f32::NEG_INFINITY, f32::max);
753 let sum_exp: f32 = scores.iter().map(|&s| (s - max_s).exp()).sum();
754 let expected_lse = max_s + sum_exp.ln();
755
756 let diff = (result.lse[qi] - expected_lse).abs();
757 assert!(
758 diff < 1e-3,
759 "LSE row={qi} flash={} expected={expected_lse} diff={diff}",
760 result.lse[qi]
761 );
762 }
763 }
764
765 #[test]
766 fn test_causal_block_mask_utility() {
767 let mask = causal_block_mask(2, 5, 0, 4);
768 assert_eq!(mask[0], vec![true, true, true, false]);
770 assert_eq!(mask[1], vec![true, true, true, true]);
772 assert_eq!(mask[2], vec![true, true, true, true]);
774 }
775
776 #[test]
777 fn test_empty_input_errors() {
778 let config = FlashConfig::default();
779 let empty: Vec<Vec<f32>> = vec![];
780 let q = vec![vec![1.0; 4]];
781
782 assert!(FlashAttention3::forward(&empty, &q, &q, &config).is_err());
783 assert!(FlashAttention3::forward(&q, &empty, &q, &config).is_err());
784 assert!(FlashAttention3::forward(&q, &q, &empty, &config).is_err());
785 }
786
787 #[test]
788 fn test_config_validation() {
789 assert!(FlashConfig::new(0, 4).is_err());
790 assert!(FlashConfig::new(4, 0).is_err());
791 assert!(FlashConfig::new(4, 4).is_ok());
792
793 assert!(FlashConfig::default().with_dropout(1.5).is_err());
794 assert!(FlashConfig::default().with_dropout(-0.1).is_err());
795 assert!(FlashConfig::default().with_dropout(0.5).is_ok());
796 }
797}