vortex_io/runtime/
tokio.rs1use std::sync::{Arc, LazyLock};
5
6use futures::future::BoxFuture;
7use tracing::Instrument;
8
9use crate::runtime::{AbortHandle, AbortHandleRef, BlockingRuntime, Executor, Handle, IoTask};
10
11pub struct TokioRuntime(Arc<tokio::runtime::Handle>);
13
14impl TokioRuntime {
15 pub fn current() -> Handle {
18 static CURRENT: LazyLock<Arc<dyn Executor>> =
19 LazyLock::new(|| Arc::new(CurrentTokioRuntime));
20 Handle::new(Arc::downgrade(&CURRENT))
21 }
22}
23
24impl From<&tokio::runtime::Handle> for TokioRuntime {
25 fn from(value: &tokio::runtime::Handle) -> Self {
26 Self::from(value.clone())
27 }
28}
29
30impl From<tokio::runtime::Handle> for TokioRuntime {
31 fn from(value: tokio::runtime::Handle) -> Self {
32 TokioRuntime(Arc::new(value))
33 }
34}
35
36impl Executor for tokio::runtime::Handle {
37 fn spawn(&self, fut: BoxFuture<'static, ()>) -> AbortHandleRef {
38 Box::new(tokio::runtime::Handle::spawn(self, fut).abort_handle())
39 }
40
41 fn spawn_cpu(&self, cpu: Box<dyn FnOnce() + Send + 'static>) -> AbortHandleRef {
42 Box::new(tokio::runtime::Handle::spawn(self, async move { cpu() }).abort_handle())
43 }
44
45 fn spawn_blocking(&self, task: Box<dyn FnOnce() + Send + 'static>) -> AbortHandleRef {
46 Box::new(tokio::runtime::Handle::spawn_blocking(self, task).abort_handle())
47 }
48
49 fn spawn_io(&self, task: IoTask) {
50 tokio::runtime::Handle::spawn(self, task.source.drive_send(task.stream).in_current_span());
51 }
52}
53
54struct CurrentTokioRuntime;
56
57impl Executor for CurrentTokioRuntime {
58 fn spawn(&self, fut: BoxFuture<'static, ()>) -> AbortHandleRef {
59 Box::new(tokio::runtime::Handle::current().spawn(fut).abort_handle())
60 }
61
62 fn spawn_cpu(&self, cpu: Box<dyn FnOnce() + Send + 'static>) -> AbortHandleRef {
63 Box::new(
64 tokio::runtime::Handle::current()
65 .spawn(async move { cpu() })
66 .abort_handle(),
67 )
68 }
69
70 fn spawn_blocking(&self, task: Box<dyn FnOnce() + Send + 'static>) -> AbortHandleRef {
71 Box::new(
72 tokio::runtime::Handle::current()
73 .spawn_blocking(task)
74 .abort_handle(),
75 )
76 }
77
78 fn spawn_io(&self, task: IoTask) {
79 tokio::runtime::Handle::current()
80 .spawn(task.source.drive_send(task.stream).in_current_span());
81 }
82}
83
84impl AbortHandle for tokio::task::AbortHandle {
85 fn abort(self: Box<Self>) {
86 tokio::task::AbortHandle::abort(&self)
87 }
88}
89
90impl BlockingRuntime for TokioRuntime {
92 type BlockingIterator<'a, R: 'a> = TokioBlockingIterator<'a, R>;
93
94 fn handle(&self) -> Handle {
95 let executor: Arc<dyn Executor> = self.0.clone();
96 Handle::new(Arc::downgrade(&executor))
97 }
98
99 fn block_on<F, Fut, R>(&self, f: F) -> R
100 where
101 F: FnOnce(Handle) -> Fut,
102 Fut: Future<Output = R>,
103 {
104 if tokio::runtime::Handle::try_current().is_ok() {
106 vortex_error::vortex_panic!("block_on cannot be called from within a Tokio runtime");
107 }
108 let handle = self.0.clone();
109 let fut = f(self.handle());
110 tokio::task::block_in_place(move || handle.block_on(fut))
111 }
112
113 fn block_on_stream<'a, F, S, R>(&self, f: F) -> Self::BlockingIterator<'a, R>
114 where
115 F: FnOnce(Handle) -> S,
116 S: futures::Stream<Item = R> + Send + 'a,
117 R: Send + 'a,
118 {
119 if tokio::runtime::Handle::try_current().is_ok() {
121 vortex_error::vortex_panic!(
122 "block_on_stream cannot be called from within a Tokio runtime"
123 );
124 }
125 let handle = self.0.clone();
126 let stream = Box::pin(f(self.handle()));
127 TokioBlockingIterator { handle, stream }
128 }
129}
130
131#[cfg(feature = "tokio")]
132pub struct TokioBlockingIterator<'a, T> {
133 handle: Arc<tokio::runtime::Handle>,
134 stream: futures::stream::BoxStream<'a, T>,
135}
136
137#[cfg(feature = "tokio")]
138impl<T> Iterator for TokioBlockingIterator<'_, T> {
139 type Item = T;
140
141 fn next(&mut self) -> Option<Self::Item> {
142 use futures::StreamExt;
143
144 tokio::task::block_in_place(|| self.handle.block_on(self.stream.next()))
145 }
146}
147
148#[cfg(test)]
149mod tests {
150 use std::sync::Arc;
151 use std::sync::atomic::{AtomicUsize, Ordering};
152
153 use futures::FutureExt;
154 use tokio::runtime::Runtime as TokioRt;
155
156 use super::*;
157
158 #[test]
159 fn test_spawn_simple_future() {
160 let tokio_rt = TokioRt::new().unwrap();
161 let runtime = TokioRuntime::from(tokio_rt.handle());
162 let result = runtime.block_on(|h| {
163 h.spawn(async {
164 let fut = async { 77 };
165 fut.await
166 })
167 });
168 assert_eq!(result, 77);
169 }
170
171 #[test]
172 fn test_spawn_and_abort() {
173 let tokio_rt = TokioRt::new().unwrap();
174 let runtime = TokioRuntime::from(tokio_rt.handle());
175
176 let counter = Arc::new(AtomicUsize::new(0));
177 let c = counter.clone();
178
179 let (send, recv) = tokio::sync::oneshot::channel::<()>();
181
182 let fut = async move {
183 let _ = recv.await;
184 c.fetch_add(1, Ordering::SeqCst);
185 };
186 let task = runtime.handle().spawn(fut.boxed());
187 drop(task);
188
189 let _ = send.send(());
191 assert_eq!(counter.load(Ordering::SeqCst), 0);
192 }
193}