Skip to main content

vk_graph/cmd/
ray_trace.rs

1use {
2    super::{PipelineCommand, cmd_ref::CommandRef},
3    crate::driver::{device::Device, ray_trace::RayTracePipeline},
4    ash::vk,
5    std::ops::Deref,
6};
7
8impl PipelineCommand<'_, RayTracePipeline> {
9    /// Begin recording a ray trace pipeline command buffer.
10    pub fn record_cmd(
11        mut self,
12        func: impl FnOnce(RayTraceCommandRef<'_>) + Send + 'static,
13    ) -> Self {
14        self.record_cmd_mut(func);
15        self
16    }
17
18    /// Begin recording a ray trace pipeline command buffer.
19    pub fn record_cmd_mut(&mut self, func: impl FnOnce(RayTraceCommandRef<'_>) + Send + 'static) {
20        let pipeline = self
21            .cmd
22            .cmd()
23            .expect_last_pipeline()
24            .expect_ray_trace()
25            .clone();
26
27        #[cfg(debug_assertions)]
28        let dynamic_stack_size = pipeline.inner.info.dynamic_stack_size;
29
30        self.cmd.push_exec(move |cmd| {
31            func(RayTraceCommandRef {
32                cmd,
33
34                #[cfg(debug_assertions)]
35                dynamic_stack_size,
36
37                pipeline,
38            });
39        });
40    }
41}
42
43/// Recording interface for ray tracing commands.
44///
45/// This structure provides a strongly-typed set of methods which allow ray trace shader code to be
46/// executed. An instance is provided to the closure argument of
47/// [`PipelineCommand::record_cmd`] which may be accessed by binding a [`RayTracePipeline`] to
48/// a command.
49///
50/// # Examples
51///
52/// Basic usage:
53///
54/// ```no_run
55/// # use ash::vk;
56/// # use vk_graph::driver::DriverError;
57/// # use vk_graph::driver::device::{Device, DeviceInfo};
58/// # use vk_graph::driver::ray_trace::{
59/// #     RayTracePipeline,
60/// #     RayTracePipelineInfo,
61/// #     RayTraceShaderGroup,
62/// # };
63/// # use vk_graph::driver::shader::Shader;
64/// # use vk_graph::Graph;
65/// # fn main() -> Result<(), DriverError> {
66/// # let device = Device::new(DeviceInfo::default())?;
67/// # let info = RayTracePipelineInfo::default();
68/// # let my_miss_code = [0u8; 1];
69/// # let my_ray_trace_pipeline = RayTracePipeline::create(&device, info,
70/// #     [Shader::new_miss(my_miss_code.as_slice())],
71/// #     [RayTraceShaderGroup::new_general(0)],
72/// # )?;
73/// # let mut my_graph = Graph::default();
74/// my_graph.begin_cmd()
75///         .debug_name("my ray trace command")
76///         .bind_pipeline(&my_ray_trace_pipeline)
77///         .record_cmd(move |cmd| {
78///             // During this closure we have access to the ray trace functions!
79///         });
80/// # Ok(()) }
81/// ```
82pub struct RayTraceCommandRef<'a> {
83    cmd: CommandRef<'a>,
84
85    #[cfg(debug_assertions)]
86    dynamic_stack_size: bool,
87
88    pipeline: RayTracePipeline,
89}
90
91impl RayTraceCommandRef<'_> {
92    /// Updates push constants.
93    ///
94    /// Push constants represent a high speed path to modify constant data in pipelines that is
95    /// expected to outperform memory-backed resource updates.
96    ///
97    /// Push constant values can be updated incrementally, causing shader stages to read the new
98    /// data for push constants modified by this command, while still reading the previous data for
99    /// push constants not modified by this command.
100    ///
101    /// # Device limitations
102    ///
103    /// See
104    /// [`device.physical_device.props.limits.max_push_constants_size`](vk::PhysicalDeviceLimits)
105    /// for the limits of the current device. You may also check [gpuinfo.org] for a listing of
106    /// reported limits on other devices.
107    ///
108    /// # Examples
109    ///
110    /// Basic usage:
111    ///
112    /// ```
113    /// # vk_shader_macros::glsl!(target: vulkan1_2, r#"
114    /// #version 460
115    /// #pragma shader_stage(closest)
116    ///
117    /// layout(push_constant) uniform PushConstants {
118    ///     layout(offset = 0) uint some_val;
119    /// } push_constants;
120    ///
121    /// void main() {
122    ///     // TODO: Add bindings to write things!
123    /// }
124    /// # "#);
125    /// ```
126    ///
127    /// ```no_run
128    /// # use ash::vk;
129    /// # use vk_graph::driver::DriverError;
130    /// # use vk_graph::driver::device::{Device, DeviceInfo};
131    /// # use vk_graph::driver::buffer::{Buffer, BufferInfo};
132    /// # use vk_graph::driver::ray_trace::{
133    /// #     RayTracePipeline,
134    /// #     RayTracePipelineInfo,
135    /// #     RayTraceShaderGroup,
136    /// # };
137    /// # use vk_graph::driver::shader::Shader;
138    /// # use vk_graph::Graph;
139    /// # fn main() -> Result<(), DriverError> {
140    /// # let device = Device::new(DeviceInfo::default())?;
141    /// # let shader = [0u8; 1];
142    /// # let info = RayTracePipelineInfo::default();
143    /// # let my_miss_code = [0u8; 1];
144    /// # let my_ray_trace_pipeline = RayTracePipeline::create(&device, info,
145    /// #     [Shader::new_miss(my_miss_code.as_slice())],
146    /// #     [RayTraceShaderGroup::new_general(0)],
147    /// # )?;
148    /// # let rgen_sbt = vk::StridedDeviceAddressRegionKHR {
149    /// #     device_address: 0,
150    /// #     stride: 0,
151    /// #     size: 0,
152    /// # };
153    /// # let hit_sbt = vk::StridedDeviceAddressRegionKHR {
154    /// #     device_address: 0,
155    /// #     stride: 0,
156    /// #     size: 0,
157    /// # };
158    /// # let miss_sbt = vk::StridedDeviceAddressRegionKHR {
159    /// #     device_address: 0,
160    /// #     stride: 0,
161    /// #     size: 0,
162    /// # };
163    /// # let call_sbt = vk::StridedDeviceAddressRegionKHR {
164    /// #     device_address: 0,
165    /// #     stride: 0,
166    /// #     size: 0,
167    /// # };
168    /// # let mut my_graph = Graph::default();
169    /// my_graph.begin_cmd()
170    ///         .debug_name("draw a cornell box")
171    ///         .bind_pipeline(&my_ray_trace_pipeline)
172    ///         .record_cmd(move |cmd| {
173    ///             cmd.push_constants(0, &[0xcb])
174    ///                    .trace_rays(&rgen_sbt, &hit_sbt, &miss_sbt, &call_sbt, 320, 200, 1);
175    ///         });
176    /// # Ok(()) }
177    /// ```
178    ///
179    /// See [`vkCmdPushConstants`](https://registry.khronos.org/vulkan/specs/latest/man/html/vkCmdPushConstants.html).
180    #[profiling::function]
181    pub fn push_constants(&self, offset: u32, data: &[u8]) -> &Self {
182        self.cmd_push_constants(
183            self.pipeline.inner.layout,
184            &self.pipeline.inner.push_constants,
185            offset,
186            data,
187        );
188
189        self
190    }
191
192    /// Set the stack size dynamically for a ray trace pipeline.
193    ///
194    /// See
195    /// `RayTracePipelineInfo::dynamic_stack_size` and see the Vulkan spec.
196    #[profiling::function]
197    pub fn set_stack_size(&self, pipeline_stack_size: u32) -> &Self {
198        #[cfg(debug_assertions)]
199        assert!(self.dynamic_stack_size);
200
201        let ray_trace_ext = Device::expect_ray_trace_ext(&self.cmd.device);
202
203        unsafe {
204            ray_trace_ext
205                .cmd_set_ray_tracing_pipeline_stack_size(self.cmd.handle, pipeline_stack_size);
206        }
207
208        self
209    }
210
211    // TODO: If the rayTraversalPrimitiveCulling or rayQuery features are enabled, the
212    // SkipTrianglesKHR and SkipAABBsKHR ray flags can be specified when tracing a ray.
213    // SkipTrianglesKHR and SkipAABBsKHR are mutually exclusive.
214
215    /// Ray traces using the currently-bound [`RayTracePipeline`] and the given shader binding
216    /// tables.
217    ///
218    /// Shader binding tables must be constructed according to this [example].
219    ///
220    /// # Examples
221    ///
222    /// Basic usage:
223    ///
224    /// ```no_run
225    /// # use ash::vk;
226    /// # use vk_graph::driver::DriverError;
227    /// # use vk_graph::driver::device::{Device, DeviceInfo};
228    /// # use vk_graph::driver::buffer::{Buffer, BufferInfo};
229    /// # use vk_graph::driver::ray_trace::{
230    /// #     RayTracePipeline,
231    /// #     RayTracePipelineInfo,
232    /// #     RayTraceShaderGroup,
233    /// # };
234    /// # use vk_graph::driver::shader::Shader;
235    /// # use vk_graph::Graph;
236    /// # fn main() -> Result<(), DriverError> {
237    /// # let device = Device::new(DeviceInfo::default())?;
238    /// # let shader = [0u8; 1];
239    /// # let info = RayTracePipelineInfo::default();
240    /// # let my_miss_code = [0u8; 1];
241    /// # let my_ray_trace_pipeline = RayTracePipeline::create(&device, info,
242    /// #     [Shader::new_miss(my_miss_code.as_slice())],
243    /// #     [RayTraceShaderGroup::new_general(0)],
244    /// # )?;
245    /// # let rgen_sbt = vk::StridedDeviceAddressRegionKHR {
246    /// #     device_address: 0,
247    /// #     stride: 0,
248    /// #     size: 0,
249    /// # };
250    /// # let hit_sbt = vk::StridedDeviceAddressRegionKHR {
251    /// #     device_address: 0,
252    /// #     stride: 0,
253    /// #     size: 0,
254    /// # };
255    /// # let miss_sbt = vk::StridedDeviceAddressRegionKHR {
256    /// #     device_address: 0,
257    /// #     stride: 0,
258    /// #     size: 0,
259    /// # };
260    /// # let call_sbt = vk::StridedDeviceAddressRegionKHR {
261    /// #     device_address: 0,
262    /// #     stride: 0,
263    /// #     size: 0,
264    /// # };
265    /// # let mut my_graph = Graph::default();
266    /// my_graph.begin_cmd()
267    ///         .debug_name("draw a cornell box")
268    ///         .bind_pipeline(&my_ray_trace_pipeline)
269    ///         .record_cmd(move |cmd| {
270    ///             cmd.trace_rays(&rgen_sbt, &hit_sbt, &miss_sbt, &call_sbt, 320, 200, 1);
271    ///         });
272    /// # Ok(()) }
273    /// ```
274    ///
275    /// [example]: https://github.com/attackgoat/vk-graph/blob/master/examples/ray_trace.rs
276    #[allow(clippy::too_many_arguments)]
277    #[profiling::function]
278    pub fn trace_rays(
279        &self,
280        raygen_shader_binding_table: &vk::StridedDeviceAddressRegionKHR,
281        miss_shader_binding_table: &vk::StridedDeviceAddressRegionKHR,
282        hit_shader_binding_table: &vk::StridedDeviceAddressRegionKHR,
283        callable_shader_binding_table: &vk::StridedDeviceAddressRegionKHR,
284        width: u32,
285        height: u32,
286        depth: u32,
287    ) -> &Self {
288        let ray_trace_ext = Device::expect_ray_trace_ext(&self.cmd.device);
289
290        unsafe {
291            ray_trace_ext.cmd_trace_rays(
292                self.cmd.handle,
293                raygen_shader_binding_table,
294                miss_shader_binding_table,
295                hit_shader_binding_table,
296                callable_shader_binding_table,
297                width,
298                height,
299                depth,
300            );
301        }
302
303        self
304    }
305
306    /// Ray traces using the currently-bound [`RayTracePipeline`] and the given shader binding
307    /// tables.
308    ///
309    /// `indirect_device_address` is a [buffer device address] which is a pointer to a
310    /// [`vk::TraceRaysIndirectCommandKHR`] structure containing the trace ray parameters.
311    ///
312    /// See [`vkCmdTraceRaysIndirectKHR`](https://registry.khronos.org/vulkan/specs/latest/man/html/vkCmdTraceRaysIndirectKHR.html).
313    ///
314    /// [buffer device address]: crate::driver::buffer::Buffer::device_address
315    #[profiling::function]
316    pub fn trace_rays_indirect(
317        &self,
318        raygen_shader_binding_table: &vk::StridedDeviceAddressRegionKHR,
319        miss_shader_binding_table: &vk::StridedDeviceAddressRegionKHR,
320        hit_shader_binding_table: &vk::StridedDeviceAddressRegionKHR,
321        callable_shader_binding_table: &vk::StridedDeviceAddressRegionKHR,
322        indirect_device_address: vk::DeviceAddress,
323    ) -> &Self {
324        let ray_trace_ext = Device::expect_ray_trace_ext(&self.cmd.device);
325
326        unsafe {
327            ray_trace_ext.cmd_trace_rays_indirect(
328                self.cmd.handle,
329                raygen_shader_binding_table,
330                miss_shader_binding_table,
331                hit_shader_binding_table,
332                callable_shader_binding_table,
333                indirect_device_address,
334            )
335        }
336
337        self
338    }
339}
340
341impl<'a> Deref for RayTraceCommandRef<'a> {
342    type Target = CommandRef<'a>;
343
344    fn deref(&self) -> &Self::Target {
345        &self.cmd
346    }
347}
348
349#[allow(unused)]
350mod deprecated {
351    use {
352        crate::{
353            Node,
354            cmd::{
355                Binding, PipelineCommand, Subresource, SubresourceRange, ViewInfo,
356                ray_trace::RayTraceCommandRef,
357            },
358            driver::ray_trace::RayTracePipeline,
359        },
360        vk_sync::AccessType,
361    };
362
363    impl RayTraceCommandRef<'_> {
364        #[deprecated = "use push_constants function"]
365        #[doc(hidden)]
366        pub fn push_constants_offset(&self, offset: u32, data: &[u8]) -> &Self {
367            self.push_constants(offset, data)
368        }
369    }
370
371    impl PipelineCommand<'_, RayTracePipeline> {
372        #[deprecated = "use shader_resource_access"]
373        #[doc(hidden)]
374        pub fn read_descriptor<N>(self, descriptor: impl Into<Binding>, node: N) -> Self
375        where
376            N: Node + Subresource,
377            N::Info: Copy,
378            SubresourceRange: From<N::Info>,
379            ViewInfo: From<N::Info>,
380        {
381            self.shader_resource_access(
382                descriptor,
383                node,
384                AccessType::RayTracingShaderReadSampledImageOrUniformTexelBuffer,
385            )
386        }
387
388        #[deprecated = "use shader_subresource_access"]
389        #[doc(hidden)]
390        pub fn read_descriptor_as<N>(
391            self,
392            descriptor: impl Into<Binding>,
393            node: N,
394            node_view: impl Into<N::Info>,
395        ) -> Self
396        where
397            N: Node + Subresource,
398            N::Info: Copy,
399            SubresourceRange: From<N::Info>,
400            ViewInfo: From<N::Info>,
401        {
402            self.shader_subresource_access(
403                descriptor,
404                node,
405                node_view,
406                AccessType::RayTracingShaderReadSampledImageOrUniformTexelBuffer,
407            )
408        }
409
410        #[deprecated = "use record_cmd function"]
411        #[doc(hidden)]
412        pub fn record_ray_trace(
413            self,
414            func: impl FnOnce(RayTraceCommandRef<'_>, ()) + Send + 'static,
415        ) -> Self {
416            self.record_cmd(|cmd| {
417                func(cmd, ());
418            })
419        }
420
421        #[deprecated = "use shader_resource_access function with AccessType::AnyShaderWrite"]
422        #[doc(hidden)]
423        pub fn write_descriptor<N>(self, descriptor: impl Into<Binding>, node: N) -> Self
424        where
425            N: Node + Subresource,
426            N::Info: Copy,
427            SubresourceRange: From<N::Info>,
428            ViewInfo: From<N::Info>,
429        {
430            self.shader_resource_access(descriptor, node, AccessType::AnyShaderWrite)
431        }
432
433        #[deprecated = "use shader_subresource_access function with AccessType::AnyShaderWrite"]
434        #[doc(hidden)]
435        pub fn write_descriptor_as<N>(
436            self,
437            descriptor: impl Into<Binding>,
438            node: N,
439            node_view: impl Into<N::Info>,
440        ) -> Self
441        where
442            N: Node + Subresource,
443            N::Info: Copy,
444            SubresourceRange: From<N::Info>,
445            ViewInfo: From<N::Info>,
446        {
447            self.shader_subresource_access(descriptor, node, node_view, AccessType::AnyShaderWrite)
448        }
449    }
450}