vyre_driver/program_walks/
grid.rs1use vyre_foundation::ir::Program;
4
5use crate::backend::{BackendError, DispatchConfig};
6use crate::binding::BindingPlan;
7use crate::program_walks::dispatch_element_count_for_program;
8
9pub fn infer_dispatch_grid(
20 program: &Program,
21 inputs: &[Vec<u8>],
22 config: &DispatchConfig,
23) -> Result<[u32; 3], BackendError> {
24 if let Some(grid) = config.grid_override {
25 return Ok(grid);
26 }
27 let plan = BindingPlan::from_program(program, inputs)?;
28 let element_count = dispatch_element_count_for_program(program, &plan.bindings);
29 infer_dispatch_grid_for_count(
30 element_count,
31 config
32 .workgroup_override
33 .unwrap_or(program.workgroup_size()),
34 )
35}
36
37pub fn auto_grid(
47 program: &Program,
48 backend: &dyn crate::backend::VyreBackend,
49) -> Result<[u32; 3], BackendError> {
50 crate::validation::validate_program_for_backend(backend, program, &DispatchConfig::default())?;
51 let plan = BindingPlan::build(program)?;
52 let element_count = dispatch_element_count_for_program(program, &plan.bindings);
53
54 infer_dispatch_grid_for_count(element_count, program.workgroup_size())
55}
56
57pub fn infer_dispatch_grid_for_count(
68 element_count: u32,
69 workgroup: [u32; 3],
70) -> Result<[u32; 3], BackendError> {
71 if workgroup.contains(&0) {
72 return Err(BackendError::new(
73 "workgroup dimensions must be non-zero. Fix: set Program::workgroup_size and DispatchConfig::workgroup_override to positive values.",
74 ));
75 }
76 let count = u64::from(element_count.max(1));
77 if workgroup[1] == 1 && workgroup[2] == 1 {
78 return Ok([ceil_div_u64(count, u64::from(workgroup[0]))?, 1, 1]);
79 }
80 if workgroup[2] == 1 {
81 let side = ceil_sqrt_u64(count);
82 return Ok([
83 ceil_div_u64(side, u64::from(workgroup[0]))?,
84 ceil_div_u64(
85 u64::from(ceil_div_u64(count, side)?),
86 u64::from(workgroup[1]),
87 )?,
88 1,
89 ]);
90 }
91 let side = ceil_cuberoot_u64(count);
92 let xy = side.checked_mul(side).ok_or_else(|| {
93 BackendError::new(format!(
94 "3D dispatch-grid side {side} overflows u64 square during shape planning. Fix: split the Program before GPU launch planning."
95 ))
96 })?;
97 Ok([
98 ceil_div_u64(side, u64::from(workgroup[0]))?,
99 ceil_div_u64(side, u64::from(workgroup[1]))?,
100 ceil_div_u64(u64::from(ceil_div_u64(count, xy)?), u64::from(workgroup[2]))?,
101 ])
102}
103
104fn ceil_div_u64(value: u64, divisor: u64) -> Result<u32, BackendError> {
105 let divided = value.div_ceil(divisor).max(1);
106 u32::try_from(divided).map_err(|_| {
107 BackendError::new(
108 "inferred dispatch grid dimension overflowed u32. Fix: split the Program into smaller dispatches.",
109 )
110 })
111}
112
113fn ceil_sqrt_u64(value: u64) -> u64 {
114 if value <= 1 {
115 return 1;
116 }
117 let mut lo = 1_u64;
118 let mut hi = 1_u64 << 32;
119 while lo < hi {
120 let mid = lo + ((hi - lo) / 2);
121 match mid.checked_mul(mid) {
122 Some(square) if square < value => lo = mid + 1,
123 _ => hi = mid,
124 }
125 }
126 lo
127}
128
129fn ceil_cuberoot_u64(value: u64) -> u64 {
130 if value <= 1 {
131 return 1;
132 }
133 let mut lo = 1_u64;
134 let mut hi = 1_u64 << 22;
135 while lo < hi {
136 let mid = lo + ((hi - lo) / 2);
137 match checked_cube_u64(mid) {
138 Some(cube) if cube < value => lo = mid + 1,
139 _ => hi = mid,
140 }
141 }
142 lo
143}
144
145fn checked_cube_u64(value: u64) -> Option<u64> {
146 value.checked_mul(value)?.checked_mul(value)
147}
148
149#[derive(Debug, Clone, Copy, PartialEq, Eq)]
167pub struct TailMaskPolicy {
168 pub original_count: u32,
170 pub rounded_count: u32,
173 pub tail_lanes: u32,
176}
177
178impl TailMaskPolicy {
179 #[must_use]
182 pub fn is_aligned(&self) -> bool {
183 self.tail_lanes == 0
184 }
185}
186
187#[must_use]
199pub fn coerce_to_pow2_with_tail_mask(element_count: u32) -> TailMaskPolicy {
200 match try_coerce_to_pow2_with_tail_mask(element_count) {
201 Ok(policy) => policy,
202 Err(_error) => TailMaskPolicy {
203 original_count: element_count,
204 rounded_count: element_count,
205 tail_lanes: 0,
206 },
207 }
208}
209
210pub fn try_coerce_to_pow2_with_tail_mask(
215 element_count: u32,
216) -> Result<TailMaskPolicy, BackendError> {
217 if element_count == 0 {
218 return Ok(TailMaskPolicy {
219 original_count: 0,
220 rounded_count: 0,
221 tail_lanes: 0,
222 });
223 }
224 let rounded = next_pow2_u32_checked(element_count)?;
225 Ok(TailMaskPolicy {
226 original_count: element_count,
227 rounded_count: rounded,
228 tail_lanes: rounded - element_count,
229 })
230}
231
232fn next_pow2_u32_checked(value: u32) -> Result<u32, BackendError> {
233 if value.is_power_of_two() {
234 return Ok(value);
235 }
236 if value > (1u32 << 31) {
237 return Err(BackendError::new(format!(
238 "cannot round element_count={value} up to a power-of-two u32 grid without overflow. Fix: split the workload before grid-shape planning; do not silently saturate or fall back to an under-dispatching shape."
239 )));
240 }
241 Ok(value.next_power_of_two())
242}
243
244#[cfg(test)]
245mod n6_tests {
246 use super::*;
247
248 #[test]
249 fn already_pow2_is_identity_with_no_tail() {
250 let p = coerce_to_pow2_with_tail_mask(64);
251 assert_eq!(p.original_count, 64);
252 assert_eq!(p.rounded_count, 64);
253 assert_eq!(p.tail_lanes, 0);
254 assert!(p.is_aligned());
255 }
256
257 #[test]
258 fn non_pow2_rounds_up_and_reports_tail() {
259 let p = coerce_to_pow2_with_tail_mask(100);
260 assert_eq!(p.original_count, 100);
261 assert_eq!(p.rounded_count, 128);
262 assert_eq!(p.tail_lanes, 28);
263 assert!(!p.is_aligned());
264 }
265
266 #[test]
267 fn one_is_pow2_no_tail() {
268 let p = coerce_to_pow2_with_tail_mask(1);
269 assert_eq!(p.rounded_count, 1);
270 assert_eq!(p.tail_lanes, 0);
271 }
272
273 #[test]
274 fn zero_passes_through_with_no_tail() {
275 let p = coerce_to_pow2_with_tail_mask(0);
276 assert_eq!(p.rounded_count, 0);
277 assert_eq!(p.tail_lanes, 0);
278 assert!(p.is_aligned());
279 }
280
281 #[test]
282 fn large_value_below_2_31_rounds_normally() {
283 let p = coerce_to_pow2_with_tail_mask(1_000_000_000);
284 assert_eq!(p.rounded_count, 1u32 << 30);
286 assert_eq!(p.tail_lanes, (1u32 << 30) - 1_000_000_000);
287 }
288
289 #[test]
290 fn value_above_2_31_errors_instead_of_saturating() {
291 let error = try_coerce_to_pow2_with_tail_mask(u32::MAX)
292 .expect_err("oversized power-of-two coercion must fail loudly");
293 let message = error.to_string();
294 assert!(
295 message.contains("Fix:"),
296 "oversized grid-shape error must be actionable"
297 );
298 }
299
300 #[test]
301 fn root_helpers_are_exact_at_large_boundaries() {
302 assert_eq!(ceil_sqrt_u64((1_u64 << 32) - 1), 65_536);
303 assert_eq!(ceil_sqrt_u64(1_u64 << 32), 65_536);
304 assert_eq!(ceil_cuberoot_u64(2_642_245_u64.pow(3)), 2_642_245);
305 assert_eq!(ceil_cuberoot_u64(2_642_245_u64.pow(3) - 1), 2_642_245);
306 }
307
308 #[test]
309 fn dispatch_grid_planning_uses_integer_roots_and_typed_errors() {
310 let source = include_str!("grid.rs");
311 let production = source
312 .split("#[cfg(test)]")
313 .next()
314 .expect("Fix: dispatch-grid production source must precede tests");
315
316 assert!(
317 !production.contains(" as f64")
318 && !production.contains(".sqrt()")
319 && !production.contains(".cbrt()"),
320 "Fix: dispatch-grid inference must use deterministic integer root arithmetic."
321 );
322 assert!(
323 production.contains("try_coerce_to_pow2_with_tail_mask")
324 && !production.contains("panic!("),
325 "Fix: dispatch-grid planning should expose typed errors instead of production panics."
326 );
327 }
328}