vortex_io/runtime/
tokio.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::sync::Arc;
5use std::sync::LazyLock;
6
7use futures::future::BoxFuture;
8use tracing::Instrument;
9
10use crate::runtime::AbortHandle;
11use crate::runtime::AbortHandleRef;
12use crate::runtime::BlockingRuntime;
13use crate::runtime::Executor;
14use crate::runtime::Handle;
15use crate::runtime::IoTask;
16
17/// A Vortex runtime that drives all work the enclosed Tokio runtime handle.
18pub struct TokioRuntime(Arc<tokio::runtime::Handle>);
19
20impl TokioRuntime {
21    pub fn new(handle: tokio::runtime::Handle) -> Self {
22        Self(Arc::new(handle))
23    }
24
25    /// Create a new [`Handle`] that always uses the currently scoped Tokio runtime at the time
26    /// each operation is invoked.
27    pub fn current() -> Handle {
28        static CURRENT: LazyLock<Arc<dyn Executor>> =
29            LazyLock::new(|| Arc::new(CurrentTokioRuntime));
30        Handle::new(Arc::downgrade(&CURRENT))
31    }
32}
33
34impl From<&tokio::runtime::Handle> for TokioRuntime {
35    fn from(value: &tokio::runtime::Handle) -> Self {
36        Self::from(value.clone())
37    }
38}
39
40impl From<tokio::runtime::Handle> for TokioRuntime {
41    fn from(value: tokio::runtime::Handle) -> Self {
42        TokioRuntime(Arc::new(value))
43    }
44}
45
46impl Executor for tokio::runtime::Handle {
47    fn spawn(&self, fut: BoxFuture<'static, ()>) -> AbortHandleRef {
48        Box::new(tokio::runtime::Handle::spawn(self, fut).abort_handle())
49    }
50
51    fn spawn_cpu(&self, cpu: Box<dyn FnOnce() + Send + 'static>) -> AbortHandleRef {
52        Box::new(tokio::runtime::Handle::spawn(self, async move { cpu() }).abort_handle())
53    }
54
55    fn spawn_blocking(&self, task: Box<dyn FnOnce() + Send + 'static>) -> AbortHandleRef {
56        Box::new(tokio::runtime::Handle::spawn_blocking(self, task).abort_handle())
57    }
58
59    fn spawn_io(&self, task: IoTask) {
60        tokio::runtime::Handle::spawn(self, task.source.drive_send(task.stream).in_current_span());
61    }
62}
63
64/// A runtime implementation that grabs the current Tokio runtime handle on each call.
65struct CurrentTokioRuntime;
66
67impl Executor for CurrentTokioRuntime {
68    fn spawn(&self, fut: BoxFuture<'static, ()>) -> AbortHandleRef {
69        Box::new(tokio::runtime::Handle::current().spawn(fut).abort_handle())
70    }
71
72    fn spawn_cpu(&self, cpu: Box<dyn FnOnce() + Send + 'static>) -> AbortHandleRef {
73        Box::new(
74            tokio::runtime::Handle::current()
75                .spawn(async move { cpu() })
76                .abort_handle(),
77        )
78    }
79
80    fn spawn_blocking(&self, task: Box<dyn FnOnce() + Send + 'static>) -> AbortHandleRef {
81        Box::new(
82            tokio::runtime::Handle::current()
83                .spawn_blocking(task)
84                .abort_handle(),
85        )
86    }
87
88    fn spawn_io(&self, task: IoTask) {
89        tokio::runtime::Handle::current()
90            .spawn(task.source.drive_send(task.stream).in_current_span());
91    }
92}
93
94impl AbortHandle for tokio::task::AbortHandle {
95    fn abort(self: Box<Self>) {
96        tokio::task::AbortHandle::abort(&self)
97    }
98}
99
100// We depend on Tokio's rt-multi-thread feature for block-in-place
101impl BlockingRuntime for TokioRuntime {
102    type BlockingIterator<'a, R: 'a> = TokioBlockingIterator<'a, R>;
103
104    fn handle(&self) -> Handle {
105        let executor: Arc<dyn Executor> = self.0.clone();
106        Handle::new(Arc::downgrade(&executor))
107    }
108
109    fn block_on<Fut, R>(&self, fut: Fut) -> R
110    where
111        Fut: Future<Output = R>,
112    {
113        // Assert that we're not currently inside the Tokio context.
114        if tokio::runtime::Handle::try_current().is_ok() {
115            vortex_error::vortex_panic!("block_on cannot be called from within a Tokio runtime");
116        }
117        let handle = self.0.clone();
118        tokio::task::block_in_place(move || handle.block_on(fut))
119    }
120
121    fn block_on_stream<'a, S, R>(&self, stream: S) -> Self::BlockingIterator<'a, R>
122    where
123        S: futures::Stream<Item = R> + Send + 'a,
124        R: Send + 'a,
125    {
126        // Assert that we're not currently inside the Tokio context.
127        if tokio::runtime::Handle::try_current().is_ok() {
128            vortex_error::vortex_panic!(
129                "block_on_stream cannot be called from within a Tokio runtime"
130            );
131        }
132        let handle = self.0.clone();
133        let stream = Box::pin(stream);
134        TokioBlockingIterator { handle, stream }
135    }
136}
137
138#[cfg(feature = "tokio")]
139pub struct TokioBlockingIterator<'a, T> {
140    handle: Arc<tokio::runtime::Handle>,
141    stream: futures::stream::BoxStream<'a, T>,
142}
143
144#[cfg(feature = "tokio")]
145impl<T> Iterator for TokioBlockingIterator<'_, T> {
146    type Item = T;
147
148    fn next(&mut self) -> Option<Self::Item> {
149        use futures::StreamExt;
150
151        tokio::task::block_in_place(|| self.handle.block_on(self.stream.next()))
152    }
153}
154
155#[cfg(test)]
156mod tests {
157    use std::sync::Arc;
158    use std::sync::atomic::AtomicUsize;
159    use std::sync::atomic::Ordering;
160
161    use futures::FutureExt;
162    use tokio::runtime::Runtime as TokioRt;
163
164    use super::*;
165
166    #[test]
167    fn test_spawn_simple_future() {
168        let tokio_rt = TokioRt::new().unwrap();
169        let runtime = TokioRuntime::from(tokio_rt.handle());
170        let h = runtime.handle();
171        let result = runtime.block_on({
172            h.spawn(async {
173                let fut = async { 77 };
174                fut.await
175            })
176        });
177        assert_eq!(result, 77);
178    }
179
180    #[test]
181    fn test_spawn_and_abort() {
182        let tokio_rt = TokioRt::new().unwrap();
183        let runtime = TokioRuntime::from(tokio_rt.handle());
184
185        let counter = Arc::new(AtomicUsize::new(0));
186        let c = counter.clone();
187
188        // Create a channel to ensure the future doesn't complete immediately
189        let (send, recv) = tokio::sync::oneshot::channel::<()>();
190
191        let fut = async move {
192            let _ = recv.await;
193            c.fetch_add(1, Ordering::SeqCst);
194        };
195        let task = runtime.handle().spawn(fut.boxed());
196        drop(task);
197
198        // Now we release the channel to let the future proceed if it wasn't aborted
199        let _ = send.send(());
200        assert_eq!(counter.load(Ordering::SeqCst), 0);
201    }
202}