tokio_util/task/spawn_pinned.rs
1use futures_util::future::{AbortHandle, Abortable};
2use std::fmt;
3use std::fmt::{Debug, Formatter};
4use std::future::Future;
5use std::sync::atomic::{AtomicUsize, Ordering};
6use std::sync::Arc;
7use tokio::runtime::Builder;
8use tokio::sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender};
9use tokio::sync::oneshot;
10use tokio::task::{spawn_local, JoinHandle, LocalSet};
11
12/// A cloneable handle to a local pool, used for spawning `!Send` tasks.
13///
14/// Internally the local pool uses a [`tokio::task::LocalSet`] for each worker thread
15/// in the pool. Consequently you can also use [`tokio::task::spawn_local`] (which will
16/// execute on the same thread) inside the Future you supply to the various spawn methods
17/// of `LocalPoolHandle`.
18///
19/// [`tokio::task::LocalSet`]: tokio::task::LocalSet
20/// [`tokio::task::spawn_local`]: tokio::task::spawn_local
21///
22/// # Examples
23///
24/// ```
25/// # #[cfg(not(target_family = "wasm"))]
26/// # {
27/// use std::rc::Rc;
28/// use tokio::task;
29/// use tokio_util::task::LocalPoolHandle;
30///
31/// #[tokio::main(flavor = "current_thread")]
32/// async fn main() {
33/// let pool = LocalPoolHandle::new(5);
34///
35/// let output = pool.spawn_pinned(|| {
36/// // `data` is !Send + !Sync
37/// let data = Rc::new("local data");
38/// let data_clone = data.clone();
39///
40/// async move {
41/// task::spawn_local(async move {
42/// println!("{}", data_clone);
43/// });
44///
45/// data.to_string()
46/// }
47/// }).await.unwrap();
48/// println!("output: {}", output);
49/// }
50/// # }
51/// ```
52///
53#[derive(Clone)]
54pub struct LocalPoolHandle {
55 pool: Arc<LocalPool>,
56}
57
58impl LocalPoolHandle {
59 /// Create a new pool of threads to handle `!Send` tasks. Spawn tasks onto this
60 /// pool via [`LocalPoolHandle::spawn_pinned`].
61 ///
62 /// # Panics
63 ///
64 /// Panics if the pool size is less than one.
65 #[track_caller]
66 pub fn new(pool_size: usize) -> LocalPoolHandle {
67 assert!(pool_size > 0);
68
69 let workers = (0..pool_size)
70 .map(|_| LocalWorkerHandle::new_worker())
71 .collect();
72
73 let pool = Arc::new(LocalPool { workers });
74
75 LocalPoolHandle { pool }
76 }
77
78 /// Returns the number of threads of the Pool.
79 #[inline]
80 pub fn num_threads(&self) -> usize {
81 self.pool.workers.len()
82 }
83
84 /// Returns the number of tasks scheduled on each worker. The indices of the
85 /// worker threads correspond to the indices of the returned `Vec`.
86 pub fn get_task_loads_for_each_worker(&self) -> Vec<usize> {
87 self.pool
88 .workers
89 .iter()
90 .map(|worker| worker.task_count.load(Ordering::SeqCst))
91 .collect::<Vec<_>>()
92 }
93
94 /// Spawn a task onto a worker thread and pin it there so it can't be moved
95 /// off of the thread. Note that the future is not [`Send`], but the
96 /// [`FnOnce`] which creates it is.
97 ///
98 /// # Examples
99 /// ```
100 /// # #[cfg(not(target_family = "wasm"))]
101 /// # {
102 /// use std::rc::Rc;
103 /// use tokio_util::task::LocalPoolHandle;
104 ///
105 /// #[tokio::main]
106 /// async fn main() {
107 /// // Create the local pool
108 /// let pool = LocalPoolHandle::new(1);
109 ///
110 /// // Spawn a !Send future onto the pool and await it
111 /// let output = pool
112 /// .spawn_pinned(|| {
113 /// // Rc is !Send + !Sync
114 /// let local_data = Rc::new("test");
115 ///
116 /// // This future holds an Rc, so it is !Send
117 /// async move { local_data.to_string() }
118 /// })
119 /// .await
120 /// .unwrap();
121 ///
122 /// assert_eq!(output, "test");
123 /// }
124 /// # }
125 /// ```
126 pub fn spawn_pinned<F, Fut>(&self, create_task: F) -> JoinHandle<Fut::Output>
127 where
128 F: FnOnce() -> Fut,
129 F: Send + 'static,
130 Fut: Future + 'static,
131 Fut::Output: Send + 'static,
132 {
133 self.pool
134 .spawn_pinned(create_task, WorkerChoice::LeastBurdened)
135 }
136
137 /// Differs from `spawn_pinned` only in that you can choose a specific worker thread
138 /// of the pool, whereas `spawn_pinned` chooses the worker with the smallest
139 /// number of tasks scheduled.
140 ///
141 /// A worker thread is chosen by index. Indices are 0 based and the largest index
142 /// is given by `num_threads() - 1`
143 ///
144 /// # Panics
145 ///
146 /// This method panics if the index is out of bounds.
147 ///
148 /// # Examples
149 ///
150 /// This method can be used to spawn a task on all worker threads of the pool:
151 ///
152 /// ```
153 /// # #[cfg(not(target_family = "wasm"))]
154 /// # {
155 /// use tokio_util::task::LocalPoolHandle;
156 ///
157 /// #[tokio::main]
158 /// async fn main() {
159 /// const NUM_WORKERS: usize = 3;
160 /// let pool = LocalPoolHandle::new(NUM_WORKERS);
161 /// let handles = (0..pool.num_threads())
162 /// .map(|worker_idx| {
163 /// pool.spawn_pinned_by_idx(
164 /// || {
165 /// async {
166 /// "test"
167 /// }
168 /// },
169 /// worker_idx,
170 /// )
171 /// })
172 /// .collect::<Vec<_>>();
173 ///
174 /// for handle in handles {
175 /// handle.await.unwrap();
176 /// }
177 /// }
178 /// # }
179 /// ```
180 ///
181 #[track_caller]
182 pub fn spawn_pinned_by_idx<F, Fut>(&self, create_task: F, idx: usize) -> JoinHandle<Fut::Output>
183 where
184 F: FnOnce() -> Fut,
185 F: Send + 'static,
186 Fut: Future + 'static,
187 Fut::Output: Send + 'static,
188 {
189 self.pool
190 .spawn_pinned(create_task, WorkerChoice::ByIdx(idx))
191 }
192}
193
194impl Debug for LocalPoolHandle {
195 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
196 f.write_str("LocalPoolHandle")
197 }
198}
199
200enum WorkerChoice {
201 LeastBurdened,
202 ByIdx(usize),
203}
204
205struct LocalPool {
206 workers: Box<[LocalWorkerHandle]>,
207}
208
209impl LocalPool {
210 /// Spawn a `?Send` future onto a worker
211 #[track_caller]
212 fn spawn_pinned<F, Fut>(
213 &self,
214 create_task: F,
215 worker_choice: WorkerChoice,
216 ) -> JoinHandle<Fut::Output>
217 where
218 F: FnOnce() -> Fut,
219 F: Send + 'static,
220 Fut: Future + 'static,
221 Fut::Output: Send + 'static,
222 {
223 let (sender, receiver) = oneshot::channel();
224 let (worker, job_guard) = match worker_choice {
225 WorkerChoice::LeastBurdened => self.find_and_incr_least_burdened_worker(),
226 WorkerChoice::ByIdx(idx) => self.find_worker_by_idx(idx),
227 };
228 let worker_spawner = worker.spawner.clone();
229
230 // Spawn a future onto the worker's runtime so we can immediately return
231 // a join handle.
232 worker.runtime_handle.spawn(async move {
233 // Move the job guard into the task
234 let _job_guard = job_guard;
235
236 // Propagate aborts via Abortable/AbortHandle
237 let (abort_handle, abort_registration) = AbortHandle::new_pair();
238 let _abort_guard = AbortGuard(abort_handle);
239
240 // Inside the future we can't run spawn_local yet because we're not
241 // in the context of a LocalSet. We need to send create_task to the
242 // LocalSet task for spawning.
243 let spawn_task = Box::new(move || {
244 // Once we're in the LocalSet context we can call spawn_local
245 let join_handle =
246 spawn_local(
247 async move { Abortable::new(create_task(), abort_registration).await },
248 );
249
250 // Send the join handle back to the spawner. If sending fails,
251 // we assume the parent task was canceled, so cancel this task
252 // as well.
253 if let Err(join_handle) = sender.send(join_handle) {
254 join_handle.abort()
255 }
256 });
257
258 // Send the callback to the LocalSet task
259 if let Err(e) = worker_spawner.send(spawn_task) {
260 // Propagate the error as a panic in the join handle.
261 panic!("Failed to send job to worker: {e}");
262 }
263
264 // Wait for the task's join handle
265 let join_handle = match receiver.await {
266 Ok(handle) => handle,
267 Err(e) => {
268 // We sent the task successfully, but failed to get its
269 // join handle... We assume something happened to the worker
270 // and the task was not spawned. Propagate the error as a
271 // panic in the join handle.
272 panic!("Worker failed to send join handle: {e}");
273 }
274 };
275
276 // Wait for the task to complete
277 let join_result = join_handle.await;
278
279 match join_result {
280 Ok(Ok(output)) => output,
281 Ok(Err(_)) => {
282 // Pinned task was aborted. But that only happens if this
283 // task is aborted. So this is an impossible branch.
284 unreachable!(
285 "Reaching this branch means this task was previously \
286 aborted but it continued running anyways"
287 )
288 }
289 Err(e) => {
290 if e.is_panic() {
291 std::panic::resume_unwind(e.into_panic());
292 } else if e.is_cancelled() {
293 // No one else should have the join handle, so this is
294 // unexpected. Forward this error as a panic in the join
295 // handle.
296 panic!("spawn_pinned task was canceled: {e}");
297 } else {
298 // Something unknown happened (not a panic or
299 // cancellation). Forward this error as a panic in the
300 // join handle.
301 panic!("spawn_pinned task failed: {e}");
302 }
303 }
304 }
305 })
306 }
307
308 /// Find the worker with the least number of tasks, increment its task
309 /// count, and return its handle. Make sure to actually spawn a task on
310 /// the worker so the task count is kept consistent with load.
311 ///
312 /// A job count guard is also returned to ensure the task count gets
313 /// decremented when the job is done.
314 fn find_and_incr_least_burdened_worker(&self) -> (&LocalWorkerHandle, JobCountGuard) {
315 loop {
316 let (worker, task_count) = self
317 .workers
318 .iter()
319 .map(|worker| (worker, worker.task_count.load(Ordering::SeqCst)))
320 .min_by_key(|&(_, count)| count)
321 .expect("There must be more than one worker");
322
323 // Make sure the task count hasn't changed since when we choose this
324 // worker. Otherwise, restart the search.
325 if worker
326 .task_count
327 .compare_exchange(
328 task_count,
329 task_count + 1,
330 Ordering::SeqCst,
331 Ordering::Relaxed,
332 )
333 .is_ok()
334 {
335 return (worker, JobCountGuard(Arc::clone(&worker.task_count)));
336 }
337 }
338 }
339
340 #[track_caller]
341 fn find_worker_by_idx(&self, idx: usize) -> (&LocalWorkerHandle, JobCountGuard) {
342 let worker = &self.workers[idx];
343 worker.task_count.fetch_add(1, Ordering::SeqCst);
344
345 (worker, JobCountGuard(Arc::clone(&worker.task_count)))
346 }
347}
348
349/// Automatically decrements a worker's job count when a job finishes (when
350/// this gets dropped).
351struct JobCountGuard(Arc<AtomicUsize>);
352
353impl Drop for JobCountGuard {
354 fn drop(&mut self) {
355 // Decrement the job count
356 let previous_value = self.0.fetch_sub(1, Ordering::SeqCst);
357 debug_assert!(previous_value >= 1);
358 }
359}
360
361/// Calls abort on the handle when dropped.
362struct AbortGuard(AbortHandle);
363
364impl Drop for AbortGuard {
365 fn drop(&mut self) {
366 self.0.abort();
367 }
368}
369
370type PinnedFutureSpawner = Box<dyn FnOnce() + Send + 'static>;
371
372struct LocalWorkerHandle {
373 runtime_handle: tokio::runtime::Handle,
374 spawner: UnboundedSender<PinnedFutureSpawner>,
375 task_count: Arc<AtomicUsize>,
376}
377
378impl LocalWorkerHandle {
379 /// Create a new worker for executing pinned tasks
380 fn new_worker() -> LocalWorkerHandle {
381 let (sender, receiver) = unbounded_channel();
382 let runtime = Builder::new_current_thread()
383 .enable_all()
384 .build()
385 .expect("Failed to start a pinned worker thread runtime");
386 let runtime_handle = runtime.handle().clone();
387 let task_count = Arc::new(AtomicUsize::new(0));
388 let task_count_clone = Arc::clone(&task_count);
389
390 std::thread::spawn(|| Self::run(runtime, receiver, task_count_clone));
391
392 LocalWorkerHandle {
393 runtime_handle,
394 spawner: sender,
395 task_count,
396 }
397 }
398
399 fn run(
400 runtime: tokio::runtime::Runtime,
401 mut task_receiver: UnboundedReceiver<PinnedFutureSpawner>,
402 task_count: Arc<AtomicUsize>,
403 ) {
404 let local_set = LocalSet::new();
405 local_set.block_on(&runtime, async {
406 while let Some(spawn_task) = task_receiver.recv().await {
407 // Calls spawn_local(future)
408 (spawn_task)();
409 }
410 });
411
412 // If there are any tasks on the runtime associated with a LocalSet task
413 // that has already completed, but whose output has not yet been
414 // reported, let that task complete.
415 //
416 // Since the task_count is decremented when the runtime task exits,
417 // reading that counter lets us know if any such tasks completed during
418 // the call to `block_on`.
419 //
420 // Tasks on the LocalSet can't complete during this loop since they're
421 // stored on the LocalSet and we aren't accessing it.
422 let mut previous_task_count = task_count.load(Ordering::SeqCst);
423 loop {
424 // This call will also run tasks spawned on the runtime.
425 runtime.block_on(tokio::task::yield_now());
426 let new_task_count = task_count.load(Ordering::SeqCst);
427 if new_task_count == previous_task_count {
428 break;
429 } else {
430 previous_task_count = new_task_count;
431 }
432 }
433
434 // It's now no longer possible for a task on the runtime to be
435 // associated with a LocalSet task that has completed. Drop both the
436 // LocalSet and runtime to let tasks on the runtime be cancelled if and
437 // only if they are still on the LocalSet.
438 //
439 // Drop the LocalSet task first so that anyone awaiting the runtime
440 // JoinHandle will see the cancelled error after the LocalSet task
441 // destructor has completed.
442 drop(local_set);
443 drop(runtime);
444 }
445}