switchyard/
lib.rs

1//! Real-time compute-focused async executor with job pools, thread-local data, and priorities.
2//!
3//! # Example
4//!
5//! ```rust
6//! use switchyard::Switchyard;
7//! use switchyard::threads::{thread_info, one_to_one};
8//! // Create a new switchyard with one job pool and empty thread local data
9//! let yard = Switchyard::new(one_to_one(thread_info(), Some("thread-name")), ||()).unwrap();
10//!
11//! // Spawn a task on pool 0 and priority 10 and get a JoinHandle
12//! let handle = yard.spawn(10, async move { 5 + 5 });
13//! // Spawn a lower priority task on the same pool
14//! let handle2 = yard.spawn(0, async move { 2 + 2 });
15//!
16//! // Wait on the results
17//! # futures_executor::block_on(async {
18//! assert_eq!(handle.await + handle2.await, 14);
19//! # });
20//! ```
21//!
22//! # How Switchyard is Different
23//!
24//! Switchyard is different from other existing async executors, focusing on situations where
25//! precise control of threads and execution order is needed. One such situation is using
26//! task parallelism to parallelize a compute workload.
27//!
28//! ## Priorites
29//!
30//! Each task has a priority and tasks are ran in order from high priority to low priority.
31//!
32//! ```rust
33//! # use switchyard::{Switchyard, threads::{thread_info, one_to_one}};
34//! # let yard = Switchyard::new(one_to_one(thread_info(), Some("thread-name")), ||()).unwrap();
35//! // Spawn task with lowest priority.
36//! yard.spawn(0, async move { /* ... */ });
37//! // Spawn task with higher priority. If both tasks are waiting, this one will run first.
38//! yard.spawn(10, async move { /* ... */ });
39//! ```
40//!
41//! ## Thread Local Data
42//!
43//! Each yard has some thread local data that can be accessed using [`spawn_local`](Switchyard::spawn_local).
44//! Both the thread local data and the future generated by the async function passed to [`spawn_local`](Switchyard::spawn_local)
45//! may be `!Send` and `!Sync`. The future will only be resumed on the thread that created it.
46//!
47//! ```rust
48//! # use switchyard::{Switchyard, threads::{thread_info, one_to_one}};
49//! # use std::cell::Cell;
50//! // Create yard with thread local data. The data is !Sync.
51//! let yard = Switchyard::new(one_to_one(thread_info(), Some("thread-name")), || Cell::new(42)).unwrap();
52//!
53//! // Spawn task that uses thread local data. Each running thread will get their own copy.
54//! yard.spawn_local(0, |data| async move { data.set(10) });
55//! ```
56//!
57//! # MSRV
58//! 1.51
59//!
60//! Future MSRV bumps will be breaking changes.
61
62#![deny(future_incompatible)]
63#![deny(nonstandard_style)]
64#![deny(rust_2018_idioms)]
65
66use crate::{
67    task::{Job, Task, ThreadLocalJob, ThreadLocalTask},
68    threads::ThreadAllocationOutput,
69    util::ThreadLocalPointer,
70};
71use futures_intrusive::{
72    channel::shared::{oneshot_channel, ChannelReceiveFuture, OneshotReceiver},
73    sync::ManualResetEvent,
74};
75use futures_task::{Context, Poll};
76use parking_lot::{Condvar, Mutex, RawMutex};
77use priority_queue::PriorityQueue;
78use slotmap::{DefaultKey, DenseSlotMap};
79use std::{
80    any::Any,
81    future::Future,
82    panic::{catch_unwind, AssertUnwindSafe, UnwindSafe},
83    pin::Pin,
84    sync::{
85        atomic::{AtomicBool, AtomicUsize, Ordering},
86        Arc,
87    },
88};
89
90pub mod affinity;
91mod error;
92mod task;
93pub mod threads;
94mod util;
95mod worker;
96
97pub use error::*;
98
99/// Integer alias for a priority.
100pub type Priority = u32;
101/// Integer alias for the maximum amount of pools.
102pub type PoolCount = u8;
103
104/// Handle to a currently running task.
105///
106/// Awaiting this future will give the return value of the task.
107pub struct JoinHandle<T: 'static> {
108    _receiver: OneshotReceiver<Result<T, Box<dyn Any + Send + 'static>>>,
109    receiver_future: ChannelReceiveFuture<RawMutex, Result<T, Box<dyn Any + Send + 'static>>>,
110}
111impl<T: 'static> Future for JoinHandle<T> {
112    type Output = T;
113
114    fn poll(self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Self::Output> {
115        let fut = unsafe { Pin::new_unchecked(&mut self.get_unchecked_mut().receiver_future) };
116        let poll_res = fut.poll(ctx);
117
118        match poll_res {
119            Poll::Ready(None) => {
120                // If this returns ready with none, that means the channel was closed
121                // due to the waker dying. We can just return pending  as this future will never
122                // return.
123                Poll::Pending
124            }
125            Poll::Ready(Some(value)) => Poll::Ready(value.unwrap_or_else(|_| panic!("Job panicked!"))),
126            Poll::Pending => Poll::Pending,
127        }
128    }
129}
130
131/// Vendored from futures-util as holy hell that's a large lib.
132struct CatchUnwind<Fut>(Fut);
133
134impl<Fut> CatchUnwind<Fut>
135where
136    Fut: Future + UnwindSafe,
137{
138    fn new(future: Fut) -> CatchUnwind<Fut> {
139        CatchUnwind(future)
140    }
141}
142
143impl<Fut> Future for CatchUnwind<Fut>
144where
145    Fut: Future + UnwindSafe,
146{
147    type Output = Result<Fut::Output, Box<dyn Any + Send>>;
148
149    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
150        let f = unsafe { Pin::new_unchecked(&mut self.get_unchecked_mut().0) };
151        catch_unwind(AssertUnwindSafe(|| f.poll(cx)))?.map(Ok)
152    }
153}
154
155struct ThreadLocalQueue<TD> {
156    waiting: Mutex<DenseSlotMap<DefaultKey, Arc<ThreadLocalTask<TD>>>>,
157    inner: Mutex<PriorityQueue<ThreadLocalJob<TD>, u32>>,
158}
159struct FlaggedCondvar {
160    running: AtomicBool,
161    inner: Condvar,
162}
163struct Queue<TD> {
164    waiting: Mutex<DenseSlotMap<DefaultKey, Arc<Task<TD>>>>,
165    inner: Mutex<PriorityQueue<Job<TD>, u32>>,
166    condvars: Vec<FlaggedCondvar>,
167}
168impl<TD> Queue<TD> {
169    /// Must be called with `queue.inner`'s lock held.
170    fn notify_one(&self) {
171        for var in &self.condvars {
172            if !var.running.load(Ordering::Relaxed) {
173                var.inner.notify_one();
174                return;
175            }
176        }
177    }
178
179    /// Must be called with `queue.inner`'s lock held.
180    fn notify_all(&self) {
181        // We could be more efficient and not notify everyone, but this is more surefire
182        // and this function is only called on shutdown.
183        for var in &self.condvars {
184            var.inner.notify_all();
185        }
186    }
187}
188
189struct Shared<TD> {
190    active_threads: AtomicUsize,
191    idle_wait: ManualResetEvent,
192    job_count: AtomicUsize,
193    death_signal: AtomicBool,
194    queue: Queue<TD>,
195}
196
197/// Compute focused async executor.
198///
199/// See crate documentation for more details.
200pub struct Switchyard<TD: 'static> {
201    shared: Arc<Shared<TD>>,
202    threads: Vec<std::thread::JoinHandle<()>>,
203    thread_local_data: Vec<*mut Arc<TD>>,
204}
205impl<TD: 'static> Switchyard<TD> {
206    /// Create a new switchyard.
207    ///
208    /// For each element in the provided `thread_allocations` iterator, the yard will spawn a worker
209    /// thread with the given settings. Helper functions in [`threads`] can generate these iterators
210    /// for common situations.
211    ///
212    /// `thread_local_data_creation` will be called on each thread to create the thread local
213    /// data accessible by `spawn_local`.
214    pub fn new<TDFunc>(
215        thread_allocations: impl IntoIterator<Item = ThreadAllocationOutput>,
216        thread_local_data_creation: TDFunc,
217    ) -> Result<Self, SwitchyardCreationError>
218    where
219        TDFunc: Fn() -> TD + Send + Sync + 'static,
220    {
221        let (thread_local_sender, thread_local_receiver) = std::sync::mpsc::channel();
222
223        let thread_local_data_creation_arc = Arc::new(thread_local_data_creation);
224        let allocation_vec: Vec<_> = thread_allocations.into_iter().collect();
225
226        let num_logical_cpus = num_cpus::get();
227        for allocation in allocation_vec.iter() {
228            if let Some(affin) = allocation.affinity {
229                if affin >= num_logical_cpus {
230                    return Err(SwitchyardCreationError::InvalidAffinity {
231                        affinity: affin,
232                        total_threads: num_logical_cpus,
233                    });
234                }
235            }
236        }
237
238        let mut shared = Arc::new(Shared {
239            queue: Queue {
240                waiting: Mutex::new(DenseSlotMap::new()),
241                inner: Mutex::new(PriorityQueue::new()),
242                condvars: Vec::new(),
243            },
244            active_threads: AtomicUsize::new(allocation_vec.len()),
245            idle_wait: ManualResetEvent::new(false),
246            job_count: AtomicUsize::new(0),
247            death_signal: AtomicBool::new(false),
248        });
249
250        let shared_guard = Arc::get_mut(&mut shared).unwrap();
251
252        let queue_local_indices: Vec<_> = allocation_vec
253            .iter()
254            .map(|_| {
255                let condvar_array = &mut shared_guard.queue.condvars;
256
257                let queue_local_index = condvar_array.len();
258                condvar_array.push(FlaggedCondvar {
259                    inner: Condvar::new(),
260                    running: AtomicBool::new(true),
261                });
262
263                queue_local_index
264            })
265            .collect();
266
267        let mut threads = Vec::with_capacity(allocation_vec.len());
268        for (mut thread_info, queue_local_index) in allocation_vec.into_iter().zip(queue_local_indices) {
269            let builder = std::thread::Builder::new();
270            let builder = if let Some(name) = thread_info.name.take() {
271                builder.name(name)
272            } else {
273                builder
274            };
275            let builder = if let Some(stack_size) = thread_info.stack_size.take() {
276                builder.stack_size(stack_size)
277            } else {
278                builder
279            };
280
281            threads.push(
282                builder
283                    .spawn(worker::body::<TD, TDFunc>(
284                        Arc::clone(&shared),
285                        thread_info,
286                        queue_local_index,
287                        thread_local_sender.clone(),
288                        thread_local_data_creation_arc.clone(),
289                    ))
290                    .unwrap_or_else(|_| panic!("Could not spawn thread")),
291            );
292        }
293        // drop the sender we own, so we can retrieve pointers until all senders are dropped
294        drop(thread_local_sender);
295
296        let mut thread_local_data = Vec::with_capacity(threads.len());
297        while let Ok(ThreadLocalPointer(ptr)) = thread_local_receiver.recv() {
298            thread_local_data.push(ptr);
299        }
300
301        Ok(Self {
302            threads,
303            shared,
304            thread_local_data,
305        })
306    }
307
308    /// Things that must be done every time a task is spawned
309    fn spawn_header(&self) {
310        assert!(
311            !self.shared.death_signal.load(Ordering::Acquire),
312            "finish() has been called on this Switchyard. No more jobs may be added."
313        );
314
315        self.shared.job_count.fetch_add(1, Ordering::AcqRel);
316
317        // Say we're no longer idle so that `yard.spawn(); yard.wait_for_idle()`
318        // won't "return early". If the thread hasn't woken up fully yet by the
319        // time wait_for_idle is called, it will immediately return even though logically there's
320        // still an outstanding, active, job.
321        self.shared.idle_wait.reset();
322    }
323
324    /// Spawn a future which can migrate between threads during executionat the given `priority`.
325    ///
326    /// A higher `priority` will cause the task to be run sooner.
327    ///
328    /// # Example
329    ///
330    /// ```rust
331    /// use switchyard::{Switchyard, threads::single_thread};
332    ///
333    /// // Create a yard with a single pool
334    /// let yard: Switchyard<()> = Switchyard::new(single_thread(None, None), || ()).unwrap();
335    ///
336    /// // Spawn a task with priority 0 and get a handle to the result.
337    /// let handle = yard.spawn(0, async move { 2 * 2 });
338    ///
339    /// // Await result
340    /// # futures_executor::block_on(async move {
341    /// assert_eq!(handle.await, 4);
342    /// # });
343    /// ```
344    ///
345    /// # Panics
346    ///
347    /// - [`finish`](Switchyard::finish) has been called on the pool.
348    pub fn spawn<Fut, T>(&self, priority: Priority, fut: Fut) -> JoinHandle<T>
349    where
350        Fut: Future<Output = T> + Send + 'static,
351        T: Send + 'static,
352    {
353        self.spawn_header();
354
355        let (sender, receiver) = oneshot_channel();
356        let job = Job::Future(Task::new(
357            Arc::clone(&self.shared),
358            async move {
359                // We don't care about the result, if this fails, that just means the join handle
360                // has been dropped.
361                let _ = sender.send(CatchUnwind::new(std::panic::AssertUnwindSafe(fut)).await);
362            },
363            priority,
364        ));
365
366        let queue: &Queue<TD> = &self.shared.queue;
367
368        let mut queue_guard = queue.inner.lock();
369        queue_guard.push(job, priority);
370        // the required guard is held in `queue_guard`
371        queue.notify_one();
372        drop(queue_guard);
373
374        JoinHandle {
375            receiver_future: receiver.receive(),
376            _receiver: receiver,
377        }
378    }
379
380    /// Spawns an async function which is tied to a single thread during execution.
381    ///
382    /// Spawns to the given job `pool` at the given `priority`.
383    ///
384    /// The given async function will be provided an `Arc` to the thread-local data to create its future with.
385    ///
386    /// A higher `priority` will cause the task to be run sooner.
387    ///
388    /// The function must be `Send`, but the future returned by that function may be `!Send`.
389    ///
390    /// # Example
391    ///
392    /// ```rust
393    /// use std::{cell::Cell, sync::Arc};
394    /// use switchyard::{Switchyard, threads::single_thread};
395    ///
396    /// // Create a yard with thread local data.
397    /// let yard: Switchyard<Cell<u64>> = Switchyard::new(
398    ///     single_thread(None, None),
399    ///     || Cell::new(42)
400    /// ).unwrap();
401    /// # let mut yard = yard;
402    ///
403    /// // Spawn an async function using the data.
404    /// yard.spawn_local(0, |data: Arc<Cell<u64>>| async move {data.set(12);});
405    /// # futures_executor::block_on(yard.wait_for_idle());
406    ///
407    /// async fn some_async(data: Arc<Cell<u64>>) -> u64 {
408    ///     data.set(15);
409    ///     2 * 2
410    /// }
411    ///
412    /// // Works with normal async functions too
413    /// let handle = yard.spawn_local(0, some_async);
414    /// # futures_executor::block_on(yard.wait_for_idle());
415    /// # futures_executor::block_on(async move {
416    /// assert_eq!(handle.await, 4);
417    /// # });
418    /// ```
419    ///
420    /// # Panics
421    ///
422    /// - Panics is `pool` refers to a non-existent job pool.
423    pub fn spawn_local<Func, Fut, T>(&self, priority: Priority, async_fn: Func) -> JoinHandle<T>
424    where
425        Func: FnOnce(Arc<TD>) -> Fut + Send + 'static,
426        Fut: Future<Output = T>,
427        T: Send + 'static,
428    {
429        self.spawn_header();
430
431        let (sender, receiver) = oneshot_channel();
432        let job = Job::Local(Box::new(move |td| {
433            Box::pin(async move {
434                // We don't care about the result, if this fails, that just means the join handle
435                // has been dropped.
436                let unwind_async_fn = AssertUnwindSafe(async_fn);
437                let unwind_td = AssertUnwindSafe(td);
438                let future = catch_unwind(move || AssertUnwindSafe(unwind_async_fn.0(unwind_td.0)));
439
440                let ret = match future {
441                    Ok(fut) => CatchUnwind::new(AssertUnwindSafe(fut)).await,
442                    Err(panic) => Err(panic),
443                };
444
445                let _ = sender.send(ret);
446            })
447        }));
448
449        let queue: &Queue<TD> = &self.shared.queue;
450
451        let mut queue_guard = queue.inner.lock();
452        queue_guard.push(job, priority);
453        // the required guard is held in `queue_guard`
454        queue.notify_one();
455        drop(queue_guard);
456
457        JoinHandle {
458            receiver_future: receiver.receive(),
459            _receiver: receiver,
460        }
461    }
462
463    /// Wait until all working threads are starved of work due
464    /// to lack of jobs or all jobs waiting.
465    ///
466    /// # Safety
467    ///
468    /// - This function provides no safety guarantees.
469    /// - Jobs may be added while the future returns.
470    /// - Jobs may be woken while the future returns.
471    pub async fn wait_for_idle(&self) {
472        // We don't reset it, threads will reset it when they become active again
473        self.shared.idle_wait.wait().await;
474    }
475
476    /// Current amount of jobs in flight.
477    ///
478    /// # Safety
479    ///
480    /// - This function provides no safety guarantees.
481    /// - Jobs may be added after the value is received and before it is returned.
482    pub fn jobs(&self) -> usize {
483        self.shared.job_count.load(Ordering::Relaxed)
484    }
485
486    /// Count of threads currently processing jobs.
487    ///
488    /// # Safety
489    ///
490    /// - This function provides no safety guarantees.
491    /// - Jobs may be added after the value is received and before it is returned re-activating threads.
492    pub fn active_threads(&self) -> usize {
493        self.shared.active_threads.load(Ordering::Relaxed)
494    }
495
496    /// Kill all threads as soon as they finish their jobs. All calls to spawn and spawn_local will
497    /// panic after this function is called.
498    ///
499    /// This is equivalent to calling drop. Calling this function twice will be a no-op
500    /// the second time.
501    pub fn finish(&mut self) {
502        // send death signal then wake everyone up
503        self.shared.death_signal.store(true, Ordering::Release);
504        let lock = self.shared.queue.inner.lock();
505        self.shared.queue.notify_all();
506        drop(lock);
507
508        self.thread_local_data.clear();
509        for thread in self.threads.drain(..) {
510            thread.join().unwrap();
511        }
512    }
513}
514
515impl<TD: 'static> Drop for Switchyard<TD> {
516    fn drop(&mut self) {
517        self.finish()
518    }
519}
520
521unsafe impl<TD> Send for Switchyard<TD> {}
522unsafe impl<TD> Sync for Switchyard<TD> {}