Skip to main content

vortex_io/runtime/
tokio.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::sync::Arc;
5use std::sync::LazyLock;
6
7use futures::future::BoxFuture;
8
9use crate::runtime::AbortHandle;
10use crate::runtime::AbortHandleRef;
11use crate::runtime::BlockingRuntime;
12use crate::runtime::Executor;
13use crate::runtime::Handle;
14
15/// A Vortex runtime that drives all work the enclosed Tokio runtime handle.
16pub struct TokioRuntime(Arc<tokio::runtime::Handle>);
17
18impl TokioRuntime {
19    pub fn new(handle: tokio::runtime::Handle) -> Self {
20        Self(Arc::new(handle))
21    }
22
23    /// Create a new [`Handle`] that always uses the currently scoped Tokio runtime at the time
24    /// each operation is invoked.
25    pub fn current() -> Handle {
26        static CURRENT: LazyLock<Arc<dyn Executor>> =
27            LazyLock::new(|| Arc::new(CurrentTokioRuntime));
28        Handle::new(Arc::downgrade(&CURRENT))
29    }
30}
31
32impl From<&tokio::runtime::Handle> for TokioRuntime {
33    fn from(value: &tokio::runtime::Handle) -> Self {
34        Self::from(value.clone())
35    }
36}
37
38impl From<tokio::runtime::Handle> for TokioRuntime {
39    fn from(value: tokio::runtime::Handle) -> Self {
40        TokioRuntime(Arc::new(value))
41    }
42}
43
44impl Executor for tokio::runtime::Handle {
45    fn spawn(&self, fut: BoxFuture<'static, ()>) -> AbortHandleRef {
46        #[cfg(unix)]
47        {
48            use custom_labels::asynchronous::Label;
49
50            let fut = fut.with_current_labels();
51            Box::new(tokio::runtime::Handle::spawn(self, fut).abort_handle())
52        }
53        #[cfg(not(unix))]
54        {
55            Box::new(tokio::runtime::Handle::spawn(self, fut).abort_handle())
56        }
57    }
58
59    fn spawn_cpu(&self, cpu: Box<dyn FnOnce() + Send + 'static>) -> AbortHandleRef {
60        #[cfg(unix)]
61        {
62            use custom_labels::asynchronous::Label;
63
64            Box::new(
65                tokio::runtime::Handle::spawn(self, async move { cpu() }.with_current_labels())
66                    .abort_handle(),
67            )
68        }
69        #[cfg(not(unix))]
70        {
71            Box::new(tokio::runtime::Handle::spawn(self, async move { cpu() }).abort_handle())
72        }
73    }
74
75    fn spawn_blocking(&self, task: Box<dyn FnOnce() + Send + 'static>) -> AbortHandleRef {
76        #[cfg(unix)]
77        {
78            use custom_labels::Labelset;
79
80            let mut set = Labelset::clone_from_current();
81            Box::new(
82                tokio::runtime::Handle::spawn_blocking(self, move || set.enter(task))
83                    .abort_handle(),
84            )
85        }
86        #[cfg(not(unix))]
87        {
88            Box::new(tokio::runtime::Handle::spawn_blocking(self, task).abort_handle())
89        }
90    }
91}
92
93/// A runtime implementation that grabs the current Tokio runtime handle on each call.
94struct CurrentTokioRuntime;
95
96impl Executor for CurrentTokioRuntime {
97    fn spawn(&self, fut: BoxFuture<'static, ()>) -> AbortHandleRef {
98        Box::new(tokio::runtime::Handle::current().spawn(fut).abort_handle())
99    }
100
101    fn spawn_cpu(&self, cpu: Box<dyn FnOnce() + Send + 'static>) -> AbortHandleRef {
102        Box::new(
103            tokio::runtime::Handle::current()
104                .spawn(async move { cpu() })
105                .abort_handle(),
106        )
107    }
108
109    fn spawn_blocking(&self, task: Box<dyn FnOnce() + Send + 'static>) -> AbortHandleRef {
110        Box::new(
111            tokio::runtime::Handle::current()
112                .spawn_blocking(task)
113                .abort_handle(),
114        )
115    }
116}
117
118impl AbortHandle for tokio::task::AbortHandle {
119    fn abort(self: Box<Self>) {
120        tokio::task::AbortHandle::abort(&self)
121    }
122}
123
124// We depend on Tokio's rt-multi-thread feature for block-in-place
125impl BlockingRuntime for TokioRuntime {
126    type BlockingIterator<'a, R: 'a> = TokioBlockingIterator<'a, R>;
127
128    fn handle(&self) -> Handle {
129        let executor: Arc<dyn Executor> = self.0.clone();
130        Handle::new(Arc::downgrade(&executor))
131    }
132
133    fn block_on<Fut, R>(&self, fut: Fut) -> R
134    where
135        Fut: Future<Output = R>,
136    {
137        // Assert that we're not currently inside the Tokio context.
138        if tokio::runtime::Handle::try_current().is_ok() {
139            vortex_error::vortex_panic!("block_on cannot be called from within a Tokio runtime");
140        }
141        let handle = self.0.clone();
142        tokio::task::block_in_place(move || handle.block_on(fut))
143    }
144
145    fn block_on_stream<'a, S, R>(&self, stream: S) -> Self::BlockingIterator<'a, R>
146    where
147        S: futures::Stream<Item = R> + Send + 'a,
148        R: Send + 'a,
149    {
150        // Assert that we're not currently inside the Tokio context.
151        if tokio::runtime::Handle::try_current().is_ok() {
152            vortex_error::vortex_panic!(
153                "block_on_stream cannot be called from within a Tokio runtime"
154            );
155        }
156        let handle = self.0.clone();
157        let stream = Box::pin(stream);
158        TokioBlockingIterator { handle, stream }
159    }
160}
161
162#[cfg(feature = "tokio")]
163pub struct TokioBlockingIterator<'a, T> {
164    handle: Arc<tokio::runtime::Handle>,
165    stream: futures::stream::BoxStream<'a, T>,
166}
167
168#[cfg(feature = "tokio")]
169impl<T> Iterator for TokioBlockingIterator<'_, T> {
170    type Item = T;
171
172    fn next(&mut self) -> Option<Self::Item> {
173        use futures::StreamExt;
174
175        tokio::task::block_in_place(|| self.handle.block_on(self.stream.next()))
176    }
177}
178
179#[cfg(test)]
180mod tests {
181    use std::sync::Arc;
182    use std::sync::atomic::AtomicUsize;
183    use std::sync::atomic::Ordering;
184
185    use futures::FutureExt;
186    use tokio::runtime::Runtime as TokioRt;
187
188    use super::*;
189
190    #[test]
191    fn test_spawn_simple_future() {
192        let tokio_rt = TokioRt::new().unwrap();
193        let runtime = TokioRuntime::from(tokio_rt.handle());
194        let h = runtime.handle();
195        let result = runtime.block_on({
196            h.spawn(async {
197                let fut = async { 77 };
198                fut.await
199            })
200        });
201        assert_eq!(result, 77);
202    }
203
204    #[test]
205    fn test_spawn_and_abort() {
206        let tokio_rt = TokioRt::new().unwrap();
207        let runtime = TokioRuntime::from(tokio_rt.handle());
208
209        let counter = Arc::new(AtomicUsize::new(0));
210        let c = counter.clone();
211
212        // Create a channel to ensure the future doesn't complete immediately
213        let (send, recv) = tokio::sync::oneshot::channel::<()>();
214
215        let fut = async move {
216            let _ = recv.await;
217            c.fetch_add(1, Ordering::SeqCst);
218        };
219        let task = runtime.handle().spawn(fut.boxed());
220        drop(task);
221
222        // Now we release the channel to let the future proceed if it wasn't aborted
223        let _ = send.send(());
224        assert_eq!(counter.load(Ordering::SeqCst), 0);
225    }
226}