vyre_wgpu/runtime/device/
device.rs1use std::future::Future;
2use std::pin::Pin;
3use std::sync::atomic::{AtomicUsize, Ordering};
4use std::sync::{Arc, Mutex, OnceLock};
5use std::task::{Context, Poll, Wake, Waker};
6use std::thread::{self, Thread};
7use vyre::error::{Error, Result};
8
9pub(crate) static CACHED_DEVICE_PTR: AtomicUsize = AtomicUsize::new(0);
10
11pub(crate) struct CachedGpu {
12 pub(crate) pair: (wgpu::Device, wgpu::Queue),
13}
14
15#[inline]
21pub fn cached_device() -> Result<&'static (wgpu::Device, wgpu::Queue)> {
22 let cached = cached_gpu()?;
23 CACHED_DEVICE_PTR.store(
24 std::ptr::from_ref::<wgpu::Device>(&cached.pair.0).addr(),
25 Ordering::Release,
26 );
27 Ok(&cached.pair)
28}
29
30#[inline]
32pub(crate) fn cached_gpu() -> Result<&'static CachedGpu> {
33 static GPU: OnceLock<Arc<CachedGpu>> = OnceLock::new();
34 static INIT_LOCK: Mutex<()> = Mutex::new(());
35 if let Some(gpu) = GPU.get() {
36 return Ok(gpu.as_ref());
37 }
38
39 let _guard = INIT_LOCK.lock().map_err(|source| Error::Gpu {
40 message: format!(
41 "cached GPU init mutex poisoned: {source}. Fix: restart the process or avoid panicking while initializing the GPU cache."
42 ),
43 })?;
44 if let Some(gpu) = GPU.get() {
45 return Ok(gpu.as_ref());
46 }
47
48 let pair = init_device()?;
49 let _ = GPU.set(Arc::new(CachedGpu { pair }));
50 let gpu = GPU.get().ok_or_else(|| Error::Gpu {
51 message: "cached GPU initialization failed after successful device creation. Fix: report this vyre runtime bug with the active platform and wgpu backend.".to_string(),
52 })?;
53 Ok(gpu.as_ref())
54}
55
56#[inline]
58pub(crate) fn is_cached_device(device: &wgpu::Device) -> bool {
59 let device_ptr = std::ptr::from_ref(device).addr();
60 CACHED_DEVICE_PTR.load(Ordering::Acquire) == device_ptr
61}
62
63#[inline]
70pub fn init_device() -> Result<(wgpu::Device, wgpu::Queue)> {
71 wait_for_gpu(acquire_gpu())
72}
73
74#[inline]
81pub async fn acquire_gpu() -> Result<(wgpu::Device, wgpu::Queue)> {
82 let instance = wgpu::Instance::default();
83 let adapter = instance
84 .request_adapter(&wgpu::RequestAdapterOptions::default())
85 .await
86 .ok_or_else(|| Error::Gpu {
87 message: "failed to acquire adapter. Fix: install a compatible GPU driver, expose a wgpu-supported adapter, or run on a host with GPU access.".to_string(),
88 })?;
89 let adapter_info = adapter.get_info();
90 if matches!(
91 adapter_info.device_type,
92 wgpu::DeviceType::Cpu | wgpu::DeviceType::Other
93 ) {
94 return Err(Error::Gpu {
95 message: format!(
96 "adapter '{}' has device type {:?}, which is not a real GPU execution target. Fix: expose a discrete, integrated, or virtual GPU adapter before running vyre.",
97 adapter_info.name, adapter_info.device_type
98 ),
99 });
100 }
101
102 let mut features = wgpu::Features::empty();
103 if adapter.features().contains(wgpu::Features::TIMESTAMP_QUERY) {
104 features |= wgpu::Features::TIMESTAMP_QUERY;
105 }
106
107 let adapter_limits = adapter.limits();
108 adapter.request_device(
109 &wgpu::DeviceDescriptor {
110 label: Some("vyre device"),
111 required_features: features,
112 required_limits: wgpu::Limits {
113 max_storage_buffers_per_shader_stage:
114 adapter_limits.max_storage_buffers_per_shader_stage,
115 ..wgpu::Limits::default()
116 },
117 memory_hints: wgpu::MemoryHints::default(),
118 },
119 None,
120 )
121 .await
122 .map_err(|error| Error::Gpu {
123 message: format!("failed to acquire device: {error}. Fix: check requested wgpu limits/features against the adapter and update the GPU driver if limits are unexpectedly low."),
124 })
125}
126
127struct ThreadWaker(Thread);
128
129impl Wake for ThreadWaker {
130 fn wake(self: Arc<Self>) {
131 self.0.unpark();
132 }
133
134 fn wake_by_ref(self: &Arc<Self>) {
135 self.0.unpark();
136 }
137}
138
139fn wait_for_gpu<T>(future: impl Future<Output = T>) -> T {
140 let waker = Waker::from(Arc::new(ThreadWaker(thread::current())));
141 let mut context = Context::from_waker(&waker);
142 let mut future = Box::pin(future);
143 loop {
144 match Pin::as_mut(&mut future).poll(&mut context) {
145 Poll::Ready(value) => return value,
146 Poll::Pending => thread::park(),
147 }
148 }
149}