Skip to main content

vortex_io/runtime/
tokio.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::sync::Arc;
5use std::sync::LazyLock;
6
7use futures::future::BoxFuture;
8
9use crate::runtime::AbortHandle;
10use crate::runtime::AbortHandleRef;
11use crate::runtime::BlockingRuntime;
12use crate::runtime::Executor;
13use crate::runtime::Handle;
14
15/// A Vortex runtime that drives all work the enclosed Tokio runtime handle.
16pub struct TokioRuntime(Arc<tokio::runtime::Handle>);
17
18impl TokioRuntime {
19    pub fn new(handle: tokio::runtime::Handle) -> Self {
20        Self(Arc::new(handle))
21    }
22
23    /// Create a new [`Handle`] that always uses the currently scoped Tokio runtime at the time
24    /// each operation is invoked.
25    pub fn current() -> Handle {
26        static CURRENT: LazyLock<Arc<dyn Executor>> =
27            LazyLock::new(|| Arc::new(CurrentTokioRuntime));
28        Handle::new(Arc::downgrade(&CURRENT))
29    }
30}
31
32impl From<&tokio::runtime::Handle> for TokioRuntime {
33    fn from(value: &tokio::runtime::Handle) -> Self {
34        Self::from(value.clone())
35    }
36}
37
38impl From<tokio::runtime::Handle> for TokioRuntime {
39    fn from(value: tokio::runtime::Handle) -> Self {
40        TokioRuntime(Arc::new(value))
41    }
42}
43
44impl Executor for tokio::runtime::Handle {
45    fn spawn(&self, fut: BoxFuture<'static, ()>) -> AbortHandleRef {
46        #[cfg(unix)]
47        {
48            use custom_labels::asynchronous::Label;
49
50            let fut = fut.with_current_labels();
51            Box::new(tokio::runtime::Handle::spawn(self, fut).abort_handle())
52        }
53        #[cfg(not(unix))]
54        {
55            Box::new(tokio::runtime::Handle::spawn(self, fut).abort_handle())
56        }
57    }
58
59    fn spawn_cpu(&self, cpu: Box<dyn FnOnce() + Send + 'static>) -> AbortHandleRef {
60        #[cfg(unix)]
61        {
62            use custom_labels::asynchronous::Label;
63
64            Box::new(
65                tokio::runtime::Handle::spawn(self, async move { cpu() }.with_current_labels())
66                    .abort_handle(),
67            )
68        }
69        #[cfg(not(unix))]
70        {
71            Box::new(tokio::runtime::Handle::spawn(self, async move { cpu() }).abort_handle())
72        }
73    }
74
75    fn spawn_blocking_io(&self, task: Box<dyn FnOnce() + Send + 'static>) -> AbortHandleRef {
76        #[cfg(unix)]
77        {
78            use custom_labels::Labelset;
79
80            let mut set = Labelset::clone_from_current();
81            Box::new(
82                tokio::runtime::Handle::spawn_blocking(self, move || set.enter(task))
83                    .abort_handle(),
84            )
85        }
86        #[cfg(not(unix))]
87        {
88            Box::new(tokio::runtime::Handle::spawn_blocking(self, task).abort_handle())
89        }
90    }
91}
92
93/// A runtime implementation that grabs the current Tokio runtime handle on each call.
94struct CurrentTokioRuntime;
95
96impl Executor for CurrentTokioRuntime {
97    fn spawn(&self, fut: BoxFuture<'static, ()>) -> AbortHandleRef {
98        Executor::spawn(&tokio::runtime::Handle::current(), fut)
99    }
100
101    fn spawn_cpu(&self, cpu: Box<dyn FnOnce() + Send + 'static>) -> AbortHandleRef {
102        Executor::spawn_cpu(&tokio::runtime::Handle::current(), cpu)
103    }
104
105    fn spawn_blocking_io(&self, task: Box<dyn FnOnce() + Send + 'static>) -> AbortHandleRef {
106        Executor::spawn_blocking_io(&tokio::runtime::Handle::current(), task)
107    }
108}
109
110impl AbortHandle for tokio::task::AbortHandle {
111    fn abort(self: Box<Self>) {
112        tokio::task::AbortHandle::abort(&self)
113    }
114}
115
116// We depend on Tokio's rt-multi-thread feature for block-in-place
117impl BlockingRuntime for TokioRuntime {
118    type BlockingIterator<'a, R: 'a> = TokioBlockingIterator<'a, R>;
119
120    fn handle(&self) -> Handle {
121        let executor: Arc<dyn Executor> = Arc::clone(&self.0) as Arc<dyn Executor>;
122        Handle::new(Arc::downgrade(&executor))
123    }
124
125    fn block_on<Fut, R>(&self, fut: Fut) -> R
126    where
127        Fut: Future<Output = R>,
128    {
129        // Assert that we're not currently inside the Tokio context.
130        if tokio::runtime::Handle::try_current().is_ok() {
131            vortex_error::vortex_panic!("block_on cannot be called from within a Tokio runtime");
132        }
133        let handle = Arc::clone(&self.0);
134        tokio::task::block_in_place(move || handle.block_on(fut))
135    }
136
137    fn block_on_stream<'a, S, R>(&self, stream: S) -> Self::BlockingIterator<'a, R>
138    where
139        S: futures::Stream<Item = R> + Send + 'a,
140        R: Send + 'a,
141    {
142        // Assert that we're not currently inside the Tokio context.
143        if tokio::runtime::Handle::try_current().is_ok() {
144            vortex_error::vortex_panic!(
145                "block_on_stream cannot be called from within a Tokio runtime"
146            );
147        }
148        let handle = Arc::clone(&self.0);
149        let stream = Box::pin(stream);
150        TokioBlockingIterator { handle, stream }
151    }
152}
153
154pub struct TokioBlockingIterator<'a, T> {
155    handle: Arc<tokio::runtime::Handle>,
156    stream: futures::stream::BoxStream<'a, T>,
157}
158
159impl<T> Iterator for TokioBlockingIterator<'_, T> {
160    type Item = T;
161
162    fn next(&mut self) -> Option<Self::Item> {
163        use futures::StreamExt;
164
165        tokio::task::block_in_place(|| self.handle.block_on(self.stream.next()))
166    }
167}
168
169#[cfg(test)]
170mod tests {
171    use std::sync::Arc;
172    use std::sync::atomic::AtomicUsize;
173    use std::sync::atomic::Ordering;
174
175    use futures::FutureExt;
176    use tokio::runtime::Runtime as TokioRt;
177
178    use super::*;
179
180    #[test]
181    fn test_spawn_simple_future() {
182        let tokio_rt = TokioRt::new().unwrap();
183        let runtime = TokioRuntime::from(tokio_rt.handle());
184        let h = runtime.handle();
185        let result = runtime.block_on({
186            h.spawn(async {
187                let fut = async { 77 };
188                fut.await
189            })
190        });
191        assert_eq!(result, 77);
192    }
193
194    #[test]
195    fn test_spawn_and_abort() {
196        let tokio_rt = TokioRt::new().unwrap();
197        let runtime = TokioRuntime::from(tokio_rt.handle());
198
199        let counter = Arc::new(AtomicUsize::new(0));
200        let c = Arc::clone(&counter);
201
202        // Create a channel to ensure the future doesn't complete immediately
203        let (send, recv) = tokio::sync::oneshot::channel::<()>();
204
205        let fut = async move {
206            let _ = recv.await;
207            c.fetch_add(1, Ordering::SeqCst);
208        };
209        let task = runtime.handle().spawn(fut.boxed());
210        drop(task);
211
212        // Now we release the channel to let the future proceed if it wasn't aborted
213        let _ = send.send(());
214        assert_eq!(counter.load(Ordering::SeqCst), 0);
215    }
216}