switchyard/
lib.rs

1//! Real-time compute-focused async executor with job pools, thread-local data, and priorities.
2//!
3//! # Example
4//!
5//! ```rust
6//! use switchyard::Switchyard;
7//! use switchyard::threads::{thread_info, one_to_one};
8//! // Create a new switchyard with one job pool and empty thread local data
9//! let yard = Switchyard::new(one_to_one(thread_info(), Some("thread-name")), ||()).unwrap();
10//!
11//! // Spawn a task on pool 0 and priority 10 and get a JoinHandle
12//! let handle = yard.spawn(10, async move { 5 + 5 });
13//! // Spawn a lower priority task on the same pool
14//! let handle2 = yard.spawn(0, async move { 2 + 2 });
15//!
16//! // Wait on the results
17//! # futures_executor::block_on(async {
18//! assert_eq!(handle.await + handle2.await, 14);
19//! # });
20//! ```
21//!
22//! # How Switchyard is Different
23//!
24//! Switchyard is different from other existing async executors, focusing on situations where
25//! precise control of threads and execution order is needed. One such situation is using
26//! task parallelism to parallelize a compute workload.
27//!
28//! ## Priorites
29//!
30//! Each task has a priority and tasks are ran in order from high priority to low priority.
31//!
32//! ```rust
33//! # use switchyard::{Switchyard, threads::{thread_info, one_to_one}};
34//! # let yard = Switchyard::new(one_to_one(thread_info(), Some("thread-name")), ||()).unwrap();
35//! // Spawn task with lowest priority.
36//! yard.spawn(0, async move { /* ... */ });
37//! // Spawn task with higher priority. If both tasks are waiting, this one will run first.
38//! yard.spawn(10, async move { /* ... */ });
39//! ```
40//!
41//! ## Thread Local Data
42//!
43//! Each yard has some thread local data that can be accessed using [`spawn_local`](Switchyard::spawn_local).
44//! Both the thread local data and the future generated by the async function passed to [`spawn_local`](Switchyard::spawn_local)
45//! may be `!Send` and `!Sync`. The future will only be resumed on the thread that created it.
46//!
47//! If the data is `Send`, then you can call [`access_per_thread_data`](Switchyard::access_per_thread_data) to get
48//! a vector of mutable references to all thread's data. See it's documentation for more information.
49//!
50//! ```rust
51//! # use switchyard::{Switchyard, threads::{thread_info, one_to_one}};
52//! # use std::cell::Cell;
53//! // Create yard with thread local data. The data is !Sync.
54//! let yard = Switchyard::new(one_to_one(thread_info(), Some("thread-name")), || Cell::new(42)).unwrap();
55//!
56//! // Spawn task that uses thread local data. Each running thread will get their own copy.
57//! yard.spawn_local(0, |data| async move { data.set(10) });
58//! ```
59//!
60//! # MSRV
61//! 1.51
62//!
63//! Future MSRV bumps will be breaking changes.
64
65#![deny(future_incompatible)]
66#![deny(nonstandard_style)]
67#![deny(rust_2018_idioms)]
68
69use crate::{
70    task::{Job, Task, ThreadLocalJob, ThreadLocalTask},
71    threads::ThreadAllocationOutput,
72    util::ThreadLocalPointer,
73};
74use futures_intrusive::{
75    channel::shared::{oneshot_channel, ChannelReceiveFuture, OneshotReceiver},
76    sync::ManualResetEvent,
77};
78use futures_task::{Context, Poll};
79use parking_lot::{Condvar, Mutex, RawMutex};
80use priority_queue::PriorityQueue;
81use slotmap::{DefaultKey, DenseSlotMap};
82use std::{
83    any::Any,
84    future::Future,
85    panic::{catch_unwind, AssertUnwindSafe, UnwindSafe},
86    pin::Pin,
87    sync::{
88        atomic::{AtomicBool, AtomicUsize, Ordering},
89        Arc,
90    },
91};
92
93pub mod affinity;
94mod error;
95mod task;
96pub mod threads;
97mod util;
98mod worker;
99
100pub use error::*;
101
102/// Integer alias for a priority.
103pub type Priority = u32;
104/// Integer alias for the maximum amount of pools.
105pub type PoolCount = u8;
106
107/// Handle to a currently running task.
108///
109/// Awaiting this future will give the return value of the task.
110pub struct JoinHandle<T: 'static> {
111    _receiver: OneshotReceiver<Result<T, Box<dyn Any + Send + 'static>>>,
112    receiver_future: ChannelReceiveFuture<RawMutex, Result<T, Box<dyn Any + Send + 'static>>>,
113}
114impl<T: 'static> Future for JoinHandle<T> {
115    type Output = T;
116
117    fn poll(self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Self::Output> {
118        let fut = unsafe { Pin::new_unchecked(&mut self.get_unchecked_mut().receiver_future) };
119        let poll_res = fut.poll(ctx);
120
121        match poll_res {
122            Poll::Ready(None) => {
123                // If this returns ready with none, that means the channel was closed
124                // due to the waker dying. We can just return pending  as this future will never
125                // return.
126                Poll::Pending
127            }
128            Poll::Ready(Some(value)) => Poll::Ready(value.unwrap_or_else(|_| panic!("Job panicked!"))),
129            Poll::Pending => Poll::Pending,
130        }
131    }
132}
133
134/// Vendored from futures-util as holy hell that's a large lib.
135struct CatchUnwind<Fut>(Fut);
136
137impl<Fut> CatchUnwind<Fut>
138where
139    Fut: Future + UnwindSafe,
140{
141    fn new(future: Fut) -> CatchUnwind<Fut> {
142        CatchUnwind(future)
143    }
144}
145
146impl<Fut> Future for CatchUnwind<Fut>
147where
148    Fut: Future + UnwindSafe,
149{
150    type Output = Result<Fut::Output, Box<dyn Any + Send>>;
151
152    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
153        let f = unsafe { Pin::new_unchecked(&mut self.get_unchecked_mut().0) };
154        catch_unwind(AssertUnwindSafe(|| f.poll(cx)))?.map(Ok)
155    }
156}
157
158struct ThreadLocalQueue<TD> {
159    waiting: Mutex<DenseSlotMap<DefaultKey, Arc<ThreadLocalTask<TD>>>>,
160    inner: Mutex<PriorityQueue<ThreadLocalJob<TD>, u32>>,
161}
162struct FlaggedCondvar {
163    running: AtomicBool,
164    inner: Condvar,
165}
166struct Queue<TD> {
167    waiting: Mutex<DenseSlotMap<DefaultKey, Arc<Task<TD>>>>,
168    inner: Mutex<PriorityQueue<Job<TD>, u32>>,
169    condvars: Vec<FlaggedCondvar>,
170}
171impl<TD> Queue<TD> {
172    /// Must be called with `queue.inner`'s lock held.
173    fn notify_one(&self) {
174        for var in &self.condvars {
175            if !var.running.load(Ordering::Relaxed) {
176                var.inner.notify_one();
177                return;
178            }
179        }
180    }
181
182    /// Must be called with `queue.inner`'s lock held.
183    fn notify_all(&self) {
184        // We could be more efficient and not notify everyone, but this is more surefire
185        // and this function is only called on shutdown.
186        for var in &self.condvars {
187            var.inner.notify_all();
188        }
189    }
190}
191
192struct Shared<TD> {
193    active_threads: AtomicUsize,
194    idle_wait: ManualResetEvent,
195    job_count: AtomicUsize,
196    death_signal: AtomicBool,
197    queue: Queue<TD>,
198}
199
200/// Compute focused async executor.
201///
202/// See crate documentation for more details.
203pub struct Switchyard<TD: 'static> {
204    shared: Arc<Shared<TD>>,
205    threads: Vec<std::thread::JoinHandle<()>>,
206    thread_local_data: Vec<*mut Arc<TD>>,
207}
208impl<TD: 'static> Switchyard<TD> {
209    /// Create a new switchyard.
210    ///
211    /// For each element in the provided `thread_allocations` iterator, the yard will spawn a worker
212    /// thread with the given settings. Helper functions in [`threads`] can generate these iterators
213    /// for common situations.
214    ///
215    /// `thread_local_data_creation` will be called on each thread to create the thread local
216    /// data accessible by `spawn_local`.
217    pub fn new<TDFunc>(
218        thread_allocations: impl IntoIterator<Item = ThreadAllocationOutput>,
219        thread_local_data_creation: TDFunc,
220    ) -> Result<Self, SwitchyardCreationError>
221    where
222        TDFunc: Fn() -> TD + Send + Sync + 'static,
223    {
224        let (thread_local_sender, thread_local_receiver) = std::sync::mpsc::channel();
225
226        let thread_local_data_creation_arc = Arc::new(thread_local_data_creation);
227        let allocation_vec: Vec<_> = thread_allocations.into_iter().collect();
228
229        let num_logical_cpus = num_cpus::get();
230        for allocation in allocation_vec.iter() {
231            if let Some(affin) = allocation.affinity {
232                if affin >= num_logical_cpus {
233                    return Err(SwitchyardCreationError::InvalidAffinity {
234                        affinity: affin,
235                        total_threads: num_logical_cpus,
236                    });
237                }
238            }
239        }
240
241        let mut shared = Arc::new(Shared {
242            queue: Queue {
243                waiting: Mutex::new(DenseSlotMap::new()),
244                inner: Mutex::new(PriorityQueue::new()),
245                condvars: Vec::new(),
246            },
247            active_threads: AtomicUsize::new(allocation_vec.len()),
248            idle_wait: ManualResetEvent::new(false),
249            job_count: AtomicUsize::new(0),
250            death_signal: AtomicBool::new(false),
251        });
252
253        let shared_guard = Arc::get_mut(&mut shared).unwrap();
254
255        let queue_local_indices: Vec<_> = allocation_vec
256            .iter()
257            .map(|_| {
258                let condvar_array = &mut shared_guard.queue.condvars;
259
260                let queue_local_index = condvar_array.len();
261                condvar_array.push(FlaggedCondvar {
262                    inner: Condvar::new(),
263                    running: AtomicBool::new(true),
264                });
265
266                queue_local_index
267            })
268            .collect();
269
270        let mut threads = Vec::with_capacity(allocation_vec.len());
271        for (mut thread_info, queue_local_index) in allocation_vec.into_iter().zip(queue_local_indices) {
272            let builder = std::thread::Builder::new();
273            let builder = if let Some(name) = thread_info.name.take() {
274                builder.name(name)
275            } else {
276                builder
277            };
278            let builder = if let Some(stack_size) = thread_info.stack_size.take() {
279                builder.stack_size(stack_size)
280            } else {
281                builder
282            };
283
284            threads.push(
285                builder
286                    .spawn(worker::body::<TD, TDFunc>(
287                        Arc::clone(&shared),
288                        thread_info,
289                        queue_local_index,
290                        thread_local_sender.clone(),
291                        thread_local_data_creation_arc.clone(),
292                    ))
293                    .unwrap_or_else(|_| panic!("Could not spawn thread")),
294            );
295        }
296        // drop the sender we own, so we can retrieve pointers until all senders are dropped
297        drop(thread_local_sender);
298
299        let mut thread_local_data = Vec::with_capacity(threads.len());
300        while let Ok(ThreadLocalPointer(ptr)) = thread_local_receiver.recv() {
301            thread_local_data.push(ptr);
302        }
303
304        Ok(Self {
305            threads,
306            shared,
307            thread_local_data,
308        })
309    }
310
311    /// Things that must be done every time a task is spawned
312    fn spawn_header(&self) {
313        assert!(
314            !self.shared.death_signal.load(Ordering::Acquire),
315            "finish() has been called on this Switchyard. No more jobs may be added."
316        );
317
318        // SAFETY: we must grab and increment this counter so `access_per_thread_data` knows
319        // we're in flight.
320        self.shared.job_count.fetch_add(1, Ordering::AcqRel);
321
322        // Say we're no longer idle so that `yard.spawn(); yard.wait_for_idle()`
323        // won't "return early". If the thread hasn't woken up fully yet by the
324        // time wait_for_idle is called, it will immediately return even though logically there's
325        // still an outstanding, active, job.
326        self.shared.idle_wait.reset();
327    }
328
329    /// Spawn a future which can migrate between threads during executionat the given `priority`.
330    ///
331    /// A higher `priority` will cause the task to be run sooner.
332    ///
333    /// # Example
334    ///
335    /// ```rust
336    /// use switchyard::{Switchyard, threads::single_thread};
337    ///
338    /// // Create a yard with a single pool
339    /// let yard: Switchyard<()> = Switchyard::new(single_thread(None, None), || ()).unwrap();
340    ///
341    /// // Spawn a task with priority 0 and get a handle to the result.
342    /// let handle = yard.spawn(0, async move { 2 * 2 });
343    ///
344    /// // Await result
345    /// # futures_executor::block_on(async move {
346    /// assert_eq!(handle.await, 4);
347    /// # });
348    /// ```
349    ///
350    /// # Panics
351    ///
352    /// - [`finish`](Switchyard::finish) has been called on the pool.
353    pub fn spawn<Fut, T>(&self, priority: Priority, fut: Fut) -> JoinHandle<T>
354    where
355        Fut: Future<Output = T> + Send + 'static,
356        T: Send + 'static,
357    {
358        self.spawn_header();
359
360        let (sender, receiver) = oneshot_channel();
361        let job = Job::Future(Task::new(
362            Arc::clone(&self.shared),
363            async move {
364                // We don't care about the result, if this fails, that just means the join handle
365                // has been dropped.
366                let _ = sender.send(CatchUnwind::new(std::panic::AssertUnwindSafe(fut)).await);
367            },
368            priority,
369        ));
370
371        let queue: &Queue<TD> = &self.shared.queue;
372
373        let mut queue_guard = queue.inner.lock();
374        queue_guard.push(job, priority);
375        // the required guard is held in `queue_guard`
376        queue.notify_one();
377        drop(queue_guard);
378
379        JoinHandle {
380            receiver_future: receiver.receive(),
381            _receiver: receiver,
382        }
383    }
384
385    /// Spawns an async function which is tied to a single thread during execution.
386    ///
387    /// Spawns to the given job `pool` at the given `priority`.
388    ///
389    /// The given async function will be provided an `Arc` to the thread-local data to create its future with.
390    ///
391    /// A higher `priority` will cause the task to be run sooner.
392    ///
393    /// The function must be `Send`, but the future returned by that function may be `!Send`.
394    ///
395    /// # Example
396    ///
397    /// ```rust
398    /// use std::{cell::Cell, sync::Arc};
399    /// use switchyard::{Switchyard, threads::single_thread};
400    ///
401    /// // Create a yard with thread local data.
402    /// let yard: Switchyard<Cell<u64>> = Switchyard::new(
403    ///     single_thread(None, None),
404    ///     || Cell::new(42)
405    /// ).unwrap();
406    /// # let mut yard = yard;
407    ///
408    /// // Spawn an async function using the data.
409    /// yard.spawn_local(0, |data: Arc<Cell<u64>>| async move {data.set(12);});
410    /// # futures_executor::block_on(yard.wait_for_idle());
411    /// # assert_eq!(yard.access_per_thread_data(), Some(vec![&mut Cell::new(12)]));
412    ///
413    /// async fn some_async(data: Arc<Cell<u64>>) -> u64 {
414    ///     data.set(15);
415    ///     2 * 2
416    /// }
417    ///
418    /// // Works with normal async functions too
419    /// let handle = yard.spawn_local(0, some_async);
420    /// # futures_executor::block_on(yard.wait_for_idle());
421    /// # assert_eq!(yard.access_per_thread_data(), Some(vec![&mut Cell::new(15)]));
422    /// # futures_executor::block_on(async move {
423    /// assert_eq!(handle.await, 4);
424    /// # });
425    /// ```
426    ///
427    /// # Panics
428    ///
429    /// - Panics is `pool` refers to a non-existent job pool.
430    pub fn spawn_local<Func, Fut, T>(&self, priority: Priority, async_fn: Func) -> JoinHandle<T>
431    where
432        Func: FnOnce(Arc<TD>) -> Fut + Send + 'static,
433        Fut: Future<Output = T>,
434        T: Send + 'static,
435    {
436        self.spawn_header();
437
438        let (sender, receiver) = oneshot_channel();
439        let job = Job::Local(Box::new(move |td| {
440            Box::pin(async move {
441                // We don't care about the result, if this fails, that just means the join handle
442                // has been dropped.
443                let unwind_async_fn = AssertUnwindSafe(async_fn);
444                let unwind_td = AssertUnwindSafe(td);
445                let future = catch_unwind(move || AssertUnwindSafe(unwind_async_fn.0(unwind_td.0)));
446
447                let ret = match future {
448                    Ok(fut) => CatchUnwind::new(AssertUnwindSafe(fut)).await,
449                    Err(panic) => Err(panic),
450                };
451
452                let _ = sender.send(ret);
453            })
454        }));
455
456        let queue: &Queue<TD> = &self.shared.queue;
457
458        let mut queue_guard = queue.inner.lock();
459        queue_guard.push(job, priority);
460        // the required guard is held in `queue_guard`
461        queue.notify_one();
462        drop(queue_guard);
463
464        JoinHandle {
465            receiver_future: receiver.receive(),
466            _receiver: receiver,
467        }
468    }
469
470    /// Wait until all working threads are starved of work due
471    /// to lack of jobs or all jobs waiting.
472    ///
473    /// # Safety
474    ///
475    /// - This function provides no safety guarantees.
476    /// - Jobs may be added while the future returns.
477    /// - Jobs may be woken while the future returns.
478    pub async fn wait_for_idle(&self) {
479        // We don't reset it, threads will reset it when they become active again
480        self.shared.idle_wait.wait().await;
481    }
482
483    /// Current amount of jobs in flight.
484    ///
485    /// # Safety
486    ///
487    /// - This function provides no safety guarantees.
488    /// - Jobs may be added after the value is received and before it is returned.
489    pub fn jobs(&self) -> usize {
490        self.shared.job_count.load(Ordering::Relaxed)
491    }
492
493    /// Count of threads currently processing jobs.
494    ///
495    /// # Safety
496    ///
497    /// - This function provides no safety guarantees.
498    /// - Jobs may be added after the value is received and before it is returned re-activating threads.
499    pub fn active_threads(&self) -> usize {
500        self.shared.active_threads.load(Ordering::Relaxed)
501    }
502
503    /// Access the per-thread data of each thread. Only available if `TD` is `Send`.
504    ///
505    /// This function requires `&mut self` in order to be sound. If you have the yard in a global,
506    /// you need to wrap it with `RwLock` so you can get a `&mut` from a `&`.
507    ///
508    /// Two conditions need to be true for this to return `Some`. First all threads must be idle
509    /// (i.e. `wait_for_idle`'s future would immediately return). Second no references to any thread's
510    /// local data may be alive.
511    ///
512    /// # Example
513    ///
514    /// ```rust
515    /// use std::{cell::Cell, sync::Arc};
516    /// use switchyard::{Switchyard, threads::single_thread};
517    ///
518    /// // Create a yard with thread local data.
519    /// let mut yard: Switchyard<Cell<u64>> = Switchyard::new(
520    ///     single_thread(None, None),
521    ///     || Cell::new(42)
522    /// ).unwrap();
523    ///
524    /// // Wait for all threads to get themselves situated.
525    /// # futures_executor::block_on(async {
526    /// yard.wait_for_idle().await;
527    /// # });
528    ///
529    /// // View that thread-local data. The yard has one thread, so returns a vec of length one.
530    /// assert_eq!(yard.access_per_thread_data(), Some(vec![&mut Cell::new(42)]));
531    ///
532    /// // Launch a task to change that data
533    /// let handle = yard.spawn_local(0, |data| async move { data.set(525_600); });
534    ///
535    /// // If the task isn't finished yet, this will return None.
536    /// yard.access_per_thread_data();
537    ///
538    /// // Wait for task to be done
539    /// # futures_executor::block_on(async {
540    /// assert_eq!(handle.await, ());
541    /// # });
542    ///
543    /// // We also need to wait for all threads to come to a stopping place
544    /// # futures_executor::block_on(async {
545    /// yard.wait_for_idle().await;
546    /// # });
547    ///
548    /// // Observe changed value
549    /// assert_eq!(yard.access_per_thread_data(), Some(vec![&mut Cell::new(525_600)]));
550    /// ```
551    ///
552    /// # Safety
553    ///
554    /// - This function guarantees that there exist no other references to this data if `Some` is returned.
555    /// - This function guarantees that `jobs()` is 0 and will stay zero while the returned references are still live.  
556    pub fn access_per_thread_data(&mut self) -> Option<Vec<&mut TD>>
557    where
558        TD: Send,
559    {
560        let threads_live = self.shared.active_threads.load(Ordering::Acquire);
561
562        // SAFETY: No more jobs can be added and threads woken because we have an exclusive reference to the yard.
563        if threads_live != 0 {
564            return None;
565        }
566
567        // SAFETY:
568        //  - We know there are no threads running because `count` is zero and we have an exclusive reference to the yard.
569        //  - Threads do not keep references to their `Arc`'s around while idle, nor hand them to tasks.
570        //  - `TD` is allowed to be `!Sync` because we never actually touch a `&TD`, only `&mut TD`.
571        let arcs = self.thread_local_data.iter().map(|&ptr| unsafe { &mut *ptr });
572
573        let data: Option<Vec<&mut TD>> = arcs.map(|arc| Arc::get_mut(arc)).collect();
574
575        data
576    }
577
578    /// Kill all threads as soon as they finish their jobs. All calls to spawn and spawn_local will
579    /// panic after this function is called.
580    ///
581    /// This is equivalent to calling drop. Calling this function twice will be a no-op
582    /// the second time.
583    pub fn finish(&mut self) {
584        // send death signal then wake everyone up
585        self.shared.death_signal.store(true, Ordering::Release);
586        let lock = self.shared.queue.inner.lock();
587        self.shared.queue.notify_all();
588        drop(lock);
589
590        self.thread_local_data.clear();
591        for thread in self.threads.drain(..) {
592            thread.join().unwrap();
593        }
594    }
595}
596
597impl<TD: 'static> Drop for Switchyard<TD> {
598    fn drop(&mut self) {
599        self.finish()
600    }
601}
602
603unsafe impl<TD> Send for Switchyard<TD> {}
604unsafe impl<TD> Sync for Switchyard<TD> {}