Skip to main content

vyre_driver/
command_reuse_policy.rs

1//! D4 substrate: pre-recorded command reuse policy.
2//!
3//! When the same dispatch shape repeats (same Program, same binding
4//! handles, same workgroup, same workload count), backends can record
5//! the launch sequence once and replay it through their native command
6//! reuse primitive. This eliminates per-launch driver API overhead.
7//!
8//! Pure decision: given a dispatch repetition count and the measured
9//! per-launch overhead vs command-record overhead, should the
10//! dispatcher record-and-replay or just launch normally?
11//!
12//! This sits next to D1 (persistent kernels). Persistent mode wins
13//! for unpredictable batches of small kernels; command reuse wins for
14//! REPEATED dispatches of the same shape.
15
16/// Inputs to the command-reuse decision.
17#[derive(Debug, Clone, Copy, PartialEq, Eq)]
18pub struct CommandReuseInputs {
19    /// Number of times this exact dispatch shape will be repeated
20    /// (the same Program + bindings + workload count).
21    pub repeat_count: u32,
22    /// Per-launch driver API overhead in nanoseconds. Same number
23    /// the persistent-kernel policy uses.
24    pub per_launch_overhead_ns: u64,
25    /// One-time cost of recording the native command sequence.
26    pub record_overhead_ns: u64,
27    /// Per-replay cost of the native command-reuse primitive.
28    pub replay_overhead_ns: u64,
29}
30
31/// Verdict from [`decide_command_reuse`].
32#[derive(Debug, Clone, Copy, PartialEq, Eq)]
33pub enum CommandReuseDecision {
34    /// Use plain dispatch  -  repeat count too low to amortise the
35    /// command-record cost.
36    PlainLaunches,
37    /// Record once, replay `repeat_count - 1` more times. Includes
38    /// the predicted savings vs plain launches for telemetry.
39    RecordAndReplay {
40        /// Predicted total time saved (in nanoseconds) vs plain
41        /// launches. Positive by construction.
42        savings_ns: u128,
43    },
44}
45
46/// Decide whether to record a command sequence once and replay it for
47/// the remaining `repeat_count - 1` dispatches.
48///
49/// Plain cost:    `repeat * per_launch_ovh`
50/// Reuse cost:    `record_ovh + repeat * replay_ovh`
51/// Reuse wins iff `repeat * (per_launch_ovh - replay_ovh) > record_ovh`.
52#[must_use]
53pub fn decide_command_reuse(inputs: CommandReuseInputs) -> CommandReuseDecision {
54    if inputs.repeat_count <= 1 {
55        return CommandReuseDecision::PlainLaunches;
56    }
57    if inputs.per_launch_overhead_ns <= inputs.replay_overhead_ns {
58        // Replay is not actually cheaper than plain launch.
59        // recording costs us bytes for nothing.
60        return CommandReuseDecision::PlainLaunches;
61    }
62    let per_call_savings =
63        u128::from(inputs.per_launch_overhead_ns) - u128::from(inputs.replay_overhead_ns);
64    let total_call_savings = u128::from(inputs.repeat_count) * per_call_savings;
65    let record_overhead_ns = u128::from(inputs.record_overhead_ns);
66    if total_call_savings <= record_overhead_ns {
67        return CommandReuseDecision::PlainLaunches;
68    }
69    let savings_ns = total_call_savings - record_overhead_ns;
70    CommandReuseDecision::RecordAndReplay { savings_ns }
71}
72
73#[cfg(test)]
74mod tests {
75    use super::*;
76
77    fn inp(rep: u32, launch: u64, record: u64, replay: u64) -> CommandReuseInputs {
78        CommandReuseInputs {
79            repeat_count: rep,
80            per_launch_overhead_ns: launch,
81            record_overhead_ns: record,
82            replay_overhead_ns: replay,
83        }
84    }
85
86    #[test]
87    fn single_dispatch_is_plain() {
88        // No repetition → recording wastes work.
89        assert_eq!(
90            decide_command_reuse(inp(1, 5_000, 25_000, 500)),
91            CommandReuseDecision::PlainLaunches
92        );
93    }
94
95    #[test]
96    fn zero_repeat_is_plain() {
97        assert_eq!(
98            decide_command_reuse(inp(0, 5_000, 25_000, 500)),
99            CommandReuseDecision::PlainLaunches
100        );
101    }
102
103    #[test]
104    fn replay_no_cheaper_than_launch_is_plain() {
105        // Graph replay = per-launch overhead → no savings possible.
106        assert_eq!(
107            decide_command_reuse(inp(1000, 5_000, 25_000, 5_000)),
108            CommandReuseDecision::PlainLaunches
109        );
110    }
111
112    #[test]
113    fn small_repeat_under_amortisation_is_plain() {
114        // 5 repeats × (5000 - 500) savings = 22_500; record costs 25_000.
115        assert_eq!(
116            decide_command_reuse(inp(5, 5_000, 25_000, 500)),
117            CommandReuseDecision::PlainLaunches
118        );
119    }
120
121    #[test]
122    fn large_repeat_above_amortisation_picks_record_and_replay() {
123        // 100 repeats × 4_500 savings = 450_000; record 25_000.
124        // Net savings = 425_000.
125        assert_eq!(
126            decide_command_reuse(inp(100, 5_000, 25_000, 500)),
127            CommandReuseDecision::RecordAndReplay {
128                savings_ns: 425_000
129            }
130        );
131    }
132
133    #[test]
134    fn savings_strictly_positive_when_record_and_replay() {
135        let dec = decide_command_reuse(inp(1000, 5_000, 25_000, 500));
136        match dec {
137            CommandReuseDecision::RecordAndReplay { savings_ns } => assert!(savings_ns > 0),
138            other => panic!("expected RecordAndReplay; got {:?}", other),
139        }
140    }
141
142    #[test]
143    fn widened_arithmetic_preserves_extreme_savings() {
144        // u32::MAX repeats × u64-near-max savings shouldn't panic.
145        let dec = decide_command_reuse(inp(u32::MAX, u64::MAX / 2, 25_000, 1));
146        match dec {
147            CommandReuseDecision::RecordAndReplay { savings_ns } => {
148                assert_eq!(
149                    savings_ns,
150                    u128::from(u32::MAX) * (u128::from(u64::MAX / 2) - 1) - 25_000
151                );
152            }
153            other => panic!("expected RecordAndReplay; got {:?}", other),
154        }
155    }
156
157    #[test]
158    fn command_reuse_policy_source_uses_exact_widened_arithmetic() {
159        let source = include_str!("command_reuse_policy.rs");
160
161        assert!(
162            !source.contains(concat!("saturating", "_mul"))
163                && !source.contains(concat!("saturating", "_sub")),
164            "Fix: command-reuse policy must use exact widened arithmetic, not saturating replay-cost math."
165        );
166        assert!(
167            source.contains("u128::from(inputs.per_launch_overhead_ns)")
168                && source.contains("u128::from(inputs.repeat_count)")
169                && source.contains("total_call_savings - record_overhead_ns"),
170            "Fix: command-reuse savings must stay widened through the verdict."
171        );
172    }
173}