vortex_io/runtime/
tokio.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::sync::{Arc, LazyLock};
5
6use futures::future::BoxFuture;
7use tracing::Instrument;
8
9use crate::runtime::{AbortHandle, AbortHandleRef, BlockingRuntime, Executor, Handle, IoTask};
10
11/// A Vortex runtime that drives all work the enclosed Tokio runtime handle.
12pub struct TokioRuntime(Arc<tokio::runtime::Handle>);
13
14impl TokioRuntime {
15    pub fn new(handle: tokio::runtime::Handle) -> Self {
16        Self(Arc::new(handle))
17    }
18
19    /// Create a new [`Handle`] that always uses the currently scoped Tokio runtime at the time
20    /// each operation is invoked.
21    pub fn current() -> Handle {
22        static CURRENT: LazyLock<Arc<dyn Executor>> =
23            LazyLock::new(|| Arc::new(CurrentTokioRuntime));
24        Handle::new(Arc::downgrade(&CURRENT))
25    }
26}
27
28impl From<&tokio::runtime::Handle> for TokioRuntime {
29    fn from(value: &tokio::runtime::Handle) -> Self {
30        Self::from(value.clone())
31    }
32}
33
34impl From<tokio::runtime::Handle> for TokioRuntime {
35    fn from(value: tokio::runtime::Handle) -> Self {
36        TokioRuntime(Arc::new(value))
37    }
38}
39
40impl Executor for tokio::runtime::Handle {
41    fn spawn(&self, fut: BoxFuture<'static, ()>) -> AbortHandleRef {
42        Box::new(tokio::runtime::Handle::spawn(self, fut).abort_handle())
43    }
44
45    fn spawn_cpu(&self, cpu: Box<dyn FnOnce() + Send + 'static>) -> AbortHandleRef {
46        Box::new(tokio::runtime::Handle::spawn(self, async move { cpu() }).abort_handle())
47    }
48
49    fn spawn_blocking(&self, task: Box<dyn FnOnce() + Send + 'static>) -> AbortHandleRef {
50        Box::new(tokio::runtime::Handle::spawn_blocking(self, task).abort_handle())
51    }
52
53    fn spawn_io(&self, task: IoTask) {
54        tokio::runtime::Handle::spawn(self, task.source.drive_send(task.stream).in_current_span());
55    }
56}
57
58/// A runtime implementation that grabs the current Tokio runtime handle on each call.
59struct CurrentTokioRuntime;
60
61impl Executor for CurrentTokioRuntime {
62    fn spawn(&self, fut: BoxFuture<'static, ()>) -> AbortHandleRef {
63        Box::new(tokio::runtime::Handle::current().spawn(fut).abort_handle())
64    }
65
66    fn spawn_cpu(&self, cpu: Box<dyn FnOnce() + Send + 'static>) -> AbortHandleRef {
67        Box::new(
68            tokio::runtime::Handle::current()
69                .spawn(async move { cpu() })
70                .abort_handle(),
71        )
72    }
73
74    fn spawn_blocking(&self, task: Box<dyn FnOnce() + Send + 'static>) -> AbortHandleRef {
75        Box::new(
76            tokio::runtime::Handle::current()
77                .spawn_blocking(task)
78                .abort_handle(),
79        )
80    }
81
82    fn spawn_io(&self, task: IoTask) {
83        tokio::runtime::Handle::current()
84            .spawn(task.source.drive_send(task.stream).in_current_span());
85    }
86}
87
88impl AbortHandle for tokio::task::AbortHandle {
89    fn abort(self: Box<Self>) {
90        tokio::task::AbortHandle::abort(&self)
91    }
92}
93
94// We depend on Tokio's rt-multi-thread feature for block-in-place
95impl BlockingRuntime for TokioRuntime {
96    type BlockingIterator<'a, R: 'a> = TokioBlockingIterator<'a, R>;
97
98    fn handle(&self) -> Handle {
99        let executor: Arc<dyn Executor> = self.0.clone();
100        Handle::new(Arc::downgrade(&executor))
101    }
102
103    fn block_on<Fut, R>(&self, fut: Fut) -> R
104    where
105        Fut: Future<Output = R>,
106    {
107        // Assert that we're not currently inside the Tokio context.
108        if tokio::runtime::Handle::try_current().is_ok() {
109            vortex_error::vortex_panic!("block_on cannot be called from within a Tokio runtime");
110        }
111        let handle = self.0.clone();
112        tokio::task::block_in_place(move || handle.block_on(fut))
113    }
114
115    fn block_on_stream<'a, S, R>(&self, stream: S) -> Self::BlockingIterator<'a, R>
116    where
117        S: futures::Stream<Item = R> + Send + 'a,
118        R: Send + 'a,
119    {
120        // Assert that we're not currently inside the Tokio context.
121        if tokio::runtime::Handle::try_current().is_ok() {
122            vortex_error::vortex_panic!(
123                "block_on_stream cannot be called from within a Tokio runtime"
124            );
125        }
126        let handle = self.0.clone();
127        let stream = Box::pin(stream);
128        TokioBlockingIterator { handle, stream }
129    }
130}
131
132#[cfg(feature = "tokio")]
133pub struct TokioBlockingIterator<'a, T> {
134    handle: Arc<tokio::runtime::Handle>,
135    stream: futures::stream::BoxStream<'a, T>,
136}
137
138#[cfg(feature = "tokio")]
139impl<T> Iterator for TokioBlockingIterator<'_, T> {
140    type Item = T;
141
142    fn next(&mut self) -> Option<Self::Item> {
143        use futures::StreamExt;
144
145        tokio::task::block_in_place(|| self.handle.block_on(self.stream.next()))
146    }
147}
148
149#[cfg(test)]
150mod tests {
151    use std::sync::Arc;
152    use std::sync::atomic::{AtomicUsize, Ordering};
153
154    use futures::FutureExt;
155    use tokio::runtime::Runtime as TokioRt;
156
157    use super::*;
158
159    #[test]
160    fn test_spawn_simple_future() {
161        let tokio_rt = TokioRt::new().unwrap();
162        let runtime = TokioRuntime::from(tokio_rt.handle());
163        let h = runtime.handle();
164        let result = runtime.block_on({
165            h.spawn(async {
166                let fut = async { 77 };
167                fut.await
168            })
169        });
170        assert_eq!(result, 77);
171    }
172
173    #[test]
174    fn test_spawn_and_abort() {
175        let tokio_rt = TokioRt::new().unwrap();
176        let runtime = TokioRuntime::from(tokio_rt.handle());
177
178        let counter = Arc::new(AtomicUsize::new(0));
179        let c = counter.clone();
180
181        // Create a channel to ensure the future doesn't complete immediately
182        let (send, recv) = tokio::sync::oneshot::channel::<()>();
183
184        let fut = async move {
185            let _ = recv.await;
186            c.fetch_add(1, Ordering::SeqCst);
187        };
188        let task = runtime.handle().spawn(fut.boxed());
189        drop(task);
190
191        // Now we release the channel to let the future proceed if it wasn't aborted
192        let _ = send.send(());
193        assert_eq!(counter.load(Ordering::SeqCst), 0);
194    }
195}