vyre_driver/
dispatch_shape.rs1use crate::{fixpoint_iterations::resolve_fixpoint_iterations, DispatchConfig};
9
10#[must_use]
13pub fn borrowed_input_shapes_match(left: &[&[u8]], right: &[&[u8]]) -> bool {
14 left.len() == right.len()
15 && left
16 .iter()
17 .zip(right.iter())
18 .all(|(left, right)| left.len() == right.len())
19}
20
21#[must_use]
24pub fn borrowed_input_batch_shapes_match(batches: &[&[&[u8]]]) -> bool {
25 let Some((first, rest)) = batches.split_first() else {
26 return true;
27 };
28 rest.iter()
29 .all(|batch| borrowed_input_shapes_match(first, batch))
30}
31
32#[must_use]
35pub fn dispatch_configs_share_launch_shape(
36 compiled: &DispatchConfig,
37 runtime: &DispatchConfig,
38) -> bool {
39 compiled.profile == runtime.profile
40 && ulp_budgets_share_launch_shape(compiled, runtime)
41 && compiled.max_output_bytes == runtime.max_output_bytes
42 && compiled.workgroup_override == runtime.workgroup_override
43 && compiled.grid_override == runtime.grid_override
44 && fixpoint_iterations_share_launch_shape(compiled, runtime)
45 && compiled.speculation == runtime.speculation
46 && compiled.persistent_thread == runtime.persistent_thread
47 && compiled.cooperative == runtime.cooperative
48}
49
50fn fixpoint_iterations_share_launch_shape(
51 compiled: &DispatchConfig,
52 runtime: &DispatchConfig,
53) -> bool {
54 let Ok(compiled_iterations) = resolve_fixpoint_iterations(compiled, "dispatch-shape") else {
55 return false;
56 };
57 let Ok(runtime_iterations) = resolve_fixpoint_iterations(runtime, "dispatch-shape") else {
58 return false;
59 };
60 compiled_iterations == runtime_iterations
61}
62
63fn ulp_budgets_share_launch_shape(compiled: &DispatchConfig, runtime: &DispatchConfig) -> bool {
64 compiled.ulp_budget.unwrap_or(0) == runtime.ulp_budget.unwrap_or(0)
65}
66
67#[cfg(test)]
68mod tests {
69 use super::{
70 borrowed_input_batch_shapes_match, borrowed_input_shapes_match,
71 dispatch_configs_share_launch_shape,
72 };
73 use crate::DispatchConfig;
74
75 #[test]
76 fn borrowed_input_shapes_compare_arity_and_lengths_only() {
77 let a0 = [1_u8, 2, 3];
78 let a1 = [4_u8];
79 let b0 = [9_u8, 8, 7];
80 let b1 = [6_u8];
81 let c = [5_u8, 4];
82
83 assert!(borrowed_input_shapes_match(&[&a0, &a1], &[&b0, &b1]));
84 assert!(!borrowed_input_shapes_match(&[&a0, &a1], &[&b0]));
85 assert!(!borrowed_input_shapes_match(&[&a0, &a1], &[&b0, &c]));
86 }
87
88 #[test]
89 fn borrowed_input_batch_shapes_accept_empty_and_uniform_batches() {
90 let a0 = [1_u8, 2];
91 let a1 = [3_u8, 4, 5];
92 let b0 = [9_u8, 8];
93 let b1 = [7_u8, 6, 5];
94 let c1 = [0_u8];
95
96 assert!(borrowed_input_batch_shapes_match(&[]));
97 assert!(borrowed_input_batch_shapes_match(&[
98 &[&a0, &a1],
99 &[&b0, &b1]
100 ]));
101 assert!(!borrowed_input_batch_shapes_match(&[
102 &[&a0, &a1],
103 &[&b0, &c1]
104 ]));
105 }
106
107 #[test]
108 fn dispatch_config_launch_shape_ignores_timeout_but_tracks_launch_fields() {
109 let base = DispatchConfig::default();
110 let mut timeout_only = base.clone();
111 timeout_only.timeout = Some(std::time::Duration::from_millis(1));
112 assert!(dispatch_configs_share_launch_shape(&base, &timeout_only));
113
114 let mut changed_grid = base.clone();
115 changed_grid.grid_override = Some([2, 1, 1]);
116 assert!(!dispatch_configs_share_launch_shape(&base, &changed_grid));
117
118 let mut changed_fixpoint = base.clone();
119 changed_fixpoint.fixpoint_iterations = Some(2);
120 assert!(!dispatch_configs_share_launch_shape(
121 &base,
122 &changed_fixpoint
123 ));
124 }
125
126 #[test]
127 fn dispatch_config_launch_shape_canonicalizes_default_fixpoint_iteration() {
128 let base = DispatchConfig::default();
129 let mut explicit_one = base.clone();
130 explicit_one.fixpoint_iterations = Some(1);
131
132 assert!(
133 dispatch_configs_share_launch_shape(&base, &explicit_one),
134 "Fix: compiled pipelines must not miss cache/replay fast paths when runtime policy spells the default fixpoint iteration count explicitly."
135 );
136 }
137
138 #[test]
139 fn dispatch_config_launch_shape_rejects_invalid_zero_fixpoint_iteration() {
140 let base = DispatchConfig::default();
141 let mut explicit_zero = base.clone();
142 explicit_zero.fixpoint_iterations = Some(0);
143
144 assert!(
145 !dispatch_configs_share_launch_shape(&base, &explicit_zero),
146 "Fix: explicit zero fixpoint iterations are invalid policy and must not be treated as a compatible launch shape."
147 );
148 }
149
150 #[test]
151 fn dispatch_config_launch_shape_canonicalizes_strict_ulp_budget() {
152 let base = DispatchConfig::default();
153 let mut explicit_strict = base.clone();
154 explicit_strict.ulp_budget = Some(0);
155
156 assert!(
157 dispatch_configs_share_launch_shape(&base, &explicit_strict),
158 "Fix: strict ULP defaults should not force duplicate compiled dispatch shapes."
159 );
160 }
161
162 #[test]
163 fn dispatch_config_launch_shape_separates_relaxed_ulp_budget() {
164 let base = DispatchConfig::default();
165 let mut relaxed = base.clone();
166 relaxed.ulp_budget = Some(1);
167
168 assert!(
169 !dispatch_configs_share_launch_shape(&base, &relaxed),
170 "Fix: relaxed ULP budgets change target intrinsic policy and need distinct dispatch shapes."
171 );
172 }
173}