vortex_io/runtime/
current.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::sync::Arc;
5
6use futures::Stream;
7use futures::StreamExt;
8use futures::stream::BoxStream;
9use smol::block_on;
10
11use crate::runtime::BlockingRuntime;
12use crate::runtime::Executor;
13use crate::runtime::Handle;
14pub use crate::runtime::pool::CurrentThreadWorkerPool;
15
16/// A current thread runtime allows callers to much more explicitly drive Vortex futures than with
17/// a Tokio runtime.
18///
19/// The current thread runtime will do no work unless `block_on` is called. In other words, the
20/// default behavior is single-threaded with code running on the thread that called `block_on`.
21///
22/// It's also possible to clone the runtime onto other threads, each of which can call `block_on`
23/// to drive work on that thread. Each thread shares the same underlying executor with the same
24/// set of tasks, allowing work to be driven in parallel.
25///
26/// For automatic driving of work, a [`CurrentThreadWorkerPool`] can be created from the runtime
27/// by calling [`new_pool`](CurrentThreadRuntime::new_pool). The returned pool can be configured
28/// with the desired number of worker threads that will drive work on behalf of the runtime.
29#[derive(Clone, Default)]
30pub struct CurrentThreadRuntime {
31    executor: Arc<smol::Executor<'static>>,
32}
33
34impl CurrentThreadRuntime {
35    /// Create a new current thread runtime.
36    pub fn new() -> Self {
37        Self::default()
38    }
39
40    /// Create a new worker pool for driving the runtime in the background.
41    ///
42    /// This pool can be used to offload work from the current thread to a set of worker threads
43    /// that will drive the runtime's executor.
44    ///
45    /// By default, the pool has no worker threads; the caller must set the desired number of
46    /// worker threads using the `set_workers` method on the returned pool.
47    pub fn new_pool(&self) -> CurrentThreadWorkerPool {
48        CurrentThreadWorkerPool::new(self.executor.clone())
49    }
50
51    /// Returns an iterator wrapper around a stream, blocking the current thread for each item.
52    ///
53    /// ## Multi-threaded Usage
54    ///
55    /// To drive the iterator from multiple threads, simply clone it and call `next()` on each
56    /// clone. Results on each thread are ordered with respect to the stream, but there is no
57    /// ordering guarantee between threads.
58    pub fn block_on_stream_thread_safe<F, S, R>(&self, f: F) -> ThreadSafeIterator<R>
59    where
60        F: FnOnce(Handle) -> S,
61        S: Stream<Item = R> + Send + 'static,
62        R: Send + 'static,
63    {
64        let stream = f(self.handle());
65
66        // We create an MPMC result channel and spawn a task to drive the stream and send results.
67        // This allows multiple worker threads to drive the execution while all waiting for results
68        // on the channel.
69        let (result_tx, result_rx) = kanal::bounded_async(1);
70        self.executor
71            .spawn(async move {
72                futures::pin_mut!(stream);
73                while let Some(item) = stream.next().await {
74                    // If all receivers are dropped, we stop driving the stream.
75                    if let Err(e) = result_tx.send(item).await {
76                        tracing::trace!("all receivers dropped, stopping stream: {}", e);
77                        break;
78                    }
79                }
80            })
81            .detach();
82
83        ThreadSafeIterator {
84            executor: self.executor.clone(),
85            results: result_rx,
86        }
87    }
88}
89
90impl BlockingRuntime for CurrentThreadRuntime {
91    type BlockingIterator<'a, R: 'a> = CurrentThreadIterator<'a, R>;
92
93    fn handle(&self) -> Handle {
94        let executor: Arc<dyn Executor> = self.executor.clone();
95        Handle::new(Arc::downgrade(&executor))
96    }
97
98    fn block_on<Fut, R>(&self, fut: Fut) -> R
99    where
100        Fut: Future<Output = R>,
101    {
102        block_on(self.executor.run(fut))
103    }
104
105    fn block_on_stream<'a, S, R>(&self, stream: S) -> Self::BlockingIterator<'a, R>
106    where
107        S: Stream<Item = R> + Send + 'a,
108        R: Send + 'a,
109    {
110        CurrentThreadIterator {
111            executor: self.executor.clone(),
112            stream: stream.boxed(),
113        }
114    }
115}
116
117/// An iterator that wraps up a stream to drive it using the current thread execution.
118pub struct CurrentThreadIterator<'a, T> {
119    executor: Arc<smol::Executor<'static>>,
120    stream: BoxStream<'a, T>,
121}
122
123impl<T> Iterator for CurrentThreadIterator<'_, T> {
124    type Item = T;
125
126    fn next(&mut self) -> Option<Self::Item> {
127        block_on(self.executor.run(self.stream.next()))
128    }
129}
130
131/// An iterator that drives a stream from multiple threads.
132pub struct ThreadSafeIterator<T> {
133    executor: Arc<smol::Executor<'static>>,
134    results: kanal::AsyncReceiver<T>,
135}
136
137// Manual clone implementation since `T` does not need to be `Clone`.
138impl<T> Clone for ThreadSafeIterator<T> {
139    fn clone(&self) -> Self {
140        Self {
141            executor: self.executor.clone(),
142            results: self.results.clone(),
143        }
144    }
145}
146
147impl<T> Iterator for ThreadSafeIterator<T> {
148    type Item = T;
149
150    fn next(&mut self) -> Option<Self::Item> {
151        block_on(self.executor.run(self.results.recv())).ok()
152    }
153}
154
155#[allow(clippy::if_then_some_else_none)] // Clippy is wrong when if/else has await.
156#[cfg(test)]
157mod tests {
158    use std::sync::Arc;
159    use std::sync::Barrier;
160    use std::sync::atomic::AtomicUsize;
161    use std::sync::atomic::Ordering;
162    use std::thread;
163    use std::time::Duration;
164
165    use futures::StreamExt;
166    use futures::stream;
167    use parking_lot::Mutex;
168
169    use super::*;
170
171    #[test]
172    fn test_worker_thread() {
173        let runtime = CurrentThreadRuntime::new();
174
175        // We spawn a future that sets a value on a separate thread.
176        let value = Arc::new(AtomicUsize::new(0));
177        let value2 = value.clone();
178        runtime
179            .handle()
180            .spawn(async move {
181                value2.store(42, Ordering::SeqCst);
182            })
183            .detach();
184
185        // By default, nothing has driven the executor, so the value should still be 0.
186        assert_eq!(value.load(Ordering::SeqCst), 0);
187
188        // An empty pool still does nothing.
189        let pool = runtime.new_pool();
190        assert_eq!(value.load(Ordering::SeqCst), 0);
191
192        // Adding a worker thread should drive the executor.
193        pool.set_workers(1);
194        for _ in 0..10 {
195            if value.load(Ordering::SeqCst) == 42 {
196                break;
197            }
198            thread::sleep(Duration::from_millis(10));
199        }
200        assert_eq!(value.load(Ordering::SeqCst), 42);
201    }
202
203    #[test]
204    fn test_block_on_stream_single_thread() {
205        let mut iter =
206            CurrentThreadRuntime::new().block_on_stream(stream::iter(vec![1, 2, 3, 4, 5]).boxed());
207
208        assert_eq!(iter.next(), Some(1));
209        assert_eq!(iter.next(), Some(2));
210        assert_eq!(iter.next(), Some(3));
211        assert_eq!(iter.next(), Some(4));
212        assert_eq!(iter.next(), Some(5));
213        assert_eq!(iter.next(), None);
214    }
215
216    #[test]
217    fn test_block_on_stream_multiple_threads() {
218        let counter = Arc::new(AtomicUsize::new(0));
219        let num_threads = 4;
220        let items_per_thread = 25;
221        let total_items = 100;
222
223        let iter = CurrentThreadRuntime::new()
224            .block_on_stream_thread_safe(|_h| stream::iter(0..total_items).boxed());
225
226        let barrier = Arc::new(Barrier::new(num_threads));
227        let results = Arc::new(Mutex::new(Vec::new()));
228
229        let threads: Vec<_> = (0..num_threads)
230            .map(|_| {
231                let mut iter = iter.clone();
232                let counter = counter.clone();
233                let barrier = barrier.clone();
234                let results = results.clone();
235
236                thread::spawn(move || {
237                    barrier.wait();
238                    let mut local_results = Vec::new();
239
240                    for _ in 0..items_per_thread {
241                        if let Some(item) = iter.next() {
242                            counter.fetch_add(1, Ordering::SeqCst);
243                            local_results.push(item);
244                        }
245                    }
246
247                    results.lock().push(local_results);
248                })
249            })
250            .collect();
251
252        for thread in threads {
253            thread.join().unwrap();
254        }
255
256        assert_eq!(counter.load(Ordering::SeqCst), total_items);
257
258        let all_results = results.lock();
259        let mut collected: Vec<_> = all_results.iter().flatten().copied().collect();
260        collected.sort();
261        assert_eq!(collected, (0..total_items).collect::<Vec<_>>());
262    }
263
264    #[test]
265    fn test_block_on_stream_concurrent_clone_and_drive() {
266        let num_items = 50;
267        let num_threads = 3;
268
269        let iter = CurrentThreadRuntime::new().block_on_stream_thread_safe(|h| {
270            stream::unfold(0, move |state| {
271                let h = h.clone();
272                async move {
273                    if state < num_items {
274                        h.spawn_cpu(move || {
275                            thread::sleep(Duration::from_micros(10));
276                            state
277                        })
278                        .await;
279                        Some((state, state + 1))
280                    } else {
281                        None
282                    }
283                }
284            })
285        });
286
287        let collected = Arc::new(Mutex::new(Vec::new()));
288        let barrier = Arc::new(Barrier::new(num_threads));
289
290        let threads: Vec<_> = (0..num_threads)
291            .map(|thread_id| {
292                let iter = iter.clone();
293                let collected = collected.clone();
294                let barrier = barrier.clone();
295
296                thread::spawn(move || {
297                    barrier.wait();
298                    let mut local_items = Vec::new();
299
300                    for item in iter {
301                        local_items.push((thread_id, item));
302                        if local_items.len() >= 5 {
303                            break;
304                        }
305                    }
306
307                    collected.lock().extend(local_items);
308                })
309            })
310            .collect();
311
312        for thread in threads {
313            thread.join().unwrap();
314        }
315
316        let results = collected.lock();
317        let mut values: Vec<_> = results.iter().map(|(_, v)| *v).collect();
318        values.sort();
319        values.dedup();
320
321        assert!(values.len() >= 5);
322        assert!(values.iter().all(|&v| v < num_items));
323    }
324
325    #[test]
326    fn test_block_on_stream_async_work() {
327        let runtime = CurrentThreadRuntime::new();
328        let handle = runtime.handle();
329        let iter = runtime.block_on_stream({
330            stream::unfold((handle, 0), |(h, state)| async move {
331                if state < 10 {
332                    let value = h
333                        .spawn(async move { futures::future::ready(state * 2).await })
334                        .await;
335                    Some((value, (h, state + 1)))
336                } else {
337                    None
338                }
339            })
340        });
341
342        let results: Vec<_> = iter.collect();
343        assert_eq!(results, vec![0, 2, 4, 6, 8, 10, 12, 14, 16, 18]);
344    }
345
346    #[test]
347    fn test_block_on_stream_drop_receivers_early() {
348        let counter = Arc::new(AtomicUsize::new(0));
349        let c = counter.clone();
350
351        let mut iter = CurrentThreadRuntime::new().block_on_stream({
352            stream::unfold(0, move |state| {
353                let c = c.clone();
354                async move {
355                    (state < 100).then(|| {
356                        c.fetch_add(1, Ordering::SeqCst);
357                        (state, state + 1)
358                    })
359                }
360            })
361            .boxed()
362        });
363
364        assert_eq!(iter.next(), Some(0));
365        assert_eq!(iter.next(), Some(1));
366        assert_eq!(iter.next(), Some(2));
367
368        drop(iter);
369
370        let final_count = counter.load(Ordering::SeqCst);
371        assert!(
372            final_count < 100,
373            "Stream should stop when all receivers are dropped"
374        );
375    }
376
377    #[test]
378    fn test_block_on_stream_interleaved_access() {
379        let barrier = Arc::new(Barrier::new(2));
380        let iter = CurrentThreadRuntime::new()
381            .block_on_stream_thread_safe(|_h| stream::iter(0..20).boxed());
382
383        let iter1 = iter.clone();
384        let iter2 = iter;
385        let barrier1 = barrier.clone();
386        let barrier2 = barrier;
387
388        let thread1 = thread::spawn(move || {
389            let mut iter = iter1;
390            let mut results = Vec::new();
391            barrier1.wait();
392
393            for _ in 0..5 {
394                if let Some(val) = iter.next() {
395                    results.push(val);
396                    thread::sleep(Duration::from_micros(50));
397                }
398            }
399            results
400        });
401
402        let thread2 = thread::spawn(move || {
403            let mut iter = iter2;
404            let mut results = Vec::new();
405            barrier2.wait();
406
407            for _ in 0..5 {
408                if let Some(val) = iter.next() {
409                    results.push(val);
410                    thread::sleep(Duration::from_micros(50));
411                }
412            }
413            results
414        });
415
416        let results1 = thread1.join().unwrap();
417        let results2 = thread2.join().unwrap();
418
419        let mut all_results = results1;
420        all_results.extend(results2);
421        all_results.sort();
422
423        assert_eq!(all_results, (0..10).collect::<Vec<_>>());
424
425        for i in 0..10 {
426            assert_eq!(all_results.iter().filter(|&&x| x == i).count(), 1);
427        }
428    }
429
430    #[test]
431    fn test_block_on_stream_stress_test() {
432        let num_threads = 10;
433        let num_items = 1000;
434
435        let iter = CurrentThreadRuntime::new()
436            .block_on_stream_thread_safe(|_h| stream::iter(0..num_items).boxed());
437
438        let received = Arc::new(Mutex::new(Vec::new()));
439        let barrier = Arc::new(Barrier::new(num_threads));
440
441        let threads: Vec<_> = (0..num_threads)
442            .map(|_| {
443                let iter = iter.clone();
444                let received = received.clone();
445                let barrier = barrier.clone();
446
447                thread::spawn(move || {
448                    barrier.wait();
449                    for val in iter {
450                        received.lock().push(val);
451                    }
452                })
453            })
454            .collect();
455
456        for thread in threads {
457            thread.join().unwrap();
458        }
459
460        let mut results = received.lock().clone();
461        results.sort();
462
463        assert_eq!(results.len(), num_items);
464        assert_eq!(results, (0..num_items).collect::<Vec<_>>());
465    }
466}