1use crate::accounting::{checked_add_usize_count, ArithmeticOverflow};
9use crate::reservation_policy::{
10 reserve_typed_vec_to_capacity as reserve_vec_to_capacity, ReservationPolicy,
11};
12
13const MEGAKERNEL_BARRIER_RESERVATION: ReservationPolicy = ReservationPolicy::new(
14 "megakernel barrier planner",
15 "shard the dependency graph before barrier planning",
16);
17
18#[derive(Clone, Copy, Debug, Eq, PartialEq)]
20pub struct MegakernelWaveDependency {
21 pub before: usize,
23 pub after: usize,
25}
26
27#[derive(Clone, Debug, Eq, PartialEq)]
29pub struct MegakernelBarrierGroup {
30 pub waves: Vec<usize>,
32}
33
34#[derive(Clone, Debug, Eq, PartialEq)]
36pub struct MegakernelBarrierPlan {
37 pub groups: Vec<MegakernelBarrierGroup>,
39 pub global_barriers: usize,
41}
42
43#[derive(Debug, Default)]
50pub struct MegakernelBarrierScratch {
51 outgoing_counts: Vec<usize>,
52 indegree: Vec<usize>,
53 outgoing_offsets: Vec<usize>,
54 outgoing_targets: Vec<usize>,
55 ready: Vec<usize>,
56 next_ready: Vec<usize>,
57}
58
59impl MegakernelBarrierScratch {
60 pub fn try_with_capacity(
68 wave_count: usize,
69 dependency_count: usize,
70 ) -> Result<Self, MegakernelBarrierPlanError> {
71 let mut scratch = Self::default();
72 scratch.try_reserve_shape(wave_count, dependency_count)?;
73 Ok(scratch)
74 }
75
76 fn try_reserve_shape(
77 &mut self,
78 wave_count: usize,
79 dependency_count: usize,
80 ) -> Result<(), MegakernelBarrierPlanError> {
81 let offset_capacity =
82 wave_count
83 .checked_add(1)
84 .ok_or(MegakernelBarrierPlanError::ByteCountOverflow {
85 field: "barrier scratch wave offsets",
86 })?;
87 reserve_vec(&mut self.outgoing_counts, wave_count, "outgoing counts")?;
88 reserve_vec(&mut self.indegree, wave_count, "indegree")?;
89 reserve_vec(
90 &mut self.outgoing_offsets,
91 offset_capacity,
92 "outgoing offsets",
93 )?;
94 reserve_vec(
95 &mut self.outgoing_targets,
96 dependency_count,
97 "outgoing targets",
98 )?;
99 reserve_vec(&mut self.ready, wave_count, "ready wave layer")?;
100 reserve_vec(&mut self.next_ready, wave_count, "next ready wave layer")?;
101 Ok(())
102 }
103
104 #[must_use]
106 pub fn wave_capacity(&self) -> usize {
107 let offset_wave_capacity = if self.outgoing_offsets.capacity() == 0 {
108 0
109 } else {
110 self.outgoing_offsets.capacity() - 1
111 };
112 self.outgoing_counts
113 .capacity()
114 .min(self.indegree.capacity())
115 .min(offset_wave_capacity)
116 }
117
118 #[must_use]
120 pub fn dependency_capacity(&self) -> usize {
121 self.outgoing_targets.capacity()
122 }
123}
124
125#[derive(Clone, Debug, Eq, PartialEq)]
127pub enum MegakernelBarrierPlanError {
128 InvalidWave {
130 wave_count: usize,
132 before: usize,
134 after: usize,
136 },
137 SelfDependency {
139 wave: usize,
141 },
142 Cycle {
144 unscheduled_waves: usize,
146 },
147 ByteCountOverflow {
149 field: &'static str,
151 },
152 StorageReserveFailed {
154 field: &'static str,
156 requested: usize,
158 message: String,
160 },
161}
162
163impl ArithmeticOverflow for MegakernelBarrierPlanError {
164 fn arithmetic_overflow(field: &'static str) -> Self {
165 Self::ByteCountOverflow { field }
166 }
167}
168
169impl std::fmt::Display for MegakernelBarrierPlanError {
170 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
171 match self {
172 Self::InvalidWave {
173 wave_count,
174 before,
175 after,
176 } => write!(
177 f,
178 "megakernel dependency references invalid wave before={before} after={after} for wave_count={wave_count}. Fix: emit dependencies only over normalized wave indices."
179 ),
180 Self::SelfDependency { wave } => write!(
181 f,
182 "megakernel wave {wave} depends on itself. Fix: remove the self-edge or split the wave into distinct producer/consumer phases."
183 ),
184 Self::Cycle { unscheduled_waves } => write!(
185 f,
186 "megakernel wave dependency graph contains a cycle with {unscheduled_waves} unscheduled waves. Fix: break the cyclic dataflow edge or insert an explicit iterative fixed-point kernel."
187 ),
188 Self::ByteCountOverflow { field } => write!(
189 f,
190 "megakernel barrier planner overflowed while computing {field}. Fix: shard the dependency graph before barrier planning."
191 ),
192 Self::StorageReserveFailed {
193 field,
194 requested,
195 message,
196 } => write!(
197 f,
198 "megakernel barrier planner could not reserve {requested} {field} entries: {message}. Fix: shard the dependency graph before barrier planning."
199 ),
200 }
201 }
202}
203
204impl std::error::Error for MegakernelBarrierPlanError {}
205
206pub fn plan_megakernel_barriers(
217 wave_count: usize,
218 dependencies: &[MegakernelWaveDependency],
219) -> Result<MegakernelBarrierPlan, MegakernelBarrierPlanError> {
220 let mut scratch = MegakernelBarrierScratch::try_with_capacity(wave_count, dependencies.len())?;
221 plan_megakernel_barriers_with_scratch(wave_count, dependencies, &mut scratch)
222}
223
224pub fn plan_megakernel_barriers_with_scratch(
231 wave_count: usize,
232 dependencies: &[MegakernelWaveDependency],
233 scratch: &mut MegakernelBarrierScratch,
234) -> Result<MegakernelBarrierPlan, MegakernelBarrierPlanError> {
235 scratch.try_reserve_shape(wave_count, dependencies.len())?;
236 if wave_count == 0 {
237 if !dependencies.is_empty() {
238 return Err(MegakernelBarrierPlanError::InvalidWave {
239 wave_count,
240 before: dependencies[0].before,
241 after: dependencies[0].after,
242 });
243 }
244 return Ok(MegakernelBarrierPlan {
245 global_barriers: 0,
246 groups: Vec::new(),
247 });
248 }
249 if dependencies.is_empty() {
250 let mut waves = Vec::new();
251 reserve_vec(&mut waves, wave_count, "independent wave group")?;
252 for wave in 0..wave_count {
253 waves.push(wave);
254 }
255 let mut groups = Vec::new();
256 reserve_vec(&mut groups, 1, "barrier groups")?;
257 groups.push(MegakernelBarrierGroup { waves });
258 return Ok(MegakernelBarrierPlan {
259 global_barriers: 0,
260 groups,
261 });
262 }
263
264 fill_barrier_vec_zeroed(&mut scratch.outgoing_counts, wave_count, "outgoing counts")?;
265 fill_barrier_vec_zeroed(&mut scratch.indegree, wave_count, "indegree")?;
266 for dependency in dependencies {
267 if dependency.before >= wave_count || dependency.after >= wave_count {
268 return Err(MegakernelBarrierPlanError::InvalidWave {
269 wave_count,
270 before: dependency.before,
271 after: dependency.after,
272 });
273 }
274 if dependency.before == dependency.after {
275 return Err(MegakernelBarrierPlanError::SelfDependency {
276 wave: dependency.before,
277 });
278 }
279 scratch.outgoing_counts[dependency.before] = scratch.outgoing_counts[dependency.before]
280 .checked_add(1)
281 .ok_or(MegakernelBarrierPlanError::ByteCountOverflow {
282 field: "outgoing dependency count",
283 })?;
284 scratch.indegree[dependency.after] = scratch.indegree[dependency.after]
285 .checked_add(1)
286 .ok_or(MegakernelBarrierPlanError::ByteCountOverflow {
287 field: "incoming dependency count",
288 })?;
289 }
290
291 scratch.outgoing_offsets.clear();
292 scratch.outgoing_offsets.push(0usize);
293 for count in &scratch.outgoing_counts {
294 let next = scratch
295 .outgoing_offsets
296 .last()
297 .copied()
298 .ok_or(MegakernelBarrierPlanError::ByteCountOverflow {
299 field: "outgoing offset seed",
300 })?
301 .checked_add(*count)
302 .ok_or(MegakernelBarrierPlanError::ByteCountOverflow {
303 field: "outgoing dependency offsets",
304 })?;
305 scratch.outgoing_offsets.push(next);
306 }
307 fill_barrier_vec_zeroed(
308 &mut scratch.outgoing_targets,
309 dependencies.len(),
310 "outgoing targets",
311 )?;
312 scratch
313 .outgoing_counts
314 .copy_from_slice(&scratch.outgoing_offsets[..wave_count]);
315 for dependency in dependencies {
316 let offset = scratch.outgoing_counts[dependency.before];
317 scratch.outgoing_targets[offset] = dependency.after;
318 scratch.outgoing_counts[dependency.before] =
319 offset
320 .checked_add(1)
321 .ok_or(MegakernelBarrierPlanError::ByteCountOverflow {
322 field: "outgoing target cursor",
323 })?;
324 }
325
326 scratch.ready.clear();
327 for (wave, degree) in scratch.indegree.iter().copied().enumerate() {
328 if degree == 0 {
329 scratch.ready.push(wave);
330 }
331 }
332
333 let mut scheduled = 0usize;
334 let mut groups = Vec::new();
335 reserve_vec(
336 &mut groups,
337 group_capacity_hint(wave_count, dependencies.len())?,
338 "barrier groups",
339 )?;
340 scratch.next_ready.clear();
341 while !scratch.ready.is_empty() {
342 scratch.next_ready.clear();
343 for &wave in &scratch.ready {
344 for &next in &scratch.outgoing_targets
345 [scratch.outgoing_offsets[wave]..scratch.outgoing_offsets[wave + 1]]
346 {
347 scratch.indegree[next] -= 1;
348 if scratch.indegree[next] == 0 {
349 scratch.next_ready.push(next);
350 }
351 }
352 }
353 scheduled += scratch.ready.len();
354 groups.push(MegakernelBarrierGroup {
355 waves: std::mem::take(&mut scratch.ready),
356 });
357 std::mem::swap(&mut scratch.ready, &mut scratch.next_ready);
358 }
359
360 if scheduled != wave_count {
361 return Err(MegakernelBarrierPlanError::Cycle {
362 unscheduled_waves: wave_count - scheduled,
363 });
364 }
365
366 Ok(MegakernelBarrierPlan {
367 global_barriers: if groups.is_empty() {
368 0
369 } else {
370 groups.len() - 1
371 },
372 groups,
373 })
374}
375
376fn group_capacity_hint(
377 wave_count: usize,
378 dependency_count: usize,
379) -> Result<usize, MegakernelBarrierPlanError> {
380 if wave_count == 0 {
381 Ok(0)
382 } else {
383 let dependency_layer_cap = checked_add_usize_count::<MegakernelBarrierPlanError>(
384 dependency_count,
385 1,
386 "barrier group capacity hint",
387 )?;
388 Ok(wave_count.min(dependency_layer_cap))
389 }
390}
391
392fn fill_barrier_vec_zeroed(
393 vec: &mut Vec<usize>,
394 len: usize,
395 field: &'static str,
396) -> Result<(), MegakernelBarrierPlanError> {
397 vec.clear();
398 reserve_vec(vec, len, field)?;
399 vec.extend((0..len).map(|_| 0));
400 Ok(())
401}
402
403fn reserve_vec<T>(
404 vec: &mut Vec<T>,
405 target_capacity: usize,
406 item: &'static str,
407) -> Result<(), MegakernelBarrierPlanError> {
408 reserve_vec_to_capacity(
409 MEGAKERNEL_BARRIER_RESERVATION,
410 vec,
411 target_capacity,
412 item,
413 storage_reserve_failed,
414 )
415}
416
417fn storage_reserve_failed(
418 field: &'static str,
419 requested: usize,
420 message: String,
421) -> MegakernelBarrierPlanError {
422 MegakernelBarrierPlanError::StorageReserveFailed {
423 field,
424 requested,
425 message,
426 }
427}
428
429#[cfg(test)]
430mod tests {
431 use super::{
432 plan_megakernel_barriers, plan_megakernel_barriers_with_scratch,
433 MegakernelBarrierPlanError, MegakernelBarrierScratch, MegakernelWaveDependency,
434 };
435
436 #[test]
437 fn independent_waves_share_one_barrier_free_group() {
438 let plan = plan_megakernel_barriers(4, &[])
439 .expect("Fix: independent megakernel waves should not need barriers.");
440
441 assert_eq!(plan.global_barriers, 0);
442 assert_eq!(plan.groups.len(), 1);
443 assert_eq!(plan.groups[0].waves, vec![0, 1, 2, 3]);
444 }
445
446 #[test]
447 fn dependency_chain_requires_one_barrier_between_each_wave() {
448 let plan = plan_megakernel_barriers(
449 4,
450 &[
451 MegakernelWaveDependency {
452 before: 0,
453 after: 1,
454 },
455 MegakernelWaveDependency {
456 before: 1,
457 after: 2,
458 },
459 MegakernelWaveDependency {
460 before: 2,
461 after: 3,
462 },
463 ],
464 )
465 .expect("Fix: acyclic megakernel wave chain should be schedulable.");
466
467 assert_eq!(plan.global_barriers, 3);
468 assert_eq!(plan.groups[0].waves, vec![0]);
469 assert_eq!(plan.groups[1].waves, vec![1]);
470 assert_eq!(plan.groups[2].waves, vec![2]);
471 assert_eq!(plan.groups[3].waves, vec![3]);
472 }
473
474 #[test]
475 fn diamond_dependencies_fuse_middle_waves() {
476 let plan = plan_megakernel_barriers(
477 4,
478 &[
479 MegakernelWaveDependency {
480 before: 0,
481 after: 1,
482 },
483 MegakernelWaveDependency {
484 before: 0,
485 after: 2,
486 },
487 MegakernelWaveDependency {
488 before: 1,
489 after: 3,
490 },
491 MegakernelWaveDependency {
492 before: 2,
493 after: 3,
494 },
495 ],
496 )
497 .expect("Fix: diamond megakernel dependencies should preserve middle-wave fusion.");
498
499 assert_eq!(plan.global_barriers, 2);
500 assert_eq!(plan.groups[0].waves, vec![0]);
501 assert_eq!(plan.groups[1].waves, vec![1, 2]);
502 assert_eq!(plan.groups[2].waves, vec![3]);
503 }
504
505 #[test]
506 fn invalid_self_and_cyclic_dependencies_fail_loudly() {
507 let invalid = plan_megakernel_barriers(
508 2,
509 &[MegakernelWaveDependency {
510 before: 0,
511 after: 2,
512 }],
513 )
514 .expect_err("Fix: invalid megakernel wave index must fail before planning.");
515 assert!(matches!(
516 invalid,
517 MegakernelBarrierPlanError::InvalidWave { .. }
518 ));
519
520 let self_edge = plan_megakernel_barriers(
521 2,
522 &[MegakernelWaveDependency {
523 before: 1,
524 after: 1,
525 }],
526 )
527 .expect_err("Fix: self-dependent megakernel waves must fail before planning.");
528 assert_eq!(
529 self_edge,
530 MegakernelBarrierPlanError::SelfDependency { wave: 1 }
531 );
532
533 let cycle = plan_megakernel_barriers(
534 2,
535 &[
536 MegakernelWaveDependency {
537 before: 0,
538 after: 1,
539 },
540 MegakernelWaveDependency {
541 before: 1,
542 after: 0,
543 },
544 ],
545 )
546 .expect_err("Fix: cyclic megakernel dependencies require explicit fixed-point kernels.");
547 assert_eq!(
548 cycle,
549 MegakernelBarrierPlanError::Cycle {
550 unscheduled_waves: 2
551 }
552 );
553 }
554
555 #[test]
556 fn barrier_planner_uses_csr_adjacency_for_wide_wave_graphs() {
557 let dependencies = (1..1_025)
558 .map(|after| MegakernelWaveDependency { before: 0, after })
559 .collect::<Vec<_>>();
560 let plan = plan_megakernel_barriers(1_025, &dependencies)
561 .expect("Fix: wide megakernel dependency fanout must schedule without per-wave adjacency allocation.");
562
563 assert_eq!(plan.global_barriers, 1);
564 assert_eq!(plan.groups[0].waves, vec![0]);
565 assert_eq!(plan.groups[1].waves.len(), 1_024);
566
567 let src = include_str!("megakernel_barrier.rs");
568 assert!(
569 !src.contains(concat!("vec![", "Vec::new(); wave_count]")),
570 "Fix: megakernel barrier planner must use contiguous CSR adjacency instead of allocating one Vec per wave."
571 );
572 assert!(
573 !src.contains(concat!("outgoing_offsets[..wave_count]", ".to_vec()")),
574 "Fix: megakernel barrier planner must reuse the counts buffer as the CSR write cursor instead of allocating an O(wave_count) cursor Vec."
575 );
576 assert!(
577 !src.contains(concat!("Vec", "Deque")),
578 "Fix: megakernel barrier planner should use contiguous current/next ready vectors, not deque queue mechanics, for wide wave layers."
579 );
580 assert!(
581 !src.contains(concat!("saturating", "_add")),
582 "Fix: megakernel barrier dependency accounting is bounded by the validated graph shape and must not hide invariant violations with saturating arithmetic."
583 );
584 assert!(
585 src.contains("field: \"outgoing dependency count\"")
586 && src.contains("field: \"incoming dependency count\"")
587 && src.contains("field: \"outgoing dependency offsets\"")
588 && src.contains("field: \"outgoing target cursor\""),
589 "Fix: megakernel barrier CSR construction must use checked arithmetic for dependency counters, offsets, and cursors."
590 );
591 assert!(
592 src.contains("reserve_typed_vec_to_capacity as reserve_vec_to_capacity")
593 && src.contains("fn fill_barrier_vec_zeroed(")
594 && src.contains("StorageReserveFailed"),
595 "Fix: megakernel barrier staging must reserve through shared fallible driver staging instead of panicking under scale pressure."
596 );
597 assert!(
598 !src.contains(concat!("Vec::with_capacity", "(wave_count)"))
599 && !src.contains(concat!(".reserve", "(wave_count)"))
600 && !src.contains(concat!("scratch.outgoing_counts", ".resize"))
601 && !src.contains(concat!("scratch.indegree", ".resize"))
602 && !src.contains(concat!("scratch.outgoing_targets", ".resize")),
603 "Fix: megakernel barrier planner must not use infallible capacity growth in release topology planning."
604 );
605 assert!(
606 !src.contains(concat!(
607 "scratch.outgoing_counts[dependency.before]",
608 " += 1"
609 ))
610 && !src.contains(concat!("scratch.indegree[dependency.after]", " += 1"))
611 && !src.contains(concat!(
612 "let next = scratch.outgoing_offsets.last().copied().unwrap_or(0)",
613 " + *count"
614 )),
615 "Fix: megakernel barrier planning must not use unchecked usize arithmetic for CSR construction."
616 );
617 }
618
619 #[test]
620 fn barrier_planner_reuses_caller_owned_csr_scratch_across_shapes() {
621 let mut scratch = MegakernelBarrierScratch::try_with_capacity(1_025, 1_024)
622 .expect("Fix: wide reusable megakernel barrier scratch should fit");
623 let wide_dependencies = (1..1_025)
624 .map(|after| MegakernelWaveDependency { before: 0, after })
625 .collect::<Vec<_>>();
626 let wide = plan_megakernel_barriers_with_scratch(1_025, &wide_dependencies, &mut scratch)
627 .expect("Fix: wide megakernel dependency fanout should plan with reusable scratch");
628 let wave_capacity = scratch.wave_capacity();
629 let dependency_capacity = scratch.dependency_capacity();
630
631 assert_eq!(wide.groups[1].waves.len(), 1_024);
632
633 let narrow = plan_megakernel_barriers_with_scratch(
634 4,
635 &[
636 MegakernelWaveDependency {
637 before: 0,
638 after: 1,
639 },
640 MegakernelWaveDependency {
641 before: 1,
642 after: 2,
643 },
644 MegakernelWaveDependency {
645 before: 2,
646 after: 3,
647 },
648 ],
649 &mut scratch,
650 )
651 .expect("Fix: narrow megakernel dependency chain should reuse larger scratch");
652
653 assert_eq!(narrow.global_barriers, 3);
654 assert!(scratch.wave_capacity() >= wave_capacity);
655 assert!(scratch.dependency_capacity() >= dependency_capacity);
656 }
657
658 #[test]
659 fn generated_layered_dags_preserve_exact_barrier_depth_for_2048_shapes() {
660 let mut scratch = MegakernelBarrierScratch::default();
661 for width in 1usize..=64 {
662 for depth in 1usize..=32 {
663 let wave_count = width * depth;
664 let mut dependencies = Vec::new();
665 for layer in 0..depth.saturating_sub(1) {
666 let base = layer * width;
667 let next = base + width;
668 for slot in 0..width {
669 dependencies.push(MegakernelWaveDependency {
670 before: base + slot,
671 after: next + slot,
672 });
673 }
674 }
675
676 let plan =
677 plan_megakernel_barriers_with_scratch(wave_count, &dependencies, &mut scratch)
678 .expect("Fix: generated layered megakernel DAG should be schedulable");
679
680 assert_eq!(plan.groups.len(), depth);
681 assert_eq!(plan.global_barriers, depth - 1);
682 for group in &plan.groups {
683 assert_eq!(group.waves.len(), width);
684 }
685 }
686 }
687 }
688}