switchyard/lib.rs
1//! Real-time compute-focused async executor with job pools, thread-local data, and priorities.
2//!
3//! # Example
4//!
5//! ```rust
6//! use switchyard::Switchyard;
7//! use switchyard::threads::{thread_info, one_to_one};
8//! // Create a new switchyard with one job pool and empty thread local data
9//! let yard = Switchyard::new(one_to_one(thread_info(), Some("thread-name")), ||()).unwrap();
10//!
11//! // Spawn a task on pool 0 and priority 10 and get a JoinHandle
12//! let handle = yard.spawn(10, async move { 5 + 5 });
13//! // Spawn a lower priority task on the same pool
14//! let handle2 = yard.spawn(0, async move { 2 + 2 });
15//!
16//! // Wait on the results
17//! # futures_executor::block_on(async {
18//! assert_eq!(handle.await + handle2.await, 14);
19//! # });
20//! ```
21//!
22//! # How Switchyard is Different
23//!
24//! Switchyard is different from other existing async executors, focusing on situations where
25//! precise control of threads and execution order is needed. One such situation is using
26//! task parallelism to parallelize a compute workload.
27//!
28//! ## Priorites
29//!
30//! Each task has a priority and tasks are ran in order from high priority to low priority.
31//!
32//! ```rust
33//! # use switchyard::{Switchyard, threads::{thread_info, one_to_one}};
34//! # let yard = Switchyard::new(one_to_one(thread_info(), Some("thread-name")), ||()).unwrap();
35//! // Spawn task with lowest priority.
36//! yard.spawn(0, async move { /* ... */ });
37//! // Spawn task with higher priority. If both tasks are waiting, this one will run first.
38//! yard.spawn(10, async move { /* ... */ });
39//! ```
40//!
41//! ## Thread Local Data
42//!
43//! Each yard has some thread local data that can be accessed using [`spawn_local`](Switchyard::spawn_local).
44//! Both the thread local data and the future generated by the async function passed to [`spawn_local`](Switchyard::spawn_local)
45//! may be `!Send` and `!Sync`. The future will only be resumed on the thread that created it.
46//!
47//! ```rust
48//! # use switchyard::{Switchyard, threads::{thread_info, one_to_one}};
49//! # use std::cell::Cell;
50//! // Create yard with thread local data. The data is !Sync.
51//! let yard = Switchyard::new(one_to_one(thread_info(), Some("thread-name")), || Cell::new(42)).unwrap();
52//!
53//! // Spawn task that uses thread local data. Each running thread will get their own copy.
54//! yard.spawn_local(0, |data| async move { data.set(10) });
55//! ```
56//!
57//! # MSRV
58//! 1.51
59//!
60//! Future MSRV bumps will be breaking changes.
61
62#![deny(future_incompatible)]
63#![deny(nonstandard_style)]
64#![deny(rust_2018_idioms)]
65
66use crate::{
67 task::{Job, Task, ThreadLocalJob, ThreadLocalTask},
68 threads::ThreadAllocationOutput,
69 util::ThreadLocalPointer,
70};
71use futures_intrusive::{
72 channel::shared::{oneshot_channel, ChannelReceiveFuture, OneshotReceiver},
73 sync::ManualResetEvent,
74};
75use futures_task::{Context, Poll};
76use parking_lot::{Condvar, Mutex, RawMutex};
77use priority_queue::PriorityQueue;
78use slotmap::{DefaultKey, DenseSlotMap};
79use std::{
80 any::Any,
81 future::Future,
82 panic::{catch_unwind, AssertUnwindSafe, UnwindSafe},
83 pin::Pin,
84 sync::{
85 atomic::{AtomicBool, AtomicUsize, Ordering},
86 Arc,
87 },
88};
89
90pub mod affinity;
91mod error;
92mod task;
93pub mod threads;
94mod util;
95mod worker;
96
97pub use error::*;
98
99/// Integer alias for a priority.
100pub type Priority = u32;
101/// Integer alias for the maximum amount of pools.
102pub type PoolCount = u8;
103
104/// Handle to a currently running task.
105///
106/// Awaiting this future will give the return value of the task.
107pub struct JoinHandle<T: 'static> {
108 _receiver: OneshotReceiver<Result<T, Box<dyn Any + Send + 'static>>>,
109 receiver_future: ChannelReceiveFuture<RawMutex, Result<T, Box<dyn Any + Send + 'static>>>,
110}
111impl<T: 'static> Future for JoinHandle<T> {
112 type Output = T;
113
114 fn poll(self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Self::Output> {
115 let fut = unsafe { Pin::new_unchecked(&mut self.get_unchecked_mut().receiver_future) };
116 let poll_res = fut.poll(ctx);
117
118 match poll_res {
119 Poll::Ready(None) => {
120 // If this returns ready with none, that means the channel was closed
121 // due to the waker dying. We can just return pending as this future will never
122 // return.
123 Poll::Pending
124 }
125 Poll::Ready(Some(value)) => Poll::Ready(value.unwrap_or_else(|_| panic!("Job panicked!"))),
126 Poll::Pending => Poll::Pending,
127 }
128 }
129}
130
131/// Vendored from futures-util as holy hell that's a large lib.
132struct CatchUnwind<Fut>(Fut);
133
134impl<Fut> CatchUnwind<Fut>
135where
136 Fut: Future + UnwindSafe,
137{
138 fn new(future: Fut) -> CatchUnwind<Fut> {
139 CatchUnwind(future)
140 }
141}
142
143impl<Fut> Future for CatchUnwind<Fut>
144where
145 Fut: Future + UnwindSafe,
146{
147 type Output = Result<Fut::Output, Box<dyn Any + Send>>;
148
149 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
150 let f = unsafe { Pin::new_unchecked(&mut self.get_unchecked_mut().0) };
151 catch_unwind(AssertUnwindSafe(|| f.poll(cx)))?.map(Ok)
152 }
153}
154
155struct ThreadLocalQueue<TD> {
156 waiting: Mutex<DenseSlotMap<DefaultKey, Arc<ThreadLocalTask<TD>>>>,
157 inner: Mutex<PriorityQueue<ThreadLocalJob<TD>, u32>>,
158}
159struct FlaggedCondvar {
160 running: AtomicBool,
161 inner: Condvar,
162}
163struct Queue<TD> {
164 waiting: Mutex<DenseSlotMap<DefaultKey, Arc<Task<TD>>>>,
165 inner: Mutex<PriorityQueue<Job<TD>, u32>>,
166 condvars: Vec<FlaggedCondvar>,
167}
168impl<TD> Queue<TD> {
169 /// Must be called with `queue.inner`'s lock held.
170 fn notify_one(&self) {
171 for var in &self.condvars {
172 if !var.running.load(Ordering::Relaxed) {
173 var.inner.notify_one();
174 return;
175 }
176 }
177 }
178
179 /// Must be called with `queue.inner`'s lock held.
180 fn notify_all(&self) {
181 // We could be more efficient and not notify everyone, but this is more surefire
182 // and this function is only called on shutdown.
183 for var in &self.condvars {
184 var.inner.notify_all();
185 }
186 }
187}
188
189struct Shared<TD> {
190 active_threads: AtomicUsize,
191 idle_wait: ManualResetEvent,
192 job_count: AtomicUsize,
193 death_signal: AtomicBool,
194 queue: Queue<TD>,
195}
196
197/// Compute focused async executor.
198///
199/// See crate documentation for more details.
200pub struct Switchyard<TD: 'static> {
201 shared: Arc<Shared<TD>>,
202 threads: Vec<std::thread::JoinHandle<()>>,
203 thread_local_data: Vec<*mut Arc<TD>>,
204}
205impl<TD: 'static> Switchyard<TD> {
206 /// Create a new switchyard.
207 ///
208 /// For each element in the provided `thread_allocations` iterator, the yard will spawn a worker
209 /// thread with the given settings. Helper functions in [`threads`] can generate these iterators
210 /// for common situations.
211 ///
212 /// `thread_local_data_creation` will be called on each thread to create the thread local
213 /// data accessible by `spawn_local`.
214 pub fn new<TDFunc>(
215 thread_allocations: impl IntoIterator<Item = ThreadAllocationOutput>,
216 thread_local_data_creation: TDFunc,
217 ) -> Result<Self, SwitchyardCreationError>
218 where
219 TDFunc: Fn() -> TD + Send + Sync + 'static,
220 {
221 let (thread_local_sender, thread_local_receiver) = std::sync::mpsc::channel();
222
223 let thread_local_data_creation_arc = Arc::new(thread_local_data_creation);
224 let allocation_vec: Vec<_> = thread_allocations.into_iter().collect();
225
226 let num_logical_cpus = num_cpus::get();
227 for allocation in allocation_vec.iter() {
228 if let Some(affin) = allocation.affinity {
229 if affin >= num_logical_cpus {
230 return Err(SwitchyardCreationError::InvalidAffinity {
231 affinity: affin,
232 total_threads: num_logical_cpus,
233 });
234 }
235 }
236 }
237
238 let mut shared = Arc::new(Shared {
239 queue: Queue {
240 waiting: Mutex::new(DenseSlotMap::new()),
241 inner: Mutex::new(PriorityQueue::new()),
242 condvars: Vec::new(),
243 },
244 active_threads: AtomicUsize::new(allocation_vec.len()),
245 idle_wait: ManualResetEvent::new(false),
246 job_count: AtomicUsize::new(0),
247 death_signal: AtomicBool::new(false),
248 });
249
250 let shared_guard = Arc::get_mut(&mut shared).unwrap();
251
252 let queue_local_indices: Vec<_> = allocation_vec
253 .iter()
254 .map(|_| {
255 let condvar_array = &mut shared_guard.queue.condvars;
256
257 let queue_local_index = condvar_array.len();
258 condvar_array.push(FlaggedCondvar {
259 inner: Condvar::new(),
260 running: AtomicBool::new(true),
261 });
262
263 queue_local_index
264 })
265 .collect();
266
267 let mut threads = Vec::with_capacity(allocation_vec.len());
268 for (mut thread_info, queue_local_index) in allocation_vec.into_iter().zip(queue_local_indices) {
269 let builder = std::thread::Builder::new();
270 let builder = if let Some(name) = thread_info.name.take() {
271 builder.name(name)
272 } else {
273 builder
274 };
275 let builder = if let Some(stack_size) = thread_info.stack_size.take() {
276 builder.stack_size(stack_size)
277 } else {
278 builder
279 };
280
281 threads.push(
282 builder
283 .spawn(worker::body::<TD, TDFunc>(
284 Arc::clone(&shared),
285 thread_info,
286 queue_local_index,
287 thread_local_sender.clone(),
288 thread_local_data_creation_arc.clone(),
289 ))
290 .unwrap_or_else(|_| panic!("Could not spawn thread")),
291 );
292 }
293 // drop the sender we own, so we can retrieve pointers until all senders are dropped
294 drop(thread_local_sender);
295
296 let mut thread_local_data = Vec::with_capacity(threads.len());
297 while let Ok(ThreadLocalPointer(ptr)) = thread_local_receiver.recv() {
298 thread_local_data.push(ptr);
299 }
300
301 Ok(Self {
302 threads,
303 shared,
304 thread_local_data,
305 })
306 }
307
308 /// Things that must be done every time a task is spawned
309 fn spawn_header(&self) {
310 assert!(
311 !self.shared.death_signal.load(Ordering::Acquire),
312 "finish() has been called on this Switchyard. No more jobs may be added."
313 );
314
315 self.shared.job_count.fetch_add(1, Ordering::AcqRel);
316
317 // Say we're no longer idle so that `yard.spawn(); yard.wait_for_idle()`
318 // won't "return early". If the thread hasn't woken up fully yet by the
319 // time wait_for_idle is called, it will immediately return even though logically there's
320 // still an outstanding, active, job.
321 self.shared.idle_wait.reset();
322 }
323
324 /// Spawn a future which can migrate between threads during executionat the given `priority`.
325 ///
326 /// A higher `priority` will cause the task to be run sooner.
327 ///
328 /// # Example
329 ///
330 /// ```rust
331 /// use switchyard::{Switchyard, threads::single_thread};
332 ///
333 /// // Create a yard with a single pool
334 /// let yard: Switchyard<()> = Switchyard::new(single_thread(None, None), || ()).unwrap();
335 ///
336 /// // Spawn a task with priority 0 and get a handle to the result.
337 /// let handle = yard.spawn(0, async move { 2 * 2 });
338 ///
339 /// // Await result
340 /// # futures_executor::block_on(async move {
341 /// assert_eq!(handle.await, 4);
342 /// # });
343 /// ```
344 ///
345 /// # Panics
346 ///
347 /// - [`finish`](Switchyard::finish) has been called on the pool.
348 pub fn spawn<Fut, T>(&self, priority: Priority, fut: Fut) -> JoinHandle<T>
349 where
350 Fut: Future<Output = T> + Send + 'static,
351 T: Send + 'static,
352 {
353 self.spawn_header();
354
355 let (sender, receiver) = oneshot_channel();
356 let job = Job::Future(Task::new(
357 Arc::clone(&self.shared),
358 async move {
359 // We don't care about the result, if this fails, that just means the join handle
360 // has been dropped.
361 let _ = sender.send(CatchUnwind::new(std::panic::AssertUnwindSafe(fut)).await);
362 },
363 priority,
364 ));
365
366 let queue: &Queue<TD> = &self.shared.queue;
367
368 let mut queue_guard = queue.inner.lock();
369 queue_guard.push(job, priority);
370 // the required guard is held in `queue_guard`
371 queue.notify_one();
372 drop(queue_guard);
373
374 JoinHandle {
375 receiver_future: receiver.receive(),
376 _receiver: receiver,
377 }
378 }
379
380 /// Spawns an async function which is tied to a single thread during execution.
381 ///
382 /// Spawns to the given job `pool` at the given `priority`.
383 ///
384 /// The given async function will be provided an `Arc` to the thread-local data to create its future with.
385 ///
386 /// A higher `priority` will cause the task to be run sooner.
387 ///
388 /// The function must be `Send`, but the future returned by that function may be `!Send`.
389 ///
390 /// # Example
391 ///
392 /// ```rust
393 /// use std::{cell::Cell, sync::Arc};
394 /// use switchyard::{Switchyard, threads::single_thread};
395 ///
396 /// // Create a yard with thread local data.
397 /// let yard: Switchyard<Cell<u64>> = Switchyard::new(
398 /// single_thread(None, None),
399 /// || Cell::new(42)
400 /// ).unwrap();
401 /// # let mut yard = yard;
402 ///
403 /// // Spawn an async function using the data.
404 /// yard.spawn_local(0, |data: Arc<Cell<u64>>| async move {data.set(12);});
405 /// # futures_executor::block_on(yard.wait_for_idle());
406 ///
407 /// async fn some_async(data: Arc<Cell<u64>>) -> u64 {
408 /// data.set(15);
409 /// 2 * 2
410 /// }
411 ///
412 /// // Works with normal async functions too
413 /// let handle = yard.spawn_local(0, some_async);
414 /// # futures_executor::block_on(yard.wait_for_idle());
415 /// # futures_executor::block_on(async move {
416 /// assert_eq!(handle.await, 4);
417 /// # });
418 /// ```
419 ///
420 /// # Panics
421 ///
422 /// - Panics is `pool` refers to a non-existent job pool.
423 pub fn spawn_local<Func, Fut, T>(&self, priority: Priority, async_fn: Func) -> JoinHandle<T>
424 where
425 Func: FnOnce(Arc<TD>) -> Fut + Send + 'static,
426 Fut: Future<Output = T>,
427 T: Send + 'static,
428 {
429 self.spawn_header();
430
431 let (sender, receiver) = oneshot_channel();
432 let job = Job::Local(Box::new(move |td| {
433 Box::pin(async move {
434 // We don't care about the result, if this fails, that just means the join handle
435 // has been dropped.
436 let unwind_async_fn = AssertUnwindSafe(async_fn);
437 let unwind_td = AssertUnwindSafe(td);
438 let future = catch_unwind(move || AssertUnwindSafe(unwind_async_fn.0(unwind_td.0)));
439
440 let ret = match future {
441 Ok(fut) => CatchUnwind::new(AssertUnwindSafe(fut)).await,
442 Err(panic) => Err(panic),
443 };
444
445 let _ = sender.send(ret);
446 })
447 }));
448
449 let queue: &Queue<TD> = &self.shared.queue;
450
451 let mut queue_guard = queue.inner.lock();
452 queue_guard.push(job, priority);
453 // the required guard is held in `queue_guard`
454 queue.notify_one();
455 drop(queue_guard);
456
457 JoinHandle {
458 receiver_future: receiver.receive(),
459 _receiver: receiver,
460 }
461 }
462
463 /// Wait until all working threads are starved of work due
464 /// to lack of jobs or all jobs waiting.
465 ///
466 /// # Safety
467 ///
468 /// - This function provides no safety guarantees.
469 /// - Jobs may be added while the future returns.
470 /// - Jobs may be woken while the future returns.
471 pub async fn wait_for_idle(&self) {
472 // We don't reset it, threads will reset it when they become active again
473 self.shared.idle_wait.wait().await;
474 }
475
476 /// Current amount of jobs in flight.
477 ///
478 /// # Safety
479 ///
480 /// - This function provides no safety guarantees.
481 /// - Jobs may be added after the value is received and before it is returned.
482 pub fn jobs(&self) -> usize {
483 self.shared.job_count.load(Ordering::Relaxed)
484 }
485
486 /// Count of threads currently processing jobs.
487 ///
488 /// # Safety
489 ///
490 /// - This function provides no safety guarantees.
491 /// - Jobs may be added after the value is received and before it is returned re-activating threads.
492 pub fn active_threads(&self) -> usize {
493 self.shared.active_threads.load(Ordering::Relaxed)
494 }
495
496 /// Kill all threads as soon as they finish their jobs. All calls to spawn and spawn_local will
497 /// panic after this function is called.
498 ///
499 /// This is equivalent to calling drop. Calling this function twice will be a no-op
500 /// the second time.
501 pub fn finish(&mut self) {
502 // send death signal then wake everyone up
503 self.shared.death_signal.store(true, Ordering::Release);
504 let lock = self.shared.queue.inner.lock();
505 self.shared.queue.notify_all();
506 drop(lock);
507
508 self.thread_local_data.clear();
509 for thread in self.threads.drain(..) {
510 thread.join().unwrap();
511 }
512 }
513}
514
515impl<TD: 'static> Drop for Switchyard<TD> {
516 fn drop(&mut self) {
517 self.finish()
518 }
519}
520
521unsafe impl<TD> Send for Switchyard<TD> {}
522unsafe impl<TD> Sync for Switchyard<TD> {}