Skip to main content

vk_graph/cmd/
compute.rs

1use {
2    super::{cmd_ref::CommandRef, pipeline::PipelineCommand},
3    crate::{driver::compute::ComputePipeline, node::AnyBufferNode},
4    ash::vk,
5    std::ops::Deref,
6};
7
8impl PipelineCommand<'_, ComputePipeline> {
9    /// Begin recording a compute pipeline command buffer.
10    pub fn record_cmd(mut self, func: impl FnOnce(ComputeCommandRef<'_>) + Send + 'static) -> Self {
11        self.record_cmd_mut(func);
12        self
13    }
14
15    /// Begin recording a compute pipeline command buffer.
16    pub fn record_cmd_mut(&mut self, func: impl FnOnce(ComputeCommandRef<'_>) + Send + 'static) {
17        let pipeline = self
18            .cmd
19            .cmd()
20            .expect_last_pipeline()
21            .expect_compute()
22            .clone();
23
24        self.cmd.push_exec(move |cmd| {
25            func(ComputeCommandRef { cmd, pipeline });
26        });
27    }
28}
29
30/// Recording interface for computing commands.
31///
32/// This structure provides a strongly-typed set of methods which allow compute shader code to be
33/// executed. An instance is provided to the closure argument of
34/// [`PipelineCommand::record_cmd`] which may be accessed by binding a [`ComputePipeline`] to a
35/// command.
36///
37/// # Examples
38///
39/// Basic usage:
40///
41/// ```no_run
42/// # use ash::vk;
43/// # use vk_graph::driver::DriverError;
44/// # use vk_graph::driver::device::{Device, DeviceInfo};
45/// # use vk_graph::driver::compute::{ComputePipeline, ComputePipelineInfo};
46/// # use vk_graph::driver::shader::{Shader};
47/// # use vk_graph::Graph;
48/// # fn main() -> Result<(), DriverError> {
49/// # let device = Device::new(DeviceInfo::default())?;
50/// # let info = ComputePipelineInfo::default();
51/// # let shader = Shader::new_compute([0u8; 1].as_slice());
52/// # let my_compute_pipeline = ComputePipeline::create(&device, info, shader)?;
53/// # let mut my_graph = Graph::default();
54/// my_graph
55///     .begin_cmd()
56///     .bind_pipeline(&my_compute_pipeline)
57///     .record_cmd(move |cmd| {
58///         // During this closure we have access to the compute functions!
59///         cmd.dispatch(64, 1, 1);
60///     });
61/// # Ok(()) }
62/// ```
63pub struct ComputeCommandRef<'a> {
64    cmd: CommandRef<'a>,
65    pipeline: ComputePipeline,
66}
67
68impl ComputeCommandRef<'_> {
69    /// [`Self::dispatch`] compute work items.
70    ///
71    /// When the command is executed, a global workgroup consisting of
72    /// `group_count_x × group_count_y × group_count_z` local workgroups is assembled.
73    ///
74    /// # Examples
75    ///
76    /// Basic usage:
77    ///
78    /// ```
79    /// # vk_shader_macros::glsl!(r#"
80    /// #version 450
81    /// #pragma shader_stage(compute)
82    ///
83    /// layout(set = 0, binding = 0, std430) restrict writeonly buffer MyBufer {
84    ///     uint my_buf[];
85    /// };
86    ///
87    /// void main() {
88    ///     // TODO
89    /// }
90    /// # "#);
91    /// ```
92    ///
93    /// ```no_run
94    /// # use ash::vk;
95    /// # use vk_graph::driver::{AccessType, DriverError};
96    /// # use vk_graph::driver::device::{Device, DeviceInfo};
97    /// # use vk_graph::driver::buffer::{Buffer, BufferInfo};
98    /// # use vk_graph::driver::compute::{ComputePipeline, ComputePipelineInfo};
99    /// # use vk_graph::driver::shader::{Shader};
100    /// # use vk_graph::Graph;
101    /// # fn main() -> Result<(), DriverError> {
102    /// # let device = Device::new(DeviceInfo::default())?;
103    /// # let buf_info = BufferInfo::device_mem(8, vk::BufferUsageFlags::STORAGE_BUFFER);
104    /// # let my_buf = Buffer::create(&device, buf_info)?;
105    /// # let info = ComputePipelineInfo::default();
106    /// # let shader = Shader::new_compute([0u8; 1].as_slice());
107    /// # let my_compute_pipeline = ComputePipeline::create(&device, info, shader)?;
108    /// # let mut my_graph = Graph::default();
109    /// # let my_buf_node = my_graph.bind_resource(my_buf);
110    /// my_graph
111    ///     .begin_cmd()
112    ///     .debug_name("fill my_buf_node with data")
113    ///     .bind_pipeline(&my_compute_pipeline)
114    ///     .shader_resource_access(0, my_buf_node, AccessType::ComputeShaderWrite)
115    ///     .record_cmd(move |cmd| {
116    ///         cmd.dispatch(128, 64, 32);
117    ///     });
118    /// # Ok(()) }
119    /// ```
120    ///
121    /// See [`vkCmdDispatch`](https://registry.khronos.org/vulkan/specs/latest/man/html/vkCmdDispatch.html).
122    #[profiling::function]
123    pub fn dispatch(&self, group_count_x: u32, group_count_y: u32, group_count_z: u32) -> &Self {
124        unsafe {
125            self.cmd.device.cmd_dispatch(
126                self.cmd.handle,
127                group_count_x,
128                group_count_y,
129                group_count_z,
130            );
131        }
132
133        self
134    }
135
136    /// [`Self::dispatch_base`] compute work items with non-zero base values for the workgroup IDs.
137    ///
138    /// When the command is executed, a global workgroup consisting of
139    /// `group_count_x × group_count_y × group_count_z` local workgroups is assembled, with
140    /// WorkgroupId values ranging from `[base_group*, base_group* + group_count*)` in each
141    /// component.
142    ///
143    /// [`Self::dispatch`] is equivalent to
144    /// `dispatch_base(0, 0, 0, group_count_x, group_count_y, group_count_z)`.
145    ///
146    /// See [`vkCmdDispatchBase`](https://registry.khronos.org/vulkan/specs/latest/man/html/vkCmdDispatchBase.html).
147    #[profiling::function]
148    pub fn dispatch_base(
149        &self,
150        base_group_x: u32,
151        base_group_y: u32,
152        base_group_z: u32,
153        group_count_x: u32,
154        group_count_y: u32,
155        group_count_z: u32,
156    ) -> &Self {
157        unsafe {
158            self.cmd.device.cmd_dispatch_base(
159                self.cmd.handle,
160                base_group_x,
161                base_group_y,
162                base_group_z,
163                group_count_x,
164                group_count_y,
165                group_count_z,
166            );
167        }
168
169        self
170    }
171
172    /// Dispatch compute work items with indirect parameters.
173    ///
174    /// `dispatch_indirect` behaves similarly to [`Self::dispatch`] except that the parameters
175    /// are read by the device from `args_buf` during execution. The parameters of the dispatch are
176    /// encoded in a [`vk::DispatchIndirectCommand`] structure taken from `args_buf` starting at
177    /// `args_offset`.
178    ///
179    /// # Examples
180    ///
181    /// Basic usage:
182    ///
183    /// ```no_run
184    /// # use ash::vk;
185    /// # use bytemuck::{bytes_of, Pod, Zeroable};
186    /// # use vk_graph::driver::{AccessType, DriverError};
187    /// # use vk_graph::driver::device::{Device, DeviceInfo};
188    /// # use vk_graph::driver::buffer::{Buffer, BufferInfo};
189    /// # use vk_graph::driver::compute::{ComputePipeline, ComputePipelineInfo};
190    /// # use vk_graph::driver::shader::{Shader};
191    /// # use vk_graph::Graph;
192    /// # fn main() -> Result<(), DriverError> {
193    /// # let device = Device::new(DeviceInfo::default())?;
194    /// # let buf_info = BufferInfo::device_mem(8, vk::BufferUsageFlags::STORAGE_BUFFER);
195    /// # let my_buf = Buffer::create(&device, buf_info)?;
196    /// # let info = ComputePipelineInfo::default();
197    /// # let shader = Shader::new_compute([0u8; 1].as_slice());
198    /// # let my_compute_pipeline = ComputePipeline::create(&device, info, shader)?;
199    /// # let mut my_graph = Graph::default();
200    /// # let my_buf_node = my_graph.bind_resource(my_buf);
201    /// # #[repr(C)]
202    /// # #[derive(Clone, Copy, Pod, Zeroable)]
203    /// # struct DispatchIndirectCommand { x: u32, y: u32, z: u32, }
204    /// let args = DispatchIndirectCommand {
205    ///     x: 1,
206    ///     y: 2,
207    ///     z: 3,
208    /// };
209    /// let data = bytes_of(&args);
210    /// let usage = vk::BufferUsageFlags::INDIRECT_BUFFER | vk::BufferUsageFlags::STORAGE_BUFFER;
211    /// let args_buf = Buffer::create_from_slice(&device, usage, data)?;
212    /// let args_buf = my_graph.bind_resource(args_buf);
213    ///
214    /// my_graph
215    ///     .begin_cmd()
216    ///     .debug_name("fill my_buf_node with data")
217    ///     .bind_pipeline(&my_compute_pipeline)
218    ///     .resource_access(args_buf, AccessType::IndirectBuffer)
219    ///     .shader_resource_access(0, my_buf_node, AccessType::ComputeShaderWrite)
220    ///     .record_cmd(move |cmd| {
221    ///         cmd.dispatch_indirect(args_buf, 0);
222    ///     });
223    /// # Ok(()) }
224    /// ```
225    ///
226    /// See [`vkCmdDispatchIndirect`](https://registry.khronos.org/vulkan/specs/latest/man/html/vkCmdDispatchIndirect.html).
227    #[profiling::function]
228    pub fn dispatch_indirect(
229        &self,
230        args_buf: impl Into<AnyBufferNode>,
231        args_offset: vk::DeviceSize,
232    ) -> &Self {
233        let args_buf = args_buf.into();
234        let args_buf = self.resource(args_buf);
235
236        unsafe {
237            self.cmd
238                .device
239                .cmd_dispatch_indirect(self.cmd.handle, args_buf.handle, args_offset);
240        }
241
242        self
243    }
244
245    /// Updates push constants.
246    ///
247    /// Push constants represent a high speed path to modify constant data in pipelines that is
248    /// expected to outperform memory-backed resource updates.
249    ///
250    /// Push constant values can be updated incrementally, causing shader stages to read the new
251    /// data for push constants modified by this command, while still reading the previous data for
252    /// push constants not modified by this command.
253    ///
254    /// # Device limitations
255    ///
256    /// See
257    /// [`device.physical_device.props.limits.max_push_constants_size`](vk::PhysicalDeviceLimits)
258    /// for the limits of the current device. You may also check [gpuinfo.org] for a listing of
259    /// reported limits on other devices.
260    ///
261    /// # Examples
262    ///
263    /// Basic usage:
264    ///
265    /// ```
266    /// # vk_shader_macros::glsl!(r#"
267    /// #version 450
268    /// #pragma shader_stage(compute)
269    ///
270    /// layout(push_constant) uniform PushConstants {
271    ///     layout(offset = 0) uint the_answer;
272    /// } push_constants;
273    ///
274    /// void main()
275    /// {
276    ///     // TODO: Add bindings to read/write things!
277    /// }
278    /// # "#);
279    /// ```
280    ///
281    /// ```no_run
282    /// # use ash::vk;
283    /// # use vk_graph::driver::DriverError;
284    /// # use vk_graph::driver::device::{Device, DeviceInfo};
285    /// # use vk_graph::driver::buffer::{Buffer, BufferInfo};
286    /// # use vk_graph::driver::compute::{ComputePipeline, ComputePipelineInfo};
287    /// # use vk_graph::driver::shader::{Shader};
288    /// # use vk_graph::Graph;
289    /// # fn main() -> Result<(), DriverError> {
290    /// # let device = Device::new(DeviceInfo::default())?;
291    /// # let info = ComputePipelineInfo::default();
292    /// # let shader = Shader::new_compute([0u8; 1].as_slice());
293    /// # let my_compute_pipeline = ComputePipeline::create(&device, info, shader)?;
294    /// # let mut my_graph = Graph::default();
295    /// my_graph
296    ///     .begin_cmd()
297    ///     .debug_name("compute the ultimate question")
298    ///     .bind_pipeline(&my_compute_pipeline)
299    ///     .record_cmd(move |cmd| {
300    ///         cmd
301    ///             .push_constants(0, &[42])
302    ///             .dispatch(1, 1, 1);
303    ///     });
304    /// # Ok(()) }
305    /// ```
306    ///
307    /// See [`vkCmdPushConstants`](https://registry.khronos.org/vulkan/specs/latest/man/html/vkCmdPushConstants.html).
308    #[profiling::function]
309    pub fn push_constants(&self, offset: u32, data: &[u8]) -> &Self {
310        self.cmd_push_constants(
311            self.pipeline.inner.layout,
312            self.pipeline.inner.push_constants.as_slice(),
313            offset,
314            data,
315        );
316
317        self
318    }
319}
320
321impl<'a> Deref for ComputeCommandRef<'a> {
322    type Target = CommandRef<'a>;
323
324    fn deref(&self) -> &Self::Target {
325        &self.cmd
326    }
327}
328
329#[allow(unused)]
330mod deprecated {
331    use {
332        crate::{
333            Node,
334            cmd::{
335                Binding, PipelineCommand, Subresource, SubresourceRange, ViewInfo,
336                compute::ComputeCommandRef,
337            },
338            driver::compute::ComputePipeline,
339        },
340        std::any::Any,
341        vk_sync::AccessType,
342    };
343
344    impl ComputeCommandRef<'_> {
345        #[deprecated = "use push_constants function"]
346        #[doc(hidden)]
347        pub fn push_constants_offset(&self, offset: u32, data: &[u8]) -> &Self {
348            self.push_constants(offset, data)
349        }
350    }
351
352    impl PipelineCommand<'_, ComputePipeline> {
353        #[deprecated = "use shader_resource_access with ComputeShaderReadOther"]
354        #[doc(hidden)]
355        pub fn read_descriptor<N>(self, descriptor: impl Into<Binding>, node: N) -> Self
356        where
357            N: Node + Subresource,
358            N::Info: Copy,
359            SubresourceRange: From<N::Info>,
360            ViewInfo: From<N::Info>,
361        {
362            self.shader_resource_access(descriptor, node, AccessType::ComputeShaderReadOther)
363        }
364
365        #[deprecated = "use shader_subresource_access with ComputeShaderReadOther"]
366        #[doc(hidden)]
367        pub fn read_descriptor_as<N>(
368            self,
369            descriptor: impl Into<Binding>,
370            node: N,
371            node_view: impl Into<N::Info>,
372        ) -> Self
373        where
374            N: Node + Subresource,
375            N::Info: Copy,
376            SubresourceRange: From<N::Info>,
377            ViewInfo: From<N::Info>,
378        {
379            self.shader_subresource_access(
380                descriptor,
381                node,
382                node_view,
383                AccessType::ComputeShaderReadOther,
384            )
385        }
386
387        #[deprecated = "use record_cmd function"]
388        #[doc(hidden)]
389        pub fn record_compute(
390            self,
391            func: impl FnOnce(ComputeCommandRef<'_>, ()) + Send + 'static,
392        ) -> Self {
393            self.record_cmd(|cmd| func(cmd, ()))
394        }
395
396        #[deprecated = "use shader_resource_access function with AccessType::ComputeShaderWrite"]
397        #[doc(hidden)]
398        pub fn write_descriptor<N>(self, descriptor: impl Into<Binding>, node: N) -> Self
399        where
400            N: Node + Subresource,
401            N::Info: Copy,
402            SubresourceRange: From<N::Info>,
403            ViewInfo: From<N::Info>,
404        {
405            self.shader_resource_access(descriptor, node, AccessType::ComputeShaderWrite)
406        }
407
408        #[deprecated = "use shader_subresource_access function with AccessType::ComputeShaderWrite"]
409        #[doc(hidden)]
410        pub fn write_descriptor_as<N>(
411            self,
412            descriptor: impl Into<Binding>,
413            node: N,
414            node_view: impl Into<N::Info>,
415        ) -> Self
416        where
417            N: Node + Subresource,
418            N::Info: Copy,
419            SubresourceRange: From<N::Info>,
420            ViewInfo: From<N::Info>,
421        {
422            self.shader_subresource_access(
423                descriptor,
424                node,
425                node_view,
426                AccessType::ComputeShaderWrite,
427            )
428        }
429    }
430}