vortex_io/runtime/
current.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::sync::Arc;
5
6use futures::stream::BoxStream;
7use futures::{Stream, StreamExt};
8use smol::block_on;
9
10pub use crate::runtime::pool::CurrentThreadWorkerPool;
11use crate::runtime::{BlockingRuntime, Executor, Handle};
12
13/// A current thread runtime allows callers to much more explicitly drive Vortex futures than with
14/// a Tokio runtime.
15///
16/// The current thread runtime will do no work unless `block_on` is called. In other words, the
17/// default behavior is single-threaded with code running on the thread that called `block_on`.
18///
19/// It's also possible to clone the runtime onto other threads, each of which can call `block_on`
20/// to drive work on that thread. Each thread shares the same underlying executor with the same
21/// set of tasks, allowing work to be driven in parallel.
22///
23/// For automatic driving of work, a [`CurrentThreadWorkerPool`] can be created from the runtime
24/// by calling [`new_pool`](CurrentThreadRuntime::new_pool). The returned pool can be configured
25/// with the desired number of worker threads that will drive work on behalf of the runtime.
26#[derive(Clone, Default)]
27pub struct CurrentThreadRuntime {
28    executor: Arc<smol::Executor<'static>>,
29}
30
31impl CurrentThreadRuntime {
32    /// Create a new current thread runtime.
33    pub fn new() -> Self {
34        Self::default()
35    }
36
37    /// Create a new worker pool for driving the runtime in the background.
38    ///
39    /// This pool can be used to offload work from the current thread to a set of worker threads
40    /// that will drive the runtime's executor.
41    ///
42    /// By default, the pool has no worker threads; the caller must set the desired number of
43    /// worker threads using the `set_workers` method on the returned pool.
44    pub fn new_pool(&self) -> CurrentThreadWorkerPool {
45        CurrentThreadWorkerPool::new(self.executor.clone())
46    }
47
48    /// Returns an iterator wrapper around a stream, blocking the current thread for each item.
49    ///
50    /// ## Multi-threaded Usage
51    ///
52    /// To drive the iterator from multiple threads, simply clone it and call `next()` on each
53    /// clone. Results on each thread are ordered with respect to the stream, but there is no
54    /// ordering guarantee between threads.
55    pub fn block_on_stream_thread_safe<F, S, R>(&self, f: F) -> ThreadSafeIterator<R>
56    where
57        F: FnOnce(Handle) -> S,
58        S: Stream<Item = R> + Send + 'static,
59        R: Send + 'static,
60    {
61        let stream = f(self.handle());
62
63        // We create an MPMC result channel and spawn a task to drive the stream and send results.
64        // This allows multiple worker threads to drive the execution while all waiting for results
65        // on the channel.
66        let (result_tx, result_rx) = kanal::bounded_async(1);
67        self.executor
68            .spawn(async move {
69                futures::pin_mut!(stream);
70                while let Some(item) = stream.next().await {
71                    // If all receivers are dropped, we stop driving the stream.
72                    if let Err(e) = result_tx.send(item).await {
73                        log::trace!("all receivers dropped, stopping stream: {}", e);
74                        break;
75                    }
76                }
77            })
78            .detach();
79
80        ThreadSafeIterator {
81            executor: self.executor.clone(),
82            results: result_rx,
83        }
84    }
85}
86
87impl BlockingRuntime for CurrentThreadRuntime {
88    type BlockingIterator<'a, R: 'a> = CurrentThreadIterator<'a, R>;
89
90    fn handle(&self) -> Handle {
91        let executor: Arc<dyn Executor> = self.executor.clone();
92        Handle::new(Arc::downgrade(&executor))
93    }
94
95    fn block_on<Fut, R>(&self, fut: Fut) -> R
96    where
97        Fut: Future<Output = R>,
98    {
99        block_on(self.executor.run(fut))
100    }
101
102    fn block_on_stream<'a, S, R>(&self, stream: S) -> Self::BlockingIterator<'a, R>
103    where
104        S: Stream<Item = R> + Send + 'a,
105        R: Send + 'a,
106    {
107        CurrentThreadIterator {
108            executor: self.executor.clone(),
109            stream: stream.boxed(),
110        }
111    }
112}
113
114/// An iterator that wraps up a stream to drive it using the current thread execution.
115pub struct CurrentThreadIterator<'a, T> {
116    executor: Arc<smol::Executor<'static>>,
117    stream: BoxStream<'a, T>,
118}
119
120impl<T> Iterator for CurrentThreadIterator<'_, T> {
121    type Item = T;
122
123    fn next(&mut self) -> Option<Self::Item> {
124        block_on(self.executor.run(self.stream.next()))
125    }
126}
127
128/// An iterator that drives a stream from multiple threads.
129pub struct ThreadSafeIterator<T> {
130    executor: Arc<smol::Executor<'static>>,
131    results: kanal::AsyncReceiver<T>,
132}
133
134// Manual clone implementation since `T` does not need to be `Clone`.
135impl<T> Clone for ThreadSafeIterator<T> {
136    fn clone(&self) -> Self {
137        Self {
138            executor: self.executor.clone(),
139            results: self.results.clone(),
140        }
141    }
142}
143
144impl<T> Iterator for ThreadSafeIterator<T> {
145    type Item = T;
146
147    fn next(&mut self) -> Option<Self::Item> {
148        block_on(self.executor.run(self.results.recv())).ok()
149    }
150}
151
152#[allow(clippy::if_then_some_else_none)] // Clippy is wrong when if/else has await.
153#[cfg(test)]
154mod tests {
155    use std::sync::atomic::{AtomicUsize, Ordering};
156    use std::sync::{Arc, Barrier};
157    use std::thread;
158    use std::time::Duration;
159
160    use futures::{StreamExt, stream};
161    use parking_lot::Mutex;
162
163    use super::*;
164
165    #[test]
166    fn test_worker_thread() {
167        let runtime = CurrentThreadRuntime::new();
168
169        // We spawn a future that sets a value on a separate thread.
170        let value = Arc::new(AtomicUsize::new(0));
171        let value2 = value.clone();
172        runtime
173            .handle()
174            .spawn(async move {
175                value2.store(42, Ordering::SeqCst);
176            })
177            .detach();
178
179        // By default, nothing has driven the executor, so the value should still be 0.
180        assert_eq!(value.load(Ordering::SeqCst), 0);
181
182        // An empty pool still does nothing.
183        let pool = runtime.new_pool();
184        assert_eq!(value.load(Ordering::SeqCst), 0);
185
186        // Adding a worker thread should drive the executor.
187        pool.set_workers(1);
188        for _ in 0..10 {
189            if value.load(Ordering::SeqCst) == 42 {
190                break;
191            }
192            thread::sleep(Duration::from_millis(10));
193        }
194        assert_eq!(value.load(Ordering::SeqCst), 42);
195    }
196
197    #[test]
198    fn test_block_on_stream_single_thread() {
199        let mut iter =
200            CurrentThreadRuntime::new().block_on_stream(stream::iter(vec![1, 2, 3, 4, 5]).boxed());
201
202        assert_eq!(iter.next(), Some(1));
203        assert_eq!(iter.next(), Some(2));
204        assert_eq!(iter.next(), Some(3));
205        assert_eq!(iter.next(), Some(4));
206        assert_eq!(iter.next(), Some(5));
207        assert_eq!(iter.next(), None);
208    }
209
210    #[test]
211    fn test_block_on_stream_multiple_threads() {
212        let counter = Arc::new(AtomicUsize::new(0));
213        let num_threads = 4;
214        let items_per_thread = 25;
215        let total_items = 100;
216
217        let iter = CurrentThreadRuntime::new()
218            .block_on_stream_thread_safe(|_h| stream::iter(0..total_items).boxed());
219
220        let barrier = Arc::new(Barrier::new(num_threads));
221        let results = Arc::new(Mutex::new(Vec::new()));
222
223        let threads: Vec<_> = (0..num_threads)
224            .map(|_| {
225                let mut iter = iter.clone();
226                let counter = counter.clone();
227                let barrier = barrier.clone();
228                let results = results.clone();
229
230                thread::spawn(move || {
231                    barrier.wait();
232                    let mut local_results = Vec::new();
233
234                    for _ in 0..items_per_thread {
235                        if let Some(item) = iter.next() {
236                            counter.fetch_add(1, Ordering::SeqCst);
237                            local_results.push(item);
238                        }
239                    }
240
241                    results.lock().push(local_results);
242                })
243            })
244            .collect();
245
246        for thread in threads {
247            thread.join().unwrap();
248        }
249
250        assert_eq!(counter.load(Ordering::SeqCst), total_items);
251
252        let all_results = results.lock();
253        let mut collected: Vec<_> = all_results.iter().flatten().copied().collect();
254        collected.sort();
255        assert_eq!(collected, (0..total_items).collect::<Vec<_>>());
256    }
257
258    #[test]
259    fn test_block_on_stream_concurrent_clone_and_drive() {
260        let num_items = 50;
261        let num_threads = 3;
262
263        let iter = CurrentThreadRuntime::new().block_on_stream_thread_safe(|h| {
264            stream::unfold(0, move |state| {
265                let h = h.clone();
266                async move {
267                    if state < num_items {
268                        h.spawn_cpu(move || {
269                            thread::sleep(Duration::from_micros(10));
270                            state
271                        })
272                        .await;
273                        Some((state, state + 1))
274                    } else {
275                        None
276                    }
277                }
278            })
279        });
280
281        let collected = Arc::new(Mutex::new(Vec::new()));
282        let barrier = Arc::new(Barrier::new(num_threads));
283
284        let threads: Vec<_> = (0..num_threads)
285            .map(|thread_id| {
286                let iter = iter.clone();
287                let collected = collected.clone();
288                let barrier = barrier.clone();
289
290                thread::spawn(move || {
291                    barrier.wait();
292                    let mut local_items = Vec::new();
293
294                    for item in iter {
295                        local_items.push((thread_id, item));
296                        if local_items.len() >= 5 {
297                            break;
298                        }
299                    }
300
301                    collected.lock().extend(local_items);
302                })
303            })
304            .collect();
305
306        for thread in threads {
307            thread.join().unwrap();
308        }
309
310        let results = collected.lock();
311        let mut values: Vec<_> = results.iter().map(|(_, v)| *v).collect();
312        values.sort();
313        values.dedup();
314
315        assert!(values.len() >= 5);
316        assert!(values.iter().all(|&v| v < num_items));
317    }
318
319    #[test]
320    fn test_block_on_stream_async_work() {
321        let runtime = CurrentThreadRuntime::new();
322        let handle = runtime.handle();
323        let iter = runtime.block_on_stream({
324            stream::unfold((handle, 0), |(h, state)| async move {
325                if state < 10 {
326                    let value = h
327                        .spawn(async move { futures::future::ready(state * 2).await })
328                        .await;
329                    Some((value, (h, state + 1)))
330                } else {
331                    None
332                }
333            })
334        });
335
336        let results: Vec<_> = iter.collect();
337        assert_eq!(results, vec![0, 2, 4, 6, 8, 10, 12, 14, 16, 18]);
338    }
339
340    #[test]
341    fn test_block_on_stream_drop_receivers_early() {
342        let counter = Arc::new(AtomicUsize::new(0));
343        let c = counter.clone();
344
345        let mut iter = CurrentThreadRuntime::new().block_on_stream({
346            stream::unfold(0, move |state| {
347                let c = c.clone();
348                async move {
349                    (state < 100).then(|| {
350                        c.fetch_add(1, Ordering::SeqCst);
351                        (state, state + 1)
352                    })
353                }
354            })
355            .boxed()
356        });
357
358        assert_eq!(iter.next(), Some(0));
359        assert_eq!(iter.next(), Some(1));
360        assert_eq!(iter.next(), Some(2));
361
362        drop(iter);
363
364        let final_count = counter.load(Ordering::SeqCst);
365        assert!(
366            final_count < 100,
367            "Stream should stop when all receivers are dropped"
368        );
369    }
370
371    #[test]
372    fn test_block_on_stream_interleaved_access() {
373        let barrier = Arc::new(Barrier::new(2));
374        let iter = CurrentThreadRuntime::new()
375            .block_on_stream_thread_safe(|_h| stream::iter(0..20).boxed());
376
377        let iter1 = iter.clone();
378        let iter2 = iter;
379        let barrier1 = barrier.clone();
380        let barrier2 = barrier;
381
382        let thread1 = thread::spawn(move || {
383            let mut iter = iter1;
384            let mut results = Vec::new();
385            barrier1.wait();
386
387            for _ in 0..5 {
388                if let Some(val) = iter.next() {
389                    results.push(val);
390                    thread::sleep(Duration::from_micros(50));
391                }
392            }
393            results
394        });
395
396        let thread2 = thread::spawn(move || {
397            let mut iter = iter2;
398            let mut results = Vec::new();
399            barrier2.wait();
400
401            for _ in 0..5 {
402                if let Some(val) = iter.next() {
403                    results.push(val);
404                    thread::sleep(Duration::from_micros(50));
405                }
406            }
407            results
408        });
409
410        let results1 = thread1.join().unwrap();
411        let results2 = thread2.join().unwrap();
412
413        let mut all_results = results1;
414        all_results.extend(results2);
415        all_results.sort();
416
417        assert_eq!(all_results, (0..10).collect::<Vec<_>>());
418
419        for i in 0..10 {
420            assert_eq!(all_results.iter().filter(|&&x| x == i).count(), 1);
421        }
422    }
423
424    #[test]
425    fn test_block_on_stream_stress_test() {
426        let num_threads = 10;
427        let num_items = 1000;
428
429        let iter = CurrentThreadRuntime::new()
430            .block_on_stream_thread_safe(|_h| stream::iter(0..num_items).boxed());
431
432        let received = Arc::new(Mutex::new(Vec::new()));
433        let barrier = Arc::new(Barrier::new(num_threads));
434
435        let threads: Vec<_> = (0..num_threads)
436            .map(|_| {
437                let iter = iter.clone();
438                let received = received.clone();
439                let barrier = barrier.clone();
440
441                thread::spawn(move || {
442                    barrier.wait();
443                    for val in iter {
444                        received.lock().push(val);
445                    }
446                })
447            })
448            .collect();
449
450        for thread in threads {
451            thread.join().unwrap();
452        }
453
454        let mut results = received.lock().clone();
455        results.sort();
456
457        assert_eq!(results.len(), num_items);
458        assert_eq!(results, (0..num_items).collect::<Vec<_>>());
459    }
460}