Skip to main content

vyre_driver/program_walks/
grid.rs

1//! Backend-neutral dispatch-grid inference.
2
3use vyre_foundation::ir::Program;
4
5use crate::backend::{BackendError, DispatchConfig};
6use crate::binding::BindingPlan;
7use crate::program_walks::dispatch_element_count_for_program;
8
9/// Infer a concrete workgroup grid from a program ABI and dispatch inputs.
10///
11/// Explicit [`DispatchConfig::grid_override`] always wins. Otherwise this uses
12/// the largest non-shared binding element count as the logical lane count and
13/// derives a deterministic 1D/2D/3D grid from the effective workgroup shape.
14///
15/// # Errors
16///
17/// Returns when the program/input ABI cannot be planned or when inferred grid
18/// dimensions overflow `u32`.
19pub fn infer_dispatch_grid(
20    program: &Program,
21    inputs: &[Vec<u8>],
22    config: &DispatchConfig,
23) -> Result<[u32; 3], BackendError> {
24    if let Some(grid) = config.grid_override {
25        return Ok(grid);
26    }
27    let plan = BindingPlan::from_program(program, inputs)?;
28    let element_count = dispatch_element_count_for_program(program, &plan.bindings);
29    infer_dispatch_grid_for_count(
30        element_count,
31        config
32            .workgroup_override
33            .unwrap_or(program.workgroup_size()),
34    )
35}
36
37/// Infer a grid size for a program based on its largest statically-known
38/// non-shared binding and its workgroup size.
39///
40/// Bench cases and backends can use this when no explicit grid_override is provided.
41///
42/// # Errors
43///
44/// Returns when the program ABI cannot be planned or if inferred dimensions
45/// overflow `u32`.
46pub fn auto_grid(
47    program: &Program,
48    backend: &dyn crate::backend::VyreBackend,
49) -> Result<[u32; 3], BackendError> {
50    crate::validation::validate_program_for_backend(backend, program, &DispatchConfig::default())?;
51    let plan = BindingPlan::build(program)?;
52    let element_count = dispatch_element_count_for_program(program, &plan.bindings);
53
54    infer_dispatch_grid_for_count(element_count, program.workgroup_size())
55}
56
57/// Infer a launch grid for a known logical element count and workgroup shape.
58///
59/// 1D kernels use a standard ceil-div over X lanes. 2D/3D kernels use a
60/// square/cube-ish decomposition so common matrix-style programs with
61/// `count = rows * cols` do not need driver-specific manual launch policy.
62///
63/// # Errors
64///
65/// Returns if any workgroup axis is zero or an inferred grid axis cannot fit
66/// in `u32`.
67pub fn infer_dispatch_grid_for_count(
68    element_count: u32,
69    workgroup: [u32; 3],
70) -> Result<[u32; 3], BackendError> {
71    if workgroup.contains(&0) {
72        return Err(BackendError::new(
73            "workgroup dimensions must be non-zero. Fix: set Program::workgroup_size and DispatchConfig::workgroup_override to positive values.",
74        ));
75    }
76    let count = u64::from(element_count.max(1));
77    if workgroup[1] == 1 && workgroup[2] == 1 {
78        return Ok([ceil_div_u64(count, u64::from(workgroup[0]))?, 1, 1]);
79    }
80    if workgroup[2] == 1 {
81        let side = ceil_sqrt_u64(count);
82        return Ok([
83            ceil_div_u64(side, u64::from(workgroup[0]))?,
84            ceil_div_u64(
85                u64::from(ceil_div_u64(count, side)?),
86                u64::from(workgroup[1]),
87            )?,
88            1,
89        ]);
90    }
91    let side = ceil_cuberoot_u64(count);
92    let xy = side.checked_mul(side).ok_or_else(|| {
93        BackendError::new(format!(
94            "3D dispatch-grid side {side} overflows u64 square during shape planning. Fix: split the Program before GPU launch planning."
95        ))
96    })?;
97    Ok([
98        ceil_div_u64(side, u64::from(workgroup[0]))?,
99        ceil_div_u64(side, u64::from(workgroup[1]))?,
100        ceil_div_u64(u64::from(ceil_div_u64(count, xy)?), u64::from(workgroup[2]))?,
101    ])
102}
103
104fn ceil_div_u64(value: u64, divisor: u64) -> Result<u32, BackendError> {
105    let divided = value.div_ceil(divisor).max(1);
106    u32::try_from(divided).map_err(|_| {
107        BackendError::new(
108            "inferred dispatch grid dimension overflowed u32. Fix: split the Program into smaller dispatches.",
109        )
110    })
111}
112
113fn ceil_sqrt_u64(value: u64) -> u64 {
114    if value <= 1 {
115        return 1;
116    }
117    let mut lo = 1_u64;
118    let mut hi = 1_u64 << 32;
119    while lo < hi {
120        let mid = lo + ((hi - lo) / 2);
121        match mid.checked_mul(mid) {
122            Some(square) if square < value => lo = mid + 1,
123            _ => hi = mid,
124        }
125    }
126    lo
127}
128
129fn ceil_cuberoot_u64(value: u64) -> u64 {
130    if value <= 1 {
131        return 1;
132    }
133    let mut lo = 1_u64;
134    let mut hi = 1_u64 << 22;
135    while lo < hi {
136        let mid = lo + ((hi - lo) / 2);
137        match checked_cube_u64(mid) {
138            Some(cube) if cube < value => lo = mid + 1,
139            _ => hi = mid,
140        }
141    }
142    lo
143}
144
145fn checked_cube_u64(value: u64) -> Option<u64> {
146    value.checked_mul(value)?.checked_mul(value)
147}
148
149// ---------------------------------------------------------------------------
150// N6 power-of-2 dispatch grid coercion + tail-mask
151// ---------------------------------------------------------------------------
152
153/// Result of coercing a logical element count up to the next power of two.
154///
155/// Backends that opt into the N6 substrate dispatch over `rounded_count`
156/// lanes (so every workgroup is uniform-shape, no boundary divergence on
157/// the last workgroup) and have the kernel guard each store with the
158/// tail-mask predicate `lane_id < original_count`. Threads beyond the
159/// original count no-op their stores.
160///
161/// The win is on tail handling for attention/softmax/reduce shapes where
162/// the workload is not a multiple of the workgroup size  -  without
163/// coercion the last workgroup runs with masked-out lanes that still
164/// incur scheduling cost; with coercion every workgroup is identical
165/// and the masked-out lanes are skipped via the predicate.
166#[derive(Debug, Clone, Copy, PartialEq, Eq)]
167pub struct TailMaskPolicy {
168    /// Logical element count requested by the caller.
169    pub original_count: u32,
170    /// Element count after rounding up to the next power of two. Equal
171    /// to `original_count` when it is already a power of two.
172    pub rounded_count: u32,
173    /// Convenience: `rounded_count - original_count`. Lanes in this
174    /// suffix range must be predicated off by the kernel.
175    pub tail_lanes: u32,
176}
177
178impl TailMaskPolicy {
179    /// True when no rounding was needed; the dispatch can run as-is
180    /// without a tail-mask predicate.
181    #[must_use]
182    pub fn is_aligned(&self) -> bool {
183        self.tail_lanes == 0
184    }
185}
186
187/// N6: round `element_count` up to the next power of two. Returns a
188/// [`TailMaskPolicy`] that the lower/emit layer consumes to insert a
189/// `lane_id < original_count` predicate around each store. Pure
190/// arithmetic; no I/O.
191///
192/// `element_count == 0` is treated as 0 (rounded_count = 0, no tail).
193/// `element_count == 1` rounds to 1 (already pow2).
194/// `element_count` beyond `1 << 31` cannot be rounded inside `u32`; callers
195/// that need to distinguish that condition must use
196/// [`try_coerce_to_pow2_with_tail_mask`]. This legacy wrapper preserves the
197/// original shape on overflow instead of panicking.
198#[must_use]
199pub fn coerce_to_pow2_with_tail_mask(element_count: u32) -> TailMaskPolicy {
200    match try_coerce_to_pow2_with_tail_mask(element_count) {
201        Ok(policy) => policy,
202        Err(_error) => TailMaskPolicy {
203            original_count: element_count,
204            rounded_count: element_count,
205            tail_lanes: 0,
206        },
207    }
208}
209
210/// Fallible N6 power-of-two dispatch-grid coercion.
211///
212/// # Errors
213/// Returns when `element_count` cannot be rounded up inside `u32`.
214pub fn try_coerce_to_pow2_with_tail_mask(
215    element_count: u32,
216) -> Result<TailMaskPolicy, BackendError> {
217    if element_count == 0 {
218        return Ok(TailMaskPolicy {
219            original_count: 0,
220            rounded_count: 0,
221            tail_lanes: 0,
222        });
223    }
224    let rounded = next_pow2_u32_checked(element_count)?;
225    Ok(TailMaskPolicy {
226        original_count: element_count,
227        rounded_count: rounded,
228        tail_lanes: rounded - element_count,
229    })
230}
231
232fn next_pow2_u32_checked(value: u32) -> Result<u32, BackendError> {
233    if value.is_power_of_two() {
234        return Ok(value);
235    }
236    if value > (1u32 << 31) {
237        return Err(BackendError::new(format!(
238            "cannot round element_count={value} up to a power-of-two u32 grid without overflow. Fix: split the workload before grid-shape planning; do not silently saturate or fall back to an under-dispatching shape."
239        )));
240    }
241    Ok(value.next_power_of_two())
242}
243
244#[cfg(test)]
245mod n6_tests {
246    use super::*;
247
248    #[test]
249    fn already_pow2_is_identity_with_no_tail() {
250        let p = coerce_to_pow2_with_tail_mask(64);
251        assert_eq!(p.original_count, 64);
252        assert_eq!(p.rounded_count, 64);
253        assert_eq!(p.tail_lanes, 0);
254        assert!(p.is_aligned());
255    }
256
257    #[test]
258    fn non_pow2_rounds_up_and_reports_tail() {
259        let p = coerce_to_pow2_with_tail_mask(100);
260        assert_eq!(p.original_count, 100);
261        assert_eq!(p.rounded_count, 128);
262        assert_eq!(p.tail_lanes, 28);
263        assert!(!p.is_aligned());
264    }
265
266    #[test]
267    fn one_is_pow2_no_tail() {
268        let p = coerce_to_pow2_with_tail_mask(1);
269        assert_eq!(p.rounded_count, 1);
270        assert_eq!(p.tail_lanes, 0);
271    }
272
273    #[test]
274    fn zero_passes_through_with_no_tail() {
275        let p = coerce_to_pow2_with_tail_mask(0);
276        assert_eq!(p.rounded_count, 0);
277        assert_eq!(p.tail_lanes, 0);
278        assert!(p.is_aligned());
279    }
280
281    #[test]
282    fn large_value_below_2_31_rounds_normally() {
283        let p = coerce_to_pow2_with_tail_mask(1_000_000_000);
284        // 2^30 = 1_073_741_824
285        assert_eq!(p.rounded_count, 1u32 << 30);
286        assert_eq!(p.tail_lanes, (1u32 << 30) - 1_000_000_000);
287    }
288
289    #[test]
290    fn value_above_2_31_errors_instead_of_saturating() {
291        let error = try_coerce_to_pow2_with_tail_mask(u32::MAX)
292            .expect_err("oversized power-of-two coercion must fail loudly");
293        let message = error.to_string();
294        assert!(
295            message.contains("Fix:"),
296            "oversized grid-shape error must be actionable"
297        );
298    }
299
300    #[test]
301    fn root_helpers_are_exact_at_large_boundaries() {
302        assert_eq!(ceil_sqrt_u64((1_u64 << 32) - 1), 65_536);
303        assert_eq!(ceil_sqrt_u64(1_u64 << 32), 65_536);
304        assert_eq!(ceil_cuberoot_u64(2_642_245_u64.pow(3)), 2_642_245);
305        assert_eq!(ceil_cuberoot_u64(2_642_245_u64.pow(3) - 1), 2_642_245);
306    }
307
308    #[test]
309    fn dispatch_grid_planning_uses_integer_roots_and_typed_errors() {
310        let source = include_str!("grid.rs");
311        let production = source
312            .split("#[cfg(test)]")
313            .next()
314            .expect("Fix: dispatch-grid production source must precede tests");
315
316        assert!(
317            !production.contains(" as f64")
318                && !production.contains(".sqrt()")
319                && !production.contains(".cbrt()"),
320            "Fix: dispatch-grid inference must use deterministic integer root arithmetic."
321        );
322        assert!(
323            production.contains("try_coerce_to_pow2_with_tail_mask")
324                && !production.contains("panic!("),
325            "Fix: dispatch-grid planning should expose typed errors instead of production panics."
326        );
327    }
328}