tokio_tasker/tasker.rs
1use std::fmt;
2use std::future::Future;
3use std::mem;
4use std::pin::Pin;
5use std::sync::atomic::{AtomicBool, AtomicU32, Ordering};
6use std::sync::Arc;
7use std::task::{Context, Poll, Waker};
8use std::vec;
9
10use futures_util::stream::FuturesUnordered;
11use futures_util::task::noop_waker_ref;
12use futures_util::{FutureExt as _, StreamExt};
13use parking_lot::Mutex;
14use tokio::sync::Notify;
15use tokio::task::{self, JoinError, JoinHandle};
16
17use crate::{JoinStream, Signaller, Stopper};
18
19#[derive(Default)]
20pub(crate) struct Handles {
21 handles: FuturesUnordered<JoinHandle<()>>,
22 waker: Option<Waker>,
23}
24
25impl Handles {
26 fn new() -> Self {
27 Self {
28 handles: FuturesUnordered::new(),
29 waker: None,
30 }
31 }
32
33 fn wake(&mut self) {
34 self.waker.take().map(Waker::wake);
35 // ^ notifies runtime to repoll JoinStream, if any
36 }
37
38 fn push(&mut self, handle: JoinHandle<()>) {
39 self.handles.push(handle);
40 self.wake();
41 }
42
43 pub(crate) fn set_waker(&mut self, cx: &Context<'_>) {
44 self.waker = Some(cx.waker().clone());
45 }
46
47 pub(crate) fn poll_next(
48 &mut self,
49 cx: &mut Context<'_>,
50 ) -> Poll<Option<Result<(), JoinError>>> {
51 self.handles.poll_next_unpin(cx)
52 }
53}
54
55/// Internal data shared by [`Tasker`] as well as [`Stopper`] clones.
56pub(crate) struct Shared {
57 pub(crate) handles: Mutex<Handles>,
58 /// Number of outstanding `Tasker` clones.
59 // NB. We can't use Arc's strong count as `Stopper` also needs a (strong) clone.
60 num_clones: AtomicU32,
61 finished_clones: AtomicU32,
62 pub(crate) stopped: AtomicBool,
63 pub(crate) notify_stop: Notify,
64}
65
66impl Shared {
67 pub(crate) fn new() -> Self {
68 Self {
69 handles: Mutex::new(Handles::new()),
70 num_clones: AtomicU32::new(1),
71 finished_clones: AtomicU32::new(0),
72 stopped: AtomicBool::new(false),
73 notify_stop: Notify::new(),
74 }
75 }
76
77 /// Returns `true` if this call signalled stopping or `false`
78 /// if the [`Tasker`] was already stopped.
79 pub(crate) fn stop(&self) -> bool {
80 let stop = !self.stopped.swap(true, Ordering::SeqCst);
81 if stop {
82 self.notify_stop.notify_waiters();
83 }
84
85 stop
86 }
87
88 pub(crate) fn all_finished(&self) -> bool {
89 self.finished_clones.load(Ordering::SeqCst) == self.num_clones.load(Ordering::SeqCst)
90 }
91
92 pub(crate) fn ptr(&self) -> *const Self {
93 self as _
94 }
95}
96
97impl Default for Shared {
98 fn default() -> Self {
99 Self::new()
100 }
101}
102
103/// Manages a group of tasks.
104///
105/// _See [library-level][crate] documentation._
106pub struct Tasker {
107 shared: Pin<Arc<Shared>>,
108}
109
110impl Tasker {
111 /// Constructs and empty task set.
112 pub fn new() -> Self {
113 Self {
114 shared: Arc::pin(Shared::new()),
115 }
116 }
117
118 /// Add a tokio task handle to the task group.
119 ///
120 /// It is your responsibility to make sure the task
121 /// is stopped by this group's [`Stopper`] future.
122 pub fn add_handle(&self, handle: JoinHandle<()>) {
123 self.shared.handles.lock().push(handle);
124 }
125
126 /// Spawn a `!Send` future the local task set and add its handle to the task group.
127 ///
128 /// It is your responsibility to make sure the task
129 /// is stopped by this group's [`Stopper`] future.
130 pub fn spawn<F>(&self, future: F)
131 where
132 F: Future + Send + 'static,
133 F::Output: Send + 'static,
134 {
135 let handle = task::spawn(future.map(|_| ()));
136 self.add_handle(handle);
137 }
138
139 /// Spawn a tokio task and add its handle to the task group.
140 /// See [`tokio::task::spawn_local()`].
141 ///
142 /// It is your responsibility to make sure the task
143 /// is stopped by this group's [`Stopper`] future.
144 pub fn spawn_local<F>(&self, future: F)
145 where
146 F: Future + 'static,
147 F::Output: 'static,
148 {
149 let handle = task::spawn_local(future.map(|_| ()));
150 self.add_handle(handle);
151 }
152
153 /// Dispense a [`Stopper`], a future that will resolve once `.stop()`
154 /// is called on any `Tasker` (or signaller [`Signaller`]) clone.
155 ///
156 /// The `Stopper` future can be used with our [`.unless()`]
157 /// extension to stop `Future`s, with [`.take_until()`] to stop `Stream`s,
158 /// as part of `tokio::select!()`, and similar...
159 ///
160 /// [`.unless()`]: crate::FutureExt::unless()
161 /// [`.take_until()`]: futures_util::stream::StreamExt::take_until()
162 pub fn stopper(&self) -> Stopper {
163 Stopper::new(&self.shared)
164 }
165
166 /// Stop the tasks in the group.
167 ///
168 /// This will resolve all [`Stopper`] futures (including ones obtained after this call).
169 ///
170 /// Returns `true` if this was the first effective stop call
171 /// or `false` if the group was already signalled to stop.
172 pub fn stop(&self) -> bool {
173 self.shared.stop()
174 }
175
176 /// Dispense a [`Signaller`], a `Tasker` clone which can be used to signal stopping,
177 /// but doesn't require dropping or calling [`finish()`][Tasker::finish()].
178 pub fn signaller(&self) -> Signaller {
179 let arc = Pin::into_inner(self.shared.clone());
180 let weak = Arc::downgrade(&arc);
181 Signaller::new(weak)
182 }
183
184 /// `true` if stopping was already signalled.
185 pub fn is_stopped(&self) -> bool {
186 self.shared.stopped.load(Ordering::SeqCst)
187 }
188
189 /// Number of tasks currently belonging to this task group.
190 pub fn num_tasks(&self) -> usize {
191 self.shared.handles.lock().handles.len()
192 }
193
194 fn mark_finished(&self) {
195 self.shared.finished_clones.fetch_add(1, Ordering::SeqCst);
196 self.shared.handles.lock().wake();
197 }
198
199 /// Join all the tasks in the group.
200 ///
201 /// **This function will panic if any of the tasks panicked**.
202 /// Use [`try_join()`][Tasker::try_join()] if you need to handle task panics yourself.
203 ///
204 /// Note that `join()` will only return once all other `Tasker` clones are finished
205 /// (via [`finish()`][Tasker::finish()] or drop).
206 pub async fn join(self) {
207 let mut join_stream = self.join_stream();
208 while let Some(handle) = join_stream.next().await {
209 handle.expect("Join error");
210 }
211 }
212
213 /// Join all the tasks in the group, returning a vector of join results
214 /// (ie. results of tokio's [`JoinHandle`][tokio::task::JoinHandle]).
215 ///
216 /// Note that `try_join()` will only return once all other `Tasker` clones are finished
217 /// (via [`finish()`][Tasker::finish()] or drop).
218 pub async fn try_join(self) -> Vec<Result<(), JoinError>> {
219 self.join_stream().collect().await
220 }
221
222 /// Returns a `Stream` which yields join results as they become available,
223 /// ie. as the tasks terminate.
224 ///
225 /// If any of the tasks terminates earlier than others, such as due to a panic,
226 /// its result should be readily avialable through this stream.
227 ///
228 /// The join stream stops after all tasks are joined *and* all other `Tasker`
229 /// clones are finished (via [`finish()`][Tasker::finish()] or drop).
230 pub fn join_stream(self) -> JoinStream {
231 JoinStream::new(self.shared.clone())
232 // self is dropped and marked finished
233 }
234
235 /// Mark this `Tasker` clone as finished.
236 ///
237 /// This has the same effect as dropping the clone, but is more explicit.
238 /// This lets the task group know that no new tasks will be added through this clone.
239 ///
240 /// All `Tasker` clones need to be finished/dropped in order for [`.join()`][Tasker::join()]
241 /// to be able to join all the tasks.
242 ///
243 /// Use [`.signaller()`][Tasker::signaller()] to get a special `Tasker` clone that doesn't
244 /// need to be dropped and can be used to `.stop()` the task group.
245 pub fn finish(self) {
246 mem::drop(self); // Drop impl will call mark_finished()
247 }
248
249 /// Poll tasks once and join those that are finished.
250 ///
251 /// Handles to tasks that are finished executing will be
252 /// removed from the internal storage.
253 ///
254 /// Returns the number of tasks that were joined.
255 ///
256 /// **This function will panic if any of the tasks panicked**.
257 /// Use [`try_poll_join()`][Tasker::try_poll_join()] if you need to handle task panics yourself.
258 pub fn poll_join(&self) -> usize {
259 let mut handles = self.shared.handles.lock();
260 let mut cx_noop = Context::from_waker(noop_waker_ref());
261
262 let mut num_joined = 0;
263 while let Poll::Ready(Some(result)) = handles.handles.poll_next_unpin(&mut cx_noop) {
264 result.expect("Join error");
265 num_joined += 1;
266 }
267
268 num_joined
269 }
270
271 /// Poll tasks once and join those that are already done.
272 ///
273 /// Handles to tasks that are already finished executing will be joined
274 /// and removed from the internal storage.
275 ///
276 /// Returns vector of join results of tasks that were joined
277 /// (may be empty if no tasks could've been joined).
278 pub fn try_poll_join(&self) -> Vec<Result<(), JoinError>> {
279 let mut handles = self.shared.handles.lock();
280
281 let mut cx_noop = Context::from_waker(noop_waker_ref());
282 let mut ready_results = vec![];
283
284 while let Poll::Ready(Some(result)) = handles.handles.poll_next_unpin(&mut cx_noop) {
285 ready_results.push(result);
286 }
287
288 ready_results
289 }
290}
291
292impl Default for Tasker {
293 fn default() -> Self {
294 Self::new()
295 }
296}
297
298impl Clone for Tasker {
299 fn clone(&self) -> Self {
300 self.shared.num_clones.fetch_add(1, Ordering::SeqCst);
301
302 Self {
303 shared: self.shared.clone(),
304 }
305 }
306}
307
308impl Drop for Tasker {
309 fn drop(&mut self) {
310 self.mark_finished();
311 }
312}
313
314impl fmt::Debug for Tasker {
315 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
316 f.debug_tuple("Tasker").field(&self.shared.ptr()).finish()
317 }
318}