tokio_util/task/
spawn_pinned.rs

1use futures_util::future::{AbortHandle, Abortable};
2use std::fmt;
3use std::fmt::{Debug, Formatter};
4use std::future::Future;
5use std::sync::atomic::{AtomicUsize, Ordering};
6use std::sync::Arc;
7use tokio::runtime::Builder;
8use tokio::sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender};
9use tokio::sync::oneshot;
10use tokio::task::{spawn_local, JoinHandle, LocalSet};
11
12/// A cloneable handle to a local pool, used for spawning `!Send` tasks.
13///
14/// Internally the local pool uses a [`tokio::task::LocalSet`] for each worker thread
15/// in the pool. Consequently you can also use [`tokio::task::spawn_local`] (which will
16/// execute on the same thread) inside the Future you supply to the various spawn methods
17/// of `LocalPoolHandle`.
18///
19/// [`tokio::task::LocalSet`]: tokio::task::LocalSet
20/// [`tokio::task::spawn_local`]: tokio::task::spawn_local
21///
22/// # Examples
23///
24/// ```
25/// # #[cfg(not(target_family = "wasm"))]
26/// # {
27/// use std::rc::Rc;
28/// use tokio::task;
29/// use tokio_util::task::LocalPoolHandle;
30///
31/// #[tokio::main(flavor = "current_thread")]
32/// async fn main() {
33///     let pool = LocalPoolHandle::new(5);
34///
35///     let output = pool.spawn_pinned(|| {
36///         // `data` is !Send + !Sync
37///         let data = Rc::new("local data");
38///         let data_clone = data.clone();
39///
40///         async move {
41///             task::spawn_local(async move {
42///                 println!("{}", data_clone);
43///             });
44///
45///             data.to_string()
46///         }
47///     }).await.unwrap();
48///     println!("output: {}", output);
49/// }
50/// # }
51/// ```
52///
53#[derive(Clone)]
54pub struct LocalPoolHandle {
55    pool: Arc<LocalPool>,
56}
57
58impl LocalPoolHandle {
59    /// Create a new pool of threads to handle `!Send` tasks. Spawn tasks onto this
60    /// pool via [`LocalPoolHandle::spawn_pinned`].
61    ///
62    /// # Panics
63    ///
64    /// Panics if the pool size is less than one.
65    #[track_caller]
66    pub fn new(pool_size: usize) -> LocalPoolHandle {
67        assert!(pool_size > 0);
68
69        let workers = (0..pool_size)
70            .map(|_| LocalWorkerHandle::new_worker())
71            .collect();
72
73        let pool = Arc::new(LocalPool { workers });
74
75        LocalPoolHandle { pool }
76    }
77
78    /// Returns the number of threads of the Pool.
79    #[inline]
80    pub fn num_threads(&self) -> usize {
81        self.pool.workers.len()
82    }
83
84    /// Returns the number of tasks scheduled on each worker. The indices of the
85    /// worker threads correspond to the indices of the returned `Vec`.
86    pub fn get_task_loads_for_each_worker(&self) -> Vec<usize> {
87        self.pool
88            .workers
89            .iter()
90            .map(|worker| worker.task_count.load(Ordering::SeqCst))
91            .collect::<Vec<_>>()
92    }
93
94    /// Spawn a task onto a worker thread and pin it there so it can't be moved
95    /// off of the thread. Note that the future is not [`Send`], but the
96    /// [`FnOnce`] which creates it is.
97    ///
98    /// # Examples
99    /// ```
100    /// # #[cfg(not(target_family = "wasm"))]
101    /// # {
102    /// use std::rc::Rc;
103    /// use tokio_util::task::LocalPoolHandle;
104    ///
105    /// #[tokio::main]
106    /// async fn main() {
107    ///     // Create the local pool
108    ///     let pool = LocalPoolHandle::new(1);
109    ///
110    ///     // Spawn a !Send future onto the pool and await it
111    ///     let output = pool
112    ///         .spawn_pinned(|| {
113    ///             // Rc is !Send + !Sync
114    ///             let local_data = Rc::new("test");
115    ///
116    ///             // This future holds an Rc, so it is !Send
117    ///             async move { local_data.to_string() }
118    ///         })
119    ///         .await
120    ///         .unwrap();
121    ///
122    ///     assert_eq!(output, "test");
123    /// }
124    /// # }
125    /// ```
126    pub fn spawn_pinned<F, Fut>(&self, create_task: F) -> JoinHandle<Fut::Output>
127    where
128        F: FnOnce() -> Fut,
129        F: Send + 'static,
130        Fut: Future + 'static,
131        Fut::Output: Send + 'static,
132    {
133        self.pool
134            .spawn_pinned(create_task, WorkerChoice::LeastBurdened)
135    }
136
137    /// Differs from `spawn_pinned` only in that you can choose a specific worker thread
138    /// of the pool, whereas `spawn_pinned` chooses the worker with the smallest
139    /// number of tasks scheduled.
140    ///
141    /// A worker thread is chosen by index. Indices are 0 based and the largest index
142    /// is given by `num_threads() - 1`
143    ///
144    /// # Panics
145    ///
146    /// This method panics if the index is out of bounds.
147    ///
148    /// # Examples
149    ///
150    /// This method can be used to spawn a task on all worker threads of the pool:
151    ///
152    /// ```
153    /// # #[cfg(not(target_family = "wasm"))]
154    /// # {
155    /// use tokio_util::task::LocalPoolHandle;
156    ///
157    /// #[tokio::main]
158    /// async fn main() {
159    ///     const NUM_WORKERS: usize = 3;
160    ///     let pool = LocalPoolHandle::new(NUM_WORKERS);
161    ///     let handles = (0..pool.num_threads())
162    ///         .map(|worker_idx| {
163    ///             pool.spawn_pinned_by_idx(
164    ///                 || {
165    ///                     async {
166    ///                         "test"
167    ///                     }
168    ///                 },
169    ///                 worker_idx,
170    ///             )
171    ///         })
172    ///         .collect::<Vec<_>>();
173    ///
174    ///     for handle in handles {
175    ///         handle.await.unwrap();
176    ///     }
177    /// }
178    /// # }
179    /// ```
180    ///
181    #[track_caller]
182    pub fn spawn_pinned_by_idx<F, Fut>(&self, create_task: F, idx: usize) -> JoinHandle<Fut::Output>
183    where
184        F: FnOnce() -> Fut,
185        F: Send + 'static,
186        Fut: Future + 'static,
187        Fut::Output: Send + 'static,
188    {
189        self.pool
190            .spawn_pinned(create_task, WorkerChoice::ByIdx(idx))
191    }
192}
193
194impl Debug for LocalPoolHandle {
195    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
196        f.write_str("LocalPoolHandle")
197    }
198}
199
200enum WorkerChoice {
201    LeastBurdened,
202    ByIdx(usize),
203}
204
205struct LocalPool {
206    workers: Box<[LocalWorkerHandle]>,
207}
208
209impl LocalPool {
210    /// Spawn a `?Send` future onto a worker
211    #[track_caller]
212    fn spawn_pinned<F, Fut>(
213        &self,
214        create_task: F,
215        worker_choice: WorkerChoice,
216    ) -> JoinHandle<Fut::Output>
217    where
218        F: FnOnce() -> Fut,
219        F: Send + 'static,
220        Fut: Future + 'static,
221        Fut::Output: Send + 'static,
222    {
223        let (sender, receiver) = oneshot::channel();
224        let (worker, job_guard) = match worker_choice {
225            WorkerChoice::LeastBurdened => self.find_and_incr_least_burdened_worker(),
226            WorkerChoice::ByIdx(idx) => self.find_worker_by_idx(idx),
227        };
228        let worker_spawner = worker.spawner.clone();
229
230        // Spawn a future onto the worker's runtime so we can immediately return
231        // a join handle.
232        worker.runtime_handle.spawn(async move {
233            // Move the job guard into the task
234            let _job_guard = job_guard;
235
236            // Propagate aborts via Abortable/AbortHandle
237            let (abort_handle, abort_registration) = AbortHandle::new_pair();
238            let _abort_guard = AbortGuard(abort_handle);
239
240            // Inside the future we can't run spawn_local yet because we're not
241            // in the context of a LocalSet. We need to send create_task to the
242            // LocalSet task for spawning.
243            let spawn_task = Box::new(move || {
244                // Once we're in the LocalSet context we can call spawn_local
245                let join_handle =
246                    spawn_local(
247                        async move { Abortable::new(create_task(), abort_registration).await },
248                    );
249
250                // Send the join handle back to the spawner. If sending fails,
251                // we assume the parent task was canceled, so cancel this task
252                // as well.
253                if let Err(join_handle) = sender.send(join_handle) {
254                    join_handle.abort()
255                }
256            });
257
258            // Send the callback to the LocalSet task
259            if let Err(e) = worker_spawner.send(spawn_task) {
260                // Propagate the error as a panic in the join handle.
261                panic!("Failed to send job to worker: {e}");
262            }
263
264            // Wait for the task's join handle
265            let join_handle = match receiver.await {
266                Ok(handle) => handle,
267                Err(e) => {
268                    // We sent the task successfully, but failed to get its
269                    // join handle... We assume something happened to the worker
270                    // and the task was not spawned. Propagate the error as a
271                    // panic in the join handle.
272                    panic!("Worker failed to send join handle: {e}");
273                }
274            };
275
276            // Wait for the task to complete
277            let join_result = join_handle.await;
278
279            match join_result {
280                Ok(Ok(output)) => output,
281                Ok(Err(_)) => {
282                    // Pinned task was aborted. But that only happens if this
283                    // task is aborted. So this is an impossible branch.
284                    unreachable!(
285                        "Reaching this branch means this task was previously \
286                         aborted but it continued running anyways"
287                    )
288                }
289                Err(e) => {
290                    if e.is_panic() {
291                        std::panic::resume_unwind(e.into_panic());
292                    } else if e.is_cancelled() {
293                        // No one else should have the join handle, so this is
294                        // unexpected. Forward this error as a panic in the join
295                        // handle.
296                        panic!("spawn_pinned task was canceled: {e}");
297                    } else {
298                        // Something unknown happened (not a panic or
299                        // cancellation). Forward this error as a panic in the
300                        // join handle.
301                        panic!("spawn_pinned task failed: {e}");
302                    }
303                }
304            }
305        })
306    }
307
308    /// Find the worker with the least number of tasks, increment its task
309    /// count, and return its handle. Make sure to actually spawn a task on
310    /// the worker so the task count is kept consistent with load.
311    ///
312    /// A job count guard is also returned to ensure the task count gets
313    /// decremented when the job is done.
314    fn find_and_incr_least_burdened_worker(&self) -> (&LocalWorkerHandle, JobCountGuard) {
315        loop {
316            let (worker, task_count) = self
317                .workers
318                .iter()
319                .map(|worker| (worker, worker.task_count.load(Ordering::SeqCst)))
320                .min_by_key(|&(_, count)| count)
321                .expect("There must be more than one worker");
322
323            // Make sure the task count hasn't changed since when we choose this
324            // worker. Otherwise, restart the search.
325            if worker
326                .task_count
327                .compare_exchange(
328                    task_count,
329                    task_count + 1,
330                    Ordering::SeqCst,
331                    Ordering::Relaxed,
332                )
333                .is_ok()
334            {
335                return (worker, JobCountGuard(Arc::clone(&worker.task_count)));
336            }
337        }
338    }
339
340    #[track_caller]
341    fn find_worker_by_idx(&self, idx: usize) -> (&LocalWorkerHandle, JobCountGuard) {
342        let worker = &self.workers[idx];
343        worker.task_count.fetch_add(1, Ordering::SeqCst);
344
345        (worker, JobCountGuard(Arc::clone(&worker.task_count)))
346    }
347}
348
349/// Automatically decrements a worker's job count when a job finishes (when
350/// this gets dropped).
351struct JobCountGuard(Arc<AtomicUsize>);
352
353impl Drop for JobCountGuard {
354    fn drop(&mut self) {
355        // Decrement the job count
356        let previous_value = self.0.fetch_sub(1, Ordering::SeqCst);
357        debug_assert!(previous_value >= 1);
358    }
359}
360
361/// Calls abort on the handle when dropped.
362struct AbortGuard(AbortHandle);
363
364impl Drop for AbortGuard {
365    fn drop(&mut self) {
366        self.0.abort();
367    }
368}
369
370type PinnedFutureSpawner = Box<dyn FnOnce() + Send + 'static>;
371
372struct LocalWorkerHandle {
373    runtime_handle: tokio::runtime::Handle,
374    spawner: UnboundedSender<PinnedFutureSpawner>,
375    task_count: Arc<AtomicUsize>,
376}
377
378impl LocalWorkerHandle {
379    /// Create a new worker for executing pinned tasks
380    fn new_worker() -> LocalWorkerHandle {
381        let (sender, receiver) = unbounded_channel();
382        let runtime = Builder::new_current_thread()
383            .enable_all()
384            .build()
385            .expect("Failed to start a pinned worker thread runtime");
386        let runtime_handle = runtime.handle().clone();
387        let task_count = Arc::new(AtomicUsize::new(0));
388        let task_count_clone = Arc::clone(&task_count);
389
390        std::thread::spawn(|| Self::run(runtime, receiver, task_count_clone));
391
392        LocalWorkerHandle {
393            runtime_handle,
394            spawner: sender,
395            task_count,
396        }
397    }
398
399    fn run(
400        runtime: tokio::runtime::Runtime,
401        mut task_receiver: UnboundedReceiver<PinnedFutureSpawner>,
402        task_count: Arc<AtomicUsize>,
403    ) {
404        let local_set = LocalSet::new();
405        local_set.block_on(&runtime, async {
406            while let Some(spawn_task) = task_receiver.recv().await {
407                // Calls spawn_local(future)
408                (spawn_task)();
409            }
410        });
411
412        // If there are any tasks on the runtime associated with a LocalSet task
413        // that has already completed, but whose output has not yet been
414        // reported, let that task complete.
415        //
416        // Since the task_count is decremented when the runtime task exits,
417        // reading that counter lets us know if any such tasks completed during
418        // the call to `block_on`.
419        //
420        // Tasks on the LocalSet can't complete during this loop since they're
421        // stored on the LocalSet and we aren't accessing it.
422        let mut previous_task_count = task_count.load(Ordering::SeqCst);
423        loop {
424            // This call will also run tasks spawned on the runtime.
425            runtime.block_on(tokio::task::yield_now());
426            let new_task_count = task_count.load(Ordering::SeqCst);
427            if new_task_count == previous_task_count {
428                break;
429            } else {
430                previous_task_count = new_task_count;
431            }
432        }
433
434        // It's now no longer possible for a task on the runtime to be
435        // associated with a LocalSet task that has completed. Drop both the
436        // LocalSet and runtime to let tasks on the runtime be cancelled if and
437        // only if they are still on the LocalSet.
438        //
439        // Drop the LocalSet task first so that anyone awaiting the runtime
440        // JoinHandle will see the cancelled error after the LocalSet task
441        // destructor has completed.
442        drop(local_set);
443        drop(runtime);
444    }
445}