Skip to main content

vyre_driver/
dispatch_shape.rs

1//! Backend-neutral dispatch shape comparison helpers.
2//!
3//! CUDA graph replay, pipeline cache reuse, and future backend command replay
4//! all need the same answer to two questions: do these borrowed input batches
5//! have the same byte shape, and does a runtime [`DispatchConfig`] preserve the
6//! launch-relevant shape captured at compile time?
7
8use crate::{fixpoint_iterations::resolve_fixpoint_iterations, DispatchConfig};
9
10/// Return true when two borrowed input lists have the same arity and per-input
11/// byte lengths.
12#[must_use]
13pub fn borrowed_input_shapes_match(left: &[&[u8]], right: &[&[u8]]) -> bool {
14    left.len() == right.len()
15        && left
16            .iter()
17            .zip(right.iter())
18            .all(|(left, right)| left.len() == right.len())
19}
20
21/// Return true when every borrowed-input batch item has the same shape as the
22/// first item.
23#[must_use]
24pub fn borrowed_input_batch_shapes_match(batches: &[&[&[u8]]]) -> bool {
25    let Some((first, rest)) = batches.split_first() else {
26        return true;
27    };
28    rest.iter()
29        .all(|batch| borrowed_input_shapes_match(first, batch))
30}
31
32/// Return true when a runtime dispatch config preserves a compiled launch
33/// shape.
34#[must_use]
35pub fn dispatch_configs_share_launch_shape(
36    compiled: &DispatchConfig,
37    runtime: &DispatchConfig,
38) -> bool {
39    compiled.profile == runtime.profile
40        && ulp_budgets_share_launch_shape(compiled, runtime)
41        && compiled.max_output_bytes == runtime.max_output_bytes
42        && compiled.workgroup_override == runtime.workgroup_override
43        && compiled.grid_override == runtime.grid_override
44        && fixpoint_iterations_share_launch_shape(compiled, runtime)
45        && compiled.speculation == runtime.speculation
46        && compiled.persistent_thread == runtime.persistent_thread
47        && compiled.cooperative == runtime.cooperative
48}
49
50fn fixpoint_iterations_share_launch_shape(
51    compiled: &DispatchConfig,
52    runtime: &DispatchConfig,
53) -> bool {
54    let Ok(compiled_iterations) = resolve_fixpoint_iterations(compiled, "dispatch-shape") else {
55        return false;
56    };
57    let Ok(runtime_iterations) = resolve_fixpoint_iterations(runtime, "dispatch-shape") else {
58        return false;
59    };
60    compiled_iterations == runtime_iterations
61}
62
63fn ulp_budgets_share_launch_shape(compiled: &DispatchConfig, runtime: &DispatchConfig) -> bool {
64    compiled.ulp_budget.unwrap_or(0) == runtime.ulp_budget.unwrap_or(0)
65}
66
67#[cfg(test)]
68mod tests {
69    use super::{
70        borrowed_input_batch_shapes_match, borrowed_input_shapes_match,
71        dispatch_configs_share_launch_shape,
72    };
73    use crate::DispatchConfig;
74
75    #[test]
76    fn borrowed_input_shapes_compare_arity_and_lengths_only() {
77        let a0 = [1_u8, 2, 3];
78        let a1 = [4_u8];
79        let b0 = [9_u8, 8, 7];
80        let b1 = [6_u8];
81        let c = [5_u8, 4];
82
83        assert!(borrowed_input_shapes_match(&[&a0, &a1], &[&b0, &b1]));
84        assert!(!borrowed_input_shapes_match(&[&a0, &a1], &[&b0]));
85        assert!(!borrowed_input_shapes_match(&[&a0, &a1], &[&b0, &c]));
86    }
87
88    #[test]
89    fn borrowed_input_batch_shapes_accept_empty_and_uniform_batches() {
90        let a0 = [1_u8, 2];
91        let a1 = [3_u8, 4, 5];
92        let b0 = [9_u8, 8];
93        let b1 = [7_u8, 6, 5];
94        let c1 = [0_u8];
95
96        assert!(borrowed_input_batch_shapes_match(&[]));
97        assert!(borrowed_input_batch_shapes_match(&[
98            &[&a0, &a1],
99            &[&b0, &b1]
100        ]));
101        assert!(!borrowed_input_batch_shapes_match(&[
102            &[&a0, &a1],
103            &[&b0, &c1]
104        ]));
105    }
106
107    #[test]
108    fn dispatch_config_launch_shape_ignores_timeout_but_tracks_launch_fields() {
109        let base = DispatchConfig::default();
110        let mut timeout_only = base.clone();
111        timeout_only.timeout = Some(std::time::Duration::from_millis(1));
112        assert!(dispatch_configs_share_launch_shape(&base, &timeout_only));
113
114        let mut changed_grid = base.clone();
115        changed_grid.grid_override = Some([2, 1, 1]);
116        assert!(!dispatch_configs_share_launch_shape(&base, &changed_grid));
117
118        let mut changed_fixpoint = base.clone();
119        changed_fixpoint.fixpoint_iterations = Some(2);
120        assert!(!dispatch_configs_share_launch_shape(
121            &base,
122            &changed_fixpoint
123        ));
124    }
125
126    #[test]
127    fn dispatch_config_launch_shape_canonicalizes_default_fixpoint_iteration() {
128        let base = DispatchConfig::default();
129        let mut explicit_one = base.clone();
130        explicit_one.fixpoint_iterations = Some(1);
131
132        assert!(
133            dispatch_configs_share_launch_shape(&base, &explicit_one),
134            "Fix: compiled pipelines must not miss cache/replay fast paths when runtime policy spells the default fixpoint iteration count explicitly."
135        );
136    }
137
138    #[test]
139    fn dispatch_config_launch_shape_rejects_invalid_zero_fixpoint_iteration() {
140        let base = DispatchConfig::default();
141        let mut explicit_zero = base.clone();
142        explicit_zero.fixpoint_iterations = Some(0);
143
144        assert!(
145            !dispatch_configs_share_launch_shape(&base, &explicit_zero),
146            "Fix: explicit zero fixpoint iterations are invalid policy and must not be treated as a compatible launch shape."
147        );
148    }
149
150    #[test]
151    fn dispatch_config_launch_shape_canonicalizes_strict_ulp_budget() {
152        let base = DispatchConfig::default();
153        let mut explicit_strict = base.clone();
154        explicit_strict.ulp_budget = Some(0);
155
156        assert!(
157            dispatch_configs_share_launch_shape(&base, &explicit_strict),
158            "Fix: strict ULP defaults should not force duplicate compiled dispatch shapes."
159        );
160    }
161
162    #[test]
163    fn dispatch_config_launch_shape_separates_relaxed_ulp_budget() {
164        let base = DispatchConfig::default();
165        let mut relaxed = base.clone();
166        relaxed.ulp_budget = Some(1);
167
168        assert!(
169            !dispatch_configs_share_launch_shape(&base, &relaxed),
170            "Fix: relaxed ULP budgets change target intrinsic policy and need distinct dispatch shapes."
171        );
172    }
173}