vortex_io/runtime/
tokio.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::sync::{Arc, LazyLock};
5
6use futures::future::BoxFuture;
7use tracing::Instrument;
8
9use crate::runtime::{AbortHandle, AbortHandleRef, BlockingRuntime, Executor, Handle, IoTask};
10
11/// A Vortex runtime that drives all work the enclosed Tokio runtime handle.
12pub struct TokioRuntime(Arc<tokio::runtime::Handle>);
13
14impl TokioRuntime {
15    /// Create a new [`Handle`] that always uses the currently scoped Tokio runtime at the time
16    /// each operation is invoked.
17    pub fn current() -> Handle {
18        static CURRENT: LazyLock<Arc<dyn Executor>> =
19            LazyLock::new(|| Arc::new(CurrentTokioRuntime));
20        Handle::new(Arc::downgrade(&CURRENT))
21    }
22}
23
24impl From<&tokio::runtime::Handle> for TokioRuntime {
25    fn from(value: &tokio::runtime::Handle) -> Self {
26        Self::from(value.clone())
27    }
28}
29
30impl From<tokio::runtime::Handle> for TokioRuntime {
31    fn from(value: tokio::runtime::Handle) -> Self {
32        TokioRuntime(Arc::new(value))
33    }
34}
35
36impl Executor for tokio::runtime::Handle {
37    fn spawn(&self, fut: BoxFuture<'static, ()>) -> AbortHandleRef {
38        Box::new(tokio::runtime::Handle::spawn(self, fut).abort_handle())
39    }
40
41    fn spawn_cpu(&self, cpu: Box<dyn FnOnce() + Send + 'static>) -> AbortHandleRef {
42        Box::new(tokio::runtime::Handle::spawn(self, async move { cpu() }).abort_handle())
43    }
44
45    fn spawn_blocking(&self, task: Box<dyn FnOnce() + Send + 'static>) -> AbortHandleRef {
46        Box::new(tokio::runtime::Handle::spawn_blocking(self, task).abort_handle())
47    }
48
49    fn spawn_io(&self, task: IoTask) {
50        tokio::runtime::Handle::spawn(self, task.source.drive_send(task.stream).in_current_span());
51    }
52}
53
54/// A runtime implementation that grabs the current Tokio runtime handle on each call.
55struct CurrentTokioRuntime;
56
57impl Executor for CurrentTokioRuntime {
58    fn spawn(&self, fut: BoxFuture<'static, ()>) -> AbortHandleRef {
59        Box::new(tokio::runtime::Handle::current().spawn(fut).abort_handle())
60    }
61
62    fn spawn_cpu(&self, cpu: Box<dyn FnOnce() + Send + 'static>) -> AbortHandleRef {
63        Box::new(
64            tokio::runtime::Handle::current()
65                .spawn(async move { cpu() })
66                .abort_handle(),
67        )
68    }
69
70    fn spawn_blocking(&self, task: Box<dyn FnOnce() + Send + 'static>) -> AbortHandleRef {
71        Box::new(
72            tokio::runtime::Handle::current()
73                .spawn_blocking(task)
74                .abort_handle(),
75        )
76    }
77
78    fn spawn_io(&self, task: IoTask) {
79        tokio::runtime::Handle::current()
80            .spawn(task.source.drive_send(task.stream).in_current_span());
81    }
82}
83
84impl AbortHandle for tokio::task::AbortHandle {
85    fn abort(self: Box<Self>) {
86        tokio::task::AbortHandle::abort(&self)
87    }
88}
89
90// We depend on Tokio's rt-multi-thread feature for block-in-place
91impl BlockingRuntime for TokioRuntime {
92    type BlockingIterator<'a, R: 'a> = TokioBlockingIterator<'a, R>;
93
94    fn handle(&self) -> Handle {
95        let executor: Arc<dyn Executor> = self.0.clone();
96        Handle::new(Arc::downgrade(&executor))
97    }
98
99    fn block_on<F, Fut, R>(&self, f: F) -> R
100    where
101        F: FnOnce(Handle) -> Fut,
102        Fut: Future<Output = R>,
103    {
104        // Assert that we're not currently inside the Tokio context.
105        if tokio::runtime::Handle::try_current().is_ok() {
106            vortex_error::vortex_panic!("block_on cannot be called from within a Tokio runtime");
107        }
108        let handle = self.0.clone();
109        let fut = f(self.handle());
110        tokio::task::block_in_place(move || handle.block_on(fut))
111    }
112
113    fn block_on_stream<'a, F, S, R>(&self, f: F) -> Self::BlockingIterator<'a, R>
114    where
115        F: FnOnce(Handle) -> S,
116        S: futures::Stream<Item = R> + Send + 'a,
117        R: Send + 'a,
118    {
119        // Assert that we're not currently inside the Tokio context.
120        if tokio::runtime::Handle::try_current().is_ok() {
121            vortex_error::vortex_panic!(
122                "block_on_stream cannot be called from within a Tokio runtime"
123            );
124        }
125        let handle = self.0.clone();
126        let stream = Box::pin(f(self.handle()));
127        TokioBlockingIterator { handle, stream }
128    }
129}
130
131#[cfg(feature = "tokio")]
132pub struct TokioBlockingIterator<'a, T> {
133    handle: Arc<tokio::runtime::Handle>,
134    stream: futures::stream::BoxStream<'a, T>,
135}
136
137#[cfg(feature = "tokio")]
138impl<T> Iterator for TokioBlockingIterator<'_, T> {
139    type Item = T;
140
141    fn next(&mut self) -> Option<Self::Item> {
142        use futures::StreamExt;
143
144        tokio::task::block_in_place(|| self.handle.block_on(self.stream.next()))
145    }
146}
147
148#[cfg(test)]
149mod tests {
150    use std::sync::Arc;
151    use std::sync::atomic::{AtomicUsize, Ordering};
152
153    use futures::FutureExt;
154    use tokio::runtime::Runtime as TokioRt;
155
156    use super::*;
157
158    #[test]
159    fn test_spawn_simple_future() {
160        let tokio_rt = TokioRt::new().unwrap();
161        let runtime = TokioRuntime::from(tokio_rt.handle());
162        let result = runtime.block_on(|h| {
163            h.spawn(async {
164                let fut = async { 77 };
165                fut.await
166            })
167        });
168        assert_eq!(result, 77);
169    }
170
171    #[test]
172    fn test_spawn_and_abort() {
173        let tokio_rt = TokioRt::new().unwrap();
174        let runtime = TokioRuntime::from(tokio_rt.handle());
175
176        let counter = Arc::new(AtomicUsize::new(0));
177        let c = counter.clone();
178
179        // Create a channel to ensure the future doesn't complete immediately
180        let (send, recv) = tokio::sync::oneshot::channel::<()>();
181
182        let fut = async move {
183            let _ = recv.await;
184            c.fetch_add(1, Ordering::SeqCst);
185        };
186        let task = runtime.handle().spawn(fut.boxed());
187        drop(task);
188
189        // Now we release the channel to let the future proceed if it wasn't aborted
190        let _ = send.send(());
191        assert_eq!(counter.load(Ordering::SeqCst), 0);
192    }
193}