switchyard/lib.rs
1//! Real-time compute-focused async executor with job pools, thread-local data, and priorities.
2//!
3//! # Example
4//!
5//! ```rust
6//! use switchyard::Switchyard;
7//! use switchyard::threads::{thread_info, one_to_one};
8//! // Create a new switchyard with one job pool and empty thread local data
9//! let yard = Switchyard::new(one_to_one(thread_info(), Some("thread-name")), ||()).unwrap();
10//!
11//! // Spawn a task on pool 0 and priority 10 and get a JoinHandle
12//! let handle = yard.spawn(10, async move { 5 + 5 });
13//! // Spawn a lower priority task on the same pool
14//! let handle2 = yard.spawn(0, async move { 2 + 2 });
15//!
16//! // Wait on the results
17//! # futures_executor::block_on(async {
18//! assert_eq!(handle.await + handle2.await, 14);
19//! # });
20//! ```
21//!
22//! # How Switchyard is Different
23//!
24//! Switchyard is different from other existing async executors, focusing on situations where
25//! precise control of threads and execution order is needed. One such situation is using
26//! task parallelism to parallelize a compute workload.
27//!
28//! ## Priorites
29//!
30//! Each task has a priority and tasks are ran in order from high priority to low priority.
31//!
32//! ```rust
33//! # use switchyard::{Switchyard, threads::{thread_info, one_to_one}};
34//! # let yard = Switchyard::new(one_to_one(thread_info(), Some("thread-name")), ||()).unwrap();
35//! // Spawn task with lowest priority.
36//! yard.spawn(0, async move { /* ... */ });
37//! // Spawn task with higher priority. If both tasks are waiting, this one will run first.
38//! yard.spawn(10, async move { /* ... */ });
39//! ```
40//!
41//! ## Thread Local Data
42//!
43//! Each yard has some thread local data that can be accessed using [`spawn_local`](Switchyard::spawn_local).
44//! Both the thread local data and the future generated by the async function passed to [`spawn_local`](Switchyard::spawn_local)
45//! may be `!Send` and `!Sync`. The future will only be resumed on the thread that created it.
46//!
47//! If the data is `Send`, then you can call [`access_per_thread_data`](Switchyard::access_per_thread_data) to get
48//! a vector of mutable references to all thread's data. See it's documentation for more information.
49//!
50//! ```rust
51//! # use switchyard::{Switchyard, threads::{thread_info, one_to_one}};
52//! # use std::cell::Cell;
53//! // Create yard with thread local data. The data is !Sync.
54//! let yard = Switchyard::new(one_to_one(thread_info(), Some("thread-name")), || Cell::new(42)).unwrap();
55//!
56//! // Spawn task that uses thread local data. Each running thread will get their own copy.
57//! yard.spawn_local(0, |data| async move { data.set(10) });
58//! ```
59//!
60//! # MSRV
61//! 1.51
62//!
63//! Future MSRV bumps will be breaking changes.
64
65#![deny(future_incompatible)]
66#![deny(nonstandard_style)]
67#![deny(rust_2018_idioms)]
68
69use crate::{
70 task::{Job, Task, ThreadLocalJob, ThreadLocalTask},
71 threads::ThreadAllocationOutput,
72 util::ThreadLocalPointer,
73};
74use futures_intrusive::{
75 channel::shared::{oneshot_channel, ChannelReceiveFuture, OneshotReceiver},
76 sync::ManualResetEvent,
77};
78use futures_task::{Context, Poll};
79use parking_lot::{Condvar, Mutex, RawMutex};
80use priority_queue::PriorityQueue;
81use slotmap::{DefaultKey, DenseSlotMap};
82use std::{
83 any::Any,
84 future::Future,
85 panic::{catch_unwind, AssertUnwindSafe, UnwindSafe},
86 pin::Pin,
87 sync::{
88 atomic::{AtomicBool, AtomicUsize, Ordering},
89 Arc,
90 },
91};
92
93pub mod affinity;
94mod error;
95mod task;
96pub mod threads;
97mod util;
98mod worker;
99
100pub use error::*;
101
102/// Integer alias for a priority.
103pub type Priority = u32;
104/// Integer alias for the maximum amount of pools.
105pub type PoolCount = u8;
106
107/// Handle to a currently running task.
108///
109/// Awaiting this future will give the return value of the task.
110pub struct JoinHandle<T: 'static> {
111 _receiver: OneshotReceiver<Result<T, Box<dyn Any + Send + 'static>>>,
112 receiver_future: ChannelReceiveFuture<RawMutex, Result<T, Box<dyn Any + Send + 'static>>>,
113}
114impl<T: 'static> Future for JoinHandle<T> {
115 type Output = T;
116
117 fn poll(self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Self::Output> {
118 let fut = unsafe { Pin::new_unchecked(&mut self.get_unchecked_mut().receiver_future) };
119 let poll_res = fut.poll(ctx);
120
121 match poll_res {
122 Poll::Ready(None) => {
123 // If this returns ready with none, that means the channel was closed
124 // due to the waker dying. We can just return pending as this future will never
125 // return.
126 Poll::Pending
127 }
128 Poll::Ready(Some(value)) => Poll::Ready(value.unwrap_or_else(|_| panic!("Job panicked!"))),
129 Poll::Pending => Poll::Pending,
130 }
131 }
132}
133
134/// Vendored from futures-util as holy hell that's a large lib.
135struct CatchUnwind<Fut>(Fut);
136
137impl<Fut> CatchUnwind<Fut>
138where
139 Fut: Future + UnwindSafe,
140{
141 fn new(future: Fut) -> CatchUnwind<Fut> {
142 CatchUnwind(future)
143 }
144}
145
146impl<Fut> Future for CatchUnwind<Fut>
147where
148 Fut: Future + UnwindSafe,
149{
150 type Output = Result<Fut::Output, Box<dyn Any + Send>>;
151
152 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
153 let f = unsafe { Pin::new_unchecked(&mut self.get_unchecked_mut().0) };
154 catch_unwind(AssertUnwindSafe(|| f.poll(cx)))?.map(Ok)
155 }
156}
157
158struct ThreadLocalQueue<TD> {
159 waiting: Mutex<DenseSlotMap<DefaultKey, Arc<ThreadLocalTask<TD>>>>,
160 inner: Mutex<PriorityQueue<ThreadLocalJob<TD>, u32>>,
161}
162struct FlaggedCondvar {
163 running: AtomicBool,
164 inner: Condvar,
165}
166struct Queue<TD> {
167 waiting: Mutex<DenseSlotMap<DefaultKey, Arc<Task<TD>>>>,
168 inner: Mutex<PriorityQueue<Job<TD>, u32>>,
169 condvars: Vec<FlaggedCondvar>,
170}
171impl<TD> Queue<TD> {
172 /// Must be called with `queue.inner`'s lock held.
173 fn notify_one(&self) {
174 for var in &self.condvars {
175 if !var.running.load(Ordering::Relaxed) {
176 var.inner.notify_one();
177 return;
178 }
179 }
180 }
181
182 /// Must be called with `queue.inner`'s lock held.
183 fn notify_all(&self) {
184 // We could be more efficient and not notify everyone, but this is more surefire
185 // and this function is only called on shutdown.
186 for var in &self.condvars {
187 var.inner.notify_all();
188 }
189 }
190}
191
192struct Shared<TD> {
193 active_threads: AtomicUsize,
194 idle_wait: ManualResetEvent,
195 job_count: AtomicUsize,
196 death_signal: AtomicBool,
197 queue: Queue<TD>,
198}
199
200/// Compute focused async executor.
201///
202/// See crate documentation for more details.
203pub struct Switchyard<TD: 'static> {
204 shared: Arc<Shared<TD>>,
205 threads: Vec<std::thread::JoinHandle<()>>,
206 thread_local_data: Vec<*mut Arc<TD>>,
207}
208impl<TD: 'static> Switchyard<TD> {
209 /// Create a new switchyard.
210 ///
211 /// For each element in the provided `thread_allocations` iterator, the yard will spawn a worker
212 /// thread with the given settings. Helper functions in [`threads`] can generate these iterators
213 /// for common situations.
214 ///
215 /// `thread_local_data_creation` will be called on each thread to create the thread local
216 /// data accessible by `spawn_local`.
217 pub fn new<TDFunc>(
218 thread_allocations: impl IntoIterator<Item = ThreadAllocationOutput>,
219 thread_local_data_creation: TDFunc,
220 ) -> Result<Self, SwitchyardCreationError>
221 where
222 TDFunc: Fn() -> TD + Send + Sync + 'static,
223 {
224 let (thread_local_sender, thread_local_receiver) = std::sync::mpsc::channel();
225
226 let thread_local_data_creation_arc = Arc::new(thread_local_data_creation);
227 let allocation_vec: Vec<_> = thread_allocations.into_iter().collect();
228
229 let num_logical_cpus = num_cpus::get();
230 for allocation in allocation_vec.iter() {
231 if let Some(affin) = allocation.affinity {
232 if affin >= num_logical_cpus {
233 return Err(SwitchyardCreationError::InvalidAffinity {
234 affinity: affin,
235 total_threads: num_logical_cpus,
236 });
237 }
238 }
239 }
240
241 let mut shared = Arc::new(Shared {
242 queue: Queue {
243 waiting: Mutex::new(DenseSlotMap::new()),
244 inner: Mutex::new(PriorityQueue::new()),
245 condvars: Vec::new(),
246 },
247 active_threads: AtomicUsize::new(allocation_vec.len()),
248 idle_wait: ManualResetEvent::new(false),
249 job_count: AtomicUsize::new(0),
250 death_signal: AtomicBool::new(false),
251 });
252
253 let shared_guard = Arc::get_mut(&mut shared).unwrap();
254
255 let queue_local_indices: Vec<_> = allocation_vec
256 .iter()
257 .map(|_| {
258 let condvar_array = &mut shared_guard.queue.condvars;
259
260 let queue_local_index = condvar_array.len();
261 condvar_array.push(FlaggedCondvar {
262 inner: Condvar::new(),
263 running: AtomicBool::new(true),
264 });
265
266 queue_local_index
267 })
268 .collect();
269
270 let mut threads = Vec::with_capacity(allocation_vec.len());
271 for (mut thread_info, queue_local_index) in allocation_vec.into_iter().zip(queue_local_indices) {
272 let builder = std::thread::Builder::new();
273 let builder = if let Some(name) = thread_info.name.take() {
274 builder.name(name)
275 } else {
276 builder
277 };
278 let builder = if let Some(stack_size) = thread_info.stack_size.take() {
279 builder.stack_size(stack_size)
280 } else {
281 builder
282 };
283
284 threads.push(
285 builder
286 .spawn(worker::body::<TD, TDFunc>(
287 Arc::clone(&shared),
288 thread_info,
289 queue_local_index,
290 thread_local_sender.clone(),
291 thread_local_data_creation_arc.clone(),
292 ))
293 .unwrap_or_else(|_| panic!("Could not spawn thread")),
294 );
295 }
296 // drop the sender we own, so we can retrieve pointers until all senders are dropped
297 drop(thread_local_sender);
298
299 let mut thread_local_data = Vec::with_capacity(threads.len());
300 while let Ok(ThreadLocalPointer(ptr)) = thread_local_receiver.recv() {
301 thread_local_data.push(ptr);
302 }
303
304 Ok(Self {
305 threads,
306 shared,
307 thread_local_data,
308 })
309 }
310
311 /// Things that must be done every time a task is spawned
312 fn spawn_header(&self) {
313 assert!(
314 !self.shared.death_signal.load(Ordering::Acquire),
315 "finish() has been called on this Switchyard. No more jobs may be added."
316 );
317
318 // SAFETY: we must grab and increment this counter so `access_per_thread_data` knows
319 // we're in flight.
320 self.shared.job_count.fetch_add(1, Ordering::AcqRel);
321
322 // Say we're no longer idle so that `yard.spawn(); yard.wait_for_idle()`
323 // won't "return early". If the thread hasn't woken up fully yet by the
324 // time wait_for_idle is called, it will immediately return even though logically there's
325 // still an outstanding, active, job.
326 self.shared.idle_wait.reset();
327 }
328
329 /// Spawn a future which can migrate between threads during executionat the given `priority`.
330 ///
331 /// A higher `priority` will cause the task to be run sooner.
332 ///
333 /// # Example
334 ///
335 /// ```rust
336 /// use switchyard::{Switchyard, threads::single_thread};
337 ///
338 /// // Create a yard with a single pool
339 /// let yard: Switchyard<()> = Switchyard::new(single_thread(None, None), || ()).unwrap();
340 ///
341 /// // Spawn a task with priority 0 and get a handle to the result.
342 /// let handle = yard.spawn(0, async move { 2 * 2 });
343 ///
344 /// // Await result
345 /// # futures_executor::block_on(async move {
346 /// assert_eq!(handle.await, 4);
347 /// # });
348 /// ```
349 ///
350 /// # Panics
351 ///
352 /// - [`finish`](Switchyard::finish) has been called on the pool.
353 pub fn spawn<Fut, T>(&self, priority: Priority, fut: Fut) -> JoinHandle<T>
354 where
355 Fut: Future<Output = T> + Send + 'static,
356 T: Send + 'static,
357 {
358 self.spawn_header();
359
360 let (sender, receiver) = oneshot_channel();
361 let job = Job::Future(Task::new(
362 Arc::clone(&self.shared),
363 async move {
364 // We don't care about the result, if this fails, that just means the join handle
365 // has been dropped.
366 let _ = sender.send(CatchUnwind::new(std::panic::AssertUnwindSafe(fut)).await);
367 },
368 priority,
369 ));
370
371 let queue: &Queue<TD> = &self.shared.queue;
372
373 let mut queue_guard = queue.inner.lock();
374 queue_guard.push(job, priority);
375 // the required guard is held in `queue_guard`
376 queue.notify_one();
377 drop(queue_guard);
378
379 JoinHandle {
380 receiver_future: receiver.receive(),
381 _receiver: receiver,
382 }
383 }
384
385 /// Spawns an async function which is tied to a single thread during execution.
386 ///
387 /// Spawns to the given job `pool` at the given `priority`.
388 ///
389 /// The given async function will be provided an `Arc` to the thread-local data to create its future with.
390 ///
391 /// A higher `priority` will cause the task to be run sooner.
392 ///
393 /// The function must be `Send`, but the future returned by that function may be `!Send`.
394 ///
395 /// # Example
396 ///
397 /// ```rust
398 /// use std::{cell::Cell, sync::Arc};
399 /// use switchyard::{Switchyard, threads::single_thread};
400 ///
401 /// // Create a yard with thread local data.
402 /// let yard: Switchyard<Cell<u64>> = Switchyard::new(
403 /// single_thread(None, None),
404 /// || Cell::new(42)
405 /// ).unwrap();
406 /// # let mut yard = yard;
407 ///
408 /// // Spawn an async function using the data.
409 /// yard.spawn_local(0, |data: Arc<Cell<u64>>| async move {data.set(12);});
410 /// # futures_executor::block_on(yard.wait_for_idle());
411 /// # assert_eq!(yard.access_per_thread_data(), Some(vec![&mut Cell::new(12)]));
412 ///
413 /// async fn some_async(data: Arc<Cell<u64>>) -> u64 {
414 /// data.set(15);
415 /// 2 * 2
416 /// }
417 ///
418 /// // Works with normal async functions too
419 /// let handle = yard.spawn_local(0, some_async);
420 /// # futures_executor::block_on(yard.wait_for_idle());
421 /// # assert_eq!(yard.access_per_thread_data(), Some(vec![&mut Cell::new(15)]));
422 /// # futures_executor::block_on(async move {
423 /// assert_eq!(handle.await, 4);
424 /// # });
425 /// ```
426 ///
427 /// # Panics
428 ///
429 /// - Panics is `pool` refers to a non-existent job pool.
430 pub fn spawn_local<Func, Fut, T>(&self, priority: Priority, async_fn: Func) -> JoinHandle<T>
431 where
432 Func: FnOnce(Arc<TD>) -> Fut + Send + 'static,
433 Fut: Future<Output = T>,
434 T: Send + 'static,
435 {
436 self.spawn_header();
437
438 let (sender, receiver) = oneshot_channel();
439 let job = Job::Local(Box::new(move |td| {
440 Box::pin(async move {
441 // We don't care about the result, if this fails, that just means the join handle
442 // has been dropped.
443 let unwind_async_fn = AssertUnwindSafe(async_fn);
444 let unwind_td = AssertUnwindSafe(td);
445 let future = catch_unwind(move || AssertUnwindSafe(unwind_async_fn.0(unwind_td.0)));
446
447 let ret = match future {
448 Ok(fut) => CatchUnwind::new(AssertUnwindSafe(fut)).await,
449 Err(panic) => Err(panic),
450 };
451
452 let _ = sender.send(ret);
453 })
454 }));
455
456 let queue: &Queue<TD> = &self.shared.queue;
457
458 let mut queue_guard = queue.inner.lock();
459 queue_guard.push(job, priority);
460 // the required guard is held in `queue_guard`
461 queue.notify_one();
462 drop(queue_guard);
463
464 JoinHandle {
465 receiver_future: receiver.receive(),
466 _receiver: receiver,
467 }
468 }
469
470 /// Wait until all working threads are starved of work due
471 /// to lack of jobs or all jobs waiting.
472 ///
473 /// # Safety
474 ///
475 /// - This function provides no safety guarantees.
476 /// - Jobs may be added while the future returns.
477 /// - Jobs may be woken while the future returns.
478 pub async fn wait_for_idle(&self) {
479 // We don't reset it, threads will reset it when they become active again
480 self.shared.idle_wait.wait().await;
481 }
482
483 /// Current amount of jobs in flight.
484 ///
485 /// # Safety
486 ///
487 /// - This function provides no safety guarantees.
488 /// - Jobs may be added after the value is received and before it is returned.
489 pub fn jobs(&self) -> usize {
490 self.shared.job_count.load(Ordering::Relaxed)
491 }
492
493 /// Count of threads currently processing jobs.
494 ///
495 /// # Safety
496 ///
497 /// - This function provides no safety guarantees.
498 /// - Jobs may be added after the value is received and before it is returned re-activating threads.
499 pub fn active_threads(&self) -> usize {
500 self.shared.active_threads.load(Ordering::Relaxed)
501 }
502
503 /// Access the per-thread data of each thread. Only available if `TD` is `Send`.
504 ///
505 /// This function requires `&mut self` in order to be sound. If you have the yard in a global,
506 /// you need to wrap it with `RwLock` so you can get a `&mut` from a `&`.
507 ///
508 /// Two conditions need to be true for this to return `Some`. First all threads must be idle
509 /// (i.e. `wait_for_idle`'s future would immediately return). Second no references to any thread's
510 /// local data may be alive.
511 ///
512 /// # Example
513 ///
514 /// ```rust
515 /// use std::{cell::Cell, sync::Arc};
516 /// use switchyard::{Switchyard, threads::single_thread};
517 ///
518 /// // Create a yard with thread local data.
519 /// let mut yard: Switchyard<Cell<u64>> = Switchyard::new(
520 /// single_thread(None, None),
521 /// || Cell::new(42)
522 /// ).unwrap();
523 ///
524 /// // Wait for all threads to get themselves situated.
525 /// # futures_executor::block_on(async {
526 /// yard.wait_for_idle().await;
527 /// # });
528 ///
529 /// // View that thread-local data. The yard has one thread, so returns a vec of length one.
530 /// assert_eq!(yard.access_per_thread_data(), Some(vec![&mut Cell::new(42)]));
531 ///
532 /// // Launch a task to change that data
533 /// let handle = yard.spawn_local(0, |data| async move { data.set(525_600); });
534 ///
535 /// // If the task isn't finished yet, this will return None.
536 /// yard.access_per_thread_data();
537 ///
538 /// // Wait for task to be done
539 /// # futures_executor::block_on(async {
540 /// assert_eq!(handle.await, ());
541 /// # });
542 ///
543 /// // We also need to wait for all threads to come to a stopping place
544 /// # futures_executor::block_on(async {
545 /// yard.wait_for_idle().await;
546 /// # });
547 ///
548 /// // Observe changed value
549 /// assert_eq!(yard.access_per_thread_data(), Some(vec![&mut Cell::new(525_600)]));
550 /// ```
551 ///
552 /// # Safety
553 ///
554 /// - This function guarantees that there exist no other references to this data if `Some` is returned.
555 /// - This function guarantees that `jobs()` is 0 and will stay zero while the returned references are still live.
556 pub fn access_per_thread_data(&mut self) -> Option<Vec<&mut TD>>
557 where
558 TD: Send,
559 {
560 let threads_live = self.shared.active_threads.load(Ordering::Acquire);
561
562 // SAFETY: No more jobs can be added and threads woken because we have an exclusive reference to the yard.
563 if threads_live != 0 {
564 return None;
565 }
566
567 // SAFETY:
568 // - We know there are no threads running because `count` is zero and we have an exclusive reference to the yard.
569 // - Threads do not keep references to their `Arc`'s around while idle, nor hand them to tasks.
570 // - `TD` is allowed to be `!Sync` because we never actually touch a `&TD`, only `&mut TD`.
571 let arcs = self.thread_local_data.iter().map(|&ptr| unsafe { &mut *ptr });
572
573 let data: Option<Vec<&mut TD>> = arcs.map(|arc| Arc::get_mut(arc)).collect();
574
575 data
576 }
577
578 /// Kill all threads as soon as they finish their jobs. All calls to spawn and spawn_local will
579 /// panic after this function is called.
580 ///
581 /// This is equivalent to calling drop. Calling this function twice will be a no-op
582 /// the second time.
583 pub fn finish(&mut self) {
584 // send death signal then wake everyone up
585 self.shared.death_signal.store(true, Ordering::Release);
586 let lock = self.shared.queue.inner.lock();
587 self.shared.queue.notify_all();
588 drop(lock);
589
590 self.thread_local_data.clear();
591 for thread in self.threads.drain(..) {
592 thread.join().unwrap();
593 }
594 }
595}
596
597impl<TD: 'static> Drop for Switchyard<TD> {
598 fn drop(&mut self) {
599 self.finish()
600 }
601}
602
603unsafe impl<TD> Send for Switchyard<TD> {}
604unsafe impl<TD> Sync for Switchyard<TD> {}