wgpu_async/
wgpu_future.rs

1use std::future::Future;
2use std::ops::DerefMut;
3use std::pin::Pin;
4use std::sync::{Arc, Mutex};
5use std::task::{Context, Poll, Waker};
6use wgpu::Maintain;
7
8#[cfg(not(target_arch = "wasm32"))]
9use std::sync::{
10    atomic::{AtomicBool, AtomicUsize, Ordering},
11    Weak,
12};
13
14use crate::AsyncDevice;
15
16/// Polls the device while-ever a future says there is something to poll.
17///
18/// This objects corresponds to a thread that parks itself when no futures are
19/// waiting on it, and then calls `device.poll(Maintain::Wait)` repeatedly to block
20/// while-ever it has work that a future is waiting on.
21///
22/// The thread dies when this object is dropped, and when the GPU has finished processing
23/// all active futures.
24#[cfg(not(target_arch = "wasm32"))]
25#[derive(Debug)]
26pub(crate) struct PollLoop {
27    /// The number of futures still waiting on resolution from the GPU.
28    /// When this is 0, the thread can park itself.
29    has_work: Arc<AtomicUsize>,
30    is_done: Arc<AtomicBool>,
31    handle: Option<std::thread::JoinHandle<()>>,
32}
33
34#[cfg(not(target_arch = "wasm32"))]
35impl PollLoop {
36    pub(crate) fn new(device: Weak<wgpu::Device>) -> Self {
37        let has_work = Arc::new(AtomicUsize::new(0));
38        let is_done = Arc::new(AtomicBool::new(false));
39        let locally_has_work = Arc::clone(&has_work);
40        let locally_is_done = Arc::clone(&is_done);
41        Self {
42            has_work,
43            is_done,
44            handle: Some(std::thread::spawn(move || {
45                while !locally_is_done.load(Ordering::Acquire) {
46                    while locally_has_work.load(Ordering::Acquire) != 0 {
47                        match device.upgrade() {
48                            None => {
49                                // If all other references to the device are dropped, don't keep hold of the device here
50                                locally_is_done.store(true, Ordering::Release);
51                                return;
52                            }
53                            Some(device) => device.poll(Maintain::Wait),
54                        };
55                    }
56
57                    std::thread::park();
58                }
59                drop(device);
60            })),
61        }
62    }
63
64    /// If the loop wasn't polling, start it polling.
65    fn start_polling(&self) -> PollToken {
66        let prev = self.has_work.fetch_add(1, Ordering::AcqRel);
67        debug_assert!(
68            prev < usize::MAX,
69            "cannot have more than `usize::MAX` outstanding operations on the GPU"
70        );
71        self.handle
72            .as_ref()
73            .expect("handle set to None on drop")
74            .thread()
75            .unpark();
76        PollToken {
77            work_count: Arc::clone(&self.has_work),
78        }
79    }
80}
81
82#[cfg(not(target_arch = "wasm32"))]
83impl Drop for PollLoop {
84    fn drop(&mut self) {
85        self.is_done.store(true, Ordering::Release);
86
87        let handle = self.handle.take().expect("PollLoop dropped twice");
88        handle.thread().unpark();
89        handle.join().expect("PollLoop thread panicked");
90    }
91}
92
93/// RAII indicating that polling is occurring, while this token is held.
94#[cfg(not(target_arch = "wasm32"))]
95struct PollToken {
96    work_count: Arc<AtomicUsize>,
97}
98
99#[cfg(not(target_arch = "wasm32"))]
100impl Drop for PollToken {
101    fn drop(&mut self) {
102        // On the web we don't poll, so don't do anything
103        #[cfg(not(target_arch = "wasm32"))]
104        {
105            let prev = self.work_count.fetch_sub(1, Ordering::AcqRel);
106            debug_assert!(
107                prev > 0,
108                "stop_polling was called without calling start_polling"
109            );
110        }
111    }
112}
113
114/// The state that both the future and the callback hold.
115struct WgpuFutureSharedState<T> {
116    result: Option<T>,
117    waker: Option<Waker>,
118}
119
120/// A future that can be awaited for once a callback completes. Created using [`AsyncDevice::do_async`].
121pub struct WgpuFuture<T> {
122    device: AsyncDevice,
123    state: Arc<Mutex<WgpuFutureSharedState<T>>>,
124
125    #[cfg(not(target_arch = "wasm32"))]
126    poll_token: Option<PollToken>,
127}
128
129impl<T: Send + 'static> WgpuFuture<T> {
130    pub(crate) fn new(device: AsyncDevice) -> Self {
131        Self {
132            device,
133            state: Arc::new(Mutex::new(WgpuFutureSharedState {
134                result: None,
135                waker: None,
136            })),
137
138            #[cfg(not(target_arch = "wasm32"))]
139            poll_token: None,
140        }
141    }
142
143    /// Generates a callback function for this future that wakes the waker and sets the shared state.
144    pub(crate) fn callback(&self) -> Box<dyn FnOnce(T) + Send> {
145        let shared_state = Arc::clone(&self.state);
146        Box::new(move |res: T| {
147            let mut lock = shared_state
148                .lock()
149                .expect("wgpu future was poisoned on complete");
150            let shared_state = lock.deref_mut();
151            shared_state.result = Some(res);
152
153            if let Some(waker) = shared_state.waker.take() {
154                waker.wake()
155            }
156        })
157    }
158}
159
160impl<T> Future for WgpuFuture<T> {
161    type Output = T;
162
163    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
164        // Poll whenever we enter to see if we can avoid waiting altogether
165        self.device.poll(Maintain::Poll);
166
167        // Check with scoped lock
168        {
169            let Self {
170                state,
171                #[cfg(not(target_arch = "wasm32"))]
172                poll_token,
173                ..
174            } = self.as_mut().get_mut();
175            let mut lock = state.lock().expect("wgpu future was poisoned on poll");
176
177            if let Some(res) = lock.result.take() {
178                #[cfg(not(target_arch = "wasm32"))]
179                {
180                    // Drop token, stopping poll loop.
181                    *poll_token = None;
182                }
183
184                return Poll::Ready(res);
185            }
186
187            lock.waker = Some(cx.waker().clone());
188        }
189
190        // If we're not ready, make sure the poll loop is running (on non-WASM)
191        #[cfg(not(target_arch = "wasm32"))]
192        if self.poll_token.is_none() {
193            self.poll_token = Some(self.device.poll_loop.start_polling());
194        }
195
196        Poll::Pending
197    }
198}