Skip to main content

vyre_runtime/megakernel/planner/
whole_megakernel_opt.rs

1//! C1 substrate: whole-megakernel optimization domain.
2//!
3//! Per-arm optimization (the existing CSE/DCE per arm, then fuse) is
4//! conservative  -  it can't see structural redundancy ACROSS arms.
5//! When two adjacent arms produce the same intermediate result the
6//! first arm could compute it once and the second arm could just
7//! read it.
8//!
9//! This substrate owns the *cross-arm redundancy detector*: given a
10//! per-arm sequence of `MegakernelWorkItem`s, identify pairs of arms that
11//! emit the same op→input→output triple. The dispatcher uses the
12//! verdict to skip the redundant compute.
13//!
14//! Pure substrate  -  no Program walk, no allocation outside the
15//! returned redundancy report. The actual rewrite (collapse
16//! redundant arms into one + rewire downstream readers) is the
17//! Codex-side runtime work; this substrate just names the
18//! optimization opportunity.
19
20use crate::{megakernel::planner::MegakernelWorkItem, PipelineError};
21use rustc_hash::FxHashMap;
22use vyre_foundation::allocation::{try_reserve_hash_map_to_capacity, try_reserve_vec_to_capacity};
23
24const DENSE_OUTPUT_UNIQUE_BITS: usize = 4096;
25const DENSE_OUTPUT_UNIQUE_WORDS: usize = DENSE_OUTPUT_UNIQUE_BITS / u64::BITS as usize;
26
27/// Report of cross-arm redundancy in a megakernel arm sequence.
28///
29/// Each pair `(early, late)` means arm `late` emits a MegakernelWorkItem that
30/// is structurally identical to one already emitted by arm `early`
31/// (and that arm has not been overwritten since). The runtime can
32/// drop the `late` arm's redundant op and rewire its readers to the
33/// `early` arm's output handle.
34#[derive(Debug, Clone, Default, PartialEq, Eq)]
35pub struct CrossArmRedundancy {
36    /// (early_arm_index, late_arm_index, redundant_op_index_in_late_arm).
37    /// `early_arm_index < late_arm_index` always; the late arm is
38    /// the one whose op is redundant.
39    pub redundant_pairs: Vec<(usize, usize, usize)>,
40    /// Total redundant ops detected across the whole sequence.
41    /// Equal to `redundant_pairs.len()` but exposed separately so
42    /// the dispatcher can budget telemetry without scanning the vec.
43    pub total_redundant_ops: usize,
44}
45
46impl CrossArmRedundancy {
47    /// Empty report  -  no redundancy across arms.
48    #[must_use]
49    pub fn new() -> Self {
50        Self::default()
51    }
52
53    /// Whether this report names any opportunity.
54    #[must_use]
55    pub fn is_empty(&self) -> bool {
56        self.redundant_pairs.is_empty()
57    }
58}
59
60/// Reusable scratch for same-batch work-item dedupe.
61#[derive(Debug, Default)]
62pub struct RedundantWorkItemPruneScratch {
63    first_seen: FxHashMap<(u32, u32, u32, u32), usize>,
64}
65
66impl RedundantWorkItemPruneScratch {
67    /// Clear retained hash state while preserving useful capacity.
68    pub fn clear(&mut self) {
69        self.first_seen.clear();
70    }
71
72    fn try_prepare_for_len(&mut self, len: usize) -> Result<(), PipelineError> {
73        self.first_seen.clear();
74        let retained_ceiling = len.checked_mul(4).unwrap_or(usize::MAX).max(1024);
75        if self.first_seen.capacity() > retained_ceiling {
76            self.first_seen.shrink_to(len);
77        }
78        if self.first_seen.capacity() < len {
79            try_reserve_hash_map_to_capacity(&mut self.first_seen, len).map_err(|source| {
80                PipelineError::Backend(format!(
81                    "megakernel redundant-work hash reservation failed for {len} item(s): {source}. Fix: shard the work batch before pruning."
82                ))
83            })?;
84        }
85        Ok(())
86    }
87}
88
89/// Walk `arms` and detect cross-arm structural redundancy.
90///
91/// For each (op_handle, input_handle, output_handle) triple the
92/// substrate sees in arm N, it remembers which arm produced it. If
93/// an identical triple appears in a later arm M > N, the substrate
94/// records `(N, M, op_idx_in_M)`. WorkItems are compared by the
95/// `(op_handle, input_handle, output_handle)` triple alone  -  the
96/// `param` field is treated as separate launch metadata.
97///
98/// O(total_ops)  -  uses one pass + one hash table. Allocation only
99/// for the redundancy report and the seen-set.
100#[must_use]
101#[cfg(any(test, feature = "legacy-infallible"))]
102pub fn detect_cross_arm_redundancy(arms: &[&[MegakernelWorkItem]]) -> CrossArmRedundancy {
103    try_detect_cross_arm_redundancy(arms).unwrap_or_else(|error| {
104        panic!(
105            "megakernel cross-arm redundancy detection allocation failed: {error}. Fix: split the fused arm sequence before planning."
106        )
107    })
108}
109
110/// Walk `arms` and detect cross-arm structural redundancy with fallible staging.
111///
112/// # Errors
113///
114/// Returns [`PipelineError::Backend`] when host hash/report storage cannot be
115/// reserved for the fused arm sequence.
116pub fn try_detect_cross_arm_redundancy(
117    arms: &[&[MegakernelWorkItem]],
118) -> Result<CrossArmRedundancy, PipelineError> {
119    // (op_handle, input_handle, output_handle) → (arm_idx, op_idx)
120    let total_ops = arms.iter().map(|arm| arm.len()).sum();
121    let mut first_seen: FxHashMap<(u32, u32, u32), usize> = FxHashMap::default();
122    reserve_hash_map(&mut first_seen, total_ops, "cross-arm first-seen")?;
123    let mut report = CrossArmRedundancy {
124        redundant_pairs: Vec::new(),
125        total_redundant_ops: 0,
126    };
127    for (arm_idx, arm) in arms.iter().enumerate() {
128        for (op_idx, item) in arm.iter().enumerate() {
129            let key = (item.op_handle, item.input_handle, item.output_handle);
130            match first_seen.get(&key) {
131                Some(&early_arm_idx) if early_arm_idx < arm_idx => {
132                    reserve_redundant_pairs(&mut report.redundant_pairs, 1, "cross-arm report")?;
133                    report
134                        .redundant_pairs
135                        .push((early_arm_idx, arm_idx, op_idx));
136                }
137                Some(_) => {
138                    // Same arm  -  not a cross-arm redundancy.
139                }
140                None => {
141                    first_seen.insert(key, arm_idx);
142                }
143            }
144        }
145    }
146    report.total_redundant_ops = report.redundant_pairs.len();
147    Ok(report)
148}
149
150/// Copy `items` into `out`, dropping later work items that are byte-for-byte
151/// redundant with an earlier item.
152///
153/// This is the runtime-safe rewrite for the opportunity named by
154/// [`detect_cross_arm_redundancy`]. The detector intentionally ignores `param`
155/// so it can flag broad structural reuse; the rewrite is stricter because
156/// concrete megakernel publishers pass `param` as an opcode argument. A
157/// duplicate `(op_handle, input_handle, output_handle, param)` writes the same
158/// result slot from the same input through the same operation with the same
159/// argument, so the later item only burns queue capacity and GPU cycles. The
160/// first item is retained; all later duplicates are omitted from `out`.
161///
162/// When no duplicates are found, `out` is left empty so hot callers can keep
163/// using the original borrowed queue without paying an avoidable copy.
164///
165#[cfg(any(test, feature = "legacy-infallible"))]
166pub fn prune_redundant_work_items_into(
167    items: &[MegakernelWorkItem],
168    out: &mut Vec<MegakernelWorkItem>,
169) -> CrossArmRedundancy {
170    try_prune_redundant_work_items_into(items, out).unwrap_or_else(|error| {
171        panic!(
172            "megakernel redundant-work pruning allocation failed: {error}. Fix: shard the work batch before pruning."
173        )
174    })
175}
176
177/// Copy `items` into `out`, dropping later exact duplicates with fallible
178/// staging.
179///
180/// # Errors
181///
182/// Returns [`PipelineError::Backend`] when host hash/report/output storage
183/// cannot be reserved for the batch.
184pub fn try_prune_redundant_work_items_into(
185    items: &[MegakernelWorkItem],
186    out: &mut Vec<MegakernelWorkItem>,
187) -> Result<CrossArmRedundancy, PipelineError> {
188    let mut scratch = RedundantWorkItemPruneScratch::default();
189    try_prune_redundant_work_items_with_scratch_into(items, out, &mut scratch)
190}
191
192/// Copy `items` into `out`, dropping later exact duplicates while reusing the
193/// caller-owned hash scratch across dispatches.
194///
195/// This is the hot megakernel-dispatch entry point. The legacy
196/// [`prune_redundant_work_items_into`] wrapper remains for callers that do not
197/// own persistent dispatch scratch.
198#[cfg(any(test, feature = "legacy-infallible"))]
199pub fn prune_redundant_work_items_with_scratch_into(
200    items: &[MegakernelWorkItem],
201    out: &mut Vec<MegakernelWorkItem>,
202    scratch: &mut RedundantWorkItemPruneScratch,
203) -> CrossArmRedundancy {
204    try_prune_redundant_work_items_with_scratch_into(items, out, scratch).unwrap_or_else(|error| {
205        panic!(
206            "megakernel redundant-work pruning allocation failed: {error}. Fix: shard the work batch before pruning."
207        )
208    })
209}
210
211/// Copy `items` into `out`, dropping later exact duplicates while reusing
212/// caller-owned hash scratch and fallible output/report staging.
213///
214/// # Errors
215///
216/// Returns [`PipelineError::Backend`] when host hash/report/output storage
217/// cannot be reserved for the batch.
218pub fn try_prune_redundant_work_items_with_scratch_into(
219    items: &[MegakernelWorkItem],
220    out: &mut Vec<MegakernelWorkItem>,
221    scratch: &mut RedundantWorkItemPruneScratch,
222) -> Result<CrossArmRedundancy, PipelineError> {
223    out.clear();
224
225    if output_handles_are_dense_unique(items) {
226        scratch.clear();
227        return Ok(CrossArmRedundancy::new());
228    }
229
230    scratch.try_prepare_for_len(items.len())?;
231    let mut report = CrossArmRedundancy {
232        redundant_pairs: Vec::new(),
233        total_redundant_ops: 0,
234    };
235    let mut found_duplicate = false;
236
237    for (idx, item) in items.iter().copied().enumerate() {
238        let key = (
239            item.op_handle,
240            item.input_handle,
241            item.output_handle,
242            item.param,
243        );
244        if let Some(&early_idx) = scratch.first_seen.get(&key) {
245            if !found_duplicate {
246                reserve_work_items(out, items.len().checked_sub(1).unwrap_or(0), "dedup output")?;
247                out.extend_from_slice(&items[..idx]);
248                found_duplicate = true;
249            }
250            reserve_redundant_pairs(&mut report.redundant_pairs, 1, "dedup report")?;
251            report.redundant_pairs.push((early_idx, idx, 0));
252            continue;
253        }
254        scratch.first_seen.insert(key, idx);
255        if found_duplicate {
256            out.push(item);
257        }
258    }
259
260    report.total_redundant_ops = report.redundant_pairs.len();
261    Ok(report)
262}
263
264fn reserve_hash_map<K, V>(
265    values: &mut FxHashMap<K, V>,
266    additional: usize,
267    label: &'static str,
268) -> Result<(), PipelineError>
269where
270    K: Eq + std::hash::Hash,
271{
272    if additional > 0 {
273        let capacity = values.len().checked_add(additional).ok_or_else(|| {
274            PipelineError::Backend(format!(
275                "megakernel {label} reservation overflowed for {additional} additional entry slot(s). Fix: shard the work batch before whole-megakernel optimization."
276            ))
277        })?;
278        try_reserve_hash_map_to_capacity(values, capacity).map_err(|source| {
279            PipelineError::Backend(format!(
280                "megakernel {label} reservation failed for {additional} additional entry slot(s): {source}. Fix: shard the work batch before whole-megakernel optimization."
281            ))
282        })?;
283    }
284    Ok(())
285}
286
287fn reserve_redundant_pairs(
288    values: &mut Vec<(usize, usize, usize)>,
289    additional: usize,
290    label: &'static str,
291) -> Result<(), PipelineError> {
292    values.try_reserve(additional).map_err(|source| {
293        PipelineError::Backend(format!(
294            "megakernel {label} reservation failed for {additional} additional pair slot(s): {source}. Fix: shard the work batch before whole-megakernel optimization."
295        ))
296    })
297}
298
299fn reserve_work_items(
300    values: &mut Vec<MegakernelWorkItem>,
301    capacity: usize,
302    label: &'static str,
303) -> Result<(), PipelineError> {
304    if values.capacity() < capacity {
305        try_reserve_vec_to_capacity(values, capacity).map_err(|source| {
306            PipelineError::Backend(format!(
307                "megakernel {label} reservation failed for {capacity} item slot(s): {source}. Fix: shard the work batch before whole-megakernel optimization."
308            ))
309        })?;
310    }
311    Ok(())
312}
313
314fn output_handles_are_dense_unique(items: &[MegakernelWorkItem]) -> bool {
315    if items.len() <= 1 {
316        return true;
317    }
318    if items.len() > DENSE_OUTPUT_UNIQUE_BITS {
319        return false;
320    }
321
322    let mut min = u32::MAX;
323    let mut max = 0u32;
324    for item in items {
325        min = min.min(item.output_handle);
326        max = max.max(item.output_handle);
327    }
328    let Some(range) = u64::from(max)
329        .checked_sub(u64::from(min))
330        .and_then(|value| value.checked_add(1))
331    else {
332        return false;
333    };
334    if range > DENSE_OUTPUT_UNIQUE_BITS as u64 {
335        return false;
336    }
337
338    let mut seen = [0u64; DENSE_OUTPUT_UNIQUE_WORDS];
339    for item in items {
340        let Some(delta) = item.output_handle.checked_sub(min) else {
341            return false;
342        };
343        let Ok(offset) = usize::try_from(delta) else {
344            return false;
345        };
346        let word = offset / 64;
347        let bit = 1u64
348            << (offset % 64);
349        if (seen[word] & bit) != 0 {
350            return false;
351        }
352        seen[word] |= bit;
353    }
354    true
355}
356
357#[cfg(test)]
358mod tests {
359    use super::*;
360
361    fn item(op: u32, inp: u32, out: u32) -> MegakernelWorkItem {
362        MegakernelWorkItem {
363            op_handle: op,
364            input_handle: inp,
365            output_handle: out,
366            param: 0,
367        }
368    }
369
370    #[test]
371    fn empty_arms_have_no_redundancy() {
372        let arms: [&[MegakernelWorkItem]; 0] = [];
373        assert_eq!(
374            detect_cross_arm_redundancy(&arms),
375            CrossArmRedundancy::new()
376        );
377    }
378
379    #[test]
380    fn single_arm_with_repeats_has_no_cross_arm_redundancy() {
381        let a = vec![item(1, 0, 5), item(1, 0, 5), item(2, 5, 6)];
382        let arms: [&[MegakernelWorkItem]; 1] = [&a];
383        let report = detect_cross_arm_redundancy(&arms);
384        assert!(report.is_empty(), "intra-arm repeats are not cross-arm");
385        assert_eq!(report.total_redundant_ops, 0);
386    }
387
388    #[test]
389    fn identical_arms_report_full_overlap() {
390        let a = vec![item(1, 0, 5), item(2, 5, 6)];
391        let b = vec![item(1, 0, 5), item(2, 5, 6)];
392        let arms: [&[MegakernelWorkItem]; 2] = [&a, &b];
393        let report = detect_cross_arm_redundancy(&arms);
394        assert_eq!(report.total_redundant_ops, 2);
395        assert_eq!(report.redundant_pairs, vec![(0, 1, 0), (0, 1, 1)]);
396    }
397
398    #[test]
399    fn fully_disjoint_arms_have_no_redundancy() {
400        let a = vec![item(1, 0, 5)];
401        let b = vec![item(2, 7, 8)];
402        let arms: [&[MegakernelWorkItem]; 2] = [&a, &b];
403        assert!(detect_cross_arm_redundancy(&arms).is_empty());
404    }
405
406    #[test]
407    fn redundancy_uses_first_seen_arm_index() {
408        // Op appears in arms 0, 2, 3  -  both 2 and 3 should reference 0.
409        let a = vec![item(1, 0, 5)];
410        let b = vec![item(99, 0, 0)];
411        let c = vec![item(1, 0, 5)];
412        let d = vec![item(1, 0, 5)];
413        let arms: [&[MegakernelWorkItem]; 4] = [&a, &b, &c, &d];
414        let report = detect_cross_arm_redundancy(&arms);
415        assert_eq!(report.total_redundant_ops, 2);
416        assert_eq!(report.redundant_pairs, vec![(0, 2, 0), (0, 3, 0)]);
417    }
418
419    #[test]
420    fn param_field_does_not_affect_redundancy() {
421        // Same (op, in, out) triple but different param  -  still
422        // cross-arm redundant by this substrate's contract.
423        let a = vec![MegakernelWorkItem {
424            op_handle: 1,
425            input_handle: 0,
426            output_handle: 5,
427            param: 7,
428        }];
429        let b = vec![MegakernelWorkItem {
430            op_handle: 1,
431            input_handle: 0,
432            output_handle: 5,
433            param: 99,
434        }];
435
436        let arms: [&[MegakernelWorkItem]; 2] = [&a, &b];
437        let report = detect_cross_arm_redundancy(&arms);
438        assert_eq!(report.total_redundant_ops, 1);
439    }
440
441    #[test]
442    fn different_inputs_are_not_redundant() {
443        let a = vec![item(1, 0, 5)];
444        let b = vec![item(1, 1, 5)]; // different input handle
445        let arms: [&[MegakernelWorkItem]; 2] = [&a, &b];
446        assert!(detect_cross_arm_redundancy(&arms).is_empty());
447    }
448
449    #[test]
450    fn different_outputs_are_not_redundant() {
451        let a = vec![item(1, 0, 5)];
452        let b = vec![item(1, 0, 6)]; // different output handle
453        let arms: [&[MegakernelWorkItem]; 2] = [&a, &b];
454        assert!(detect_cross_arm_redundancy(&arms).is_empty());
455    }
456
457    #[test]
458    fn op_index_refers_to_late_arm_position() {
459        // Verify the third tuple element is the index WITHIN the
460        // late arm, not a global op index.
461        let a = vec![item(1, 0, 5)];
462        let b = vec![item(99, 0, 0), item(1, 0, 5), item(42, 0, 0)];
463        let arms: [&[MegakernelWorkItem]; 2] = [&a, &b];
464        let report = detect_cross_arm_redundancy(&arms);
465        assert_eq!(report.redundant_pairs, vec![(0, 1, 1)]);
466    }
467
468    #[test]
469    fn prune_redundant_work_items_drops_later_duplicates() {
470        let items = vec![
471            item(1, 0, 5),
472            item(2, 5, 6),
473            item(1, 0, 5),
474            item(3, 6, 7),
475            item(2, 5, 6),
476        ];
477        let mut out = Vec::new();
478
479        let report = prune_redundant_work_items_into(&items, &mut out);
480
481        assert_eq!(out, vec![item(1, 0, 5), item(2, 5, 6), item(3, 6, 7)]);
482        assert_eq!(report.total_redundant_ops, 2);
483        assert_eq!(report.redundant_pairs, vec![(0, 2, 0), (1, 4, 0)]);
484    }
485
486    #[test]
487    fn prune_redundant_work_items_reuses_hash_scratch() {
488        let items = vec![item(1, 0, 5), item(2, 5, 6), item(1, 0, 5), item(3, 6, 7)];
489        let mut out = Vec::new();
490        let mut scratch = RedundantWorkItemPruneScratch::default();
491
492        let first = prune_redundant_work_items_with_scratch_into(&items, &mut out, &mut scratch);
493        let retained_capacity = scratch.first_seen.capacity();
494        out.clear();
495        let second = prune_redundant_work_items_with_scratch_into(&items, &mut out, &mut scratch);
496
497        assert_eq!(first, second);
498        assert!(
499            scratch.first_seen.capacity() >= retained_capacity,
500            "hot megakernel dedupe must retain hash capacity across repeated dispatches"
501        );
502        assert_eq!(out, vec![item(1, 0, 5), item(2, 5, 6), item(3, 6, 7)]);
503    }
504
505    #[test]
506    fn prune_redundant_work_items_handles_empty_input() {
507        let mut out = vec![item(99, 99, 99)];
508
509        let report = prune_redundant_work_items_into(&[], &mut out);
510
511        assert!(report.is_empty());
512        assert!(out.is_empty());
513    }
514
515    #[test]
516    fn prune_redundant_work_items_all_duplicates_keep_one() {
517        let items = vec![item(1, 0, 5), item(1, 0, 5), item(1, 0, 5)];
518        let mut out = Vec::new();
519
520        let report = prune_redundant_work_items_into(&items, &mut out);
521
522        assert_eq!(out, vec![item(1, 0, 5)]);
523        assert_eq!(report.total_redundant_ops, 2);
524        assert_eq!(report.redundant_pairs, vec![(0, 1, 0), (0, 2, 0)]);
525    }
526
527    #[test]
528    fn prune_redundant_work_items_preserves_order_after_first_duplicate() {
529        let items = vec![
530            item(1, 0, 5),
531            item(2, 5, 6),
532            item(1, 0, 5),
533            item(3, 6, 7),
534            item(4, 7, 8),
535        ];
536        let mut out = Vec::new();
537
538        let report = prune_redundant_work_items_into(&items, &mut out);
539
540        assert_eq!(
541            out,
542            vec![item(1, 0, 5), item(2, 5, 6), item(3, 6, 7), item(4, 7, 8)]
543        );
544        assert_eq!(report.redundant_pairs, vec![(0, 2, 0)]);
545    }
546
547    #[test]
548    fn prune_redundant_work_items_leaves_output_empty_when_no_copy_needed() {
549        let items = vec![item(1, 0, 5)];
550        let mut out = vec![item(99, 99, 99)];
551
552        let report = prune_redundant_work_items_into(&items, &mut out);
553
554        assert!(report.is_empty());
555        assert!(out.is_empty());
556    }
557
558    #[test]
559    fn prune_redundant_work_items_keeps_distinct_params() {
560        let mut a = item(1, 0, 5);
561        a.param = 7;
562        let mut b = item(1, 0, 5);
563        b.param = 99;
564        let items = vec![a, b];
565        let mut out = Vec::new();
566
567        let report = prune_redundant_work_items_into(&items, &mut out);
568
569        assert!(report.is_empty());
570        assert!(out.is_empty());
571    }
572
573    #[test]
574    fn output_handles_dense_unique_accepts_single_owner_outputs() {
575        let items = vec![item(1, 0, 5), item(1, 0, 6), item(1, 0, 7)];
576
577        assert!(output_handles_are_dense_unique(&items));
578    }
579
580    #[test]
581    fn output_handles_dense_unique_rejects_repeated_output() {
582        let items = vec![item(1, 0, 5), item(2, 0, 5)];
583
584        assert!(!output_handles_are_dense_unique(&items));
585    }
586
587    #[test]
588    fn prune_redundant_work_items_still_catches_duplicate_with_repeated_output() {
589        let items = vec![item(1, 0, 5), item(2, 0, 6), item(1, 0, 5)];
590        let mut out = Vec::new();
591
592        let report = prune_redundant_work_items_into(&items, &mut out);
593
594        assert_eq!(report.total_redundant_ops, 1);
595        assert_eq!(out, vec![item(1, 0, 5), item(2, 0, 6)]);
596    }
597}