Skip to main content

scry_gpu/
device.rs

1// SPDX-License-Identifier: MIT OR Apache-2.0
2//! Device acquisition and the primary user-facing API.
3
4use crate::backend::{Backend, BackendBuffer, BackendKernel};
5use crate::buffer::{Buffer, GpuBuf};
6use crate::dispatch::{self, DispatchConfig};
7use crate::error::{GpuError, Result};
8use crate::kernel::Kernel;
9use crate::shader;
10
11/// A GPU compute device.
12///
13/// This is the main entry point for scry-gpu. A `Device` wraps a single
14/// GPU and provides methods to upload data, dispatch shaders, and read
15/// results back.
16///
17/// # Example
18///
19/// ```ignore
20/// let gpu = Device::auto()?;
21///
22/// let input = gpu.upload(&[1.0f32, 2.0, 3.0, 4.0])?;
23/// let output = gpu.alloc::<f32>(4)?;
24///
25/// gpu.dispatch(SHADER_SRC, &[&input, &output], 4)?;
26///
27/// let result: Vec<f32> = output.download()?;
28/// ```
29pub struct Device {
30    inner: DeviceInner,
31}
32
33enum DeviceInner {
34    #[cfg(feature = "vulkan")]
35    Vulkan(crate::backend::vulkan::VulkanBackend),
36    #[cfg(feature = "cuda")]
37    Cuda(crate::backend::cuda::CudaBackend),
38}
39
40/// Available backend types.
41#[derive(Debug, Clone, Copy, PartialEq, Eq)]
42pub enum BackendKind {
43    /// Vulkan (Linux, Windows, Android).
44    Vulkan,
45    /// CUDA (NVIDIA GPUs).
46    Cuda,
47    // Metal, // future
48}
49
50impl Device {
51    /// Auto-select the best available GPU.
52    ///
53    /// Tries backends in order of preference: CUDA → Vulkan → (Metal in future).
54    /// CUDA is preferred when available because it enables cuBLAS matmul and
55    /// native CUDA kernel dispatch.
56    pub fn auto() -> Result<Self> {
57        #[cfg(feature = "cuda")]
58        {
59            use crate::backend::cuda::CudaBackend;
60            if let Ok(backend) = CudaBackend::create() {
61                return Ok(Self {
62                    inner: DeviceInner::Cuda(backend),
63                });
64            }
65        }
66
67        #[cfg(feature = "vulkan")]
68        {
69            use crate::backend::vulkan::VulkanBackend;
70            if let Ok(backend) = VulkanBackend::create() {
71                return Ok(Self {
72                    inner: DeviceInner::Vulkan(backend),
73                });
74            }
75        }
76
77        Err(GpuError::NoDevice)
78    }
79
80    /// Create a device with a specific backend.
81    pub fn with_backend(kind: BackendKind) -> Result<Self> {
82        match kind {
83            BackendKind::Vulkan => {
84                #[cfg(feature = "vulkan")]
85                {
86                    use crate::backend::vulkan::VulkanBackend;
87                    let backend = VulkanBackend::create()?;
88                    Ok(Self {
89                        inner: DeviceInner::Vulkan(backend),
90                    })
91                }
92                #[cfg(not(feature = "vulkan"))]
93                {
94                    Err(GpuError::BackendUnavailable(
95                        "vulkan feature not enabled".into(),
96                    ))
97                }
98            }
99            BackendKind::Cuda => {
100                #[cfg(feature = "cuda")]
101                {
102                    use crate::backend::cuda::CudaBackend;
103                    let backend = CudaBackend::create()?;
104                    Ok(Self {
105                        inner: DeviceInner::Cuda(backend),
106                    })
107                }
108                #[cfg(not(feature = "cuda"))]
109                {
110                    Err(GpuError::BackendUnavailable(
111                        "cuda feature not enabled".into(),
112                    ))
113                }
114            }
115        }
116    }
117
118    /// Upload a slice to GPU memory, returning a typed buffer.
119    pub fn upload<T: bytemuck::Pod>(&self, data: &[T]) -> Result<Buffer<T>> {
120        let bytes = bytemuck::cast_slice(data);
121        let inner = self.upload_raw(bytes)?;
122        Ok(Buffer {
123            inner,
124            len: data.len(),
125            _marker: std::marker::PhantomData,
126        })
127    }
128
129    /// Allocate an uninitialized GPU buffer for `count` elements of type `T`.
130    pub fn alloc<T: bytemuck::Pod>(&self, count: usize) -> Result<Buffer<T>> {
131        let size = count.checked_mul(std::mem::size_of::<T>()).ok_or_else(|| {
132            GpuError::AllocationFailed {
133                requested: u64::MAX,
134                device_max: self.memory(),
135            }
136        })? as u64;
137        let inner = self.alloc_raw(size)?;
138        Ok(Buffer {
139            inner,
140            len: count,
141            _marker: std::marker::PhantomData,
142        })
143    }
144
145    /// Dispatch a WGSL compute shader.
146    ///
147    /// Buffers are bound in order to `@binding(0)`, `@binding(1)`, etc.
148    /// Workgroup dispatch dimensions are auto-calculated from `invocations`
149    /// and the shader's `@workgroup_size`.
150    pub fn dispatch(
151        &self,
152        shader_src: &str,
153        buffers: &[&dyn GpuBuf],
154        invocations: u32,
155    ) -> Result<()> {
156        let entry = "main";
157        let compiled = shader::compile_wgsl(shader_src, entry)?;
158
159        let expected = shader::binding_count(&compiled.module);
160        let backend_bufs: Vec<&BackendBuffer> = buffers.iter().map(|b| b.raw()).collect();
161        if expected != backend_bufs.len() {
162            return Err(GpuError::BindingMismatch {
163                expected,
164                got: backend_bufs.len(),
165            });
166        }
167
168        let wg_size = dispatch::extract_workgroup_size(&compiled.module, entry);
169        let workgroups = dispatch::calc_dispatch(invocations, wg_size);
170
171        self.dispatch_spirv(&compiled.spirv, entry, &backend_bufs, workgroups, None)
172    }
173
174    /// Dispatch with full configuration.
175    pub fn dispatch_configured(
176        &self,
177        config: &DispatchConfig<'_>,
178        buffers: &[&dyn GpuBuf],
179    ) -> Result<()> {
180        let entry = config.entry_point.unwrap_or("main");
181        let compiled = shader::compile_wgsl(config.shader, entry)?;
182
183        let expected = shader::binding_count(&compiled.module);
184        let backend_bufs: Vec<&BackendBuffer> = buffers.iter().map(|b| b.raw()).collect();
185        if expected != backend_bufs.len() {
186            return Err(GpuError::BindingMismatch {
187                expected,
188                got: backend_bufs.len(),
189            });
190        }
191
192        let workgroups = config.workgroups.unwrap_or_else(|| {
193            let wg_size = dispatch::extract_workgroup_size(&compiled.module, entry);
194            dispatch::calc_dispatch(config.invocations, wg_size)
195        });
196
197        self.dispatch_spirv(
198            &compiled.spirv,
199            entry,
200            &backend_bufs,
201            workgroups,
202            config.push_constants,
203        )
204    }
205
206    /// Compile a WGSL compute shader into a reusable [`Kernel`].
207    ///
208    /// The returned kernel holds all GPU objects (pipeline, layouts,
209    /// shader module) and can be dispatched many times via [`Device::run`].
210    ///
211    /// Uses `"main"` as the entry point. See [`Device::compile_named`]
212    /// for a custom entry point.
213    pub fn compile(&self, shader_src: &str) -> Result<Kernel> {
214        self.compile_named(shader_src, "main")
215    }
216
217    /// Compile a WGSL shader with a specific entry point name.
218    pub fn compile_named(&self, shader_src: &str, entry_point: &str) -> Result<Kernel> {
219        let compiled = shader::compile_wgsl(shader_src, entry_point)?;
220        let binding_count = shader::binding_count(&compiled.module);
221        let workgroup_size = dispatch::extract_workgroup_size(&compiled.module, entry_point);
222        let push_constant_size = shader::push_constant_size(&compiled.module);
223
224        let inner = self.create_pipeline(
225            &compiled.spirv,
226            entry_point,
227            binding_count,
228            push_constant_size,
229        )?;
230
231        Ok(Kernel {
232            inner,
233            binding_count,
234            workgroup_size,
235            entry_point: entry_point.to_string(),
236        })
237    }
238
239    /// Dispatch a precompiled kernel.
240    ///
241    /// Buffers are bound in order to `@binding(0)`, `@binding(1)`, etc.
242    /// Workgroup dispatch dimensions are auto-calculated from `invocations`
243    /// and the kernel's compiled `@workgroup_size`.
244    pub fn run(&self, kernel: &Kernel, buffers: &[&dyn GpuBuf], invocations: u32) -> Result<()> {
245        let backend_bufs: Vec<&BackendBuffer> = buffers.iter().map(|b| b.raw()).collect();
246        if kernel.binding_count != backend_bufs.len() {
247            return Err(GpuError::BindingMismatch {
248                expected: kernel.binding_count,
249                got: backend_bufs.len(),
250            });
251        }
252
253        let workgroups = dispatch::calc_dispatch(invocations, kernel.workgroup_size);
254        self.run_pipeline(kernel, &backend_bufs, workgroups, None)
255    }
256
257    /// Dispatch a precompiled kernel with push constants.
258    pub fn run_with_push_constants(
259        &self,
260        kernel: &Kernel,
261        buffers: &[&dyn GpuBuf],
262        invocations: u32,
263        push_constants: &[u8],
264    ) -> Result<()> {
265        let backend_bufs: Vec<&BackendBuffer> = buffers.iter().map(|b| b.raw()).collect();
266        if kernel.binding_count != backend_bufs.len() {
267            return Err(GpuError::BindingMismatch {
268                expected: kernel.binding_count,
269                got: backend_bufs.len(),
270            });
271        }
272
273        let workgroups = dispatch::calc_dispatch(invocations, kernel.workgroup_size);
274        self.run_pipeline(kernel, &backend_bufs, workgroups, Some(push_constants))
275    }
276
277    /// Dispatch a precompiled kernel with explicit workgroup dimensions.
278    ///
279    /// Use this for 2D/3D dispatches or when you need precise control over
280    /// workgroup counts. For simple 1D dispatches, prefer [`Device::run`].
281    pub fn run_configured(
282        &self,
283        kernel: &Kernel,
284        buffers: &[&dyn GpuBuf],
285        workgroups: [u32; 3],
286        push_constants: Option<&[u8]>,
287    ) -> Result<()> {
288        let backend_bufs: Vec<&BackendBuffer> = buffers.iter().map(|b| b.raw()).collect();
289        if kernel.binding_count != backend_bufs.len() {
290            return Err(GpuError::BindingMismatch {
291                expected: kernel.binding_count,
292                got: backend_bufs.len(),
293            });
294        }
295
296        self.run_pipeline(kernel, &backend_bufs, workgroups, push_constants)
297    }
298
299    /// Create a GPU-to-GPU copy of a buffer.
300    ///
301    /// Allocates a new buffer on the same device and copies the contents
302    /// of `src` into it. The copy is synchronous (blocks until complete).
303    pub fn copy_buffer<T: bytemuck::Pod>(&self, src: &Buffer<T>) -> Result<Buffer<T>> {
304        let size = src.byte_size();
305        let inner = self.copy_buffer_raw(&src.inner, size)?;
306        Ok(Buffer {
307            inner,
308            len: src.len,
309            _marker: std::marker::PhantomData,
310        })
311    }
312
313    /// Begin a batched dispatch session.
314    ///
315    /// Records multiple dispatches into a single command buffer, submitted
316    /// with one fence wait via [`Batch::submit`].
317    pub fn batch(&self) -> Result<crate::batch::Batch> {
318        match &self.inner {
319            #[cfg(feature = "vulkan")]
320            DeviceInner::Vulkan(b) => {
321                let vk_batch = b.begin_batch()?;
322                Ok(crate::batch::Batch::new_vulkan(vk_batch))
323            }
324            #[cfg(feature = "cuda")]
325            DeviceInner::Cuda(b) => {
326                let cuda_batch = b.begin_batch()?;
327                Ok(crate::batch::Batch::new_cuda(cuda_batch))
328            }
329        }
330    }
331
332    /// Device name (for diagnostics / logging).
333    pub fn name(&self) -> &str {
334        match &self.inner {
335            #[cfg(feature = "vulkan")]
336            DeviceInner::Vulkan(b) => b.device_name(),
337            #[cfg(feature = "cuda")]
338            DeviceInner::Cuda(b) => b.device_name(),
339        }
340    }
341
342    /// Total device memory in bytes.
343    pub fn memory(&self) -> u64 {
344        match &self.inner {
345            #[cfg(feature = "vulkan")]
346            DeviceInner::Vulkan(b) => b.device_memory(),
347            #[cfg(feature = "cuda")]
348            DeviceInner::Cuda(b) => b.device_memory(),
349        }
350    }
351
352    /// Subgroup (warp/wavefront) size.
353    ///
354    /// Typically 32 on NVIDIA, 64 on AMD, 32 on Intel.
355    /// Useful for sizing subgroup-aware shaders.
356    pub fn subgroup_size(&self) -> u32 {
357        match &self.inner {
358            #[cfg(feature = "vulkan")]
359            DeviceInner::Vulkan(b) => b.subgroup_size(),
360            #[cfg(feature = "cuda")]
361            DeviceInner::Cuda(b) => b.subgroup_size(),
362        }
363    }
364
365    /// Which backend this device is using.
366    pub const fn backend_kind(&self) -> BackendKind {
367        match &self.inner {
368            #[cfg(feature = "vulkan")]
369            DeviceInner::Vulkan(_) => BackendKind::Vulkan,
370            #[cfg(feature = "cuda")]
371            DeviceInner::Cuda(_) => BackendKind::Cuda,
372        }
373    }
374
375    // ── private helpers ──
376
377    fn upload_raw(&self, data: &[u8]) -> Result<BackendBuffer> {
378        match &self.inner {
379            #[cfg(feature = "vulkan")]
380            DeviceInner::Vulkan(b) => {
381                let buf = b.upload(data)?;
382                Ok(BackendBuffer::Vulkan(buf))
383            }
384            #[cfg(feature = "cuda")]
385            DeviceInner::Cuda(b) => {
386                let buf = b.upload(data)?;
387                Ok(BackendBuffer::Cuda(buf))
388            }
389        }
390    }
391
392    fn copy_buffer_raw(&self, src: &BackendBuffer, size: u64) -> Result<BackendBuffer> {
393        match &self.inner {
394            #[cfg(feature = "vulkan")]
395            DeviceInner::Vulkan(b) => {
396                #[allow(irrefutable_let_patterns)]
397                let BackendBuffer::Vulkan(vk_src) = src
398                else {
399                    return Err(GpuError::BackendUnavailable(
400                        "buffer/backend mismatch: expected Vulkan buffer".into(),
401                    ));
402                };
403                let buf = b.copy_buffer(vk_src, size)?;
404                Ok(BackendBuffer::Vulkan(buf))
405            }
406            #[cfg(feature = "cuda")]
407            DeviceInner::Cuda(b) => {
408                #[allow(irrefutable_let_patterns)]
409                let BackendBuffer::Cuda(cuda_src) = src
410                else {
411                    return Err(GpuError::BackendUnavailable(
412                        "buffer/backend mismatch: expected CUDA buffer".into(),
413                    ));
414                };
415                let buf = b.copy_buffer(cuda_src, size)?;
416                Ok(BackendBuffer::Cuda(buf))
417            }
418        }
419    }
420
421    fn alloc_raw(&self, size: u64) -> Result<BackendBuffer> {
422        match &self.inner {
423            #[cfg(feature = "vulkan")]
424            DeviceInner::Vulkan(b) => {
425                let buf = b.alloc(size)?;
426                Ok(BackendBuffer::Vulkan(buf))
427            }
428            #[cfg(feature = "cuda")]
429            DeviceInner::Cuda(b) => {
430                let buf = b.alloc(size)?;
431                Ok(BackendBuffer::Cuda(buf))
432            }
433        }
434    }
435
436    fn dispatch_spirv(
437        &self,
438        spirv: &[u32],
439        entry_point: &str,
440        buffers: &[&BackendBuffer],
441        workgroups: [u32; 3],
442        push_constants: Option<&[u8]>,
443    ) -> Result<()> {
444        match &self.inner {
445            #[cfg(feature = "vulkan")]
446            DeviceInner::Vulkan(b) => {
447                let vk_bufs: Vec<&crate::backend::vulkan::VulkanBuffer> = buffers
448                    .iter()
449                    .map(|buf| match buf {
450                        BackendBuffer::Vulkan(vb) => Ok(vb),
451                        #[cfg(feature = "cuda")]
452                        _ => Err(GpuError::BackendUnavailable(
453                            "buffer/backend mismatch: expected Vulkan buffer".into(),
454                        )),
455                    })
456                    .collect::<Result<Vec<_>>>()?;
457                b.dispatch(spirv, entry_point, &vk_bufs, workgroups, push_constants)
458            }
459            #[cfg(feature = "cuda")]
460            DeviceInner::Cuda(b) => {
461                let cuda_bufs: Vec<&crate::backend::cuda::CudaBuffer> = buffers
462                    .iter()
463                    .map(|buf| match buf {
464                        BackendBuffer::Cuda(cb) => Ok(cb),
465                        #[cfg(feature = "vulkan")]
466                        _ => Err(GpuError::BackendUnavailable(
467                            "buffer/backend mismatch: expected CUDA buffer".into(),
468                        )),
469                    })
470                    .collect::<Result<Vec<_>>>()?;
471                b.dispatch(spirv, entry_point, &cuda_bufs, workgroups, push_constants)
472            }
473        }
474    }
475
476    fn create_pipeline(
477        &self,
478        spirv: &[u32],
479        entry_point: &str,
480        binding_count: usize,
481        push_constant_size: u32,
482    ) -> Result<BackendKernel> {
483        match &self.inner {
484            #[cfg(feature = "vulkan")]
485            DeviceInner::Vulkan(b) => {
486                let kernel =
487                    b.create_pipeline(spirv, entry_point, binding_count, push_constant_size)?;
488                Ok(BackendKernel::Vulkan(kernel))
489            }
490            #[cfg(feature = "cuda")]
491            DeviceInner::Cuda(b) => {
492                let kernel =
493                    b.create_pipeline(spirv, entry_point, binding_count, push_constant_size)?;
494                Ok(BackendKernel::Cuda(kernel))
495            }
496        }
497    }
498
499    fn run_pipeline(
500        &self,
501        kernel: &Kernel,
502        buffers: &[&BackendBuffer],
503        workgroups: [u32; 3],
504        push_constants: Option<&[u8]>,
505    ) -> Result<()> {
506        match &self.inner {
507            #[cfg(feature = "vulkan")]
508            DeviceInner::Vulkan(b) => {
509                #[allow(irrefutable_let_patterns)]
510                let BackendKernel::Vulkan(vk_kernel) = &kernel.inner
511                else {
512                    return Err(GpuError::BackendUnavailable(
513                        "kernel was not compiled for Vulkan".into(),
514                    ));
515                };
516                let vk_bufs: Vec<&crate::backend::vulkan::VulkanBuffer> = buffers
517                    .iter()
518                    .map(|buf| match buf {
519                        BackendBuffer::Vulkan(vb) => Ok(vb),
520                        #[cfg(feature = "cuda")]
521                        _ => Err(GpuError::BackendUnavailable(
522                            "buffer/backend mismatch: expected Vulkan buffer".into(),
523                        )),
524                    })
525                    .collect::<Result<Vec<_>>>()?;
526                b.dispatch_pipeline(vk_kernel, &vk_bufs, workgroups, push_constants)
527            }
528            #[cfg(feature = "cuda")]
529            DeviceInner::Cuda(b) => {
530                let BackendKernel::Cuda(cuda_kernel) = &kernel.inner else {
531                    return Err(GpuError::BackendUnavailable(
532                        "kernel was not compiled for CUDA".into(),
533                    ));
534                };
535                let cuda_bufs: Vec<&crate::backend::cuda::CudaBuffer> = buffers
536                    .iter()
537                    .map(|buf| match buf {
538                        BackendBuffer::Cuda(cb) => Ok(cb),
539                        #[cfg(feature = "vulkan")]
540                        _ => Err(GpuError::BackendUnavailable(
541                            "buffer/backend mismatch: expected CUDA buffer".into(),
542                        )),
543                    })
544                    .collect::<Result<Vec<_>>>()?;
545                b.dispatch_pipeline(cuda_kernel, &cuda_bufs, workgroups, push_constants)
546            }
547        }
548    }
549}
550
551// ── CUDA-specific methods ──
552
553#[cfg(feature = "cuda")]
554impl Device {
555    /// Compile a CUDA C kernel source into a reusable [`Kernel`].
556    ///
557    /// Only available on the CUDA backend. Uses NVRTC for compilation.
558    ///
559    /// Unlike [`Device::compile`] (which uses WGSL→SPIR-V), this accepts
560    /// native CUDA C source. Because CUDA kernels don't embed metadata
561    /// like WGSL's `@workgroup_size` and `@binding`, you must provide
562    /// `binding_count` and `workgroup_size` explicitly.
563    ///
564    /// # Errors
565    ///
566    /// Returns [`GpuError::BackendUnavailable`] if the device is not using
567    /// the CUDA backend.
568    pub fn compile_cuda(
569        &self,
570        source: &str,
571        entry_point: &str,
572        binding_count: usize,
573        workgroup_size: [u32; 3],
574    ) -> Result<Kernel> {
575        match &self.inner {
576            DeviceInner::Cuda(b) => {
577                let block_dim = (workgroup_size[0], workgroup_size[1], workgroup_size[2]);
578                let cuda_kernel = b.compile_cuda(source, entry_point, block_dim)?;
579                Ok(Kernel {
580                    inner: BackendKernel::Cuda(cuda_kernel),
581                    binding_count,
582                    workgroup_size,
583                    entry_point: entry_point.to_string(),
584                })
585            }
586            #[cfg(feature = "vulkan")]
587            _ => Err(GpuError::BackendUnavailable(
588                "compile_cuda requires CUDA backend".into(),
589            )),
590        }
591    }
592
593    /// Run cuBLAS SGEMM: `C = A × B` (row-major `f32` matrices).
594    ///
595    /// Dimensions: A is `m×k`, B is `k×n`, C is `m×n`.
596    ///
597    /// This is the recommended matmul path on CUDA — it reaches 80%+ peak
598    /// throughput without any custom kernels.
599    #[allow(clippy::many_single_char_names)]
600    pub fn cublas_matmul(
601        &self,
602        a: &Buffer<f32>,
603        b: &Buffer<f32>,
604        c: &mut Buffer<f32>,
605        m: u32,
606        n: u32,
607        k: u32,
608    ) -> Result<()> {
609        match &self.inner {
610            DeviceInner::Cuda(backend) => {
611                let BackendBuffer::Cuda(a_buf) = &a.inner else {
612                    return Err(GpuError::BackendUnavailable(
613                        "buffer not from CUDA backend".into(),
614                    ));
615                };
616                let BackendBuffer::Cuda(b_buf) = &b.inner else {
617                    return Err(GpuError::BackendUnavailable(
618                        "buffer not from CUDA backend".into(),
619                    ));
620                };
621                let BackendBuffer::Cuda(c_buf) = &mut c.inner else {
622                    return Err(GpuError::BackendUnavailable(
623                        "buffer not from CUDA backend".into(),
624                    ));
625                };
626                backend.cublas_matmul(a_buf, b_buf, c_buf, m, n, k)
627            }
628            #[cfg(feature = "vulkan")]
629            _ => Err(GpuError::BackendUnavailable(
630                "cublas_matmul requires CUDA backend".into(),
631            )),
632        }
633    }
634}
635
636impl std::fmt::Debug for Device {
637    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
638        f.debug_struct("Device")
639            .field("name", &self.name())
640            .field("memory_mb", &(self.memory() / (1024 * 1024)))
641            .finish()
642    }
643}