use std::future::Future;
use std::ops::DerefMut;
use std::pin::Pin;
use std::sync::{Arc, Mutex};
use std::task::{Context, Poll, Waker};
use wgpu::Maintain;
#[cfg(not(target_arch = "wasm32"))]
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
use crate::AsyncDevice;
#[cfg(not(target_arch = "wasm32"))]
#[derive(Debug)]
pub(crate) struct PollLoop {
has_work: Arc<AtomicUsize>,
is_done: Arc<AtomicBool>,
handle: std::thread::JoinHandle<()>,
}
#[cfg(not(target_arch = "wasm32"))]
impl PollLoop {
pub(crate) fn new(device: Arc<wgpu::Device>) -> Self {
let has_work = Arc::new(AtomicUsize::new(0));
let is_done = Arc::new(AtomicBool::new(false));
let locally_has_work = has_work.clone();
let locally_is_done = is_done.clone();
Self {
has_work,
is_done,
handle: std::thread::spawn(move || {
while !locally_is_done.load(Ordering::Acquire) {
while locally_has_work.load(Ordering::Acquire) != 0 {
device.poll(Maintain::Wait);
}
std::thread::park();
}
}),
}
}
fn start_polling(&self) -> PollToken {
let prev = self.has_work.fetch_add(1, Ordering::AcqRel);
debug_assert!(
prev < usize::MAX,
"cannot have more than `usize::MAX` outstanding operations on the GPU"
);
self.handle.thread().unpark();
PollToken {
work_count: Arc::clone(&self.has_work),
}
}
}
#[cfg(not(target_arch = "wasm32"))]
impl Drop for PollLoop {
fn drop(&mut self) {
self.is_done.store(true, Ordering::Release);
self.handle.thread().unpark()
}
}
#[cfg(not(target_arch = "wasm32"))]
struct PollToken {
work_count: Arc<AtomicUsize>,
}
#[cfg(not(target_arch = "wasm32"))]
impl Drop for PollToken {
fn drop(&mut self) {
#[cfg(not(target_arch = "wasm32"))]
{
let prev = self.work_count.fetch_sub(1, Ordering::AcqRel);
debug_assert!(
prev > 0,
"stop_polling was called without calling start_polling"
);
}
}
}
struct WgpuFutureSharedState<T> {
result: Option<T>,
waker: Option<Waker>,
}
pub struct WgpuFuture<T> {
device: AsyncDevice,
state: Arc<Mutex<WgpuFutureSharedState<T>>>,
#[cfg(not(target_arch = "wasm32"))]
poll_token: Option<PollToken>,
}
impl<T: Send + 'static> WgpuFuture<T> {
pub(crate) fn new(device: AsyncDevice) -> Self {
Self {
device,
state: Arc::new(Mutex::new(WgpuFutureSharedState {
result: None,
waker: None,
})),
#[cfg(not(target_arch = "wasm32"))]
poll_token: None,
}
}
pub(crate) fn callback(&self) -> Box<dyn FnOnce(T) + Send> {
let shared_state = self.state.clone();
return Box::new(move |res: T| {
let mut lock = shared_state
.lock()
.expect("wgpu future was poisoned on complete");
let shared_state = lock.deref_mut();
shared_state.result = Some(res);
if let Some(waker) = shared_state.waker.take() {
waker.wake()
}
});
}
}
impl<T> Future for WgpuFuture<T> {
type Output = T;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
self.device.poll(Maintain::Poll);
{
let Self {
state,
#[cfg(not(target_arch = "wasm32"))]
poll_token,
..
} = self.as_mut().get_mut();
let mut lock = state.lock().expect("wgpu future was poisoned on poll");
if let Some(res) = lock.result.take() {
#[cfg(not(target_arch = "wasm32"))]
{
*poll_token = None;
}
return Poll::Ready(res);
}
lock.waker = Some(cx.waker().clone());
}
#[cfg(not(target_arch = "wasm32"))]
if self.poll_token.is_none() {
self.poll_token = Some(self.device.poll_loop.start_polling());
}
return Poll::Pending;
}
}