Skip to main content

vyre_runtime/megakernel/
speculation.rs

1//! Runtime-side paired speculation races for megakernel dispatch.
2//!
3//! The driver crate owns the backend-neutral decision math. This module
4//! owns the megakernel/runtime bridge: every candidate rewrite is measured
5//! as a conservative/speculative pair, the faster side is recorded in the
6//! shared autotune store, and the accumulated sample window is converted
7//! into the N2 adoption verdict.
8
9use vyre_driver::autotune_store::{AutotuneRecord, AutotuneStore};
10use vyre_driver::speculate::{
11    record_speculative_variant_race, SpeculativeVariantDecision, SpeculativeVariantKeys,
12    SpeculativeVariantRace,
13};
14use vyre_driver::speculation_substrate::{
15    decide_speculation, SpeculationObservation, SpeculationVerdict,
16};
17
18/// One measured conservative/speculative dispatch pair.
19#[derive(Debug, Clone)]
20pub struct PairedSpeculationSample {
21    /// Conservative dispatch elapsed time, excluding compile/cache miss.
22    pub conservative_dispatch_ns: u64,
23    /// Speculative dispatch elapsed time, excluding compile/cache miss.
24    pub speculative_dispatch_ns: u64,
25    /// Conservative compile/cache-miss time for this pair.
26    pub conservative_compile_ns: u64,
27    /// Speculative compile/cache-miss time for this pair.
28    pub speculative_compile_ns: u64,
29    /// Autotune record attached to the conservative variant.
30    pub conservative_record: AutotuneRecord,
31    /// Autotune record attached to the speculative variant.
32    pub speculative_record: AutotuneRecord,
33}
34
35/// Result of recording one paired race.
36#[derive(Debug, Clone)]
37pub struct PairedSpeculationUpdate {
38    /// Winning per-sample cache/autotune decision.
39    pub race_decision: SpeculativeVariantDecision,
40    /// Accumulated N2 verdict for the shape.
41    pub verdict: SpeculationVerdict,
42    /// Observation fed into the verdict.
43    pub observation: SpeculationObservation,
44}
45
46/// Accumulated paired-race window for one rewrite candidate and shape.
47#[derive(Debug, Default, Clone)]
48pub struct PairedSpeculationWindow {
49    conservative: RunningMean,
50    speculative: RunningMean,
51    side_compile_cost_ns: u64,
52}
53
54impl PairedSpeculationWindow {
55    /// Empty paired-race window.
56    #[must_use]
57    pub const fn new() -> Self {
58        Self {
59            conservative: RunningMean::new(),
60            speculative: RunningMean::new(),
61            side_compile_cost_ns: 0,
62        }
63    }
64
65    /// Number of paired samples recorded.
66    #[must_use]
67    pub fn len(&self) -> u32 {
68        self.conservative.count.min(self.speculative.count)
69    }
70
71    /// True when no paired samples were recorded.
72    #[must_use]
73    pub fn is_empty(&self) -> bool {
74        self.len() == 0
75    }
76
77    /// Current observation for the N2 speculation policy.
78    #[must_use]
79    pub fn observation(&self) -> SpeculationObservation {
80        SpeculationObservation {
81            baseline_dispatches: self.conservative.count,
82            baseline_mean_ns: self.conservative.mean_ns(),
83            speculative_dispatches: self.speculative.count,
84            speculative_mean_ns: self.speculative.mean_ns(),
85            side_compile_cost_ns: self.side_compile_cost_ns,
86        }
87    }
88
89    /// Record one paired sample, update the autotune store with the
90    /// per-sample winner, and return the accumulated adoption verdict.
91    pub fn record_sample(
92        &mut self,
93        store: &mut AutotuneStore,
94        keys: SpeculativeVariantKeys<'_>,
95        sample: PairedSpeculationSample,
96    ) -> PairedSpeculationUpdate {
97        self.conservative.record(sample.conservative_dispatch_ns);
98        self.speculative.record(sample.speculative_dispatch_ns);
99        self.side_compile_cost_ns = self
100            .side_compile_cost_ns
101            .checked_add(sample.speculative_compile_ns)
102            .unwrap_or_else(|| {
103                panic!(
104                    "paired speculation side compile cost overflowed u64. Fix: reset the speculation window before accumulating more samples."
105                )
106            });
107
108        let race_decision = record_speculative_variant_race(
109            store,
110            keys,
111            SpeculativeVariantRace {
112                conservative_dispatch_ns: sample.conservative_dispatch_ns,
113                speculative_dispatch_ns: sample.speculative_dispatch_ns,
114                conservative_compile_ns: sample.conservative_compile_ns,
115                speculative_compile_ns: sample.speculative_compile_ns,
116                conservative_record: sample.conservative_record,
117                speculative_record: sample.speculative_record,
118            },
119        );
120        let observation = self.observation();
121        let verdict = decide_speculation(observation);
122        PairedSpeculationUpdate {
123            race_decision,
124            verdict,
125            observation,
126        }
127    }
128}
129
130#[derive(Debug, Default, Clone)]
131struct RunningMean {
132    count: u32,
133    total_ns: u128,
134}
135
136impl RunningMean {
137    const fn new() -> Self {
138        Self {
139            count: 0,
140            total_ns: 0,
141        }
142    }
143
144    fn record(&mut self, value_ns: u64) {
145        self.count = self.count.checked_add(1).unwrap_or_else(|| {
146            panic!(
147                "paired speculation sample count overflowed u32. Fix: reset the speculation window before accumulating more samples."
148            )
149        });
150        self.total_ns = self.total_ns.checked_add(u128::from(value_ns)).unwrap_or_else(|| {
151            panic!(
152                "paired speculation total nanoseconds overflowed u128. Fix: reset the speculation window before accumulating more samples."
153            )
154        });
155    }
156
157    fn mean_ns(&self) -> u64 {
158        if self.count == 0 {
159            return 0;
160        }
161        let mean = self.total_ns / u128::from(self.count);
162        u64::try_from(mean).unwrap_or_else(|error| {
163            panic!(
164                "paired speculation mean nanoseconds cannot fit u64: {error}. Fix: reset the speculation window before accumulating more samples."
165            )
166        })
167    }
168}
169
170#[cfg(test)]
171mod tests {
172    use super::*;
173    use vyre_driver::specialization::SpecCacheKey;
174    use vyre_driver::speculate::SpeculativeVariantKind;
175
176    fn key(id: u64) -> SpecCacheKey {
177        SpecCacheKey {
178            shader_hash: id,
179            binding_sig: id << 8,
180            workgroup_size: [64, 1, 1],
181            spec_hash: id << 16,
182        }
183    }
184
185    fn record(workgroup: u32) -> AutotuneRecord {
186        AutotuneRecord {
187            workgroup_size: [workgroup, 1, 1],
188            unroll: 1,
189            tile: [0, 0, 0],
190            recorded_at: "2026-05-02".to_string(),
191        }
192    }
193
194    fn sample(conservative_ns: u64, speculative_ns: u64) -> PairedSpeculationSample {
195        PairedSpeculationSample {
196            conservative_dispatch_ns: conservative_ns,
197            speculative_dispatch_ns: speculative_ns,
198            conservative_compile_ns: 0,
199            speculative_compile_ns: 0,
200            conservative_record: record(64),
201            speculative_record: record(128),
202        }
203    }
204
205    #[test]
206    fn paired_window_keeps_racing_under_threshold() {
207        let mut store = AutotuneStore::default();
208        let conservative = key(1);
209        let speculative = key(2);
210        let keys = SpeculativeVariantKeys {
211            conservative: &conservative,
212            speculative: &speculative,
213            adapter_id: "test-adapter",
214        };
215        let mut window = PairedSpeculationWindow::new();
216        let update = window.record_sample(&mut store, keys, sample(100_000, 50_000));
217        assert_eq!(update.verdict, SpeculationVerdict::KeepRacing);
218        assert_eq!(update.observation.baseline_dispatches, 1);
219        assert_eq!(update.observation.speculative_dispatches, 1);
220    }
221
222    #[test]
223    fn paired_window_adopts_after_sustained_win() {
224        let mut store = AutotuneStore::default();
225        let conservative = key(3);
226        let speculative = key(4);
227        let keys = SpeculativeVariantKeys {
228            conservative: &conservative,
229            speculative: &speculative,
230            adapter_id: "test-adapter",
231        };
232        let mut window = PairedSpeculationWindow::new();
233        let mut last = None;
234        for _ in 0..8 {
235            last = Some(window.record_sample(&mut store, keys, sample(100_000, 50_000)));
236        }
237        let update = last.expect("Fix: loop records at least one sample");
238        assert_eq!(update.verdict, SpeculationVerdict::Adopt);
239        assert_eq!(
240            update.race_decision.winner,
241            SpeculativeVariantKind::Speculative
242        );
243        assert_eq!(store.len(), 1);
244    }
245
246    #[test]
247    fn paired_window_rejects_sustained_loss() {
248        let mut store = AutotuneStore::default();
249        let conservative = key(5);
250        let speculative = key(6);
251        let keys = SpeculativeVariantKeys {
252            conservative: &conservative,
253            speculative: &speculative,
254            adapter_id: "test-adapter",
255        };
256        let mut window = PairedSpeculationWindow::new();
257        let mut verdict = SpeculationVerdict::KeepRacing;
258        for _ in 0..8 {
259            verdict = window
260                .record_sample(&mut store, keys, sample(50_000, 100_000))
261                .verdict;
262        }
263        assert_eq!(verdict, SpeculationVerdict::Reject);
264    }
265
266    #[test]
267    fn paired_window_amortizes_speculative_compile_cost() {
268        let mut store = AutotuneStore::default();
269        let conservative = key(7);
270        let speculative = key(8);
271        let keys = SpeculativeVariantKeys {
272            conservative: &conservative,
273            speculative: &speculative,
274            adapter_id: "test-adapter",
275        };
276        let mut window = PairedSpeculationWindow::new();
277        let mut update = None;
278        for _ in 0..8 {
279            let mut s = sample(100_000, 50_000);
280            s.speculative_compile_ns = 1_000_000;
281            update = Some(window.record_sample(&mut store, keys, s));
282        }
283        let update = update.expect("Fix: loop records at least one sample");
284        assert_eq!(update.verdict, SpeculationVerdict::Reject);
285        assert_eq!(update.observation.side_compile_cost_ns, 8_000_000);
286    }
287}