1use crate::backend::BackendBuffer;
20use crate::buffer::GpuBuf;
21use crate::dispatch;
22use crate::error::{GpuError, Result};
23use crate::kernel::Kernel;
24use crate::ticket::Ticket;
25
26pub struct Batch {
32 inner: BatchInner,
33}
34
35enum BatchInner {
36 #[cfg(feature = "vulkan")]
37 Vulkan(crate::backend::vulkan::VulkanBatch),
38 #[cfg(feature = "cuda")]
39 Cuda(crate::backend::cuda::CudaBatch),
40}
41
42impl Batch {
43 #[cfg(feature = "vulkan")]
44 pub(crate) const fn new_vulkan(vk_batch: crate::backend::vulkan::VulkanBatch) -> Self {
45 Self {
46 inner: BatchInner::Vulkan(vk_batch),
47 }
48 }
49
50 #[cfg(feature = "cuda")]
51 pub(crate) const fn new_cuda(cuda_batch: crate::backend::cuda::CudaBatch) -> Self {
52 Self {
53 inner: BatchInner::Cuda(cuda_batch),
54 }
55 }
56
57 pub fn run(
59 &mut self,
60 kernel: &Kernel,
61 buffers: &[&dyn GpuBuf],
62 invocations: u32,
63 ) -> Result<&mut Self> {
64 let workgroups = dispatch::calc_dispatch(invocations, kernel.workgroup_size);
65 self.run_configured(kernel, buffers, workgroups, None)
66 }
67
68 pub fn run_with_push_constants(
70 &mut self,
71 kernel: &Kernel,
72 buffers: &[&dyn GpuBuf],
73 invocations: u32,
74 push_constants: &[u8],
75 ) -> Result<&mut Self> {
76 let workgroups = dispatch::calc_dispatch(invocations, kernel.workgroup_size);
77 self.run_configured(kernel, buffers, workgroups, Some(push_constants))
78 }
79
80 pub fn run_configured(
82 &mut self,
83 kernel: &Kernel,
84 buffers: &[&dyn GpuBuf],
85 workgroups: [u32; 3],
86 push_constants: Option<&[u8]>,
87 ) -> Result<&mut Self> {
88 let backend_bufs: Vec<&BackendBuffer> = buffers.iter().map(|b| b.raw()).collect();
89 if kernel.binding_count != backend_bufs.len() {
90 return Err(GpuError::BindingMismatch {
91 expected: kernel.binding_count,
92 got: backend_bufs.len(),
93 });
94 }
95
96 match &mut self.inner {
97 #[cfg(feature = "vulkan")]
98 BatchInner::Vulkan(vk_batch) => {
99 #[allow(irrefutable_let_patterns)]
100 let crate::backend::BackendKernel::Vulkan(vk_kernel) = &kernel.inner
101 else {
102 return Err(GpuError::BackendUnavailable(
103 "kernel was not compiled for Vulkan".into(),
104 ));
105 };
106 let vk_bufs: Vec<&crate::backend::vulkan::VulkanBuffer> = backend_bufs
107 .iter()
108 .map(|buf| match buf {
109 BackendBuffer::Vulkan(vb) => Ok(vb),
110 #[cfg(feature = "cuda")]
111 _ => Err(GpuError::BackendUnavailable(
112 "buffer/backend mismatch: expected Vulkan buffer".into(),
113 )),
114 })
115 .collect::<Result<Vec<_>>>()?;
116 vk_batch.record_dispatch(vk_kernel, &vk_bufs, workgroups, push_constants)?;
117 }
118 #[cfg(feature = "cuda")]
119 BatchInner::Cuda(cuda_batch) => {
120 let crate::backend::BackendKernel::Cuda(cuda_kernel) = &kernel.inner else {
121 return Err(GpuError::BackendUnavailable(
122 "kernel was not compiled for CUDA".into(),
123 ));
124 };
125 let cuda_bufs: Vec<&crate::backend::cuda::CudaBuffer> = backend_bufs
126 .iter()
127 .map(|buf| match buf {
128 BackendBuffer::Cuda(cb) => Ok(cb),
129 #[cfg(feature = "vulkan")]
130 _ => Err(GpuError::BackendUnavailable(
131 "buffer/backend mismatch: expected CUDA buffer".into(),
132 )),
133 })
134 .collect::<Result<Vec<_>>>()?;
135 cuda_batch.record_dispatch(cuda_kernel, &cuda_bufs, workgroups, push_constants)?;
136 }
137 }
138
139 Ok(self)
140 }
141
142 pub fn barrier(&mut self) -> &mut Self {
148 match &mut self.inner {
149 #[cfg(feature = "vulkan")]
150 BatchInner::Vulkan(vk_batch) => vk_batch.record_barrier(),
151 #[cfg(feature = "cuda")]
152 BatchInner::Cuda(cuda_batch) => cuda_batch.record_barrier(),
153 }
154 self
155 }
156
157 pub fn submit(self) -> Result<()> {
164 self.submit_async()?.wait()
165 }
166
167 pub fn submit_async(self) -> Result<Ticket> {
186 match self.inner {
187 #[cfg(feature = "vulkan")]
188 BatchInner::Vulkan(vk_batch) => Ok(Ticket::new_vulkan(vk_batch.submit_async()?)),
189 #[cfg(feature = "cuda")]
190 BatchInner::Cuda(cuda_batch) => Ok(Ticket::new_cuda(cuda_batch.submit_async()?)),
191 }
192 }
193}