1use crate::{megakernel::planner::MegakernelWorkItem, PipelineError};
21use rustc_hash::FxHashMap;
22use vyre_foundation::allocation::{try_reserve_hash_map_to_capacity, try_reserve_vec_to_capacity};
23
24const DENSE_OUTPUT_UNIQUE_BITS: usize = 4096;
25const DENSE_OUTPUT_UNIQUE_WORDS: usize = DENSE_OUTPUT_UNIQUE_BITS / u64::BITS as usize;
26
27#[derive(Debug, Clone, Default, PartialEq, Eq)]
35pub struct CrossArmRedundancy {
36 pub redundant_pairs: Vec<(usize, usize, usize)>,
40 pub total_redundant_ops: usize,
44}
45
46impl CrossArmRedundancy {
47 #[must_use]
49 pub fn new() -> Self {
50 Self::default()
51 }
52
53 #[must_use]
55 pub fn is_empty(&self) -> bool {
56 self.redundant_pairs.is_empty()
57 }
58}
59
60#[derive(Debug, Default)]
62pub struct RedundantWorkItemPruneScratch {
63 first_seen: FxHashMap<(u32, u32, u32, u32), usize>,
64}
65
66impl RedundantWorkItemPruneScratch {
67 pub fn clear(&mut self) {
69 self.first_seen.clear();
70 }
71
72 fn try_prepare_for_len(&mut self, len: usize) -> Result<(), PipelineError> {
73 self.first_seen.clear();
74 let retained_ceiling = len.checked_mul(4).unwrap_or(usize::MAX).max(1024);
75 if self.first_seen.capacity() > retained_ceiling {
76 self.first_seen.shrink_to(len);
77 }
78 if self.first_seen.capacity() < len {
79 try_reserve_hash_map_to_capacity(&mut self.first_seen, len).map_err(|source| {
80 PipelineError::Backend(format!(
81 "megakernel redundant-work hash reservation failed for {len} item(s): {source}. Fix: shard the work batch before pruning."
82 ))
83 })?;
84 }
85 Ok(())
86 }
87}
88
89#[must_use]
101#[cfg(any(test, feature = "legacy-infallible"))]
102pub fn detect_cross_arm_redundancy(arms: &[&[MegakernelWorkItem]]) -> CrossArmRedundancy {
103 try_detect_cross_arm_redundancy(arms).unwrap_or_else(|error| {
104 panic!(
105 "megakernel cross-arm redundancy detection allocation failed: {error}. Fix: split the fused arm sequence before planning."
106 )
107 })
108}
109
110pub fn try_detect_cross_arm_redundancy(
117 arms: &[&[MegakernelWorkItem]],
118) -> Result<CrossArmRedundancy, PipelineError> {
119 let total_ops = arms.iter().map(|arm| arm.len()).sum();
121 let mut first_seen: FxHashMap<(u32, u32, u32), usize> = FxHashMap::default();
122 reserve_hash_map(&mut first_seen, total_ops, "cross-arm first-seen")?;
123 let mut report = CrossArmRedundancy {
124 redundant_pairs: Vec::new(),
125 total_redundant_ops: 0,
126 };
127 for (arm_idx, arm) in arms.iter().enumerate() {
128 for (op_idx, item) in arm.iter().enumerate() {
129 let key = (item.op_handle, item.input_handle, item.output_handle);
130 match first_seen.get(&key) {
131 Some(&early_arm_idx) if early_arm_idx < arm_idx => {
132 reserve_redundant_pairs(&mut report.redundant_pairs, 1, "cross-arm report")?;
133 report
134 .redundant_pairs
135 .push((early_arm_idx, arm_idx, op_idx));
136 }
137 Some(_) => {
138 }
140 None => {
141 first_seen.insert(key, arm_idx);
142 }
143 }
144 }
145 }
146 report.total_redundant_ops = report.redundant_pairs.len();
147 Ok(report)
148}
149
150#[cfg(any(test, feature = "legacy-infallible"))]
166pub fn prune_redundant_work_items_into(
167 items: &[MegakernelWorkItem],
168 out: &mut Vec<MegakernelWorkItem>,
169) -> CrossArmRedundancy {
170 try_prune_redundant_work_items_into(items, out).unwrap_or_else(|error| {
171 panic!(
172 "megakernel redundant-work pruning allocation failed: {error}. Fix: shard the work batch before pruning."
173 )
174 })
175}
176
177pub fn try_prune_redundant_work_items_into(
185 items: &[MegakernelWorkItem],
186 out: &mut Vec<MegakernelWorkItem>,
187) -> Result<CrossArmRedundancy, PipelineError> {
188 let mut scratch = RedundantWorkItemPruneScratch::default();
189 try_prune_redundant_work_items_with_scratch_into(items, out, &mut scratch)
190}
191
192#[cfg(any(test, feature = "legacy-infallible"))]
199pub fn prune_redundant_work_items_with_scratch_into(
200 items: &[MegakernelWorkItem],
201 out: &mut Vec<MegakernelWorkItem>,
202 scratch: &mut RedundantWorkItemPruneScratch,
203) -> CrossArmRedundancy {
204 try_prune_redundant_work_items_with_scratch_into(items, out, scratch).unwrap_or_else(|error| {
205 panic!(
206 "megakernel redundant-work pruning allocation failed: {error}. Fix: shard the work batch before pruning."
207 )
208 })
209}
210
211pub fn try_prune_redundant_work_items_with_scratch_into(
219 items: &[MegakernelWorkItem],
220 out: &mut Vec<MegakernelWorkItem>,
221 scratch: &mut RedundantWorkItemPruneScratch,
222) -> Result<CrossArmRedundancy, PipelineError> {
223 out.clear();
224
225 if output_handles_are_dense_unique(items) {
226 scratch.clear();
227 return Ok(CrossArmRedundancy::new());
228 }
229
230 scratch.try_prepare_for_len(items.len())?;
231 let mut report = CrossArmRedundancy {
232 redundant_pairs: Vec::new(),
233 total_redundant_ops: 0,
234 };
235 let mut found_duplicate = false;
236
237 for (idx, item) in items.iter().copied().enumerate() {
238 let key = (
239 item.op_handle,
240 item.input_handle,
241 item.output_handle,
242 item.param,
243 );
244 if let Some(&early_idx) = scratch.first_seen.get(&key) {
245 if !found_duplicate {
246 reserve_work_items(out, items.len().checked_sub(1).unwrap_or(0), "dedup output")?;
247 out.extend_from_slice(&items[..idx]);
248 found_duplicate = true;
249 }
250 reserve_redundant_pairs(&mut report.redundant_pairs, 1, "dedup report")?;
251 report.redundant_pairs.push((early_idx, idx, 0));
252 continue;
253 }
254 scratch.first_seen.insert(key, idx);
255 if found_duplicate {
256 out.push(item);
257 }
258 }
259
260 report.total_redundant_ops = report.redundant_pairs.len();
261 Ok(report)
262}
263
264fn reserve_hash_map<K, V>(
265 values: &mut FxHashMap<K, V>,
266 additional: usize,
267 label: &'static str,
268) -> Result<(), PipelineError>
269where
270 K: Eq + std::hash::Hash,
271{
272 if additional > 0 {
273 let capacity = values.len().checked_add(additional).ok_or_else(|| {
274 PipelineError::Backend(format!(
275 "megakernel {label} reservation overflowed for {additional} additional entry slot(s). Fix: shard the work batch before whole-megakernel optimization."
276 ))
277 })?;
278 try_reserve_hash_map_to_capacity(values, capacity).map_err(|source| {
279 PipelineError::Backend(format!(
280 "megakernel {label} reservation failed for {additional} additional entry slot(s): {source}. Fix: shard the work batch before whole-megakernel optimization."
281 ))
282 })?;
283 }
284 Ok(())
285}
286
287fn reserve_redundant_pairs(
288 values: &mut Vec<(usize, usize, usize)>,
289 additional: usize,
290 label: &'static str,
291) -> Result<(), PipelineError> {
292 values.try_reserve(additional).map_err(|source| {
293 PipelineError::Backend(format!(
294 "megakernel {label} reservation failed for {additional} additional pair slot(s): {source}. Fix: shard the work batch before whole-megakernel optimization."
295 ))
296 })
297}
298
299fn reserve_work_items(
300 values: &mut Vec<MegakernelWorkItem>,
301 capacity: usize,
302 label: &'static str,
303) -> Result<(), PipelineError> {
304 if values.capacity() < capacity {
305 try_reserve_vec_to_capacity(values, capacity).map_err(|source| {
306 PipelineError::Backend(format!(
307 "megakernel {label} reservation failed for {capacity} item slot(s): {source}. Fix: shard the work batch before whole-megakernel optimization."
308 ))
309 })?;
310 }
311 Ok(())
312}
313
314fn output_handles_are_dense_unique(items: &[MegakernelWorkItem]) -> bool {
315 if items.len() <= 1 {
316 return true;
317 }
318 if items.len() > DENSE_OUTPUT_UNIQUE_BITS {
319 return false;
320 }
321
322 let mut min = u32::MAX;
323 let mut max = 0u32;
324 for item in items {
325 min = min.min(item.output_handle);
326 max = max.max(item.output_handle);
327 }
328 let Some(range) = u64::from(max)
329 .checked_sub(u64::from(min))
330 .and_then(|value| value.checked_add(1))
331 else {
332 return false;
333 };
334 if range > DENSE_OUTPUT_UNIQUE_BITS as u64 {
335 return false;
336 }
337
338 let mut seen = [0u64; DENSE_OUTPUT_UNIQUE_WORDS];
339 for item in items {
340 let Some(delta) = item.output_handle.checked_sub(min) else {
341 return false;
342 };
343 let Ok(offset) = usize::try_from(delta) else {
344 return false;
345 };
346 let word = offset / 64;
347 let bit = 1u64
348 << (offset % 64);
349 if (seen[word] & bit) != 0 {
350 return false;
351 }
352 seen[word] |= bit;
353 }
354 true
355}
356
357#[cfg(test)]
358mod tests {
359 use super::*;
360
361 fn item(op: u32, inp: u32, out: u32) -> MegakernelWorkItem {
362 MegakernelWorkItem {
363 op_handle: op,
364 input_handle: inp,
365 output_handle: out,
366 param: 0,
367 }
368 }
369
370 #[test]
371 fn empty_arms_have_no_redundancy() {
372 let arms: [&[MegakernelWorkItem]; 0] = [];
373 assert_eq!(
374 detect_cross_arm_redundancy(&arms),
375 CrossArmRedundancy::new()
376 );
377 }
378
379 #[test]
380 fn single_arm_with_repeats_has_no_cross_arm_redundancy() {
381 let a = vec![item(1, 0, 5), item(1, 0, 5), item(2, 5, 6)];
382 let arms: [&[MegakernelWorkItem]; 1] = [&a];
383 let report = detect_cross_arm_redundancy(&arms);
384 assert!(report.is_empty(), "intra-arm repeats are not cross-arm");
385 assert_eq!(report.total_redundant_ops, 0);
386 }
387
388 #[test]
389 fn identical_arms_report_full_overlap() {
390 let a = vec![item(1, 0, 5), item(2, 5, 6)];
391 let b = vec![item(1, 0, 5), item(2, 5, 6)];
392 let arms: [&[MegakernelWorkItem]; 2] = [&a, &b];
393 let report = detect_cross_arm_redundancy(&arms);
394 assert_eq!(report.total_redundant_ops, 2);
395 assert_eq!(report.redundant_pairs, vec![(0, 1, 0), (0, 1, 1)]);
396 }
397
398 #[test]
399 fn fully_disjoint_arms_have_no_redundancy() {
400 let a = vec![item(1, 0, 5)];
401 let b = vec![item(2, 7, 8)];
402 let arms: [&[MegakernelWorkItem]; 2] = [&a, &b];
403 assert!(detect_cross_arm_redundancy(&arms).is_empty());
404 }
405
406 #[test]
407 fn redundancy_uses_first_seen_arm_index() {
408 let a = vec![item(1, 0, 5)];
410 let b = vec![item(99, 0, 0)];
411 let c = vec![item(1, 0, 5)];
412 let d = vec![item(1, 0, 5)];
413 let arms: [&[MegakernelWorkItem]; 4] = [&a, &b, &c, &d];
414 let report = detect_cross_arm_redundancy(&arms);
415 assert_eq!(report.total_redundant_ops, 2);
416 assert_eq!(report.redundant_pairs, vec![(0, 2, 0), (0, 3, 0)]);
417 }
418
419 #[test]
420 fn param_field_does_not_affect_redundancy() {
421 let a = vec![MegakernelWorkItem {
424 op_handle: 1,
425 input_handle: 0,
426 output_handle: 5,
427 param: 7,
428 }];
429 let b = vec![MegakernelWorkItem {
430 op_handle: 1,
431 input_handle: 0,
432 output_handle: 5,
433 param: 99,
434 }];
435
436 let arms: [&[MegakernelWorkItem]; 2] = [&a, &b];
437 let report = detect_cross_arm_redundancy(&arms);
438 assert_eq!(report.total_redundant_ops, 1);
439 }
440
441 #[test]
442 fn different_inputs_are_not_redundant() {
443 let a = vec![item(1, 0, 5)];
444 let b = vec![item(1, 1, 5)]; let arms: [&[MegakernelWorkItem]; 2] = [&a, &b];
446 assert!(detect_cross_arm_redundancy(&arms).is_empty());
447 }
448
449 #[test]
450 fn different_outputs_are_not_redundant() {
451 let a = vec![item(1, 0, 5)];
452 let b = vec![item(1, 0, 6)]; let arms: [&[MegakernelWorkItem]; 2] = [&a, &b];
454 assert!(detect_cross_arm_redundancy(&arms).is_empty());
455 }
456
457 #[test]
458 fn op_index_refers_to_late_arm_position() {
459 let a = vec![item(1, 0, 5)];
462 let b = vec![item(99, 0, 0), item(1, 0, 5), item(42, 0, 0)];
463 let arms: [&[MegakernelWorkItem]; 2] = [&a, &b];
464 let report = detect_cross_arm_redundancy(&arms);
465 assert_eq!(report.redundant_pairs, vec![(0, 1, 1)]);
466 }
467
468 #[test]
469 fn prune_redundant_work_items_drops_later_duplicates() {
470 let items = vec![
471 item(1, 0, 5),
472 item(2, 5, 6),
473 item(1, 0, 5),
474 item(3, 6, 7),
475 item(2, 5, 6),
476 ];
477 let mut out = Vec::new();
478
479 let report = prune_redundant_work_items_into(&items, &mut out);
480
481 assert_eq!(out, vec![item(1, 0, 5), item(2, 5, 6), item(3, 6, 7)]);
482 assert_eq!(report.total_redundant_ops, 2);
483 assert_eq!(report.redundant_pairs, vec![(0, 2, 0), (1, 4, 0)]);
484 }
485
486 #[test]
487 fn prune_redundant_work_items_reuses_hash_scratch() {
488 let items = vec![item(1, 0, 5), item(2, 5, 6), item(1, 0, 5), item(3, 6, 7)];
489 let mut out = Vec::new();
490 let mut scratch = RedundantWorkItemPruneScratch::default();
491
492 let first = prune_redundant_work_items_with_scratch_into(&items, &mut out, &mut scratch);
493 let retained_capacity = scratch.first_seen.capacity();
494 out.clear();
495 let second = prune_redundant_work_items_with_scratch_into(&items, &mut out, &mut scratch);
496
497 assert_eq!(first, second);
498 assert!(
499 scratch.first_seen.capacity() >= retained_capacity,
500 "hot megakernel dedupe must retain hash capacity across repeated dispatches"
501 );
502 assert_eq!(out, vec![item(1, 0, 5), item(2, 5, 6), item(3, 6, 7)]);
503 }
504
505 #[test]
506 fn prune_redundant_work_items_handles_empty_input() {
507 let mut out = vec![item(99, 99, 99)];
508
509 let report = prune_redundant_work_items_into(&[], &mut out);
510
511 assert!(report.is_empty());
512 assert!(out.is_empty());
513 }
514
515 #[test]
516 fn prune_redundant_work_items_all_duplicates_keep_one() {
517 let items = vec![item(1, 0, 5), item(1, 0, 5), item(1, 0, 5)];
518 let mut out = Vec::new();
519
520 let report = prune_redundant_work_items_into(&items, &mut out);
521
522 assert_eq!(out, vec![item(1, 0, 5)]);
523 assert_eq!(report.total_redundant_ops, 2);
524 assert_eq!(report.redundant_pairs, vec![(0, 1, 0), (0, 2, 0)]);
525 }
526
527 #[test]
528 fn prune_redundant_work_items_preserves_order_after_first_duplicate() {
529 let items = vec![
530 item(1, 0, 5),
531 item(2, 5, 6),
532 item(1, 0, 5),
533 item(3, 6, 7),
534 item(4, 7, 8),
535 ];
536 let mut out = Vec::new();
537
538 let report = prune_redundant_work_items_into(&items, &mut out);
539
540 assert_eq!(
541 out,
542 vec![item(1, 0, 5), item(2, 5, 6), item(3, 6, 7), item(4, 7, 8)]
543 );
544 assert_eq!(report.redundant_pairs, vec![(0, 2, 0)]);
545 }
546
547 #[test]
548 fn prune_redundant_work_items_leaves_output_empty_when_no_copy_needed() {
549 let items = vec![item(1, 0, 5)];
550 let mut out = vec![item(99, 99, 99)];
551
552 let report = prune_redundant_work_items_into(&items, &mut out);
553
554 assert!(report.is_empty());
555 assert!(out.is_empty());
556 }
557
558 #[test]
559 fn prune_redundant_work_items_keeps_distinct_params() {
560 let mut a = item(1, 0, 5);
561 a.param = 7;
562 let mut b = item(1, 0, 5);
563 b.param = 99;
564 let items = vec![a, b];
565 let mut out = Vec::new();
566
567 let report = prune_redundant_work_items_into(&items, &mut out);
568
569 assert!(report.is_empty());
570 assert!(out.is_empty());
571 }
572
573 #[test]
574 fn output_handles_dense_unique_accepts_single_owner_outputs() {
575 let items = vec![item(1, 0, 5), item(1, 0, 6), item(1, 0, 7)];
576
577 assert!(output_handles_are_dense_unique(&items));
578 }
579
580 #[test]
581 fn output_handles_dense_unique_rejects_repeated_output() {
582 let items = vec![item(1, 0, 5), item(2, 0, 5)];
583
584 assert!(!output_handles_are_dense_unique(&items));
585 }
586
587 #[test]
588 fn prune_redundant_work_items_still_catches_duplicate_with_repeated_output() {
589 let items = vec![item(1, 0, 5), item(2, 0, 6), item(1, 0, 5)];
590 let mut out = Vec::new();
591
592 let report = prune_redundant_work_items_into(&items, &mut out);
593
594 assert_eq!(report.total_redundant_ops, 1);
595 assert_eq!(out, vec![item(1, 0, 5), item(2, 0, 6)]);
596 }
597}