vk_graph/cmd/compute.rs
1use {
2 super::{cmd_ref::CommandRef, pipeline::PipelineCommand},
3 crate::{driver::compute::ComputePipeline, node::AnyBufferNode},
4 ash::vk,
5 std::ops::Deref,
6};
7
8impl PipelineCommand<'_, ComputePipeline> {
9 /// Begin recording a compute pipeline command buffer.
10 pub fn record_cmd(mut self, func: impl FnOnce(ComputeCommandRef<'_>) + Send + 'static) -> Self {
11 self.record_cmd_mut(func);
12 self
13 }
14
15 /// Begin recording a compute pipeline command buffer.
16 pub fn record_cmd_mut(&mut self, func: impl FnOnce(ComputeCommandRef<'_>) + Send + 'static) {
17 let pipeline = self
18 .cmd
19 .cmd()
20 .expect_last_pipeline()
21 .expect_compute()
22 .clone();
23
24 self.cmd.push_exec(move |cmd| {
25 func(ComputeCommandRef { cmd, pipeline });
26 });
27 }
28}
29
30/// Recording interface for computing commands.
31///
32/// This structure provides a strongly-typed set of methods which allow compute shader code to be
33/// executed. An instance is provided to the closure argument of
34/// [`PipelineCommand::record_cmd`] which may be accessed by binding a [`ComputePipeline`] to a
35/// command.
36///
37/// # Examples
38///
39/// Basic usage:
40///
41/// ```no_run
42/// # use ash::vk;
43/// # use vk_graph::driver::DriverError;
44/// # use vk_graph::driver::device::{Device, DeviceInfo};
45/// # use vk_graph::driver::compute::{ComputePipeline, ComputePipelineInfo};
46/// # use vk_graph::driver::shader::{Shader};
47/// # use vk_graph::Graph;
48/// # fn main() -> Result<(), DriverError> {
49/// # let device = Device::new(DeviceInfo::default())?;
50/// # let info = ComputePipelineInfo::default();
51/// # let shader = Shader::new_compute([0u8; 1].as_slice());
52/// # let my_compute_pipeline = ComputePipeline::create(&device, info, shader)?;
53/// # let mut my_graph = Graph::default();
54/// my_graph
55/// .begin_cmd()
56/// .bind_pipeline(&my_compute_pipeline)
57/// .record_cmd(move |cmd| {
58/// // During this closure we have access to the compute functions!
59/// cmd.dispatch(64, 1, 1);
60/// });
61/// # Ok(()) }
62/// ```
63pub struct ComputeCommandRef<'a> {
64 cmd: CommandRef<'a>,
65 pipeline: ComputePipeline,
66}
67
68impl ComputeCommandRef<'_> {
69 /// [`Self::dispatch`] compute work items.
70 ///
71 /// When the command is executed, a global workgroup consisting of
72 /// `group_count_x × group_count_y × group_count_z` local workgroups is assembled.
73 ///
74 /// # Examples
75 ///
76 /// Basic usage:
77 ///
78 /// ```
79 /// # vk_shader_macros::glsl!(r#"
80 /// #version 450
81 /// #pragma shader_stage(compute)
82 ///
83 /// layout(set = 0, binding = 0, std430) restrict writeonly buffer MyBufer {
84 /// uint my_buf[];
85 /// };
86 ///
87 /// void main() {
88 /// // TODO
89 /// }
90 /// # "#);
91 /// ```
92 ///
93 /// ```no_run
94 /// # use ash::vk;
95 /// # use vk_graph::driver::{AccessType, DriverError};
96 /// # use vk_graph::driver::device::{Device, DeviceInfo};
97 /// # use vk_graph::driver::buffer::{Buffer, BufferInfo};
98 /// # use vk_graph::driver::compute::{ComputePipeline, ComputePipelineInfo};
99 /// # use vk_graph::driver::shader::{Shader};
100 /// # use vk_graph::Graph;
101 /// # fn main() -> Result<(), DriverError> {
102 /// # let device = Device::new(DeviceInfo::default())?;
103 /// # let buf_info = BufferInfo::device_mem(8, vk::BufferUsageFlags::STORAGE_BUFFER);
104 /// # let my_buf = Buffer::create(&device, buf_info)?;
105 /// # let info = ComputePipelineInfo::default();
106 /// # let shader = Shader::new_compute([0u8; 1].as_slice());
107 /// # let my_compute_pipeline = ComputePipeline::create(&device, info, shader)?;
108 /// # let mut my_graph = Graph::default();
109 /// # let my_buf_node = my_graph.bind_resource(my_buf);
110 /// my_graph
111 /// .begin_cmd()
112 /// .debug_name("fill my_buf_node with data")
113 /// .bind_pipeline(&my_compute_pipeline)
114 /// .shader_resource_access(0, my_buf_node, AccessType::ComputeShaderWrite)
115 /// .record_cmd(move |cmd| {
116 /// cmd.dispatch(128, 64, 32);
117 /// });
118 /// # Ok(()) }
119 /// ```
120 ///
121 /// See [`vkCmdDispatch`](https://registry.khronos.org/vulkan/specs/latest/man/html/vkCmdDispatch.html).
122 #[profiling::function]
123 pub fn dispatch(&self, group_count_x: u32, group_count_y: u32, group_count_z: u32) -> &Self {
124 unsafe {
125 self.cmd.device.cmd_dispatch(
126 self.cmd.handle,
127 group_count_x,
128 group_count_y,
129 group_count_z,
130 );
131 }
132
133 self
134 }
135
136 /// [`Self::dispatch_base`] compute work items with non-zero base values for the workgroup IDs.
137 ///
138 /// When the command is executed, a global workgroup consisting of
139 /// `group_count_x × group_count_y × group_count_z` local workgroups is assembled, with
140 /// WorkgroupId values ranging from `[base_group*, base_group* + group_count*)` in each
141 /// component.
142 ///
143 /// [`Self::dispatch`] is equivalent to
144 /// `dispatch_base(0, 0, 0, group_count_x, group_count_y, group_count_z)`.
145 ///
146 /// See [`vkCmdDispatchBase`](https://registry.khronos.org/vulkan/specs/latest/man/html/vkCmdDispatchBase.html).
147 #[profiling::function]
148 pub fn dispatch_base(
149 &self,
150 base_group_x: u32,
151 base_group_y: u32,
152 base_group_z: u32,
153 group_count_x: u32,
154 group_count_y: u32,
155 group_count_z: u32,
156 ) -> &Self {
157 unsafe {
158 self.cmd.device.cmd_dispatch_base(
159 self.cmd.handle,
160 base_group_x,
161 base_group_y,
162 base_group_z,
163 group_count_x,
164 group_count_y,
165 group_count_z,
166 );
167 }
168
169 self
170 }
171
172 /// Dispatch compute work items with indirect parameters.
173 ///
174 /// `dispatch_indirect` behaves similarly to [`Self::dispatch`] except that the parameters
175 /// are read by the device from `args_buf` during execution. The parameters of the dispatch are
176 /// encoded in a [`vk::DispatchIndirectCommand`] structure taken from `args_buf` starting at
177 /// `args_offset`.
178 ///
179 /// # Examples
180 ///
181 /// Basic usage:
182 ///
183 /// ```no_run
184 /// # use ash::vk;
185 /// # use bytemuck::{bytes_of, Pod, Zeroable};
186 /// # use vk_graph::driver::{AccessType, DriverError};
187 /// # use vk_graph::driver::device::{Device, DeviceInfo};
188 /// # use vk_graph::driver::buffer::{Buffer, BufferInfo};
189 /// # use vk_graph::driver::compute::{ComputePipeline, ComputePipelineInfo};
190 /// # use vk_graph::driver::shader::{Shader};
191 /// # use vk_graph::Graph;
192 /// # fn main() -> Result<(), DriverError> {
193 /// # let device = Device::new(DeviceInfo::default())?;
194 /// # let buf_info = BufferInfo::device_mem(8, vk::BufferUsageFlags::STORAGE_BUFFER);
195 /// # let my_buf = Buffer::create(&device, buf_info)?;
196 /// # let info = ComputePipelineInfo::default();
197 /// # let shader = Shader::new_compute([0u8; 1].as_slice());
198 /// # let my_compute_pipeline = ComputePipeline::create(&device, info, shader)?;
199 /// # let mut my_graph = Graph::default();
200 /// # let my_buf_node = my_graph.bind_resource(my_buf);
201 /// # #[repr(C)]
202 /// # #[derive(Clone, Copy, Pod, Zeroable)]
203 /// # struct DispatchIndirectCommand { x: u32, y: u32, z: u32, }
204 /// let args = DispatchIndirectCommand {
205 /// x: 1,
206 /// y: 2,
207 /// z: 3,
208 /// };
209 /// let data = bytes_of(&args);
210 /// let usage = vk::BufferUsageFlags::INDIRECT_BUFFER | vk::BufferUsageFlags::STORAGE_BUFFER;
211 /// let args_buf = Buffer::create_from_slice(&device, usage, data)?;
212 /// let args_buf = my_graph.bind_resource(args_buf);
213 ///
214 /// my_graph
215 /// .begin_cmd()
216 /// .debug_name("fill my_buf_node with data")
217 /// .bind_pipeline(&my_compute_pipeline)
218 /// .resource_access(args_buf, AccessType::IndirectBuffer)
219 /// .shader_resource_access(0, my_buf_node, AccessType::ComputeShaderWrite)
220 /// .record_cmd(move |cmd| {
221 /// cmd.dispatch_indirect(args_buf, 0);
222 /// });
223 /// # Ok(()) }
224 /// ```
225 ///
226 /// See [`vkCmdDispatchIndirect`](https://registry.khronos.org/vulkan/specs/latest/man/html/vkCmdDispatchIndirect.html).
227 #[profiling::function]
228 pub fn dispatch_indirect(
229 &self,
230 args_buf: impl Into<AnyBufferNode>,
231 args_offset: vk::DeviceSize,
232 ) -> &Self {
233 let args_buf = args_buf.into();
234 let args_buf = self.resource(args_buf);
235
236 unsafe {
237 self.cmd
238 .device
239 .cmd_dispatch_indirect(self.cmd.handle, args_buf.handle, args_offset);
240 }
241
242 self
243 }
244
245 /// Updates push constants.
246 ///
247 /// Push constants represent a high speed path to modify constant data in pipelines that is
248 /// expected to outperform memory-backed resource updates.
249 ///
250 /// Push constant values can be updated incrementally, causing shader stages to read the new
251 /// data for push constants modified by this command, while still reading the previous data for
252 /// push constants not modified by this command.
253 ///
254 /// # Device limitations
255 ///
256 /// See
257 /// [`device.physical_device.props.limits.max_push_constants_size`](vk::PhysicalDeviceLimits)
258 /// for the limits of the current device. You may also check [gpuinfo.org] for a listing of
259 /// reported limits on other devices.
260 ///
261 /// # Examples
262 ///
263 /// Basic usage:
264 ///
265 /// ```
266 /// # vk_shader_macros::glsl!(r#"
267 /// #version 450
268 /// #pragma shader_stage(compute)
269 ///
270 /// layout(push_constant) uniform PushConstants {
271 /// layout(offset = 0) uint the_answer;
272 /// } push_constants;
273 ///
274 /// void main()
275 /// {
276 /// // TODO: Add bindings to read/write things!
277 /// }
278 /// # "#);
279 /// ```
280 ///
281 /// ```no_run
282 /// # use ash::vk;
283 /// # use vk_graph::driver::DriverError;
284 /// # use vk_graph::driver::device::{Device, DeviceInfo};
285 /// # use vk_graph::driver::buffer::{Buffer, BufferInfo};
286 /// # use vk_graph::driver::compute::{ComputePipeline, ComputePipelineInfo};
287 /// # use vk_graph::driver::shader::{Shader};
288 /// # use vk_graph::Graph;
289 /// # fn main() -> Result<(), DriverError> {
290 /// # let device = Device::new(DeviceInfo::default())?;
291 /// # let info = ComputePipelineInfo::default();
292 /// # let shader = Shader::new_compute([0u8; 1].as_slice());
293 /// # let my_compute_pipeline = ComputePipeline::create(&device, info, shader)?;
294 /// # let mut my_graph = Graph::default();
295 /// my_graph
296 /// .begin_cmd()
297 /// .debug_name("compute the ultimate question")
298 /// .bind_pipeline(&my_compute_pipeline)
299 /// .record_cmd(move |cmd| {
300 /// cmd
301 /// .push_constants(0, &[42])
302 /// .dispatch(1, 1, 1);
303 /// });
304 /// # Ok(()) }
305 /// ```
306 ///
307 /// See [`vkCmdPushConstants`](https://registry.khronos.org/vulkan/specs/latest/man/html/vkCmdPushConstants.html).
308 #[profiling::function]
309 pub fn push_constants(&self, offset: u32, data: &[u8]) -> &Self {
310 self.cmd_push_constants(
311 self.pipeline.inner.layout,
312 self.pipeline.inner.push_constants.as_slice(),
313 offset,
314 data,
315 );
316
317 self
318 }
319}
320
321impl<'a> Deref for ComputeCommandRef<'a> {
322 type Target = CommandRef<'a>;
323
324 fn deref(&self) -> &Self::Target {
325 &self.cmd
326 }
327}
328
329#[allow(unused)]
330mod deprecated {
331 use {
332 crate::{
333 Node,
334 cmd::{
335 Binding, PipelineCommand, Subresource, SubresourceRange, ViewInfo,
336 compute::ComputeCommandRef,
337 },
338 driver::compute::ComputePipeline,
339 },
340 std::any::Any,
341 vk_sync::AccessType,
342 };
343
344 impl ComputeCommandRef<'_> {
345 #[deprecated = "use push_constants function"]
346 #[doc(hidden)]
347 pub fn push_constants_offset(&self, offset: u32, data: &[u8]) -> &Self {
348 self.push_constants(offset, data)
349 }
350 }
351
352 impl PipelineCommand<'_, ComputePipeline> {
353 #[deprecated = "use shader_resource_access with ComputeShaderReadOther"]
354 #[doc(hidden)]
355 pub fn read_descriptor<N>(self, descriptor: impl Into<Binding>, node: N) -> Self
356 where
357 N: Node + Subresource,
358 N::Info: Copy,
359 SubresourceRange: From<N::Info>,
360 ViewInfo: From<N::Info>,
361 {
362 self.shader_resource_access(descriptor, node, AccessType::ComputeShaderReadOther)
363 }
364
365 #[deprecated = "use shader_subresource_access with ComputeShaderReadOther"]
366 #[doc(hidden)]
367 pub fn read_descriptor_as<N>(
368 self,
369 descriptor: impl Into<Binding>,
370 node: N,
371 node_view: impl Into<N::Info>,
372 ) -> Self
373 where
374 N: Node + Subresource,
375 N::Info: Copy,
376 SubresourceRange: From<N::Info>,
377 ViewInfo: From<N::Info>,
378 {
379 self.shader_subresource_access(
380 descriptor,
381 node,
382 node_view,
383 AccessType::ComputeShaderReadOther,
384 )
385 }
386
387 #[deprecated = "use record_cmd function"]
388 #[doc(hidden)]
389 pub fn record_compute(
390 self,
391 func: impl FnOnce(ComputeCommandRef<'_>, ()) + Send + 'static,
392 ) -> Self {
393 self.record_cmd(|cmd| func(cmd, ()))
394 }
395
396 #[deprecated = "use shader_resource_access function with AccessType::ComputeShaderWrite"]
397 #[doc(hidden)]
398 pub fn write_descriptor<N>(self, descriptor: impl Into<Binding>, node: N) -> Self
399 where
400 N: Node + Subresource,
401 N::Info: Copy,
402 SubresourceRange: From<N::Info>,
403 ViewInfo: From<N::Info>,
404 {
405 self.shader_resource_access(descriptor, node, AccessType::ComputeShaderWrite)
406 }
407
408 #[deprecated = "use shader_subresource_access function with AccessType::ComputeShaderWrite"]
409 #[doc(hidden)]
410 pub fn write_descriptor_as<N>(
411 self,
412 descriptor: impl Into<Binding>,
413 node: N,
414 node_view: impl Into<N::Info>,
415 ) -> Self
416 where
417 N: Node + Subresource,
418 N::Info: Copy,
419 SubresourceRange: From<N::Info>,
420 ViewInfo: From<N::Info>,
421 {
422 self.shader_subresource_access(
423 descriptor,
424 node,
425 node_view,
426 AccessType::ComputeShaderWrite,
427 )
428 }
429 }
430}