vortex_io/dispatcher/
mod.rs1#[cfg(feature = "compio")]
5mod compio;
6#[cfg(not(target_arch = "wasm32"))]
7mod tokio;
8#[cfg(target_arch = "wasm32")]
9mod wasm;
10
11use std::future::Future;
12use std::sync::{Arc, LazyLock};
13use std::task::Poll;
14
15use cfg_if::cfg_if;
16use futures::FutureExt;
17use futures::channel::oneshot;
18use vortex_error::{VortexResult, vortex_err};
19
20static SHARED: LazyLock<IoDispatcher> = LazyLock::new(IoDispatcher::new);
21
22#[cfg(feature = "compio")]
23use self::compio::*;
24#[cfg(not(target_arch = "wasm32"))]
25use self::tokio::*;
26#[cfg(target_arch = "wasm32")]
27use self::wasm::*;
28
29mod sealed {
30 pub trait Sealed {}
31
32 impl Sealed for super::IoDispatcher {}
33
34 #[cfg(feature = "compio")]
35 impl Sealed for super::CompioDispatcher {}
36
37 #[cfg(not(target_arch = "wasm32"))]
38 impl Sealed for super::TokioDispatcher {}
39
40 #[cfg(target_arch = "wasm32")]
41 impl Sealed for super::WasmDispatcher {}
42}
43
44pub trait Dispatch: sealed::Sealed {
46 fn dispatch<F, Fut, R>(&self, task: F) -> VortexResult<JoinHandle<R>>
54 where
55 F: (FnOnce() -> Fut) + Send + 'static,
56 Fut: Future<Output = R> + 'static,
57 R: Send + 'static;
58
59 fn shutdown(self) -> VortexResult<()>;
63}
64
65#[derive(Clone, Debug)]
76pub struct IoDispatcher(Arc<Inner>);
77
78impl IoDispatcher {
79 pub fn new() -> Self {
80 cfg_if! {
81 if #[cfg(target_arch = "wasm32")] {
82 Self(Arc::new(Inner::Wasm(WasmDispatcher::new())))
83 } else if #[cfg(not(feature = "compio"))] {
84 Self(Arc::new(Inner::Tokio(TokioDispatcher::new(1))))
85 } else {
86 Self(Arc::new(Inner::Compio(CompioDispatcher::new(1))))
87 }
88 }
89 }
90
91 #[cfg(not(target_arch = "wasm32"))]
97 pub fn new_tokio(num_thread: usize) -> Self {
98 Self(Arc::new(Inner::Tokio(TokioDispatcher::new(num_thread))))
99 }
100
101 #[cfg(feature = "compio")]
102 pub fn new_compio(num_threads: usize) -> Self {
103 Self(Arc::new(Inner::Compio(CompioDispatcher::new(num_threads))))
104 }
105
106 #[cfg(target_arch = "wasm32")]
107 pub fn new_wasm() -> Self {
108 Self(Arc::new(Inner::Wasm(WasmDispatcher)))
109 }
110}
111
112impl Default for IoDispatcher {
113 fn default() -> Self {
114 Self::new()
115 }
116}
117
118impl IoDispatcher {
119 pub fn shared() -> Self {
121 SHARED.clone()
122 }
123}
124
125pub struct JoinHandle<R>(oneshot::Receiver<R>);
126
127impl<R> Future for JoinHandle<R> {
128 type Output = VortexResult<R>;
129
130 fn poll(
131 mut self: std::pin::Pin<&mut Self>,
132 cx: &mut std::task::Context<'_>,
133 ) -> Poll<Self::Output> {
134 match self.0.poll_unpin(cx) {
135 Poll::Ready(Ok(v)) => Poll::Ready(Ok(v)),
136 Poll::Ready(Err(_)) => Poll::Ready(Err(vortex_err!("Task was canceled"))),
137 Poll::Pending => Poll::Pending,
138 }
139 }
140}
141
142#[derive(Debug)]
143enum Inner {
144 #[cfg(not(target_arch = "wasm32"))]
145 Tokio(TokioDispatcher),
146 #[cfg(feature = "compio")]
147 Compio(CompioDispatcher),
148 #[cfg(target_arch = "wasm32")]
149 Wasm(WasmDispatcher),
150}
151
152impl Dispatch for IoDispatcher {
153 #[allow(unused_variables)] fn dispatch<F, Fut, R>(&self, task: F) -> VortexResult<JoinHandle<R>>
155 where
156 F: (FnOnce() -> Fut) + Send + 'static,
157 Fut: Future<Output = R> + 'static,
158 R: Send + 'static,
159 {
160 match self.0.as_ref() {
161 #[cfg(not(target_arch = "wasm32"))]
162 Inner::Tokio(tokio_dispatch) => tokio_dispatch.dispatch(task),
163 #[cfg(feature = "compio")]
164 Inner::Compio(compio_dispatch) => compio_dispatch.dispatch(task),
165 #[cfg(target_arch = "wasm32")]
166 Inner::Wasm(wasm_dispatch) => wasm_dispatch.dispatch(task),
167 }
168 }
169
170 fn shutdown(self) -> VortexResult<()> {
171 if let Ok(inner) = Arc::try_unwrap(self.0) {
172 match inner {
173 #[cfg(not(target_arch = "wasm32"))]
174 Inner::Tokio(tokio_dispatch) => tokio_dispatch.shutdown(),
175 #[cfg(feature = "compio")]
176 Inner::Compio(compio_dispatch) => compio_dispatch.shutdown(),
177 #[cfg(target_arch = "wasm32")]
178 Inner::Wasm(wasm_dispatch) => wasm_dispatch.shutdown(),
179 }
180 } else {
181 Ok(())
182 }
183 }
184}
185
186#[cfg(test)]
187#[cfg(not(target_arch = "wasm32"))]
188#[cfg(feature = "tokio")]
189mod tests {
190 use std::sync::Arc;
191 use std::sync::atomic::{AtomicBool, Ordering};
192
193 use super::{Dispatch, IoDispatcher};
194
195 #[::tokio::test]
196 async fn test_dispatcher_task_panic_handling() {
197 let dispatcher = IoDispatcher::new();
198 let completed = Arc::new(AtomicBool::new(false));
199 let completed_clone = completed.clone();
200
201 #[allow(clippy::panic)]
203 let _handle = dispatcher.dispatch(move || async move {
204 panic!("Task panic");
205 });
206
207 let normal_handle = dispatcher
209 .dispatch(move || async move {
210 completed_clone.store(true, Ordering::SeqCst);
211 42
212 })
213 .unwrap();
214
215 let result = normal_handle.await;
220 assert_eq!(result.unwrap(), 42);
221 assert!(completed.load(Ordering::SeqCst));
222
223 dispatcher.shutdown().unwrap();
224 }
225
226 #[test]
227 fn test_dispatcher_shutdown_empty_queue() {
228 let dispatcher = IoDispatcher::new();
229 dispatcher.shutdown().unwrap();
231 }
232
233 #[::tokio::test]
234 async fn test_dispatcher_many_threads() {
235 let dispatcher = IoDispatcher::new();
236 let mut handles = Vec::new();
237
238 for i in 0..100 {
239 let handle = dispatcher.dispatch(move || async move { i * 2 }).unwrap();
240 handles.push(handle);
241 }
242
243 for (i, handle) in handles.into_iter().enumerate() {
244 let result = handle.await;
245 assert_eq!(result.unwrap(), i * 2);
246 }
247
248 dispatcher.shutdown().unwrap();
249 }
250}