Skip to main content

vortex_io/runtime/
single.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::rc::Rc;
5use std::rc::Weak as RcWeak;
6use std::sync::Arc;
7
8use futures::Stream;
9use futures::StreamExt;
10use futures::future::BoxFuture;
11use futures::stream::LocalBoxStream;
12use parking_lot::Mutex;
13use smol::LocalExecutor;
14use vortex_error::vortex_panic;
15
16use crate::runtime::AbortHandle;
17use crate::runtime::AbortHandleRef;
18use crate::runtime::BlockingRuntime;
19use crate::runtime::Executor;
20use crate::runtime::Handle;
21use crate::runtime::smol::SmolAbortHandle;
22
23/// A runtime that drives all work on the current thread.
24///
25/// This is subtly different from using a current-thread runtime to drive a future since it is
26/// capable of running `!Send` I/O futures.
27pub struct SingleThreadRuntime {
28    sender: Arc<Sender>,
29    executor: Rc<LocalExecutor<'static>>,
30}
31
32impl Default for SingleThreadRuntime {
33    fn default() -> Self {
34        let executor = Rc::new(LocalExecutor::new());
35        let sender = Arc::new(Sender::new(&executor));
36        Self { sender, executor }
37    }
38}
39
40struct Sender {
41    scheduling: kanal::Sender<SpawnAsync<'static>>,
42    cpu: kanal::Sender<SpawnSync<'static>>,
43    blocking: kanal::Sender<SpawnSync<'static>>,
44}
45
46impl Sender {
47    fn new(local: &Rc<LocalExecutor<'static>>) -> Self {
48        let (scheduling_send, scheduling_recv) = kanal::unbounded::<SpawnAsync>();
49        let (cpu_send, cpu_recv) = kanal::unbounded::<SpawnSync>();
50        let (blocking_send, blocking_recv) = kanal::unbounded::<SpawnSync>();
51
52        // We pass weak references to the local execution into the async tasks such that the task's
53        // reference doesn't keep the execution alive after the runtime is dropped.
54        let weak_local = Rc::downgrade(local);
55
56        // Drive scheduling tasks.
57        let weak_local2 = RcWeak::clone(&weak_local);
58        local
59            .spawn(async move {
60                while let Ok(spawn) = scheduling_recv.as_async().recv().await {
61                    if let Some(local) = weak_local2.upgrade() {
62                        // Ignore send errors since it means the caller immediately detached.
63                        drop(
64                            spawn
65                                .task_callback
66                                .send(SmolAbortHandle::new_handle(local.spawn(spawn.future))),
67                        );
68                    }
69                }
70            })
71            .detach();
72
73        // Drive CPU tasks.
74        let weak_local2 = RcWeak::clone(&weak_local);
75        local
76            .spawn(async move {
77                while let Ok(spawn) = cpu_recv.as_async().recv().await {
78                    if let Some(local) = weak_local2.upgrade() {
79                        let work = spawn.sync;
80                        // Ignore send errors since it means the caller immediately detached.
81                        drop(spawn.task_callback.send(SmolAbortHandle::new_handle(
82                            local.spawn(async move { work() }),
83                        )));
84                    }
85                }
86            })
87            .detach();
88
89        // Drive blocking tasks.
90        let weak_local2 = RcWeak::clone(&weak_local);
91        local
92            .spawn(async move {
93                while let Ok(spawn) = blocking_recv.as_async().recv().await {
94                    if let Some(local) = weak_local2.upgrade() {
95                        let work = spawn.sync;
96                        // Ignore send errors since it means the caller immediately detached.
97                        drop(spawn.task_callback.send(SmolAbortHandle::new_handle(
98                            local.spawn(async move { work() }),
99                        )));
100                    }
101                }
102            })
103            .detach();
104
105        Self {
106            scheduling: scheduling_send,
107            cpu: cpu_send,
108            blocking: blocking_send,
109        }
110    }
111}
112
113/// Since the [`Handle`], and therefore runtime implementation needs to be `Send` and `Sync`,
114/// we cannot just `impl Runtime for LocalExecutor`. Instead, we create channels that the handle
115/// can forward its work into, and we drive the resulting tasks on a [`LocalExecutor`] on the
116/// calling thread.
117impl Executor for Sender {
118    fn spawn(&self, future: BoxFuture<'static, ()>) -> AbortHandleRef {
119        let (send, recv) = oneshot::channel();
120        if let Err(e) = self.scheduling.send(SpawnAsync {
121            future,
122            task_callback: send,
123        }) {
124            vortex_panic!("Executor missing: {}", e);
125        }
126        Box::new(LazyAbortHandle {
127            task: Mutex::new(recv),
128        })
129    }
130
131    fn spawn_cpu(&self, cpu: Box<dyn FnOnce() + Send + 'static>) -> AbortHandleRef {
132        let (send, recv) = oneshot::channel();
133        if let Err(e) = self.cpu.send(SpawnSync {
134            sync: cpu,
135            task_callback: send,
136        }) {
137            vortex_panic!("Executor missing: {}", e);
138        }
139        Box::new(LazyAbortHandle {
140            task: Mutex::new(recv),
141        })
142    }
143
144    fn spawn_blocking_io(&self, work: Box<dyn FnOnce() + Send + 'static>) -> AbortHandleRef {
145        let (send, recv) = oneshot::channel();
146        if let Err(e) = self.blocking.send(SpawnSync {
147            sync: work,
148            task_callback: send,
149        }) {
150            vortex_panic!("Executor missing: {}", e);
151        }
152        Box::new(LazyAbortHandle {
153            task: Mutex::new(recv),
154        })
155    }
156}
157
158impl BlockingRuntime for SingleThreadRuntime {
159    type BlockingIterator<'a, R: 'a> = SingleThreadIterator<'a, R>;
160
161    fn handle(&self) -> Handle {
162        let executor: Arc<dyn Executor> = Arc::clone(&self.sender) as Arc<dyn Executor>;
163        Handle::new(Arc::downgrade(&executor))
164    }
165
166    fn block_on<Fut, R>(&self, fut: Fut) -> R
167    where
168        Fut: Future<Output = R>,
169    {
170        smol::block_on(self.executor.run(fut))
171    }
172
173    fn block_on_stream<'a, S, R>(&self, stream: S) -> Self::BlockingIterator<'a, R>
174    where
175        S: Stream<Item = R> + Send + 'a,
176        R: Send + 'a,
177    {
178        SingleThreadIterator {
179            executor: Rc::clone(&self.executor),
180            stream: stream.boxed_local(),
181        }
182    }
183}
184
185/// Runs a future to completion on the current thread until it completes.
186///
187/// The future is provided a [`Handle`] to the runtime so that it may spawn additional tasks
188/// to be executed concurrently.
189pub fn block_on<F, Fut, R>(f: F) -> R
190where
191    F: FnOnce(Handle) -> Fut,
192    Fut: Future<Output = R>,
193{
194    let runtime = SingleThreadRuntime::default();
195    let handle = runtime.handle();
196    runtime.block_on(f(handle))
197}
198
199/// Returns an iterator wrapper around a stream, blocking the current thread for each item.
200pub fn block_on_stream<'a, F, S, R>(f: F) -> SingleThreadIterator<'a, R>
201where
202    F: FnOnce(Handle) -> S,
203    S: Stream<Item = R> + Send + Unpin + 'a,
204    R: Send + 'a,
205{
206    let runtime = SingleThreadRuntime::default();
207    let handle = runtime.handle();
208    runtime.block_on_stream(f(handle))
209}
210
211/// A spawn request for a future.
212///
213/// We pass back the abort handle via oneshot channel because this is a single-threaded runtime,
214/// meaning we need the spawning channel consumer to do some work before the caller can actually
215/// get ahold of their task handle.
216///
217/// The reason we don't pass back a smol::Task, and instead pass back a SmolAbortHandle, is because
218/// we invert the behaviour of abort and drop. Dropping the abort handle results in the task being
219/// detached, whereas dropping the smol::Task results in the task being canceled. This helps avoid
220/// a race where the caller detaches the LazyAbortHandle before the smol::Task has been launched.
221struct SpawnAsync<'rt> {
222    future: BoxFuture<'rt, ()>,
223    task_callback: oneshot::Sender<AbortHandleRef>,
224}
225
226// A spawn request for a synchronous job.
227struct SpawnSync<'rt> {
228    sync: Box<dyn FnOnce() + Send + 'rt>,
229    task_callback: oneshot::Sender<AbortHandleRef>,
230}
231
232struct LazyAbortHandle {
233    task: Mutex<oneshot::Receiver<AbortHandleRef>>,
234}
235
236impl AbortHandle for LazyAbortHandle {
237    fn abort(self: Box<Self>) {
238        // Aborting a smol::Task is done by dropping it.
239        if let Ok(task) = self.task.lock().try_recv() {
240            task.abort()
241        }
242    }
243}
244
245/// A stream that wraps up the stream with the execution that drives it.
246pub struct SingleThreadIterator<'a, T> {
247    executor: Rc<LocalExecutor<'static>>,
248    stream: LocalBoxStream<'a, T>,
249}
250
251impl<T> Iterator for SingleThreadIterator<'_, T> {
252    type Item = T;
253
254    fn next(&mut self) -> Option<Self::Item> {
255        let fut = self.stream.next();
256        smol::block_on(self.executor.run(fut))
257    }
258}
259
260#[cfg(test)]
261mod tests {
262    use std::sync::Arc;
263    use std::sync::atomic::AtomicUsize;
264    use std::sync::atomic::Ordering;
265
266    use futures::FutureExt;
267
268    use crate::runtime::BlockingRuntime;
269    use crate::runtime::single::SingleThreadRuntime;
270    use crate::runtime::single::block_on;
271
272    #[test]
273    fn test_drive_simple_future() {
274        let result = SingleThreadRuntime::default().block_on(async { 123 }.boxed_local());
275        assert_eq!(result, 123);
276    }
277
278    #[test]
279    fn test_spawn_cpu_task() {
280        let counter = Arc::new(AtomicUsize::new(0));
281        let c = Arc::clone(&counter);
282
283        block_on(|handle| async move {
284            handle
285                .spawn_cpu(move || {
286                    c.fetch_add(1, Ordering::SeqCst);
287                })
288                .await
289        });
290
291        assert_eq!(counter.load(Ordering::SeqCst), 1);
292    }
293}