Skip to main content

vyre_runtime/megakernel/
speculation.rs

1//! Runtime-side paired speculation races for megakernel dispatch.
2//!
3//! The driver crate owns the backend-neutral decision math. This module
4//! owns the megakernel/runtime bridge: every candidate rewrite is measured
5//! as a conservative/speculative pair, the faster side is recorded in the
6//! shared autotune store, and the accumulated sample window is converted
7//! into the N2 adoption verdict.
8
9use vyre_driver::autotune_store::{AutotuneRecord, AutotuneStore};
10use vyre_driver::speculate::{
11    record_speculative_variant_race, SpeculativeVariantDecision, SpeculativeVariantKeys,
12    SpeculativeVariantRace,
13};
14use vyre_driver::speculation_substrate::{
15    decide_speculation, SpeculationObservation, SpeculationVerdict,
16};
17
18/// One measured conservative/speculative dispatch pair.
19#[derive(Debug, Clone)]
20pub struct PairedSpeculationSample {
21    /// Conservative dispatch elapsed time, excluding compile/cache miss.
22    pub conservative_dispatch_ns: u64,
23    /// Speculative dispatch elapsed time, excluding compile/cache miss.
24    pub speculative_dispatch_ns: u64,
25    /// Conservative compile/cache-miss time for this pair.
26    pub conservative_compile_ns: u64,
27    /// Speculative compile/cache-miss time for this pair.
28    pub speculative_compile_ns: u64,
29    /// Autotune record attached to the conservative variant.
30    pub conservative_record: AutotuneRecord,
31    /// Autotune record attached to the speculative variant.
32    pub speculative_record: AutotuneRecord,
33}
34
35/// Result of recording one paired race.
36#[derive(Debug, Clone)]
37pub struct PairedSpeculationUpdate {
38    /// Winning per-sample cache/autotune decision.
39    pub race_decision: SpeculativeVariantDecision,
40    /// Accumulated N2 verdict for the shape.
41    pub verdict: SpeculationVerdict,
42    /// Observation fed into the verdict.
43    pub observation: SpeculationObservation,
44}
45
46/// Accumulated paired-race window for one rewrite candidate and shape.
47#[derive(Debug, Default, Clone)]
48pub struct PairedSpeculationWindow {
49    conservative: RunningMean,
50    speculative: RunningMean,
51    side_compile_cost_ns: u64,
52}
53
54impl PairedSpeculationWindow {
55    /// Empty paired-race window.
56    #[must_use]
57    pub const fn new() -> Self {
58        Self {
59            conservative: RunningMean::new(),
60            speculative: RunningMean::new(),
61            side_compile_cost_ns: 0,
62        }
63    }
64
65    /// Number of paired samples recorded.
66    #[must_use]
67    pub fn len(&self) -> u32 {
68        self.conservative.count.min(self.speculative.count)
69    }
70
71    /// True when no paired samples were recorded.
72    #[must_use]
73    pub fn is_empty(&self) -> bool {
74        self.len() == 0
75    }
76
77    /// Current observation for the N2 speculation policy.
78    #[must_use]
79    pub fn observation(&self) -> SpeculationObservation {
80        SpeculationObservation {
81            baseline_dispatches: self.conservative.count,
82            baseline_mean_ns: self.conservative.mean_ns(),
83            speculative_dispatches: self.speculative.count,
84            speculative_mean_ns: self.speculative.mean_ns(),
85            side_compile_cost_ns: self.side_compile_cost_ns,
86        }
87    }
88
89    /// Record one paired sample, update the autotune store with the
90    /// per-sample winner, and return the accumulated adoption verdict.
91    pub fn record_sample(
92        &mut self,
93        store: &mut AutotuneStore,
94        keys: SpeculativeVariantKeys<'_>,
95        sample: PairedSpeculationSample,
96    ) -> PairedSpeculationUpdate {
97        self.conservative.record(sample.conservative_dispatch_ns);
98        self.speculative.record(sample.speculative_dispatch_ns);
99        self.side_compile_cost_ns = self
100            .side_compile_cost_ns
101            .saturating_add(sample.speculative_compile_ns);
102
103        let race_decision = record_speculative_variant_race(
104            store,
105            keys,
106            SpeculativeVariantRace {
107                conservative_dispatch_ns: sample.conservative_dispatch_ns,
108                speculative_dispatch_ns: sample.speculative_dispatch_ns,
109                conservative_compile_ns: sample.conservative_compile_ns,
110                speculative_compile_ns: sample.speculative_compile_ns,
111                conservative_record: sample.conservative_record,
112                speculative_record: sample.speculative_record,
113            },
114        );
115        let observation = self.observation();
116        let verdict = decide_speculation(observation);
117        PairedSpeculationUpdate {
118            race_decision,
119            verdict,
120            observation,
121        }
122    }
123}
124
125#[derive(Debug, Default, Clone)]
126struct RunningMean {
127    count: u32,
128    total_ns: u128,
129}
130
131impl RunningMean {
132    const fn new() -> Self {
133        Self {
134            count: 0,
135            total_ns: 0,
136        }
137    }
138
139    fn record(&mut self, value_ns: u64) {
140        self.count = self.count.saturating_add(1);
141        self.total_ns = self.total_ns.saturating_add(u128::from(value_ns));
142    }
143
144    fn mean_ns(&self) -> u64 {
145        if self.count == 0 {
146            return 0;
147        }
148        let mean = self.total_ns / u128::from(self.count);
149        match u64::try_from(mean) {
150            Ok(mean) => mean,
151            Err(_) => u64::MAX,
152        }
153    }
154}
155
156#[cfg(test)]
157mod tests {
158    use super::*;
159    use vyre_driver::specialization::SpecCacheKey;
160    use vyre_driver::speculate::SpeculativeVariantKind;
161
162    fn key(id: u64) -> SpecCacheKey {
163        SpecCacheKey {
164            shader_hash: id,
165            binding_sig: id << 8,
166            workgroup_size: [64, 1, 1],
167            spec_hash: id << 16,
168        }
169    }
170
171    fn record(workgroup: u32) -> AutotuneRecord {
172        AutotuneRecord {
173            workgroup_size: [workgroup, 1, 1],
174            unroll: 1,
175            tile: [0, 0, 0],
176            recorded_at: "2026-05-02".to_string(),
177        }
178    }
179
180    fn sample(conservative_ns: u64, speculative_ns: u64) -> PairedSpeculationSample {
181        PairedSpeculationSample {
182            conservative_dispatch_ns: conservative_ns,
183            speculative_dispatch_ns: speculative_ns,
184            conservative_compile_ns: 0,
185            speculative_compile_ns: 0,
186            conservative_record: record(64),
187            speculative_record: record(128),
188        }
189    }
190
191    #[test]
192    fn paired_window_keeps_racing_under_threshold() {
193        let mut store = AutotuneStore::default();
194        let conservative = key(1);
195        let speculative = key(2);
196        let keys = SpeculativeVariantKeys {
197            conservative: &conservative,
198            speculative: &speculative,
199            adapter_id: "test-adapter",
200        };
201        let mut window = PairedSpeculationWindow::new();
202        let update = window.record_sample(&mut store, keys, sample(100_000, 50_000));
203        assert_eq!(update.verdict, SpeculationVerdict::KeepRacing);
204        assert_eq!(update.observation.baseline_dispatches, 1);
205        assert_eq!(update.observation.speculative_dispatches, 1);
206    }
207
208    #[test]
209    fn paired_window_adopts_after_sustained_win() {
210        let mut store = AutotuneStore::default();
211        let conservative = key(3);
212        let speculative = key(4);
213        let keys = SpeculativeVariantKeys {
214            conservative: &conservative,
215            speculative: &speculative,
216            adapter_id: "test-adapter",
217        };
218        let mut window = PairedSpeculationWindow::new();
219        let mut last = None;
220        for _ in 0..8 {
221            last = Some(window.record_sample(&mut store, keys, sample(100_000, 50_000)));
222        }
223        let update = last.expect("Fix: loop records at least one sample");
224        assert_eq!(update.verdict, SpeculationVerdict::Adopt);
225        assert_eq!(
226            update.race_decision.winner,
227            SpeculativeVariantKind::Speculative
228        );
229        assert_eq!(store.len(), 1);
230    }
231
232    #[test]
233    fn paired_window_rejects_sustained_loss() {
234        let mut store = AutotuneStore::default();
235        let conservative = key(5);
236        let speculative = key(6);
237        let keys = SpeculativeVariantKeys {
238            conservative: &conservative,
239            speculative: &speculative,
240            adapter_id: "test-adapter",
241        };
242        let mut window = PairedSpeculationWindow::new();
243        let mut verdict = SpeculationVerdict::KeepRacing;
244        for _ in 0..8 {
245            verdict = window
246                .record_sample(&mut store, keys, sample(50_000, 100_000))
247                .verdict;
248        }
249        assert_eq!(verdict, SpeculationVerdict::Reject);
250    }
251
252    #[test]
253    fn paired_window_amortizes_speculative_compile_cost() {
254        let mut store = AutotuneStore::default();
255        let conservative = key(7);
256        let speculative = key(8);
257        let keys = SpeculativeVariantKeys {
258            conservative: &conservative,
259            speculative: &speculative,
260            adapter_id: "test-adapter",
261        };
262        let mut window = PairedSpeculationWindow::new();
263        let mut update = None;
264        for _ in 0..8 {
265            let mut s = sample(100_000, 50_000);
266            s.speculative_compile_ns = 1_000_000;
267            update = Some(window.record_sample(&mut store, keys, s));
268        }
269        let update = update.expect("Fix: loop records at least one sample");
270        assert_eq!(update.verdict, SpeculationVerdict::Reject);
271        assert_eq!(update.observation.side_compile_cost_ns, 8_000_000);
272    }
273}