1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
use std::future::Future;
use std::ops::DerefMut;
use std::pin::Pin;
use std::sync::{Arc, Mutex};
use std::task::{Context, Poll, Waker};
use wgpu::Maintain;

#[cfg(not(target_arch = "wasm32"))]
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};

use crate::AsyncDevice;

/// Polls the device whilever a future says there is something to poll.
///
/// This objects corresponds to a thread that parks itself when no futures are
/// waiting on it, and then calls `device.poll(Maintain::Wait)` repeatedly to block
/// whilever it has work that a future is waiting on.
///
/// The thread dies when this object is dropped, and when the GPU has finished processing
/// all active futures.
#[cfg(not(target_arch = "wasm32"))]
#[derive(Debug)]
pub(crate) struct PollLoop {
    /// The number of futures still waiting on resolution from the GPU.
    /// When this is 0, the thread can park itself.
    has_work: Arc<AtomicUsize>,
    is_done: Arc<AtomicBool>,
    handle: std::thread::JoinHandle<()>,
}

#[cfg(not(target_arch = "wasm32"))]
impl PollLoop {
    pub(crate) fn new(device: Arc<wgpu::Device>) -> Self {
        let has_work = Arc::new(AtomicUsize::new(0));
        let is_done = Arc::new(AtomicBool::new(false));
        let locally_has_work = has_work.clone();
        let locally_is_done = is_done.clone();
        Self {
            has_work,
            is_done,
            handle: std::thread::spawn(move || {
                while !locally_is_done.load(Ordering::Acquire) {
                    while locally_has_work.load(Ordering::Acquire) != 0 {
                        device.poll(Maintain::Wait);
                    }

                    std::thread::park();
                }
            }),
        }
    }

    /// If the loop wasn't polling, start it polling.
    fn start_polling(&self) -> PollToken {
        let prev = self.has_work.fetch_add(1, Ordering::AcqRel);
        debug_assert!(
            prev < usize::MAX,
            "cannot have more than `usize::MAX` outstanding operations on the GPU"
        );
        self.handle.thread().unpark();
        PollToken {
            work_count: Arc::clone(&self.has_work),
        }
    }
}

#[cfg(not(target_arch = "wasm32"))]
impl Drop for PollLoop {
    fn drop(&mut self) {
        self.is_done.store(true, Ordering::Release);
        self.handle.thread().unpark()
    }
}

/// RAII indicating that polling is occurring, while this token is held.
#[cfg(not(target_arch = "wasm32"))]
struct PollToken {
    work_count: Arc<AtomicUsize>,
}

#[cfg(not(target_arch = "wasm32"))]
impl Drop for PollToken {
    fn drop(&mut self) {
        // On the web we don't poll, so don't do anything
        #[cfg(not(target_arch = "wasm32"))]
        {
            let prev = self.work_count.fetch_sub(1, Ordering::AcqRel);
            debug_assert!(
                prev > 0,
                "stop_polling was called without calling start_polling"
            );
        }
    }
}

/// The state that both the future and the callback hold.
struct WgpuFutureSharedState<T> {
    result: Option<T>,
    waker: Option<Waker>,
}

/// A future that can be awaited for once a callback completes. Created using [`AsyncDevice::do_async`].
pub struct WgpuFuture<T> {
    device: AsyncDevice,
    state: Arc<Mutex<WgpuFutureSharedState<T>>>,

    #[cfg(not(target_arch = "wasm32"))]
    poll_token: Option<PollToken>,
}

impl<T: Send + 'static> WgpuFuture<T> {
    pub(crate) fn new(device: AsyncDevice) -> Self {
        Self {
            device,
            state: Arc::new(Mutex::new(WgpuFutureSharedState {
                result: None,
                waker: None,
            })),

            #[cfg(not(target_arch = "wasm32"))]
            poll_token: None,
        }
    }

    /// Generates a callback function for this future that wakes the waker and sets the shared state.
    pub(crate) fn callback(&self) -> Box<dyn FnOnce(T) + Send> {
        let shared_state = self.state.clone();
        return Box::new(move |res: T| {
            let mut lock = shared_state
                .lock()
                .expect("wgpu future was poisoned on complete");
            let shared_state = lock.deref_mut();
            shared_state.result = Some(res);

            if let Some(waker) = shared_state.waker.take() {
                waker.wake()
            }
        });
    }
}

impl<T> Future for WgpuFuture<T> {
    type Output = T;

    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
        // Poll whenever we enter to see if we can avoid waiting altogether
        self.device.poll(Maintain::Poll);

        // Check with scoped lock
        {
            let Self {
                state,
                #[cfg(not(target_arch = "wasm32"))]
                poll_token,
                ..
            } = self.as_mut().get_mut();
            let mut lock = state.lock().expect("wgpu future was poisoned on poll");

            if let Some(res) = lock.result.take() {
                #[cfg(not(target_arch = "wasm32"))]
                {
                    // Drop token, stopping poll loop.
                    *poll_token = None;
                }

                return Poll::Ready(res);
            }

            lock.waker = Some(cx.waker().clone());
        }

        // If we're not ready, make sure the poll loop is running (on non-WASM)
        #[cfg(not(target_arch = "wasm32"))]
        if self.poll_token.is_none() {
            self.poll_token = Some(self.device.poll_loop.start_polling());
        }

        return Poll::Pending;
    }
}