Skip to main content

scry_gpu/
batch.rs

1// SPDX-License-Identifier: MIT OR Apache-2.0
2//! Batched dispatch — multiple dispatches in a single GPU submission.
3//!
4//! A [`Batch`] records multiple kernel dispatches into one command buffer,
5//! then submits them all with a single fence wait. This eliminates the
6//! per-dispatch synchronization overhead that dominates bandwidth-bound
7//! workloads.
8//!
9//! # Example
10//!
11//! ```ignore
12//! let mut batch = gpu.batch()?;
13//! batch.run(&kernel, &[&input, &pass1], n)?;
14//! batch.barrier();  // ensure pass1 finishes before pass2 reads it
15//! batch.run(&kernel, &[&pass1, &pass2], pass1_n)?;
16//! batch.submit()?;
17//! ```
18
19use crate::backend::BackendBuffer;
20use crate::buffer::GpuBuf;
21use crate::dispatch;
22use crate::error::{GpuError, Result};
23use crate::kernel::Kernel;
24use crate::ticket::Ticket;
25
26/// A batch of dispatches recorded into a single command buffer.
27///
28/// Created via [`Device::batch`](crate::Device::batch).
29/// Use [`barrier`](Batch::barrier) between dispatches that have data
30/// dependencies (where one dispatch reads from another's output).
31pub struct Batch {
32    inner: BatchInner,
33}
34
35enum BatchInner {
36    #[cfg(feature = "vulkan")]
37    Vulkan(crate::backend::vulkan::VulkanBatch),
38    #[cfg(feature = "cuda")]
39    Cuda(crate::backend::cuda::CudaBatch),
40}
41
42impl Batch {
43    #[cfg(feature = "vulkan")]
44    pub(crate) const fn new_vulkan(vk_batch: crate::backend::vulkan::VulkanBatch) -> Self {
45        Self {
46            inner: BatchInner::Vulkan(vk_batch),
47        }
48    }
49
50    #[cfg(feature = "cuda")]
51    pub(crate) const fn new_cuda(cuda_batch: crate::backend::cuda::CudaBatch) -> Self {
52        Self {
53            inner: BatchInner::Cuda(cuda_batch),
54        }
55    }
56
57    /// Record a kernel dispatch with auto-calculated workgroups.
58    pub fn run(
59        &mut self,
60        kernel: &Kernel,
61        buffers: &[&dyn GpuBuf],
62        invocations: u32,
63    ) -> Result<&mut Self> {
64        let workgroups = dispatch::calc_dispatch(invocations, kernel.workgroup_size);
65        self.run_configured(kernel, buffers, workgroups, None)
66    }
67
68    /// Record a kernel dispatch with push constants.
69    pub fn run_with_push_constants(
70        &mut self,
71        kernel: &Kernel,
72        buffers: &[&dyn GpuBuf],
73        invocations: u32,
74        push_constants: &[u8],
75    ) -> Result<&mut Self> {
76        let workgroups = dispatch::calc_dispatch(invocations, kernel.workgroup_size);
77        self.run_configured(kernel, buffers, workgroups, Some(push_constants))
78    }
79
80    /// Record a kernel dispatch with explicit workgroups and optional push constants.
81    pub fn run_configured(
82        &mut self,
83        kernel: &Kernel,
84        buffers: &[&dyn GpuBuf],
85        workgroups: [u32; 3],
86        push_constants: Option<&[u8]>,
87    ) -> Result<&mut Self> {
88        let backend_bufs: Vec<&BackendBuffer> = buffers.iter().map(|b| b.raw()).collect();
89        if kernel.binding_count != backend_bufs.len() {
90            return Err(GpuError::BindingMismatch {
91                expected: kernel.binding_count,
92                got: backend_bufs.len(),
93            });
94        }
95
96        match &mut self.inner {
97            #[cfg(feature = "vulkan")]
98            BatchInner::Vulkan(vk_batch) => {
99                #[allow(irrefutable_let_patterns)]
100                let crate::backend::BackendKernel::Vulkan(vk_kernel) = &kernel.inner
101                else {
102                    return Err(GpuError::BackendUnavailable(
103                        "kernel was not compiled for Vulkan".into(),
104                    ));
105                };
106                let vk_bufs: Vec<&crate::backend::vulkan::VulkanBuffer> = backend_bufs
107                    .iter()
108                    .map(|buf| match buf {
109                        BackendBuffer::Vulkan(vb) => Ok(vb),
110                        #[cfg(feature = "cuda")]
111                        _ => Err(GpuError::BackendUnavailable(
112                            "buffer/backend mismatch: expected Vulkan buffer".into(),
113                        )),
114                    })
115                    .collect::<Result<Vec<_>>>()?;
116                vk_batch.record_dispatch(vk_kernel, &vk_bufs, workgroups, push_constants)?;
117            }
118            #[cfg(feature = "cuda")]
119            BatchInner::Cuda(cuda_batch) => {
120                let crate::backend::BackendKernel::Cuda(cuda_kernel) = &kernel.inner else {
121                    return Err(GpuError::BackendUnavailable(
122                        "kernel was not compiled for CUDA".into(),
123                    ));
124                };
125                let cuda_bufs: Vec<&crate::backend::cuda::CudaBuffer> = backend_bufs
126                    .iter()
127                    .map(|buf| match buf {
128                        BackendBuffer::Cuda(cb) => Ok(cb),
129                        #[cfg(feature = "vulkan")]
130                        _ => Err(GpuError::BackendUnavailable(
131                            "buffer/backend mismatch: expected CUDA buffer".into(),
132                        )),
133                    })
134                    .collect::<Result<Vec<_>>>()?;
135                cuda_batch.record_dispatch(cuda_kernel, &cuda_bufs, workgroups, push_constants)?;
136            }
137        }
138
139        Ok(self)
140    }
141
142    /// Insert a compute-to-compute barrier.
143    ///
144    /// Use this between dispatches where a later dispatch reads from an
145    /// earlier dispatch's output buffer. Without a barrier, the GPU may
146    /// execute dispatches out of order or overlap writes with reads.
147    pub fn barrier(&mut self) -> &mut Self {
148        match &mut self.inner {
149            #[cfg(feature = "vulkan")]
150            BatchInner::Vulkan(vk_batch) => vk_batch.record_barrier(),
151            #[cfg(feature = "cuda")]
152            BatchInner::Cuda(cuda_batch) => cuda_batch.record_barrier(),
153        }
154        self
155    }
156
157    /// Submit all recorded dispatches and wait for completion.
158    ///
159    /// All dispatches execute in a single command buffer with one fence wait,
160    /// eliminating per-dispatch synchronization overhead.
161    ///
162    /// Equivalent to `self.submit_async()?.wait()`.
163    pub fn submit(self) -> Result<()> {
164        self.submit_async()?.wait()
165    }
166
167    /// Submit all recorded dispatches and return a [`Ticket`] for
168    /// non-blocking completion tracking.
169    ///
170    /// The GPU work is queued immediately. Use [`Ticket::wait`] to block
171    /// until completion, or [`Ticket::is_ready`] to poll.
172    ///
173    /// # Example
174    ///
175    /// ```ignore
176    /// let mut batch = gpu.batch()?;
177    /// batch.run(&kernel, &[&input, &output], n)?;
178    /// let ticket = batch.submit_async()?;
179    ///
180    /// // ... CPU work while GPU runs ...
181    ///
182    /// ticket.wait()?;
183    /// let result: Vec<f32> = output.download()?;
184    /// ```
185    pub fn submit_async(self) -> Result<Ticket> {
186        match self.inner {
187            #[cfg(feature = "vulkan")]
188            BatchInner::Vulkan(vk_batch) => Ok(Ticket::new_vulkan(vk_batch.submit_async()?)),
189            #[cfg(feature = "cuda")]
190            BatchInner::Cuda(cuda_batch) => Ok(Ticket::new_cuda(cuda_batch.submit_async()?)),
191        }
192    }
193}