Skip to main content

vyre_wgpu/runtime/device/
device.rs

1use std::future::Future;
2use std::pin::Pin;
3use std::sync::atomic::{AtomicUsize, Ordering};
4use std::sync::{Arc, Mutex, OnceLock};
5use std::task::{Context, Poll, Wake, Waker};
6use std::thread::{self, Thread};
7use vyre::error::{Error, Result};
8
9pub(crate) static CACHED_DEVICE_PTR: AtomicUsize = AtomicUsize::new(0);
10
11pub(crate) struct CachedGpu {
12    pub(crate) pair: (wgpu::Device, wgpu::Queue),
13}
14
15/// Return the cached device/queue pair for repeated scans.
16///
17/// # Errors
18///
19/// Returns an error if the GPU adapter or device cannot be initialized.
20#[inline]
21pub fn cached_device() -> Result<&'static (wgpu::Device, wgpu::Queue)> {
22    let cached = cached_gpu()?;
23    CACHED_DEVICE_PTR.store(
24        std::ptr::from_ref::<wgpu::Device>(&cached.pair.0).addr(),
25        Ordering::Release,
26    );
27    Ok(&cached.pair)
28}
29
30/// Return the cached runtime device and its reusable buffer pool.
31#[inline]
32pub(crate) fn cached_gpu() -> Result<&'static CachedGpu> {
33    static GPU: OnceLock<Arc<CachedGpu>> = OnceLock::new();
34    static INIT_LOCK: Mutex<()> = Mutex::new(());
35    if let Some(gpu) = GPU.get() {
36        return Ok(gpu.as_ref());
37    }
38
39    let _guard = INIT_LOCK.lock().map_err(|source| Error::Gpu {
40        message: format!(
41            "cached GPU init mutex poisoned: {source}. Fix: restart the process or avoid panicking while initializing the GPU cache."
42        ),
43    })?;
44    if let Some(gpu) = GPU.get() {
45        return Ok(gpu.as_ref());
46    }
47
48    let pair = init_device()?;
49    let _ = GPU.set(Arc::new(CachedGpu { pair }));
50    let gpu = GPU.get().ok_or_else(|| Error::Gpu {
51        message: "cached GPU initialization failed after successful device creation. Fix: report this vyre runtime bug with the active platform and wgpu backend.".to_string(),
52    })?;
53    Ok(gpu.as_ref())
54}
55
56/// Return true when `device` is the runtime singleton device.
57#[inline]
58pub(crate) fn is_cached_device(device: &wgpu::Device) -> bool {
59    let device_ptr = std::ptr::from_ref(device).addr();
60    CACHED_DEVICE_PTR.load(Ordering::Acquire) == device_ptr
61}
62
63/// Initialize a new GPU device and queue.
64///
65/// # Errors
66///
67/// Returns an actionable GPU error if no compatible adapter is available, if
68/// the selected adapter is CPU-backed, or if device creation fails.
69#[inline]
70pub fn init_device() -> Result<(wgpu::Device, wgpu::Queue)> {
71    wait_for_gpu(acquire_gpu())
72}
73
74/// Asynchronously initialize a new GPU device and queue.
75///
76/// # Errors
77///
78/// Returns an actionable GPU error if no compatible adapter is available, if
79/// the selected adapter is CPU-backed, or if device creation fails.
80#[inline]
81pub async fn acquire_gpu() -> Result<(wgpu::Device, wgpu::Queue)> {
82    let instance = wgpu::Instance::default();
83    let adapter = instance
84        .request_adapter(&wgpu::RequestAdapterOptions::default())
85        .await
86        .ok_or_else(|| Error::Gpu {
87            message: "failed to acquire adapter. Fix: install a compatible GPU driver, expose a wgpu-supported adapter, or run on a host with GPU access.".to_string(),
88        })?;
89    let adapter_info = adapter.get_info();
90    if matches!(
91        adapter_info.device_type,
92        wgpu::DeviceType::Cpu | wgpu::DeviceType::Other
93    ) {
94        return Err(Error::Gpu {
95            message: format!(
96                "adapter '{}' has device type {:?}, which is not a real GPU execution target. Fix: expose a discrete, integrated, or virtual GPU adapter before running vyre.",
97                adapter_info.name, adapter_info.device_type
98            ),
99        });
100    }
101
102    let mut features = wgpu::Features::empty();
103    if adapter.features().contains(wgpu::Features::TIMESTAMP_QUERY) {
104        features |= wgpu::Features::TIMESTAMP_QUERY;
105    }
106
107    let adapter_limits = adapter.limits();
108    adapter.request_device(
109        &wgpu::DeviceDescriptor {
110            label: Some("vyre device"),
111            required_features: features,
112            required_limits: wgpu::Limits {
113                max_storage_buffers_per_shader_stage:
114                    adapter_limits.max_storage_buffers_per_shader_stage,
115                ..wgpu::Limits::default()
116            },
117            memory_hints: wgpu::MemoryHints::default(),
118        },
119        None,
120    )
121    .await
122    .map_err(|error| Error::Gpu {
123        message: format!("failed to acquire device: {error}. Fix: check requested wgpu limits/features against the adapter and update the GPU driver if limits are unexpectedly low."),
124    })
125}
126
127struct ThreadWaker(Thread);
128
129impl Wake for ThreadWaker {
130    fn wake(self: Arc<Self>) {
131        self.0.unpark();
132    }
133
134    fn wake_by_ref(self: &Arc<Self>) {
135        self.0.unpark();
136    }
137}
138
139fn wait_for_gpu<T>(future: impl Future<Output = T>) -> T {
140    let waker = Waker::from(Arc::new(ThreadWaker(thread::current())));
141    let mut context = Context::from_waker(&waker);
142    let mut future = Box::pin(future);
143    loop {
144        match Pin::as_mut(&mut future).poll(&mut context) {
145            Poll::Ready(value) => return value,
146            Poll::Pending => thread::park(),
147        }
148    }
149}