vortex_io/runtime/
current.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::sync::Arc;
5
6use futures::stream::BoxStream;
7use futures::{Stream, StreamExt};
8use smol::block_on;
9
10use crate::runtime::{BlockingRuntime, Executor, Handle};
11
12/// A current thread runtime allows users to explicitly drive Vortex futures from multiple worker
13/// threads that they manage. This is useful in environments where the user already has a thread
14/// pool and wants to integrate Vortex into that pool, for example query engines.
15#[derive(Clone, Default)]
16pub struct CurrentThreadRuntime {
17    executor: Arc<smol::Executor<'static>>,
18}
19
20impl BlockingRuntime for CurrentThreadRuntime {
21    type BlockingIterator<'a, R: 'a> = CurrentThreadIterator<'a, R>;
22
23    fn handle(&self) -> Handle {
24        let executor: Arc<dyn Executor> = self.executor.clone();
25        Handle::new(Arc::downgrade(&executor))
26    }
27
28    fn block_on<F, Fut, R>(&self, f: F) -> R
29    where
30        F: FnOnce(Handle) -> Fut,
31        Fut: Future<Output = R>,
32    {
33        block_on(self.executor.run(f(self.handle())))
34    }
35
36    fn block_on_stream<'a, F, S, R>(&self, f: F) -> Self::BlockingIterator<'a, R>
37    where
38        F: FnOnce(Handle) -> S,
39        S: Stream<Item = R> + Send + 'a,
40        R: Send + 'a,
41    {
42        CurrentThreadIterator {
43            executor: self.executor.clone(),
44            stream: f(self.handle()).boxed(),
45        }
46    }
47}
48
49impl CurrentThreadRuntime {
50    /// Create a new current thread runtime.
51    pub fn new() -> Self {
52        Self::default()
53    }
54
55    /// Returns an iterator wrapper around a stream, blocking the current thread for each item.
56    ///
57    /// ## Multi-threaded Usage
58    ///
59    /// To drive the iterator from multiple threads, simply clone it and call `next()` on each
60    /// clone. Results on each thread are ordered with respect to the stream, but there is no
61    /// ordering guarantee between threads.
62    pub fn block_on_stream_thread_safe<F, S, R>(&self, f: F) -> ThreadSafeIterator<R>
63    where
64        F: FnOnce(Handle) -> S,
65        S: Stream<Item = R> + Send + 'static,
66        R: Send + 'static,
67    {
68        let stream = f(self.handle());
69
70        // We create an MPMC result channel and spawn a task to drive the stream and send results.
71        // This allows multiple worker threads to drive the executor while all waiting for results
72        // on the channel.
73        let (result_tx, result_rx) = kanal::bounded_async(1);
74        self.executor
75            .spawn(async move {
76                futures::pin_mut!(stream);
77                while let Some(item) = stream.next().await {
78                    // If all receivers are dropped, we stop driving the stream.
79                    if let Err(e) = result_tx.send(item).await {
80                        log::trace!("all receivers dropped, stopping stream: {}", e);
81                        break;
82                    }
83                }
84            })
85            .detach();
86
87        ThreadSafeIterator {
88            executor: self.executor.clone(),
89            results: result_rx,
90        }
91    }
92}
93
94/// An iterator that wraps up a stream to drive it using the current thread executor.
95pub struct CurrentThreadIterator<'a, T> {
96    executor: Arc<smol::Executor<'static>>,
97    stream: BoxStream<'a, T>,
98}
99
100impl<T> Iterator for CurrentThreadIterator<'_, T> {
101    type Item = T;
102
103    fn next(&mut self) -> Option<Self::Item> {
104        block_on(self.executor.run(self.stream.next()))
105    }
106}
107
108/// An iterator that drives a stream from multiple threads.
109pub struct ThreadSafeIterator<T> {
110    executor: Arc<smol::Executor<'static>>,
111    results: kanal::AsyncReceiver<T>,
112}
113
114// Manual clone implementation since `T` does not need to be `Clone`.
115impl<T> Clone for ThreadSafeIterator<T> {
116    fn clone(&self) -> Self {
117        Self {
118            executor: self.executor.clone(),
119            results: self.results.clone(),
120        }
121    }
122}
123
124impl<T> Iterator for ThreadSafeIterator<T> {
125    type Item = T;
126
127    fn next(&mut self) -> Option<Self::Item> {
128        block_on(self.executor.run(self.results.recv())).ok()
129    }
130}
131
132#[allow(clippy::if_then_some_else_none)] // Clippy is wrong when if/else has await.
133#[cfg(test)]
134mod tests {
135    use std::sync::atomic::{AtomicUsize, Ordering};
136    use std::sync::{Arc, Barrier};
137    use std::thread;
138    use std::time::Duration;
139
140    use futures::{StreamExt, stream};
141    use parking_lot::Mutex;
142
143    use super::*;
144
145    #[test]
146    fn test_block_on_stream_single_thread() {
147        let mut iter = CurrentThreadRuntime::new()
148            .block_on_stream(|_h| stream::iter(vec![1, 2, 3, 4, 5]).boxed());
149
150        assert_eq!(iter.next(), Some(1));
151        assert_eq!(iter.next(), Some(2));
152        assert_eq!(iter.next(), Some(3));
153        assert_eq!(iter.next(), Some(4));
154        assert_eq!(iter.next(), Some(5));
155        assert_eq!(iter.next(), None);
156    }
157
158    #[test]
159    fn test_block_on_stream_multiple_threads() {
160        let counter = Arc::new(AtomicUsize::new(0));
161        let num_threads = 4;
162        let items_per_thread = 25;
163        let total_items = 100;
164
165        let iter = CurrentThreadRuntime::new()
166            .block_on_stream_thread_safe(|_h| stream::iter(0..total_items).boxed());
167
168        let barrier = Arc::new(Barrier::new(num_threads));
169        let results = Arc::new(Mutex::new(Vec::new()));
170
171        let threads: Vec<_> = (0..num_threads)
172            .map(|_| {
173                let mut iter = iter.clone();
174                let counter = counter.clone();
175                let barrier = barrier.clone();
176                let results = results.clone();
177
178                thread::spawn(move || {
179                    barrier.wait();
180                    let mut local_results = Vec::new();
181
182                    for _ in 0..items_per_thread {
183                        if let Some(item) = iter.next() {
184                            counter.fetch_add(1, Ordering::SeqCst);
185                            local_results.push(item);
186                        }
187                    }
188
189                    results.lock().push(local_results);
190                })
191            })
192            .collect();
193
194        for thread in threads {
195            thread.join().unwrap();
196        }
197
198        assert_eq!(counter.load(Ordering::SeqCst), total_items);
199
200        let all_results = results.lock();
201        let mut collected: Vec<_> = all_results.iter().flatten().copied().collect();
202        collected.sort();
203        assert_eq!(collected, (0..total_items).collect::<Vec<_>>());
204    }
205
206    #[test]
207    fn test_block_on_stream_concurrent_clone_and_drive() {
208        let num_items = 50;
209        let num_threads = 3;
210
211        let iter = CurrentThreadRuntime::new().block_on_stream_thread_safe(|h| {
212            stream::unfold(0, move |state| {
213                let h = h.clone();
214                async move {
215                    if state < num_items {
216                        h.spawn_cpu(move || {
217                            thread::sleep(Duration::from_micros(10));
218                            state
219                        })
220                        .await;
221                        Some((state, state + 1))
222                    } else {
223                        None
224                    }
225                }
226            })
227        });
228
229        let collected = Arc::new(Mutex::new(Vec::new()));
230        let barrier = Arc::new(Barrier::new(num_threads));
231
232        let threads: Vec<_> = (0..num_threads)
233            .map(|thread_id| {
234                let iter = iter.clone();
235                let collected = collected.clone();
236                let barrier = barrier.clone();
237
238                thread::spawn(move || {
239                    barrier.wait();
240                    let mut local_items = Vec::new();
241
242                    for item in iter {
243                        local_items.push((thread_id, item));
244                        if local_items.len() >= 5 {
245                            break;
246                        }
247                    }
248
249                    collected.lock().extend(local_items);
250                })
251            })
252            .collect();
253
254        for thread in threads {
255            thread.join().unwrap();
256        }
257
258        let results = collected.lock();
259        let mut values: Vec<_> = results.iter().map(|(_, v)| *v).collect();
260        values.sort();
261        values.dedup();
262
263        assert!(values.len() >= 5);
264        assert!(values.iter().all(|&v| v < num_items));
265    }
266
267    #[test]
268    fn test_block_on_stream_async_work() {
269        let iter = CurrentThreadRuntime::new().block_on_stream(|h| {
270            stream::unfold((h, 0), |(h, state)| async move {
271                if state < 10 {
272                    let value = h
273                        .spawn(async move { futures::future::ready(state * 2).await })
274                        .await;
275                    Some((value, (h, state + 1)))
276                } else {
277                    None
278                }
279            })
280        });
281
282        let results: Vec<_> = iter.collect();
283        assert_eq!(results, vec![0, 2, 4, 6, 8, 10, 12, 14, 16, 18]);
284    }
285
286    #[test]
287    fn test_block_on_stream_drop_receivers_early() {
288        let counter = Arc::new(AtomicUsize::new(0));
289        let c = counter.clone();
290
291        let mut iter = CurrentThreadRuntime::new().block_on_stream(|_h| {
292            stream::unfold(0, move |state| {
293                let c = c.clone();
294                async move {
295                    (state < 100).then(|| {
296                        c.fetch_add(1, Ordering::SeqCst);
297                        (state, state + 1)
298                    })
299                }
300            })
301            .boxed()
302        });
303
304        assert_eq!(iter.next(), Some(0));
305        assert_eq!(iter.next(), Some(1));
306        assert_eq!(iter.next(), Some(2));
307
308        drop(iter);
309
310        let final_count = counter.load(Ordering::SeqCst);
311        assert!(
312            final_count < 100,
313            "Stream should stop when all receivers are dropped"
314        );
315    }
316
317    #[test]
318    fn test_block_on_stream_interleaved_access() {
319        let barrier = Arc::new(Barrier::new(2));
320        let iter = CurrentThreadRuntime::new()
321            .block_on_stream_thread_safe(|_h| stream::iter(0..20).boxed());
322
323        let iter1 = iter.clone();
324        let iter2 = iter;
325        let barrier1 = barrier.clone();
326        let barrier2 = barrier;
327
328        let thread1 = thread::spawn(move || {
329            let mut iter = iter1;
330            let mut results = Vec::new();
331            barrier1.wait();
332
333            for _ in 0..5 {
334                if let Some(val) = iter.next() {
335                    results.push(val);
336                    thread::sleep(Duration::from_micros(50));
337                }
338            }
339            results
340        });
341
342        let thread2 = thread::spawn(move || {
343            let mut iter = iter2;
344            let mut results = Vec::new();
345            barrier2.wait();
346
347            for _ in 0..5 {
348                if let Some(val) = iter.next() {
349                    results.push(val);
350                    thread::sleep(Duration::from_micros(50));
351                }
352            }
353            results
354        });
355
356        let results1 = thread1.join().unwrap();
357        let results2 = thread2.join().unwrap();
358
359        let mut all_results = results1;
360        all_results.extend(results2);
361        all_results.sort();
362
363        assert_eq!(all_results, (0..10).collect::<Vec<_>>());
364
365        for i in 0..10 {
366            assert_eq!(all_results.iter().filter(|&&x| x == i).count(), 1);
367        }
368    }
369
370    #[test]
371    fn test_block_on_stream_stress_test() {
372        let num_threads = 10;
373        let num_items = 1000;
374
375        let iter = CurrentThreadRuntime::new()
376            .block_on_stream_thread_safe(|_h| stream::iter(0..num_items).boxed());
377
378        let received = Arc::new(Mutex::new(Vec::new()));
379        let barrier = Arc::new(Barrier::new(num_threads));
380
381        let threads: Vec<_> = (0..num_threads)
382            .map(|_| {
383                let iter = iter.clone();
384                let received = received.clone();
385                let barrier = barrier.clone();
386
387                thread::spawn(move || {
388                    barrier.wait();
389                    for val in iter {
390                        received.lock().push(val);
391                    }
392                })
393            })
394            .collect();
395
396        for thread in threads {
397            thread.join().unwrap();
398        }
399
400        let mut results = received.lock().clone();
401        results.sort();
402
403        assert_eq!(results.len(), num_items);
404        assert_eq!(results, (0..num_items).collect::<Vec<_>>());
405    }
406}