Skip to main content

vortex_io/runtime/
single.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::rc::Rc;
5use std::sync::Arc;
6
7use futures::Stream;
8use futures::StreamExt;
9use futures::future::BoxFuture;
10use futures::stream::LocalBoxStream;
11use parking_lot::Mutex;
12use smol::LocalExecutor;
13use vortex_error::vortex_panic;
14
15use crate::runtime::AbortHandle;
16use crate::runtime::AbortHandleRef;
17use crate::runtime::BlockingRuntime;
18use crate::runtime::Executor;
19use crate::runtime::Handle;
20use crate::runtime::smol::SmolAbortHandle;
21
22/// A runtime that drives all work on the current thread.
23///
24/// This is subtly different from using a current-thread runtime to drive a future since it is
25/// capable of running `!Send` I/O futures.
26pub struct SingleThreadRuntime {
27    sender: Arc<Sender>,
28    executor: Rc<LocalExecutor<'static>>,
29}
30
31impl Default for SingleThreadRuntime {
32    fn default() -> Self {
33        let executor = Rc::new(LocalExecutor::new());
34        let sender = Arc::new(Sender::new(&executor));
35        Self { sender, executor }
36    }
37}
38
39struct Sender {
40    scheduling: kanal::Sender<SpawnAsync<'static>>,
41    cpu: kanal::Sender<SpawnSync<'static>>,
42    blocking: kanal::Sender<SpawnSync<'static>>,
43}
44
45impl Sender {
46    fn new(local: &Rc<LocalExecutor<'static>>) -> Self {
47        let (scheduling_send, scheduling_recv) = kanal::unbounded::<SpawnAsync>();
48        let (cpu_send, cpu_recv) = kanal::unbounded::<SpawnSync>();
49        let (blocking_send, blocking_recv) = kanal::unbounded::<SpawnSync>();
50
51        // We pass weak references to the local execution into the async tasks such that the task's
52        // reference doesn't keep the execution alive after the runtime is dropped.
53        let weak_local = Rc::downgrade(local);
54
55        // Drive scheduling tasks.
56        let weak_local2 = weak_local.clone();
57        local
58            .spawn(async move {
59                while let Ok(spawn) = scheduling_recv.as_async().recv().await {
60                    if let Some(local) = weak_local2.upgrade() {
61                        // Ignore send errors since it means the caller immediately detached.
62                        drop(
63                            spawn
64                                .task_callback
65                                .send(SmolAbortHandle::new_handle(local.spawn(spawn.future))),
66                        );
67                    }
68                }
69            })
70            .detach();
71
72        // Drive CPU tasks.
73        let weak_local2 = weak_local.clone();
74        local
75            .spawn(async move {
76                while let Ok(spawn) = cpu_recv.as_async().recv().await {
77                    if let Some(local) = weak_local2.upgrade() {
78                        let work = spawn.sync;
79                        // Ignore send errors since it means the caller immediately detached.
80                        drop(spawn.task_callback.send(SmolAbortHandle::new_handle(
81                            local.spawn(async move { work() }),
82                        )));
83                    }
84                }
85            })
86            .detach();
87
88        // Drive blocking tasks.
89        let weak_local2 = weak_local.clone();
90        local
91            .spawn(async move {
92                while let Ok(spawn) = blocking_recv.as_async().recv().await {
93                    if let Some(local) = weak_local2.upgrade() {
94                        let work = spawn.sync;
95                        // Ignore send errors since it means the caller immediately detached.
96                        drop(spawn.task_callback.send(SmolAbortHandle::new_handle(
97                            local.spawn(async move { work() }),
98                        )));
99                    }
100                }
101            })
102            .detach();
103
104        Self {
105            scheduling: scheduling_send,
106            cpu: cpu_send,
107            blocking: blocking_send,
108        }
109    }
110}
111
112/// Since the [`Handle`], and therefore runtime implementation needs to be `Send` and `Sync`,
113/// we cannot just `impl Runtime for LocalExecutor`. Instead, we create channels that the handle
114/// can forward its work into, and we drive the resulting tasks on a [`LocalExecutor`] on the
115/// calling thread.
116impl Executor for Sender {
117    fn spawn(&self, future: BoxFuture<'static, ()>) -> AbortHandleRef {
118        let (send, recv) = oneshot::channel();
119        if let Err(e) = self.scheduling.send(SpawnAsync {
120            future,
121            task_callback: send,
122        }) {
123            vortex_panic!("Executor missing: {}", e);
124        }
125        Box::new(LazyAbortHandle {
126            task: Mutex::new(recv),
127        })
128    }
129
130    fn spawn_cpu(&self, cpu: Box<dyn FnOnce() + Send + 'static>) -> AbortHandleRef {
131        let (send, recv) = oneshot::channel();
132        if let Err(e) = self.cpu.send(SpawnSync {
133            sync: cpu,
134            task_callback: send,
135        }) {
136            vortex_panic!("Executor missing: {}", e);
137        }
138        Box::new(LazyAbortHandle {
139            task: Mutex::new(recv),
140        })
141    }
142
143    fn spawn_blocking(&self, work: Box<dyn FnOnce() + Send + 'static>) -> AbortHandleRef {
144        let (send, recv) = oneshot::channel();
145        if let Err(e) = self.blocking.send(SpawnSync {
146            sync: work,
147            task_callback: send,
148        }) {
149            vortex_panic!("Executor missing: {}", e);
150        }
151        Box::new(LazyAbortHandle {
152            task: Mutex::new(recv),
153        })
154    }
155}
156
157impl BlockingRuntime for SingleThreadRuntime {
158    type BlockingIterator<'a, R: 'a> = SingleThreadIterator<'a, R>;
159
160    fn handle(&self) -> Handle {
161        let executor: Arc<dyn Executor> = self.sender.clone();
162        Handle::new(Arc::downgrade(&executor))
163    }
164
165    fn block_on<Fut, R>(&self, fut: Fut) -> R
166    where
167        Fut: Future<Output = R>,
168    {
169        smol::block_on(self.executor.run(fut))
170    }
171
172    fn block_on_stream<'a, S, R>(&self, stream: S) -> Self::BlockingIterator<'a, R>
173    where
174        S: Stream<Item = R> + Send + 'a,
175        R: Send + 'a,
176    {
177        SingleThreadIterator {
178            executor: self.executor.clone(),
179            stream: stream.boxed_local(),
180        }
181    }
182}
183
184/// Runs a future to completion on the current thread until it completes.
185///
186/// The future is provided a [`Handle`] to the runtime so that it may spawn additional tasks
187/// to be executed concurrently.
188pub fn block_on<F, Fut, R>(f: F) -> R
189where
190    F: FnOnce(Handle) -> Fut,
191    Fut: Future<Output = R>,
192{
193    let runtime = SingleThreadRuntime::default();
194    let handle = runtime.handle();
195    runtime.block_on(f(handle))
196}
197
198/// Returns an iterator wrapper around a stream, blocking the current thread for each item.
199pub fn block_on_stream<'a, F, S, R>(f: F) -> SingleThreadIterator<'a, R>
200where
201    F: FnOnce(Handle) -> S,
202    S: Stream<Item = R> + Send + Unpin + 'a,
203    R: Send + 'a,
204{
205    let runtime = SingleThreadRuntime::default();
206    let handle = runtime.handle();
207    runtime.block_on_stream(f(handle))
208}
209
210/// A spawn request for a future.
211///
212/// We pass back the abort handle via oneshot channel because this is a single-threaded runtime,
213/// meaning we need the spawning channel consumer to do some work before the caller can actually
214/// get ahold of their task handle.
215///
216/// The reason we don't pass back a smol::Task, and instead pass back a SmolAbortHandle, is because
217/// we invert the behaviour of abort and drop. Dropping the abort handle results in the task being
218/// detached, whereas dropping the smol::Task results in the task being canceled. This helps avoid
219/// a race where the caller detaches the LazyAbortHandle before the smol::Task has been launched.
220struct SpawnAsync<'rt> {
221    future: BoxFuture<'rt, ()>,
222    task_callback: oneshot::Sender<AbortHandleRef>,
223}
224
225// A spawn request for a synchronous job.
226struct SpawnSync<'rt> {
227    sync: Box<dyn FnOnce() + Send + 'rt>,
228    task_callback: oneshot::Sender<AbortHandleRef>,
229}
230
231struct LazyAbortHandle {
232    task: Mutex<oneshot::Receiver<AbortHandleRef>>,
233}
234
235impl AbortHandle for LazyAbortHandle {
236    fn abort(self: Box<Self>) {
237        // Aborting a smol::Task is done by dropping it.
238        if let Ok(task) = self.task.lock().try_recv() {
239            task.abort()
240        }
241    }
242}
243
244/// A stream that wraps up the stream with the execution that drives it.
245pub struct SingleThreadIterator<'a, T> {
246    executor: Rc<LocalExecutor<'static>>,
247    stream: LocalBoxStream<'a, T>,
248}
249
250impl<T> Iterator for SingleThreadIterator<'_, T> {
251    type Item = T;
252
253    fn next(&mut self) -> Option<Self::Item> {
254        let fut = self.stream.next();
255        smol::block_on(self.executor.run(fut))
256    }
257}
258
259#[cfg(test)]
260mod tests {
261    use std::sync::Arc;
262    use std::sync::atomic::AtomicUsize;
263    use std::sync::atomic::Ordering;
264
265    use futures::FutureExt;
266
267    use crate::runtime::BlockingRuntime;
268    use crate::runtime::single::SingleThreadRuntime;
269    use crate::runtime::single::block_on;
270
271    #[test]
272    fn test_drive_simple_future() {
273        let result = SingleThreadRuntime::default().block_on(async { 123 }.boxed_local());
274        assert_eq!(result, 123);
275    }
276
277    #[test]
278    fn test_spawn_cpu_task() {
279        let counter = Arc::new(AtomicUsize::new(0));
280        let c = counter.clone();
281
282        block_on(|handle| async move {
283            handle
284                .spawn_cpu(move || {
285                    c.fetch_add(1, Ordering::SeqCst);
286                })
287                .await
288        });
289
290        assert_eq!(counter.load(Ordering::SeqCst), 1);
291    }
292}