scoped_join_set/
lib.rs

1use tokio::{
2    runtime::{Handle, RuntimeFlavor},
3    task::{AbortHandle, Id, JoinSet},
4};
5
6use std::{
7    any::Any,
8    cell::UnsafeCell,
9    collections::HashMap,
10    error::Error,
11    fmt,
12    future::Future,
13    pin::Pin,
14    sync::{Arc, Weak},
15    task::{Context, Poll},
16};
17
18type TokioJoinError = tokio::task::JoinError;
19
20/// A scoped variant of [`tokio::task::JoinSet`] that ensures all futures outlive `'scope`,
21/// while allowing them to be spawned and awaited dynamically.
22///
23/// This structure tracks all spawned futures internally with a holder, allowing weak references
24/// to be used for polling in a way that is memory-safe and supports scoped lifetimes.
25///
26/// # Safety
27///
28/// There is only one `FutureHolder` and one `WeakFuture` at a time per task.
29/// If the `WeakFuture` successfully upgrades, the future is guaranteed to still be valid
30/// because `FutureHolder` holds the strong reference until it is dropped (after task join).
31#[derive(Default)]
32pub struct ScopedJoinSet<'scope, T>
33where
34    T: 'static,
35{
36    join_set: JoinSet<Option<T>>,
37    holders: HashMap<Id, FutureHolder<'scope, T>>,
38}
39
40impl<'scope, T> ScopedJoinSet<'scope, T>
41where
42    T: 'static,
43{
44    /// Creates a new, empty `ScopedJoinSet`.
45    pub fn new() -> Self {
46        Self {
47            join_set: JoinSet::new(),
48            holders: HashMap::new(),
49        }
50    }
51
52    /// Spawns a new task on the join set.
53    ///
54    /// The future must be `'scope`-bound and `Send`.
55    /// Internally, the future is wrapped in a `FutureHolder` and a weak reference is
56    /// passed to the join set for execution.
57    ///
58    /// # Panics
59    ///
60    /// This method panics if called outside of a Tokio runtime.
61    pub fn spawn<F>(&mut self, task: F)
62    where
63        F: Future<Output = T> + Send + 'scope,
64        T: Send,
65    {
66        let strong: Arc<UnsafeCell<dyn Future<Output = T> + Send + 'scope>> =
67            Arc::new(UnsafeCell::new(task));
68
69        let weak_future = WeakFuture {
70            future: unsafe {
71                std::mem::transmute::<
72                    Weak<UnsafeCell<dyn Future<Output = T> + Send + 'scope>>,
73                    Weak<UnsafeCell<dyn Future<Output = T> + Send>>,
74                >(Arc::downgrade(&strong))
75            },
76        };
77        let handle = self.join_set.spawn(weak_future);
78        let holder = FutureHolder {
79            abort_handle: handle.clone(),
80            future: strong,
81        };
82        self.holders.insert(handle.id(), holder);
83    }
84
85    /// Returns `true` if there are no remaining tasks in the join set.
86    pub fn is_empty(&self) -> bool {
87        self.join_set.is_empty()
88    }
89
90    /// Waits for the next task to complete.
91    ///
92    /// Returns:
93    /// - `Some(Ok(T))` if a task completed successfully.
94    /// - `Some(Err(JoinError))` if the task was cancelled or panicked.
95    /// - `None` if there are no more tasks in the set.
96    ///
97    /// # Cancel Safety
98    ///
99    /// This method is **cancellation safe** — if the `join_next().await` is itself
100    /// cancelled (e.g., due to timeout or a `select!` branch winning), no state is lost.
101    /// The underlying task remains in the join set and will be yielded again on the next
102    /// call to `join_next()`.
103    ///
104    /// Internally, the `JoinSet`'s `join_next_with_id()` ensures that the task is not removed
105    /// from the set unless it has completed and its result has been received.
106    ///
107    /// The associated future holder is only dropped and removed from internal tracking
108    /// once the task finishes and is returned from `join_next()`.
109    pub async fn join_next(&mut self) -> Option<Result<T, JoinError>> {
110        match self.join_set.join_next_with_id().await? {
111            Ok((id, Some(value))) => {
112                self.holders.remove(&id);
113                Some(Ok(value))
114            }
115            Ok((id, None)) => {
116                self.holders.remove(&id);
117                Some(Err(JoinError::Cancelled))
118            }
119            Err(error) => {
120                self.holders.remove(&error.id());
121                Some(Err(error.into()))
122            }
123        }
124    }
125}
126
127/// Holds a strong reference to the future so it can be upgraded by the `WeakFuture`.
128///
129/// Internally uses an `Arc<UnsafeCell<dyn Future>>` to allow safe access from the
130/// `poll` method in `WeakFuture`. Only `FutureHolder` owns the strong reference;
131/// once dropped, the weak reference becomes invalid and polling returns `None`.
132struct FutureHolder<'scope, T>
133where
134    T: 'static,
135{
136    abort_handle: AbortHandle,
137    future: Arc<UnsafeCell<dyn Future<Output = T> + Send + 'scope>>,
138}
139
140impl<'scope, T> Drop for FutureHolder<'scope, T> {
141    fn drop(&mut self) {
142        self.abort_handle.abort();
143        if Handle::current().runtime_flavor() == RuntimeFlavor::CurrentThread {
144            return;
145        }
146        if Arc::strong_count(&self.future) > 1 {
147            ::tokio::task::block_in_place(|| {
148                let mut spin = 0;
149                while Arc::strong_count(&self.future) > 1 {
150                    spin += 1;
151                    if spin < 10 {
152                        std::hint::spin_loop();
153                    } else {
154                        std::thread::yield_now();
155                    }
156                }
157            });
158        }
159    }
160}
161
162/// A weak reference to a future, used inside the join set.
163///
164/// If the weak reference can be upgraded, the future is still alive and will be polled.
165/// Once the strong reference (`FutureHolder`) is dropped, the weak reference cannot be upgraded,
166/// and polling will return `Poll::Ready(None)`.
167struct WeakFuture<T> {
168    future: Weak<UnsafeCell<dyn Future<Output = T> + Send>>,
169}
170
171unsafe impl<T: Send> Send for WeakFuture<T> {}
172
173impl<T> Future for WeakFuture<T> {
174    type Output = Option<T>;
175
176    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
177        if let Some(fut) = self.future.upgrade() {
178            let fut = unsafe { fut.get().as_mut().unwrap() };
179            let fut = unsafe { Pin::new_unchecked(fut) };
180            Future::poll(fut, cx).map(Some)
181        } else {
182            Poll::Ready(None)
183        }
184    }
185}
186
187/// Represents failure when joining a task.
188#[derive(Debug)]
189pub enum JoinError {
190    /// The task was cancelled (e.g., dropped before completing).
191    Cancelled,
192
193    /// The task panicked.
194    Panicked(Box<dyn Any + Send + 'static>),
195}
196
197impl fmt::Display for JoinError {
198    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
199        match self {
200            JoinError::Cancelled => write!(f, "task was cancelled"),
201            JoinError::Panicked(_) => write!(f, "task panicked"),
202        }
203    }
204}
205
206impl Error for JoinError {}
207
208impl JoinError {
209    /// Returns `true` if the task was cancelled.
210    pub fn is_cancelled(&self) -> bool {
211        matches!(self, JoinError::Cancelled)
212    }
213
214    /// Returns `true` if the task panicked.
215    pub fn is_panic(&self) -> bool {
216        matches!(self, JoinError::Panicked(_))
217    }
218}
219
220impl From<TokioJoinError> for JoinError {
221    fn from(err: TokioJoinError) -> Self {
222        if err.is_cancelled() {
223            JoinError::Cancelled
224        } else if err.is_panic() {
225            JoinError::Panicked(err.into_panic())
226        } else {
227            // Should never happen, but we guard anyway.
228            JoinError::Cancelled
229        }
230    }
231}