scoped_join_set/lib.rs
1use tokio::{
2 runtime::{Handle, RuntimeFlavor},
3 task::{AbortHandle, Id, JoinSet},
4};
5
6use std::{
7 any::Any,
8 cell::UnsafeCell,
9 collections::HashMap,
10 error::Error,
11 fmt,
12 future::Future,
13 pin::Pin,
14 sync::{Arc, Weak},
15 task::{Context, Poll},
16};
17
18type TokioJoinError = tokio::task::JoinError;
19
20/// A scoped variant of [`tokio::task::JoinSet`] that ensures all futures outlive `'scope`,
21/// while allowing them to be spawned and awaited dynamically.
22///
23/// This structure tracks all spawned futures internally with a holder, allowing weak references
24/// to be used for polling in a way that is memory-safe and supports scoped lifetimes.
25///
26/// # Safety
27///
28/// There is only one `FutureHolder` and one `WeakFuture` at a time per task.
29/// If the `WeakFuture` successfully upgrades, the future is guaranteed to still be valid
30/// because `FutureHolder` holds the strong reference until it is dropped (after task join).
31#[derive(Default)]
32pub struct ScopedJoinSet<'scope, T>
33where
34 T: 'static,
35{
36 join_set: JoinSet<Option<T>>,
37 holders: HashMap<Id, FutureHolder<'scope, T>>,
38}
39
40impl<'scope, T> ScopedJoinSet<'scope, T>
41where
42 T: 'static,
43{
44 /// Creates a new, empty `ScopedJoinSet`.
45 pub fn new() -> Self {
46 Self {
47 join_set: JoinSet::new(),
48 holders: HashMap::new(),
49 }
50 }
51
52 /// Spawns a new task on the join set.
53 ///
54 /// The future must be `'scope`-bound and `Send`.
55 /// Internally, the future is wrapped in a `FutureHolder` and a weak reference is
56 /// passed to the join set for execution.
57 ///
58 /// # Panics
59 ///
60 /// This method panics if called outside of a Tokio runtime.
61 pub fn spawn<F>(&mut self, task: F)
62 where
63 F: Future<Output = T> + Send + 'scope,
64 T: Send,
65 {
66 let strong: Arc<UnsafeCell<dyn Future<Output = T> + Send + 'scope>> =
67 Arc::new(UnsafeCell::new(task));
68
69 let weak_future = WeakFuture {
70 future: unsafe {
71 std::mem::transmute::<
72 Weak<UnsafeCell<dyn Future<Output = T> + Send + 'scope>>,
73 Weak<UnsafeCell<dyn Future<Output = T> + Send>>,
74 >(Arc::downgrade(&strong))
75 },
76 };
77 let handle = self.join_set.spawn(weak_future);
78 let holder = FutureHolder {
79 abort_handle: handle.clone(),
80 future: strong,
81 };
82 self.holders.insert(handle.id(), holder);
83 }
84
85 /// Returns `true` if there are no remaining tasks in the join set.
86 pub fn is_empty(&self) -> bool {
87 self.join_set.is_empty()
88 }
89
90 /// Waits for the next task to complete.
91 ///
92 /// Returns:
93 /// - `Some(Ok(T))` if a task completed successfully.
94 /// - `Some(Err(JoinError))` if the task was cancelled or panicked.
95 /// - `None` if there are no more tasks in the set.
96 ///
97 /// # Cancel Safety
98 ///
99 /// This method is **cancellation safe** — if the `join_next().await` is itself
100 /// cancelled (e.g., due to timeout or a `select!` branch winning), no state is lost.
101 /// The underlying task remains in the join set and will be yielded again on the next
102 /// call to `join_next()`.
103 ///
104 /// Internally, the `JoinSet`'s `join_next_with_id()` ensures that the task is not removed
105 /// from the set unless it has completed and its result has been received.
106 ///
107 /// The associated future holder is only dropped and removed from internal tracking
108 /// once the task finishes and is returned from `join_next()`.
109 pub async fn join_next(&mut self) -> Option<Result<T, JoinError>> {
110 match self.join_set.join_next_with_id().await? {
111 Ok((id, Some(value))) => {
112 self.holders.remove(&id);
113 Some(Ok(value))
114 }
115 Ok((id, None)) => {
116 self.holders.remove(&id);
117 Some(Err(JoinError::Cancelled))
118 }
119 Err(error) => {
120 self.holders.remove(&error.id());
121 Some(Err(error.into()))
122 }
123 }
124 }
125}
126
127/// Holds a strong reference to the future so it can be upgraded by the `WeakFuture`.
128///
129/// Internally uses an `Arc<UnsafeCell<dyn Future>>` to allow safe access from the
130/// `poll` method in `WeakFuture`. Only `FutureHolder` owns the strong reference;
131/// once dropped, the weak reference becomes invalid and polling returns `None`.
132struct FutureHolder<'scope, T>
133where
134 T: 'static,
135{
136 abort_handle: AbortHandle,
137 future: Arc<UnsafeCell<dyn Future<Output = T> + Send + 'scope>>,
138}
139
140impl<'scope, T> Drop for FutureHolder<'scope, T> {
141 fn drop(&mut self) {
142 self.abort_handle.abort();
143 if Handle::current().runtime_flavor() == RuntimeFlavor::CurrentThread {
144 return;
145 }
146 if Arc::strong_count(&self.future) > 1 {
147 ::tokio::task::block_in_place(|| {
148 let mut spin = 0;
149 while Arc::strong_count(&self.future) > 1 {
150 spin += 1;
151 if spin < 10 {
152 std::hint::spin_loop();
153 } else {
154 std::thread::yield_now();
155 }
156 }
157 });
158 }
159 }
160}
161
162/// A weak reference to a future, used inside the join set.
163///
164/// If the weak reference can be upgraded, the future is still alive and will be polled.
165/// Once the strong reference (`FutureHolder`) is dropped, the weak reference cannot be upgraded,
166/// and polling will return `Poll::Ready(None)`.
167struct WeakFuture<T> {
168 future: Weak<UnsafeCell<dyn Future<Output = T> + Send>>,
169}
170
171unsafe impl<T: Send> Send for WeakFuture<T> {}
172
173impl<T> Future for WeakFuture<T> {
174 type Output = Option<T>;
175
176 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
177 if let Some(fut) = self.future.upgrade() {
178 let fut = unsafe { fut.get().as_mut().unwrap() };
179 let fut = unsafe { Pin::new_unchecked(fut) };
180 Future::poll(fut, cx).map(Some)
181 } else {
182 Poll::Ready(None)
183 }
184 }
185}
186
187/// Represents failure when joining a task.
188#[derive(Debug)]
189pub enum JoinError {
190 /// The task was cancelled (e.g., dropped before completing).
191 Cancelled,
192
193 /// The task panicked.
194 Panicked(Box<dyn Any + Send + 'static>),
195}
196
197impl fmt::Display for JoinError {
198 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
199 match self {
200 JoinError::Cancelled => write!(f, "task was cancelled"),
201 JoinError::Panicked(_) => write!(f, "task panicked"),
202 }
203 }
204}
205
206impl Error for JoinError {}
207
208impl JoinError {
209 /// Returns `true` if the task was cancelled.
210 pub fn is_cancelled(&self) -> bool {
211 matches!(self, JoinError::Cancelled)
212 }
213
214 /// Returns `true` if the task panicked.
215 pub fn is_panic(&self) -> bool {
216 matches!(self, JoinError::Panicked(_))
217 }
218}
219
220impl From<TokioJoinError> for JoinError {
221 fn from(err: TokioJoinError) -> Self {
222 if err.is_cancelled() {
223 JoinError::Cancelled
224 } else if err.is_panic() {
225 JoinError::Panicked(err.into_panic())
226 } else {
227 // Should never happen, but we guard anyway.
228 JoinError::Cancelled
229 }
230 }
231}