wgpu_async/
wgpu_future.rs1use std::future::Future;
2use std::ops::DerefMut;
3use std::pin::Pin;
4use std::sync::{Arc, Mutex};
5use std::task::{Context, Poll, Waker};
6use wgpu::Maintain;
7
8#[cfg(not(target_arch = "wasm32"))]
9use std::sync::{
10 atomic::{AtomicBool, AtomicUsize, Ordering},
11 Weak,
12};
13
14use crate::AsyncDevice;
15
16#[cfg(not(target_arch = "wasm32"))]
25#[derive(Debug)]
26pub(crate) struct PollLoop {
27 has_work: Arc<AtomicUsize>,
30 is_done: Arc<AtomicBool>,
31 handle: Option<std::thread::JoinHandle<()>>,
32}
33
34#[cfg(not(target_arch = "wasm32"))]
35impl PollLoop {
36 pub(crate) fn new(device: Weak<wgpu::Device>) -> Self {
37 let has_work = Arc::new(AtomicUsize::new(0));
38 let is_done = Arc::new(AtomicBool::new(false));
39 let locally_has_work = Arc::clone(&has_work);
40 let locally_is_done = Arc::clone(&is_done);
41 Self {
42 has_work,
43 is_done,
44 handle: Some(std::thread::spawn(move || {
45 while !locally_is_done.load(Ordering::Acquire) {
46 while locally_has_work.load(Ordering::Acquire) != 0 {
47 match device.upgrade() {
48 None => {
49 locally_is_done.store(true, Ordering::Release);
51 return;
52 }
53 Some(device) => device.poll(Maintain::Wait),
54 };
55 }
56
57 std::thread::park();
58 }
59 drop(device);
60 })),
61 }
62 }
63
64 fn start_polling(&self) -> PollToken {
66 let prev = self.has_work.fetch_add(1, Ordering::AcqRel);
67 debug_assert!(
68 prev < usize::MAX,
69 "cannot have more than `usize::MAX` outstanding operations on the GPU"
70 );
71 self.handle
72 .as_ref()
73 .expect("handle set to None on drop")
74 .thread()
75 .unpark();
76 PollToken {
77 work_count: Arc::clone(&self.has_work),
78 }
79 }
80}
81
82#[cfg(not(target_arch = "wasm32"))]
83impl Drop for PollLoop {
84 fn drop(&mut self) {
85 self.is_done.store(true, Ordering::Release);
86
87 let handle = self.handle.take().expect("PollLoop dropped twice");
88 handle.thread().unpark();
89 handle.join().expect("PollLoop thread panicked");
90 }
91}
92
93#[cfg(not(target_arch = "wasm32"))]
95struct PollToken {
96 work_count: Arc<AtomicUsize>,
97}
98
99#[cfg(not(target_arch = "wasm32"))]
100impl Drop for PollToken {
101 fn drop(&mut self) {
102 #[cfg(not(target_arch = "wasm32"))]
104 {
105 let prev = self.work_count.fetch_sub(1, Ordering::AcqRel);
106 debug_assert!(
107 prev > 0,
108 "stop_polling was called without calling start_polling"
109 );
110 }
111 }
112}
113
114struct WgpuFutureSharedState<T> {
116 result: Option<T>,
117 waker: Option<Waker>,
118}
119
120pub struct WgpuFuture<T> {
122 device: AsyncDevice,
123 state: Arc<Mutex<WgpuFutureSharedState<T>>>,
124
125 #[cfg(not(target_arch = "wasm32"))]
126 poll_token: Option<PollToken>,
127}
128
129impl<T: Send + 'static> WgpuFuture<T> {
130 pub(crate) fn new(device: AsyncDevice) -> Self {
131 Self {
132 device,
133 state: Arc::new(Mutex::new(WgpuFutureSharedState {
134 result: None,
135 waker: None,
136 })),
137
138 #[cfg(not(target_arch = "wasm32"))]
139 poll_token: None,
140 }
141 }
142
143 pub(crate) fn callback(&self) -> Box<dyn FnOnce(T) + Send> {
145 let shared_state = Arc::clone(&self.state);
146 Box::new(move |res: T| {
147 let mut lock = shared_state
148 .lock()
149 .expect("wgpu future was poisoned on complete");
150 let shared_state = lock.deref_mut();
151 shared_state.result = Some(res);
152
153 if let Some(waker) = shared_state.waker.take() {
154 waker.wake()
155 }
156 })
157 }
158}
159
160impl<T> Future for WgpuFuture<T> {
161 type Output = T;
162
163 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
164 self.device.poll(Maintain::Poll);
166
167 {
169 let Self {
170 state,
171 #[cfg(not(target_arch = "wasm32"))]
172 poll_token,
173 ..
174 } = self.as_mut().get_mut();
175 let mut lock = state.lock().expect("wgpu future was poisoned on poll");
176
177 if let Some(res) = lock.result.take() {
178 #[cfg(not(target_arch = "wasm32"))]
179 {
180 *poll_token = None;
182 }
183
184 return Poll::Ready(res);
185 }
186
187 lock.waker = Some(cx.waker().clone());
188 }
189
190 #[cfg(not(target_arch = "wasm32"))]
192 if self.poll_token.is_none() {
193 self.poll_token = Some(self.device.poll_loop.start_polling());
194 }
195
196 Poll::Pending
197 }
198}