1use crate::error::{AttentionError, AttentionResult};
20
21#[derive(Clone, Debug)]
23pub struct FlashConfig {
24 pub block_size_q: usize,
26 pub block_size_kv: usize,
28 pub causal: bool,
30 pub dropout_p: f32,
33}
34
35impl Default for FlashConfig {
36 fn default() -> Self {
37 Self {
38 block_size_q: 64,
39 block_size_kv: 64,
40 causal: false,
41 dropout_p: 0.0,
42 }
43 }
44}
45
46impl FlashConfig {
47 pub fn new(block_size_q: usize, block_size_kv: usize) -> AttentionResult<Self> {
49 if block_size_q == 0 || block_size_kv == 0 {
50 return Err(AttentionError::InvalidConfig(
51 "Block sizes must be > 0".into(),
52 ));
53 }
54 Ok(Self {
55 block_size_q,
56 block_size_kv,
57 ..Default::default()
58 })
59 }
60
61 pub fn with_causal(mut self) -> Self {
63 self.causal = true;
64 self
65 }
66
67 pub fn with_dropout(mut self, p: f32) -> AttentionResult<Self> {
69 if !(0.0..=1.0).contains(&p) {
70 return Err(AttentionError::InvalidConfig(
71 "Dropout must be in [0, 1]".into(),
72 ));
73 }
74 self.dropout_p = p;
75 Ok(self)
76 }
77}
78
79#[derive(Clone, Debug, Default)]
81pub struct IOStats {
82 pub total_flops: u64,
84 pub memory_reads: u64,
86 pub memory_writes: u64,
88 seq_len: usize,
90 head_dim: usize,
92 #[allow(dead_code)]
94 block_size_q: usize,
95 #[allow(dead_code)]
97 block_size_kv: usize,
98}
99
100impl IOStats {
101 pub fn flop_ratio(&self) -> f32 {
104 if self.total_flops == 0 {
105 return 1.0;
106 }
107 let n = self.seq_len as f64;
111 let d = self.head_dim as f64;
112 let naive_io = n * n + n * d; let tiled_io = self.memory_reads as f64 + self.memory_writes as f64;
114 if tiled_io < 1.0 {
115 return 1.0;
116 }
117 (naive_io / tiled_io) as f32
118 }
119
120 pub fn memory_complexity(&self) -> &'static str {
123 "O(N)"
124 }
125
126 pub fn naive_memory_complexity(&self) -> &'static str {
128 "O(N^2)"
129 }
130}
131
132pub struct FlashAttention3;
138
139#[derive(Clone, Debug)]
141pub struct FlashOutput {
142 pub output: Vec<Vec<f32>>,
144 pub lse: Vec<f32>,
146 pub stats: IOStats,
148}
149
150impl FlashAttention3 {
151 pub fn forward(
169 q: &[Vec<f32>],
170 k: &[Vec<f32>],
171 v: &[Vec<f32>],
172 config: &FlashConfig,
173 ) -> AttentionResult<FlashOutput> {
174 if q.is_empty() {
175 return Err(AttentionError::EmptyInput("queries".into()));
176 }
177 if k.is_empty() || v.is_empty() {
178 return Err(AttentionError::EmptyInput("keys or values".into()));
179 }
180 if k.len() != v.len() {
181 return Err(AttentionError::DimensionMismatch {
182 expected: k.len(),
183 actual: v.len(),
184 });
185 }
186 let d = q[0].len();
187 if d == 0 {
188 return Err(AttentionError::InvalidConfig("Dimension must be > 0".into()));
189 }
190 let scale = 1.0 / (d as f32).sqrt();
191 let n_q = q.len();
192 let n_kv = k.len();
193 let br = config.block_size_q;
194 let bc = config.block_size_kv;
195
196 let mut output = vec![vec![0.0f32; d]; n_q];
197 let mut lse = vec![f32::NEG_INFINITY; n_q];
198 let mut row_max = vec![f32::NEG_INFINITY; n_q];
199 let mut row_sum = vec![0.0f32; n_q];
200
201 let mut stats = IOStats {
202 seq_len: n_q.max(n_kv),
203 head_dim: d,
204 block_size_q: br,
205 block_size_kv: bc,
206 ..Default::default()
207 };
208
209 for qi_start in (0..n_q).step_by(br) {
211 let qi_end = (qi_start + br).min(n_q);
212
213 for kj_start in (0..n_kv).step_by(bc) {
215 let kj_end = (kj_start + bc).min(n_kv);
216
217 stats.memory_reads += ((qi_end - qi_start) * d
219 + (kj_end - kj_start) * d * 2) as u64;
220
221 for qi in qi_start..qi_end {
223 let mut block_scores = Vec::with_capacity(kj_end - kj_start);
225 for kj in kj_start..kj_end {
226 let mut dot = 0.0f32;
227 for dd in 0..d {
228 dot += q[qi][dd] * k[kj][dd];
229 }
230 let mut score = dot * scale;
231
232 if config.causal && kj > qi {
234 score = f32::NEG_INFINITY;
235 }
236 block_scores.push(score);
237 stats.total_flops += (2 * d) as u64; }
239
240 let m_ij = block_scores
242 .iter()
243 .copied()
244 .fold(f32::NEG_INFINITY, f32::max);
245
246 if !m_ij.is_finite() {
247 continue; }
249
250 let exp_scores: Vec<f32> =
252 block_scores.iter().map(|&s| (s - m_ij).exp()).collect();
253 let l_ij: f32 = exp_scores
254 .iter()
255 .filter(|x| x.is_finite())
256 .sum();
257
258 let m_old = row_max[qi];
260 let m_new = m_old.max(m_ij);
261
262 let exp_old = if m_old.is_finite() {
263 (m_old - m_new).exp()
264 } else {
265 0.0
266 };
267 let exp_new = (m_ij - m_new).exp();
268
269 let l_new = exp_old * row_sum[qi] + exp_new * l_ij;
270
271 if l_new > 0.0 {
275 let inv_l_new = 1.0 / l_new;
276 let scale_old = exp_old * row_sum[qi] * inv_l_new;
277 let scale_new = exp_new * inv_l_new;
278
279 for dd in 0..d {
280 let mut pv = 0.0f32;
281 for (local_j, kj) in (kj_start..kj_end).enumerate() {
282 if exp_scores[local_j].is_finite() {
283 pv += exp_scores[local_j] * v[kj][dd];
284 }
285 }
286 output[qi][dd] =
287 scale_old * output[qi][dd] + scale_new * pv;
288 stats.total_flops += (2 * (kj_end - kj_start)) as u64;
289 }
290 }
291
292 row_max[qi] = m_new;
293 row_sum[qi] = l_new;
294 }
295 }
296
297 stats.memory_writes += ((qi_end - qi_start) * d) as u64;
299 }
300
301 for i in 0..n_q {
303 if row_sum[i] > 0.0 && row_max[i].is_finite() {
304 lse[i] = row_max[i] + row_sum[i].ln();
305 }
306 }
307
308 Ok(FlashOutput {
309 output,
310 lse,
311 stats,
312 })
313 }
314}
315
316pub fn causal_block_mask(
321 qi_start: usize,
322 qi_end: usize,
323 kj_start: usize,
324 kj_end: usize,
325) -> Vec<Vec<bool>> {
326 let mut mask = Vec::with_capacity(qi_end - qi_start);
327 for qi in qi_start..qi_end {
328 let mut row = Vec::with_capacity(kj_end - kj_start);
329 for kj in kj_start..kj_end {
330 row.push(kj <= qi);
331 }
332 mask.push(row);
333 }
334 mask
335}
336
337pub struct RingAttention;
343
344#[derive(Clone, Debug)]
346pub struct RingDeviceOutput {
347 pub output: Vec<Vec<f32>>,
349 pub lse: Vec<f32>,
351 pub transfers: usize,
353}
354
355impl RingAttention {
356 pub fn ring_forward(
368 q_shards: &[Vec<Vec<f32>>],
369 k_shards: &[Vec<Vec<f32>>],
370 v_shards: &[Vec<Vec<f32>>],
371 ) -> AttentionResult<Vec<RingDeviceOutput>> {
372 let num_devices = q_shards.len();
373 if num_devices == 0 {
374 return Err(AttentionError::EmptyInput("shards".into()));
375 }
376 if k_shards.len() != num_devices || v_shards.len() != num_devices {
377 return Err(AttentionError::DimensionMismatch {
378 expected: num_devices,
379 actual: k_shards.len().min(v_shards.len()),
380 });
381 }
382
383 let config = FlashConfig {
384 block_size_q: 32,
385 block_size_kv: 32,
386 causal: false,
387 dropout_p: 0.0,
388 };
389
390 let mut results = Vec::with_capacity(num_devices);
391
392 for device_id in 0..num_devices {
394 let local_q = &q_shards[device_id];
395 if local_q.is_empty() {
396 return Err(AttentionError::EmptyInput(
397 format!("Q shard on device {device_id}"),
398 ));
399 }
400 let d = local_q[0].len();
401 let n_q = local_q.len();
402
403 let mut output = vec![vec![0.0f32; d]; n_q];
404 let mut row_max = vec![f32::NEG_INFINITY; n_q];
405 let mut row_sum = vec![0.0f32; n_q];
406 let mut lse = vec![f32::NEG_INFINITY; n_q];
407 let mut transfers = 0usize;
408
409 for step in 0..num_devices {
411 let kv_idx = (device_id + step) % num_devices;
412 if step > 0 {
413 transfers += 1; }
415
416 let partial = FlashAttention3::forward(
417 local_q,
418 &k_shards[kv_idx],
419 &v_shards[kv_idx],
420 &config,
421 )?;
422
423 for qi in 0..n_q {
425 let m_partial = if partial.lse[qi].is_finite() {
426 partial.lse[qi]
429 } else {
430 continue;
431 };
432
433 let m_old = row_max[qi];
434 let m_new = m_old.max(m_partial);
435
436 let exp_old = if m_old.is_finite() {
437 (m_old - m_new).exp()
438 } else {
439 0.0
440 };
441 let exp_partial = (m_partial - m_new).exp();
442
443 let l_partial = if partial.lse[qi].is_finite() {
447 partial.lse[qi].exp()
448 } else {
449 0.0
450 };
451 let l_old = row_sum[qi];
452
453 let l_new = exp_old * l_old + exp_partial * l_partial;
454
455 if l_new > 0.0 {
456 let inv_l = 1.0 / l_new;
457 for dd in 0..d {
458 output[qi][dd] = (exp_old * l_old * output[qi][dd]
459 + exp_partial * l_partial * partial.output[qi][dd])
460 * inv_l;
461 }
462 }
463
464 row_max[qi] = m_new;
465 row_sum[qi] = l_new;
466 }
467 }
468
469 for qi in 0..n_q {
471 if row_sum[qi] > 0.0 && row_max[qi].is_finite() {
472 lse[qi] = row_max[qi] + row_sum[qi].ln();
473 }
474 }
475
476 results.push(RingDeviceOutput {
477 output,
478 lse,
479 transfers,
480 });
481 }
482
483 Ok(results)
484 }
485}
486
487fn naive_attention(
490 q: &[Vec<f32>],
491 k: &[Vec<f32>],
492 v: &[Vec<f32>],
493 causal: bool,
494) -> Vec<Vec<f32>> {
495 let n_q = q.len();
496 let n_kv = k.len();
497 let d = q[0].len();
498 let scale = 1.0 / (d as f32).sqrt();
499
500 let mut output = vec![vec![0.0f32; d]; n_q];
501
502 for qi in 0..n_q {
503 let mut scores = Vec::with_capacity(n_kv);
505 for kj in 0..n_kv {
506 let mut dot = 0.0f32;
507 for dd in 0..d {
508 dot += q[qi][dd] * k[kj][dd];
509 }
510 let mut s = dot * scale;
511 if causal && kj > qi {
512 s = f32::NEG_INFINITY;
513 }
514 scores.push(s);
515 }
516
517 let max_s = scores.iter().copied().fold(f32::NEG_INFINITY, f32::max);
519 let exp_s: Vec<f32> = scores.iter().map(|&s| (s - max_s).exp()).collect();
520 let sum_s: f32 = exp_s.iter().sum();
521
522 for dd in 0..d {
524 let mut val = 0.0f32;
525 for kj in 0..n_kv {
526 val += (exp_s[kj] / sum_s) * v[kj][dd];
527 }
528 output[qi][dd] = val;
529 }
530 }
531
532 output
533}
534
535#[cfg(test)]
536mod tests {
537 use super::*;
538
539 fn make_seq(n: usize, d: usize, seed: f32) -> Vec<Vec<f32>> {
540 (0..n)
541 .map(|i| {
542 (0..d)
543 .map(|j| ((i as f32 + 1.0) * (j as f32 + 1.0) * seed).sin() * 0.5)
544 .collect()
545 })
546 .collect()
547 }
548
549 #[test]
550 fn test_forward_matches_naive() {
551 let d = 16;
552 let n = 12;
553 let q = make_seq(n, d, 0.1);
554 let k = make_seq(n, d, 0.2);
555 let v = make_seq(n, d, 0.3);
556
557 let config = FlashConfig::new(4, 4).unwrap();
558 let flash = FlashAttention3::forward(&q, &k, &v, &config).unwrap();
559 let naive = naive_attention(&q, &k, &v, false);
560
561 for qi in 0..n {
562 for dd in 0..d {
563 let diff = (flash.output[qi][dd] - naive[qi][dd]).abs();
564 assert!(diff < 1e-4, "row={qi} col={dd} flash={} naive={} diff={diff}",
565 flash.output[qi][dd], naive[qi][dd]);
566 }
567 }
568 }
569
570 #[test]
571 fn test_causal_masking() {
572 let d = 8;
573 let n = 6;
574 let q = make_seq(n, d, 0.4);
575 let k = make_seq(n, d, 0.5);
576 let v = make_seq(n, d, 0.6);
577
578 let config = FlashConfig::new(2, 2).unwrap().with_causal();
579 let flash = FlashAttention3::forward(&q, &k, &v, &config).unwrap();
580 let naive = naive_attention(&q, &k, &v, true);
581
582 for qi in 0..n {
583 for dd in 0..d {
584 let diff = (flash.output[qi][dd] - naive[qi][dd]).abs();
585 assert!(diff < 1e-4, "causal row={qi} col={dd} diff={diff}");
586 }
587 }
588 }
589
590 #[test]
591 fn test_numerical_stability_large_values() {
592 let d = 8;
593 let n = 4;
594 let q: Vec<Vec<f32>> = (0..n)
596 .map(|i| vec![100.0 * (i as f32 + 1.0); d])
597 .collect();
598 let k = q.clone();
599 let v: Vec<Vec<f32>> = (0..n).map(|i| vec![i as f32; d]).collect();
600
601 let config = FlashConfig::new(2, 2).unwrap();
602 let result = FlashAttention3::forward(&q, &k, &v, &config).unwrap();
603
604 for row in &result.output {
606 for &val in row {
607 assert!(val.is_finite(), "Non-finite output: {val}");
608 }
609 }
610 for &l in &result.lse {
611 assert!(l.is_finite(), "Non-finite LSE: {l}");
612 }
613 }
614
615 #[test]
616 fn test_block_size_variations() {
617 let d = 8;
618 let n = 10;
619 let q = make_seq(n, d, 0.7);
620 let k = make_seq(n, d, 0.8);
621 let v = make_seq(n, d, 0.9);
622
623 let block_sizes = [(2, 2), (3, 5), (1, 1), (10, 10), (7, 3)];
624 let naive = naive_attention(&q, &k, &v, false);
625
626 for (bq, bk) in block_sizes {
627 let config = FlashConfig::new(bq, bk).unwrap();
628 let flash = FlashAttention3::forward(&q, &k, &v, &config).unwrap();
629
630 for qi in 0..n {
631 for dd in 0..d {
632 let diff = (flash.output[qi][dd] - naive[qi][dd]).abs();
633 assert!(
634 diff < 1e-4,
635 "blocks=({bq},{bk}) row={qi} col={dd} diff={diff}"
636 );
637 }
638 }
639 }
640 }
641
642 #[test]
643 fn test_io_stats_tracking() {
644 let d = 8;
645 let n = 16;
646 let q = make_seq(n, d, 1.0);
647 let k = make_seq(n, d, 1.1);
648 let v = make_seq(n, d, 1.2);
649
650 let config = FlashConfig::new(4, 4).unwrap();
651 let result = FlashAttention3::forward(&q, &k, &v, &config).unwrap();
652
653 assert!(result.stats.total_flops > 0, "FLOPs should be tracked");
654 assert!(result.stats.memory_reads > 0, "Reads should be tracked");
655 assert!(result.stats.memory_writes > 0, "Writes should be tracked");
656 assert_eq!(result.stats.memory_complexity(), "O(N)");
657 assert_eq!(result.stats.naive_memory_complexity(), "O(N^2)");
658
659 let ratio = result.stats.flop_ratio();
660 assert!(ratio > 0.0, "IO ratio should be positive");
661 }
662
663 #[test]
664 fn test_ring_attention() {
665 let d = 8;
666 let shard_size = 4;
667 let num_devices = 3;
668
669 let q_shards: Vec<Vec<Vec<f32>>> = (0..num_devices)
670 .map(|dev| make_seq(shard_size, d, 0.1 * (dev as f32 + 1.0)))
671 .collect();
672 let k_shards: Vec<Vec<Vec<f32>>> = (0..num_devices)
673 .map(|dev| make_seq(shard_size, d, 0.2 * (dev as f32 + 1.0)))
674 .collect();
675 let v_shards: Vec<Vec<Vec<f32>>> = (0..num_devices)
676 .map(|dev| make_seq(shard_size, d, 0.3 * (dev as f32 + 1.0)))
677 .collect();
678
679 let results =
680 RingAttention::ring_forward(&q_shards, &k_shards, &v_shards).unwrap();
681
682 assert_eq!(results.len(), num_devices);
683 for (dev_id, res) in results.iter().enumerate() {
684 assert_eq!(res.output.len(), shard_size);
685 assert_eq!(res.output[0].len(), d);
686 assert_eq!(res.transfers, num_devices - 1,
688 "Device {dev_id} should have {} transfers", num_devices - 1);
689 for row in &res.output {
690 for &val in row {
691 assert!(val.is_finite(), "Device {dev_id} has non-finite output");
692 }
693 }
694 }
695 }
696
697 #[test]
698 fn test_single_block() {
699 let d = 4;
701 let n = 3;
702 let q = make_seq(n, d, 1.5);
703 let k = make_seq(n, d, 1.6);
704 let v = make_seq(n, d, 1.7);
705
706 let config = FlashConfig::new(n, n).unwrap();
707 let flash = FlashAttention3::forward(&q, &k, &v, &config).unwrap();
708 let naive = naive_attention(&q, &k, &v, false);
709
710 for qi in 0..n {
711 for dd in 0..d {
712 let diff = (flash.output[qi][dd] - naive[qi][dd]).abs();
713 assert!(diff < 1e-5, "single block row={qi} col={dd} diff={diff}");
714 }
715 }
716 }
717
718 #[test]
719 fn test_large_sequence() {
720 let d = 16;
721 let n = 128;
722 let q = make_seq(n, d, 2.0);
723 let k = make_seq(n, d, 2.1);
724 let v = make_seq(n, d, 2.2);
725
726 let config = FlashConfig::new(16, 16).unwrap();
727 let flash = FlashAttention3::forward(&q, &k, &v, &config).unwrap();
728 let naive = naive_attention(&q, &k, &v, false);
729
730 let mut max_diff = 0.0f32;
731 for qi in 0..n {
732 for dd in 0..d {
733 max_diff = max_diff.max((flash.output[qi][dd] - naive[qi][dd]).abs());
734 }
735 }
736 assert!(max_diff < 1e-3, "Large seq max diff: {max_diff}");
737 }
738
739 #[test]
740 fn test_lse_correctness() {
741 let d = 8;
742 let n = 6;
743 let q = make_seq(n, d, 3.0);
744 let k = make_seq(n, d, 3.1);
745 let v = make_seq(n, d, 3.2);
746 let scale = 1.0 / (d as f32).sqrt();
747
748 let config = FlashConfig::new(2, 3).unwrap();
749 let result = FlashAttention3::forward(&q, &k, &v, &config).unwrap();
750
751 for qi in 0..n {
753 let mut scores = Vec::with_capacity(n);
754 for kj in 0..n {
755 let dot: f32 = (0..d).map(|dd| q[qi][dd] * k[kj][dd]).sum();
756 scores.push(dot * scale);
757 }
758 let max_s = scores.iter().copied().fold(f32::NEG_INFINITY, f32::max);
759 let sum_exp: f32 = scores.iter().map(|&s| (s - max_s).exp()).sum();
760 let expected_lse = max_s + sum_exp.ln();
761
762 let diff = (result.lse[qi] - expected_lse).abs();
763 assert!(diff < 1e-3, "LSE row={qi} flash={} expected={expected_lse} diff={diff}",
764 result.lse[qi]);
765 }
766 }
767
768 #[test]
769 fn test_causal_block_mask_utility() {
770 let mask = causal_block_mask(2, 5, 0, 4);
771 assert_eq!(mask[0], vec![true, true, true, false]);
773 assert_eq!(mask[1], vec![true, true, true, true]);
775 assert_eq!(mask[2], vec![true, true, true, true]);
777 }
778
779 #[test]
780 fn test_empty_input_errors() {
781 let config = FlashConfig::default();
782 let empty: Vec<Vec<f32>> = vec![];
783 let q = vec![vec![1.0; 4]];
784
785 assert!(FlashAttention3::forward(&empty, &q, &q, &config).is_err());
786 assert!(FlashAttention3::forward(&q, &empty, &q, &config).is_err());
787 assert!(FlashAttention3::forward(&q, &q, &empty, &config).is_err());
788 }
789
790 #[test]
791 fn test_config_validation() {
792 assert!(FlashConfig::new(0, 4).is_err());
793 assert!(FlashConfig::new(4, 0).is_err());
794 assert!(FlashConfig::new(4, 4).is_ok());
795
796 assert!(FlashConfig::default().with_dropout(1.5).is_err());
797 assert!(FlashConfig::default().with_dropout(-0.1).is_err());
798 assert!(FlashConfig::default().with_dropout(0.5).is_ok());
799 }
800}