Skip to main content

scry_gpu/
dispatch.rs

1// SPDX-License-Identifier: MIT OR Apache-2.0
2//! Shader dispatch configuration and execution.
3
4/// Configuration for a compute dispatch.
5///
6/// The simple path ([`Device::dispatch`]) covers most cases.
7/// Use `DispatchConfig` when you need control over workgroup sizes
8/// or push constants.
9///
10/// [`Device::dispatch`]: crate::Device::dispatch
11pub struct DispatchConfig<'a> {
12    /// Shader source (WGSL).
13    pub shader: &'a str,
14
15    /// Entry point name. Defaults to `"main"` if `None`.
16    pub entry_point: Option<&'a str>,
17
18    /// Workgroup dimensions `[x, y, z]`.
19    ///
20    /// If `None`, the crate auto-calculates from `invocations` and the
21    /// shader's declared `@workgroup_size`.
22    pub workgroups: Option<[u32; 3]>,
23
24    /// Total invocations requested. Used to auto-calculate workgroup
25    /// dispatch count when `workgroups` is `None`.
26    pub invocations: u32,
27
28    /// Optional push constant data (raw bytes, must match shader layout).
29    pub push_constants: Option<&'a [u8]>,
30}
31
32impl<'a> DispatchConfig<'a> {
33    /// Create a minimal dispatch config.
34    pub const fn new(shader: &'a str, invocations: u32) -> Self {
35        Self {
36            shader,
37            entry_point: None,
38            workgroups: None,
39            invocations,
40            push_constants: None,
41        }
42    }
43
44    /// Override the entry point name (default: `"main"`).
45    pub const fn entry_point(mut self, name: &'a str) -> Self {
46        self.entry_point = Some(name);
47        self
48    }
49
50    /// Set explicit workgroup dispatch dimensions.
51    pub const fn workgroups(mut self, dims: [u32; 3]) -> Self {
52        self.workgroups = Some(dims);
53        self
54    }
55
56    /// Attach push constant data.
57    pub const fn push_constants(mut self, data: &'a [u8]) -> Self {
58        self.push_constants = Some(data);
59        self
60    }
61}
62
63/// Extract `@workgroup_size` from a parsed naga module's entry point.
64///
65/// Returns `[x, y, z]` or a default of `[64, 1, 1]` if the shader
66/// doesn't declare one.
67pub fn extract_workgroup_size(module: &naga::Module, entry: &str) -> [u32; 3] {
68    for ep in &module.entry_points {
69        if ep.name == entry {
70            let s = ep.workgroup_size;
71            return [s[0], s[1], s[2]];
72        }
73    }
74    [64, 1, 1]
75}
76
77/// Calculate dispatch dimensions given total invocations and per-workgroup size.
78///
79/// Applies `ceil(invocations / workgroup_size)` and clamps to the Vulkan
80/// `maxComputeWorkGroupCount` limit (65535 per axis).
81pub fn calc_dispatch(invocations: u32, workgroup_size: [u32; 3]) -> [u32; 3] {
82    let ceil_div = |a: u32, b: u32| a.div_ceil(b);
83
84    [ceil_div(invocations, workgroup_size[0]).min(65535), 1, 1]
85}