vortex_io/dispatcher/
mod.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4#[cfg(feature = "compio")]
5mod compio;
6#[cfg(not(target_arch = "wasm32"))]
7mod tokio;
8#[cfg(target_arch = "wasm32")]
9mod wasm;
10
11use std::future::Future;
12use std::sync::{Arc, LazyLock};
13use std::task::Poll;
14
15use cfg_if::cfg_if;
16use futures::FutureExt;
17use futures::channel::oneshot;
18use vortex_error::{VortexResult, vortex_err};
19
20static SHARED: LazyLock<IoDispatcher> = LazyLock::new(IoDispatcher::new);
21
22#[cfg(feature = "compio")]
23use self::compio::*;
24#[cfg(not(target_arch = "wasm32"))]
25use self::tokio::*;
26#[cfg(target_arch = "wasm32")]
27use self::wasm::*;
28
29mod sealed {
30    pub trait Sealed {}
31
32    impl Sealed for super::IoDispatcher {}
33
34    #[cfg(feature = "compio")]
35    impl Sealed for super::CompioDispatcher {}
36
37    #[cfg(not(target_arch = "wasm32"))]
38    impl Sealed for super::TokioDispatcher {}
39
40    #[cfg(target_arch = "wasm32")]
41    impl Sealed for super::WasmDispatcher {}
42}
43
44/// A trait for types that may be dispatched.
45pub trait Dispatch: sealed::Sealed {
46    /// Dispatch a new asynchronous task.
47    ///
48    /// The function spawning the task must be `Send` as it will be sent to
49    /// the driver thread.
50    ///
51    /// The returned `Future` will be executed to completion on a single thread,
52    /// thus it may be `!Send`.
53    fn dispatch<F, Fut, R>(&self, task: F) -> VortexResult<JoinHandle<R>>
54    where
55        F: (FnOnce() -> Fut) + Send + 'static,
56        Fut: Future<Output = R> + 'static,
57        R: Send + 'static;
58
59    /// Gracefully shutdown the dispatcher, consuming it.
60    ///
61    /// Existing tasks are awaited before exiting.
62    fn shutdown(self) -> VortexResult<()>;
63}
64
65/// <div class="warning">IoDispatcher is unstable and may change in the future.</div>
66///
67/// A cross-thread, cross-runtime dispatcher of async IO workloads.
68///
69/// `IoDispatcher`s are handles to an async runtime that can handle work submissions and
70/// multiplexes them across a set of worker threads. Unlike an async runtime, which is free
71/// to balance tasks as they see fit, the purpose of the Dispatcher is to enable the spawning
72/// of asynchronous, `!Send` tasks across potentially many worker threads, and allowing work
73/// submission from any other runtime.
74///
75#[derive(Clone, Debug)]
76pub struct IoDispatcher(Arc<Inner>);
77
78impl IoDispatcher {
79    pub fn new() -> Self {
80        cfg_if! {
81            if #[cfg(target_arch = "wasm32")] {
82                Self(Arc::new(Inner::Wasm(WasmDispatcher::new())))
83            } else if #[cfg(not(feature = "compio"))] {
84                Self(Arc::new(Inner::Tokio(TokioDispatcher::new(1))))
85            } else {
86                Self(Arc::new(Inner::Compio(CompioDispatcher::new(1))))
87            }
88        }
89    }
90
91    /// Create a new IO dispatcher that uses a set of Tokio `current_thread` runtimes to
92    /// execute both `Send` and `!Send` futures.
93    ///
94    /// A handle to the dispatcher can be passed freely among threads, allowing multiple parties to
95    /// perform dispatching across different threads.
96    #[cfg(not(target_arch = "wasm32"))]
97    pub fn new_tokio(num_thread: usize) -> Self {
98        Self(Arc::new(Inner::Tokio(TokioDispatcher::new(num_thread))))
99    }
100
101    #[cfg(feature = "compio")]
102    pub fn new_compio(num_threads: usize) -> Self {
103        Self(Arc::new(Inner::Compio(CompioDispatcher::new(num_threads))))
104    }
105
106    #[cfg(target_arch = "wasm32")]
107    pub fn new_wasm() -> Self {
108        Self(Arc::new(Inner::Wasm(WasmDispatcher)))
109    }
110}
111
112impl Default for IoDispatcher {
113    fn default() -> Self {
114        Self::new()
115    }
116}
117
118impl IoDispatcher {
119    /// Returns a handle to the current process's shared Dispatcher.
120    pub fn shared() -> Self {
121        SHARED.clone()
122    }
123}
124
125pub struct JoinHandle<R>(oneshot::Receiver<R>);
126
127impl<R> Future for JoinHandle<R> {
128    type Output = VortexResult<R>;
129
130    fn poll(
131        mut self: std::pin::Pin<&mut Self>,
132        cx: &mut std::task::Context<'_>,
133    ) -> Poll<Self::Output> {
134        match self.0.poll_unpin(cx) {
135            Poll::Ready(Ok(v)) => Poll::Ready(Ok(v)),
136            Poll::Ready(Err(_)) => Poll::Ready(Err(vortex_err!("Task was canceled"))),
137            Poll::Pending => Poll::Pending,
138        }
139    }
140}
141
142#[derive(Debug)]
143enum Inner {
144    #[cfg(not(target_arch = "wasm32"))]
145    Tokio(TokioDispatcher),
146    #[cfg(feature = "compio")]
147    Compio(CompioDispatcher),
148    #[cfg(target_arch = "wasm32")]
149    Wasm(WasmDispatcher),
150}
151
152impl Dispatch for IoDispatcher {
153    #[allow(unused_variables)] // If no features are enabled `task` ends up being unused
154    fn dispatch<F, Fut, R>(&self, task: F) -> VortexResult<JoinHandle<R>>
155    where
156        F: (FnOnce() -> Fut) + Send + 'static,
157        Fut: Future<Output = R> + 'static,
158        R: Send + 'static,
159    {
160        match self.0.as_ref() {
161            #[cfg(not(target_arch = "wasm32"))]
162            Inner::Tokio(tokio_dispatch) => tokio_dispatch.dispatch(task),
163            #[cfg(feature = "compio")]
164            Inner::Compio(compio_dispatch) => compio_dispatch.dispatch(task),
165            #[cfg(target_arch = "wasm32")]
166            Inner::Wasm(wasm_dispatch) => wasm_dispatch.dispatch(task),
167        }
168    }
169
170    fn shutdown(self) -> VortexResult<()> {
171        if let Ok(inner) = Arc::try_unwrap(self.0) {
172            match inner {
173                #[cfg(not(target_arch = "wasm32"))]
174                Inner::Tokio(tokio_dispatch) => tokio_dispatch.shutdown(),
175                #[cfg(feature = "compio")]
176                Inner::Compio(compio_dispatch) => compio_dispatch.shutdown(),
177                #[cfg(target_arch = "wasm32")]
178                Inner::Wasm(wasm_dispatch) => wasm_dispatch.shutdown(),
179            }
180        } else {
181            Ok(())
182        }
183    }
184}
185
186#[cfg(test)]
187#[cfg(not(target_arch = "wasm32"))]
188#[cfg(feature = "tokio")]
189mod tests {
190    use std::sync::Arc;
191    use std::sync::atomic::{AtomicBool, Ordering};
192
193    use super::{Dispatch, IoDispatcher};
194
195    #[::tokio::test]
196    async fn test_dispatcher_task_panic_handling() {
197        let dispatcher = IoDispatcher::new();
198        let completed = Arc::new(AtomicBool::new(false));
199        let completed_clone = completed.clone();
200
201        // Dispatch a task that will panic
202        #[allow(clippy::panic)]
203        let _handle = dispatcher.dispatch(move || async move {
204            panic!("Task panic");
205        });
206
207        // Also dispatch a normal task to verify dispatcher continues working
208        let normal_handle = dispatcher
209            .dispatch(move || async move {
210                completed_clone.store(true, Ordering::SeqCst);
211                42
212            })
213            .unwrap();
214
215        // The panic task should propagate the error
216        // Note: this depends on implementation details
217
218        // The normal task should complete
219        let result = normal_handle.await;
220        assert_eq!(result.unwrap(), 42);
221        assert!(completed.load(Ordering::SeqCst));
222
223        dispatcher.shutdown().unwrap();
224    }
225
226    #[test]
227    fn test_dispatcher_shutdown_empty_queue() {
228        let dispatcher = IoDispatcher::new();
229        // Immediate shutdown should work
230        dispatcher.shutdown().unwrap();
231    }
232
233    #[::tokio::test]
234    async fn test_dispatcher_many_threads() {
235        let dispatcher = IoDispatcher::new();
236        let mut handles = Vec::new();
237
238        for i in 0..100 {
239            let handle = dispatcher.dispatch(move || async move { i * 2 }).unwrap();
240            handles.push(handle);
241        }
242
243        for (i, handle) in handles.into_iter().enumerate() {
244            let result = handle.await;
245            assert_eq!(result.unwrap(), i * 2);
246        }
247
248        dispatcher.shutdown().unwrap();
249    }
250}