vortex_io/runtime/
single.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::rc::Rc;
5use std::sync::Arc;
6
7use futures::future::BoxFuture;
8use futures::stream::LocalBoxStream;
9use futures::{Stream, StreamExt};
10use parking_lot::Mutex;
11use smol::LocalExecutor;
12use vortex_error::vortex_panic;
13
14use crate::runtime::smol::SmolAbortHandle;
15use crate::runtime::{AbortHandle, AbortHandleRef, BlockingRuntime, Executor, Handle, IoTask};
16
17/// A runtime that drives all work on the current thread.
18///
19/// This is subtly different from using a current-thread runtime to drive a future since it is
20/// capable of running `!Send` I/O futures.
21pub struct SingleThreadRuntime {
22    sender: Arc<Sender>,
23    executor: Rc<LocalExecutor<'static>>,
24}
25
26impl Default for SingleThreadRuntime {
27    fn default() -> Self {
28        let executor = Rc::new(LocalExecutor::new());
29        let sender = Arc::new(Sender::new(&executor));
30        Self { sender, executor }
31    }
32}
33
34struct Sender {
35    scheduling: kanal::Sender<SpawnAsync<'static>>,
36    cpu: kanal::Sender<SpawnSync<'static>>,
37    blocking: kanal::Sender<SpawnSync<'static>>,
38    io: kanal::Sender<IoTask>,
39}
40
41impl Sender {
42    fn new(local: &Rc<LocalExecutor<'static>>) -> Self {
43        let (scheduling_send, scheduling_recv) = kanal::unbounded::<SpawnAsync>();
44        let (cpu_send, cpu_recv) = kanal::unbounded::<SpawnSync>();
45        let (blocking_send, blocking_recv) = kanal::unbounded::<SpawnSync>();
46        let (io_send, io_recv) = kanal::unbounded::<IoTask>();
47
48        // We pass weak references to the local executor into the async tasks such that the task's
49        // reference doesn't keep the executor alive after the runtime is dropped.
50        let weak_local = Rc::downgrade(local);
51
52        // Drive scheduling tasks.
53        let weak_local2 = weak_local.clone();
54        local
55            .spawn(async move {
56                while let Ok(spawn) = scheduling_recv.as_async().recv().await {
57                    if let Some(local) = weak_local2.upgrade() {
58                        // Ignore send errors since it means the caller immediately detached.
59                        let _ = spawn
60                            .task_callback
61                            .send(SmolAbortHandle::new_handle(local.spawn(spawn.future)));
62                    }
63                }
64            })
65            .detach();
66
67        // Drive CPU tasks.
68        let weak_local2 = weak_local.clone();
69        local
70            .spawn(async move {
71                while let Ok(spawn) = cpu_recv.as_async().recv().await {
72                    if let Some(local) = weak_local2.upgrade() {
73                        let work = spawn.sync;
74                        // Ignore send errors since it means the caller immediately detached.
75                        let _ = spawn.task_callback.send(SmolAbortHandle::new_handle(
76                            local.spawn(async move { work() }),
77                        ));
78                    }
79                }
80            })
81            .detach();
82
83        // Drive blocking tasks.
84        let weak_local2 = weak_local.clone();
85        local
86            .spawn(async move {
87                while let Ok(spawn) = blocking_recv.as_async().recv().await {
88                    if let Some(local) = weak_local2.upgrade() {
89                        let work = spawn.sync;
90                        // Ignore send errors since it means the caller immediately detached.
91                        let _ = spawn.task_callback.send(SmolAbortHandle::new_handle(
92                            local.spawn(async move { work() }),
93                        ));
94                    }
95                }
96            })
97            .detach();
98
99        // Drive I/O tasks.
100        let weak_local2 = weak_local;
101        local
102            .spawn(async move {
103                while let Ok(task) = io_recv.as_async().recv().await {
104                    if let Some(local) = weak_local2.upgrade() {
105                        local.spawn(task.source.drive_local(task.stream)).detach();
106                    }
107                }
108            })
109            .detach();
110
111        Self {
112            scheduling: scheduling_send,
113            cpu: cpu_send,
114            blocking: blocking_send,
115            io: io_send,
116        }
117    }
118}
119
120/// Since the [`Handle`], and therefore runtime implementation needs to be `Send` and `Sync`,
121/// we cannot just `impl Runtime for LocalExecutor`. Instead, we create channels that the handle
122/// can forward its work into, and we drive the resulting tasks on a [`LocalExecutor`] on the
123/// calling thread.
124impl Executor for Sender {
125    fn spawn(&self, future: BoxFuture<'static, ()>) -> AbortHandleRef {
126        let (send, recv) = oneshot::channel();
127        if let Err(e) = self.scheduling.send(SpawnAsync {
128            future,
129            task_callback: send,
130        }) {
131            vortex_panic!("Executor missing: {}", e);
132        }
133        Box::new(LazyAbortHandle {
134            task: Mutex::new(recv),
135        })
136    }
137
138    fn spawn_cpu(&self, cpu: Box<dyn FnOnce() + Send + 'static>) -> AbortHandleRef {
139        let (send, recv) = oneshot::channel();
140        if let Err(e) = self.cpu.send(SpawnSync {
141            sync: cpu,
142            task_callback: send,
143        }) {
144            vortex_panic!("Executor missing: {}", e);
145        }
146        Box::new(LazyAbortHandle {
147            task: Mutex::new(recv),
148        })
149    }
150
151    fn spawn_blocking(&self, work: Box<dyn FnOnce() + Send + 'static>) -> AbortHandleRef {
152        let (send, recv) = oneshot::channel();
153        if let Err(e) = self.blocking.send(SpawnSync {
154            sync: work,
155            task_callback: send,
156        }) {
157            vortex_panic!("Executor missing: {}", e);
158        }
159        Box::new(LazyAbortHandle {
160            task: Mutex::new(recv),
161        })
162    }
163
164    fn spawn_io(&self, task: IoTask) {
165        if let Err(e) = self.io.send(task) {
166            vortex_panic!("Executor missing: {}", e);
167        }
168    }
169}
170
171impl BlockingRuntime for SingleThreadRuntime {
172    type BlockingIterator<'a, R: 'a> = SingleThreadIterator<'a, R>;
173
174    fn handle(&self) -> Handle {
175        let executor: Arc<dyn Executor> = self.sender.clone();
176        Handle::new(Arc::downgrade(&executor))
177    }
178
179    fn block_on<F, Fut, R>(&self, f: F) -> R
180    where
181        F: FnOnce(Handle) -> Fut,
182        Fut: Future<Output = R>,
183    {
184        smol::block_on(self.executor.run(f(self.handle())))
185    }
186
187    fn block_on_stream<'a, F, S, R>(&self, f: F) -> Self::BlockingIterator<'a, R>
188    where
189        F: FnOnce(Handle) -> S,
190        S: Stream<Item = R> + Send + 'a,
191        R: Send + 'a,
192    {
193        SingleThreadIterator {
194            executor: self.executor.clone(),
195            stream: f(self.handle()).boxed_local(),
196        }
197    }
198}
199
200/// Runs a future to completion on the current thread until it completes.
201///
202/// The future is provided a [`Handle`] to the runtime so that it may spawn additional tasks
203/// to be executed concurrently.
204pub fn block_on<F, Fut, R>(f: F) -> R
205where
206    F: FnOnce(Handle) -> Fut,
207    Fut: Future<Output = R>,
208{
209    SingleThreadRuntime::default().block_on(f)
210}
211
212/// Returns an iterator wrapper around a stream, blocking the current thread for each item.
213pub fn block_on_stream<'a, F, S, R>(f: F) -> SingleThreadIterator<'a, R>
214where
215    F: FnOnce(Handle) -> S,
216    S: Stream<Item = R> + Send + Unpin + 'a,
217    R: Send + 'a,
218{
219    SingleThreadRuntime::default().block_on_stream(f)
220}
221
222/// A spawn request for a future.
223///
224/// We pass back the abort handle via oneshot channel because this is a single-threaded runtime,
225/// meaning we need the spawning channel consumer to do some work before the caller can actually
226/// get ahold of their task handle.
227///
228/// The reason we don't pass back a smol::Task, and instead pass back a SmolAbortHandle, is because
229/// we invert the behaviour of abort and drop. Dropping the abort handle results in the task being
230/// detached, whereas dropping the smol::Task results in the task being canceled. This helps avoid
231/// a race where the caller detaches the LazyAbortHandle before the smol::Task has been launched.
232struct SpawnAsync<'rt> {
233    future: BoxFuture<'rt, ()>,
234    task_callback: oneshot::Sender<AbortHandleRef>,
235}
236
237// A spawn request for a synchronous job.
238struct SpawnSync<'rt> {
239    sync: Box<dyn FnOnce() + Send + 'rt>,
240    task_callback: oneshot::Sender<AbortHandleRef>,
241}
242
243struct LazyAbortHandle {
244    task: Mutex<oneshot::Receiver<AbortHandleRef>>,
245}
246
247impl AbortHandle for LazyAbortHandle {
248    fn abort(self: Box<Self>) {
249        // Aborting a smol::Task is done by dropping it.
250        if let Ok(task) = self.task.lock().try_recv() {
251            task.abort()
252        }
253    }
254}
255
256/// A stream that wraps up the stream with the executor that drives it.
257pub struct SingleThreadIterator<'a, T> {
258    executor: Rc<LocalExecutor<'static>>,
259    stream: LocalBoxStream<'a, T>,
260}
261
262impl<T> Iterator for SingleThreadIterator<'_, T> {
263    type Item = T;
264
265    fn next(&mut self) -> Option<Self::Item> {
266        let fut = self.stream.next();
267        smol::block_on(self.executor.run(fut))
268    }
269}
270
271#[cfg(test)]
272mod tests {
273    use std::sync::Arc;
274    use std::sync::atomic::{AtomicUsize, Ordering};
275
276    use futures::FutureExt;
277
278    use crate::runtime::BlockingRuntime;
279    use crate::runtime::single::{SingleThreadRuntime, block_on};
280
281    #[test]
282    fn test_drive_simple_future() {
283        let result = SingleThreadRuntime::default().block_on(|_h| async { 123 }.boxed_local());
284        assert_eq!(result, 123);
285    }
286
287    #[test]
288    fn test_spawn_cpu_task() {
289        let counter = Arc::new(AtomicUsize::new(0));
290        let c = counter.clone();
291
292        block_on(|handle| async move {
293            handle
294                .spawn_cpu(move || {
295                    c.fetch_add(1, Ordering::SeqCst);
296                })
297                .await
298        });
299
300        assert_eq!(counter.load(Ordering::SeqCst), 1);
301    }
302}