Skip to main content

tokio_immediate/
lib.rs

1// SPDX-License-Identifier: Apache-2.0 OR MIT
2
3//! Primitives for calling asynchronous code from immediate mode GUIs.
4//!
5//! The [`single`] module contains [`AsyncCall`], which spawns a single
6//! [`Future`] onto a Tokio runtime and exposes its result through a
7//! poll-based API that fits naturally into an immediate mode update loop.
8//! An [`AsyncViewport`] ties a GUI viewport to a wake-up callback so
9//! that completed tasks automatically trigger a repaint.
10//!
11//! With the `sync` feature enabled, the [`sync`] and [`trigger`] modules
12//! provide channel wrappers that wake viewports when values are sent,
13//! enabling continuous progress reporting from async tasks to the UI.
14//! The `parallel` module provides `AsyncParallelRunner`.
15//! The `serial` module is enabled by the `sync` feature and provides
16//! `AsyncSerialRunner`.
17//!
18//! ## Feature flags
19#![cfg_attr(docsrs, feature(doc_cfg))]
20#![cfg_attr(
21    feature = "document-features",
22    cfg_attr(doc, doc = ::document_features::document_features!())
23)]
24//
25// Clippy lints.
26#![warn(clippy::pedantic)]
27#![warn(clippy::cargo)]
28#![warn(clippy::undocumented_unsafe_blocks)]
29
30use ::std::mem::replace;
31use ::std::ops::Deref;
32use ::std::panic::resume_unwind;
33use ::std::sync::atomic::{AtomicBool, Ordering};
34use ::std::sync::{Arc, RwLock, RwLockReadGuard, RwLockWriteGuard, Weak};
35
36use ::tokio::runtime::Handle;
37use ::tokio::task::{JoinHandle, JoinSet};
38
39/// Re-export `tokio` crate.
40pub use ::tokio;
41
42/// Async parallel runner: schedule futures to run concurrently.
43pub mod parallel;
44/// Async serial runner: schedule futures to run one after another.
45#[cfg(feature = "sync")]
46#[cfg_attr(docsrs, doc(cfg(feature = "sync")))]
47pub mod serial;
48/// Async call: spawn one [`Future`] and track its result.
49pub mod single;
50/// Wrappers around `tokio::sync` primitives that wake up viewports on send.
51#[cfg(feature = "sync")]
52#[cfg_attr(docsrs, doc(cfg(feature = "sync")))]
53pub mod sync;
54/// A notification channel for signaling async tasks from the UI thread.
55#[cfg(feature = "sync")]
56#[cfg_attr(docsrs, doc(cfg(feature = "sync")))]
57pub mod trigger;
58
59use parallel::AsyncParallelRunner;
60#[cfg(feature = "sync")]
61use serial::AsyncSerialRunner;
62use single::AsyncCall;
63
64/// Represents a single GUI viewport (window) that can be woken up from
65/// asynchronous tasks.
66///
67/// An `AsyncViewport` owns a wake-up callback (typically one that
68/// requests a repaint of the viewport) and hands out [`AsyncWaker`]
69/// handles via [`new_waker()`](Self::new_waker).
70///
71/// The wake-up callback **must not block**, because it may be called from
72/// inside an async executor.
73#[derive(Clone)]
74pub struct AsyncViewport {
75    wake_up_requested: Arc<AtomicBool>,
76    wake_up: Arc<AsyncWakeUpSlot>,
77}
78
79/// A thread-safe collection of [`AsyncWaker`]s that can all be woken up
80/// at once.
81///
82/// This is primarily used by synchronisation primitives (e.g.
83/// [`crate::sync::watch`]) that need to notify every viewport observing the same
84/// shared value.
85#[derive(Clone)]
86pub struct AsyncWakerList {
87    inner: Arc<RwLock<AsyncWakerListInner>>,
88}
89
90struct AsyncWakerListInner {
91    wakers: Vec<Option<AsyncWaker>>,
92    free: Vec<usize>,
93}
94
95/// A lightweight, cloneable handle that can request a repaint of the
96/// [`AsyncViewport`] it was created from.
97///
98/// If the viewport has been dropped, [`AsyncWakeUp::wake_up`]
99/// becomes a no-op and returns `false`.
100#[derive(Clone)]
101pub struct AsyncWaker {
102    wake_up_requested: Arc<AtomicBool>,
103    wake_up: Weak<AsyncWakeUpSlot>,
104}
105
106type AsyncWakeUpSlot = RwLock<Option<AsyncWakeUpCallback>>;
107pub type AsyncWakeUpCallback = Arc<dyn Fn() + Send + Sync>;
108
109/// RAII guard that wakes up on drop.
110pub struct AsyncWakeUpGuard<W>
111where
112    W: AsyncWakeUp,
113{
114    waker: W,
115}
116
117/// Common interface for types that can request a viewport wake-up.
118pub trait AsyncWakeUp {
119    /// Creates a guard that calls [`AsyncWakeUp::wake_up`] when dropped.
120    #[must_use]
121    fn wake_up_guard(&self) -> AsyncWakeUpGuard<&Self>
122    where
123        Self: Sized,
124    {
125        AsyncWakeUpGuard { waker: self }
126    }
127
128    #[must_use]
129    fn wake_up_guard_owned(self) -> AsyncWakeUpGuard<Self>
130    where
131        Self: Sized,
132    {
133        AsyncWakeUpGuard { waker: self }
134    }
135
136    /// Requests a wake-up.
137    fn wake_up(&self);
138}
139
140/// Abstraction over how [`AsyncCall`] accesses a Tokio runtime.
141///
142/// Implemented for [`AsyncCurrentRuntime`] (thread-local context) and
143/// [`Handle`] (explicit handle stored inside `AsyncCall`).
144pub trait AsyncRuntime {
145    /// Spawns a future onto the runtime, returning a [`JoinHandle`].
146    ///
147    /// # Panics
148    ///
149    /// Implementations may panic if their runtime access preconditions are not
150    /// met.
151    fn spawn<Fut, T>(&mut self, future: Fut) -> JoinHandle<T>
152    where
153        Fut: 'static + Send + Future<Output = T>,
154        T: 'static + Send;
155
156    /// Spawns a future onto the runtime and tracks it in a [`JoinSet`].
157    ///
158    /// # Panics
159    ///
160    /// Implementations may panic if their runtime access preconditions are not
161    /// met.
162    fn spawn_join_set<Fut, T>(&mut self, join_set: &mut JoinSet<T>, future: Fut)
163    where
164        Fut: 'static + Send + Future<Output = T>,
165        T: 'static + Send;
166
167    /// Blocks the current thread until the task completes.
168    ///
169    /// Returns `Some(value)` on success, or `None` if the task was
170    /// cancelled. Re-raises the panic (via [`resume_unwind`]) if the task
171    /// panicked.
172    ///
173    /// # Panics
174    ///
175    /// Implementations may panic if their runtime access preconditions are not
176    /// met.
177    ///
178    /// Re-raises panics from the joined task in the calling thread.
179    fn block_on<T>(&mut self, join_handle: JoinHandle<T>) -> Option<T>
180    where
181        T: 'static + Send;
182}
183
184/// The default [`AsyncRuntime`] for [`AsyncCall`].
185///
186/// Uses the Tokio runtime entered on the current thread (i.e. the one set
187/// up by [`Runtime::enter()`](tokio::runtime::Runtime::enter)). This is
188/// convenient when the runtime context is guaranteed to be available, but
189/// will panic if it is not (e.g. in a deferred viewport callback running on
190/// a different thread). In that case, pass a
191/// [`Handle`] explicitly instead.
192#[derive(Default)]
193pub struct AsyncCurrentRuntime;
194
195impl Default for AsyncViewport {
196    fn default() -> Self {
197        Self {
198            wake_up_requested: Arc::new(AtomicBool::new(false)),
199            wake_up: Arc::new(RwLock::new(None)),
200        }
201    }
202}
203
204impl AsyncWakeUp for AsyncViewport {
205    fn wake_up(&self) {
206        if self
207            .wake_up_requested
208            .compare_exchange(false, true, Ordering::Relaxed, Ordering::Relaxed)
209            .is_ok()
210        {
211            let wake_up = self
212                .wake_up
213                .read()
214                .expect("Failed to read-lock AsyncViewport callback: poisoned by panic")
215                .clone();
216            if let Some(wake_up) = wake_up {
217                (wake_up)();
218            }
219        }
220    }
221}
222
223impl AsyncViewport {
224    /// Creates a new viewport with the given wake-up callback.
225    ///
226    /// `wake_up` is called whenever an async task associated with this
227    /// viewport finishes or a synchronisation primitive is updated. It
228    /// **must not block**.
229    #[must_use]
230    pub fn new_with_wake_up(wake_up: AsyncWakeUpCallback) -> Self {
231        let viewport = Self::default();
232        let _ = viewport.replace_wake_up(Some(wake_up));
233        viewport
234    }
235
236    /// Replaces the viewport wake-up callback, returning the previous one.
237    ///
238    /// # Panics
239    ///
240    /// Panics if the callback lock is poisoned by another thread panicking
241    /// while holding the write lock.
242    #[must_use]
243    pub fn replace_wake_up(
244        &self,
245        wake_up: Option<AsyncWakeUpCallback>,
246    ) -> Option<AsyncWakeUpCallback> {
247        replace(
248            &mut *self
249                .wake_up
250                .write()
251                .expect("Failed to write-lock AsyncViewport callback: poisoned by panic"),
252            wake_up,
253        )
254    }
255
256    /// Creates an [`AsyncCall`] wired to this viewport, using `A::default()`
257    /// as the runtime.
258    #[must_use]
259    pub fn new_call<T, A>(&self) -> AsyncCall<T, A>
260    where
261        T: 'static + Send,
262        A: Default + AsyncRuntime,
263    {
264        AsyncCall::new(self.new_waker())
265    }
266
267    /// Creates an [`AsyncCall`] wired to this viewport with an explicit
268    /// runtime.
269    #[must_use]
270    pub fn new_call_with_runtime<T, A>(&self, runtime: A) -> AsyncCall<T, A>
271    where
272        T: 'static + Send,
273        A: AsyncRuntime,
274    {
275        AsyncCall::new_with_runtime(self.new_waker(), runtime)
276    }
277
278    /// Creates an [`AsyncSerialRunner`] wired to this viewport, using
279    /// `A::default()` as the runtime.
280    #[must_use]
281    #[cfg(feature = "sync")]
282    #[cfg_attr(docsrs, doc(cfg(feature = "sync")))]
283    pub fn new_serial_runner<T, A>(&self) -> AsyncSerialRunner<T, A>
284    where
285        T: 'static + Send,
286        A: Default + AsyncRuntime,
287    {
288        AsyncSerialRunner::new(self.new_waker())
289    }
290
291    /// Creates an [`AsyncSerialRunner`] wired to this viewport with an
292    /// explicit runtime.
293    #[must_use]
294    #[cfg(feature = "sync")]
295    #[cfg_attr(docsrs, doc(cfg(feature = "sync")))]
296    pub fn new_serial_runner_with_runtime<T, A>(&self, runtime: A) -> AsyncSerialRunner<T, A>
297    where
298        T: 'static + Send,
299        A: AsyncRuntime,
300    {
301        AsyncSerialRunner::new_with_runtime(self.new_waker(), runtime)
302    }
303
304    /// Creates an [`AsyncParallelRunner`] wired to this viewport, using
305    /// `A::default()` as the runtime.
306    #[must_use]
307    pub fn new_parallel_runner<T, A>(&self) -> AsyncParallelRunner<T, A>
308    where
309        T: 'static + Send,
310        A: Default + AsyncRuntime,
311    {
312        AsyncParallelRunner::new(self.new_waker())
313    }
314
315    /// Creates an [`AsyncParallelRunner`] wired to this viewport with an
316    /// explicit runtime.
317    #[must_use]
318    pub fn new_parallel_runner_with_runtime<T, A>(&self, runtime: A) -> AsyncParallelRunner<T, A>
319    where
320        T: 'static + Send,
321        A: AsyncRuntime,
322    {
323        AsyncParallelRunner::new_with_runtime(self.new_waker(), runtime)
324    }
325
326    /// Creates a new [`AsyncWaker`] that can request a repaint of this
327    /// viewport.
328    #[must_use]
329    pub fn new_waker(&self) -> AsyncWaker {
330        AsyncWaker {
331            wake_up_requested: self.wake_up_requested.clone(),
332            wake_up: Arc::downgrade(&self.wake_up),
333        }
334    }
335
336    /// Acknowledges that the viewport has been repainted after a wake-up
337    /// request, clearing the pending flag.
338    ///
339    /// Call this at the start of every frame (before polling any
340    /// [`AsyncCall`] instances) so that subsequent wake-up requests are not
341    /// swallowed.
342    pub fn woke_up(&self) {
343        self.wake_up_requested.store(false, Ordering::Relaxed);
344    }
345
346    /// Returns `true` if `self` and `other` represent the same viewport.
347    #[must_use]
348    pub fn is_same_viewport(&self, other: &Self) -> bool {
349        self.wake_up_requested.as_ptr() == other.wake_up_requested.as_ptr()
350    }
351}
352
353impl Default for AsyncWakerList {
354    fn default() -> Self {
355        Self::with_capacity(1)
356    }
357}
358
359impl AsyncWakeUp for AsyncWakerList {
360    fn wake_up(&self) {
361        for waker in self.inner().wakers.iter().flatten() {
362            waker.wake_up();
363        }
364    }
365}
366
367impl AsyncWakerList {
368    /// Creates a new waker list pre-allocated for `capacity` wakers.
369    #[must_use]
370    pub fn with_capacity(capacity: usize) -> Self {
371        Self {
372            inner: Arc::new(RwLock::new(AsyncWakerListInner {
373                wakers: Vec::with_capacity(capacity),
374                free: Vec::with_capacity(capacity),
375            })),
376        }
377    }
378
379    /// Registers a waker and returns its index.
380    ///
381    /// The returned index must later be passed to
382    /// [`remove_waker()`](Self::remove_waker) exactly once when the waker is
383    /// no longer needed.
384    #[must_use]
385    pub fn add_waker(&self, waker: AsyncWaker) -> usize {
386        let mut inner = self.inner_mut();
387        if let Some(idx) = inner.free.pop() {
388            // SAFETY: We never shrink `wakers`, and `free` contains only indexes
389            // previously produced by `add_waker`.
390            let place = unsafe { inner.wakers.get_unchecked_mut(idx) };
391            *place = Some(waker);
392            idx
393        } else {
394            let idx = inner.wakers.len();
395            inner.wakers.push(Some(waker));
396            let free_vec_reserve = inner.wakers.capacity() - inner.free.len();
397            inner.free.reserve_exact(free_vec_reserve);
398            idx
399        }
400    }
401
402    /// Removes a previously registered waker by index.
403    ///
404    /// # Safety
405    ///
406    /// `idx` must be a value previously returned by
407    /// [`add_waker()`](Self::add_waker), and this method must be called
408    /// **exactly once** per index.
409    pub unsafe fn remove_waker(&self, idx: usize) {
410        let mut inner = self.inner_mut();
411        // SAFETY: `idx` must satisfy this function's safety contract.
412        let place = unsafe { inner.wakers.get_unchecked_mut(idx) };
413        *place = None;
414        inner.free.push(idx);
415    }
416
417    fn inner(&'_ self) -> RwLockReadGuard<'_, AsyncWakerListInner> {
418        self.inner
419            .read()
420            .expect("Failed to read-lock AsyncWakerList: poisoned by panic in another thread")
421    }
422
423    fn inner_mut(&'_ self) -> RwLockWriteGuard<'_, AsyncWakerListInner> {
424        self.inner
425            .write()
426            .expect("Failed to write-lock AsyncWakerList: poisoned by panic in another thread")
427    }
428}
429
430impl AsyncWakeUp for AsyncWaker {
431    fn wake_up(&self) {
432        if self
433            .wake_up_requested
434            .compare_exchange(false, true, Ordering::Relaxed, Ordering::Relaxed)
435            .is_ok()
436        {
437            if let Some(wake_up) = self.wake_up.upgrade() {
438                let wake_up = wake_up
439                    .read()
440                    .expect("Failed to read-lock AsyncWaker callback: poisoned by panic")
441                    .clone();
442                if let Some(wake_up) = wake_up {
443                    (wake_up)();
444                }
445            } else {
446                self.wake_up_requested.store(false, Ordering::Relaxed);
447            }
448        }
449    }
450}
451
452impl AsyncWaker {
453    /// Returns `true` if the owning [`AsyncViewport`] is still alive.
454    #[must_use]
455    pub fn is_alive(&self) -> bool {
456        self.wake_up.strong_count() > 0
457    }
458
459    /// Returns `true` if `self` and `other` belong to the same viewport.
460    #[must_use]
461    pub fn is_same_viewport(&self, other: &Self) -> bool {
462        self.wake_up_requested.as_ptr() == other.wake_up_requested.as_ptr()
463    }
464}
465
466impl<W> Deref for AsyncWakeUpGuard<W>
467where
468    W: AsyncWakeUp,
469{
470    type Target = W;
471
472    fn deref(&self) -> &Self::Target {
473        &self.waker
474    }
475}
476
477impl<W, T> AsRef<T> for AsyncWakeUpGuard<W>
478where
479    W: AsyncWakeUp,
480    <Self as Deref>::Target: AsRef<T>,
481{
482    fn as_ref(&self) -> &T {
483        self.deref().as_ref()
484    }
485}
486
487impl<W> Drop for AsyncWakeUpGuard<W>
488where
489    W: AsyncWakeUp,
490{
491    fn drop(&mut self) {
492        self.waker.wake_up();
493    }
494}
495
496impl<T> AsyncWakeUp for &T
497where
498    T: AsyncWakeUp,
499{
500    fn wake_up(&self) {
501        (*self).wake_up();
502    }
503}
504
505impl AsyncRuntime for AsyncCurrentRuntime {
506    fn spawn<Fut, T>(&mut self, future: Fut) -> JoinHandle<T>
507    where
508        Fut: 'static + Send + Future<Output = T>,
509        T: 'static + Send,
510    {
511        tokio::spawn(future)
512    }
513
514    fn spawn_join_set<Fut, T>(&mut self, join_set: &mut JoinSet<T>, future: Fut)
515    where
516        Fut: 'static + Send + Future<Output = T>,
517        T: 'static + Send,
518    {
519        drop(join_set.spawn(future));
520    }
521
522    fn block_on<T>(&mut self, join_handle: JoinHandle<T>) -> Option<T>
523    where
524        T: 'static + Send,
525    {
526        match Handle::current().block_on(join_handle) {
527            Ok(value) => Some(value),
528
529            Err(error) => {
530                if error.is_cancelled() {
531                    None
532                } else {
533                    resume_unwind(error.into_panic());
534                }
535            }
536        }
537    }
538}
539
540impl AsyncRuntime for Handle {
541    fn spawn<Fut, T>(&mut self, future: Fut) -> JoinHandle<T>
542    where
543        Fut: 'static + Send + Future<Output = T>,
544        T: 'static + Send,
545    {
546        Handle::spawn(self, future)
547    }
548
549    fn spawn_join_set<Fut, T>(&mut self, join_set: &mut JoinSet<T>, future: Fut)
550    where
551        Fut: 'static + Send + Future<Output = T>,
552        T: 'static + Send,
553    {
554        drop(join_set.spawn_on(future, self));
555    }
556
557    fn block_on<T>(&mut self, join_handle: JoinHandle<T>) -> Option<T>
558    where
559        T: 'static + Send,
560    {
561        match Handle::block_on(self, join_handle) {
562            Ok(value) => Some(value),
563
564            Err(error) => {
565                if error.is_cancelled() {
566                    None
567                } else {
568                    resume_unwind(error.into_panic());
569                }
570            }
571        }
572    }
573}