Skip to main content

rust_stream_ext_concurrent/
then_concurrent.rs

1//! Highly inspired by [and_then_concurrent](https://docs.rs/and-then-concurrent/latest/and_then_concurrent/)
2//!
3//! `ThenConcurrent` extension adds the `then_concurrent` method to any `futures::stream::Stream`
4//! object allowing a concurrent execution of futures over the stream items.
5
6use futures::stream::{FuturesUnordered, Stream};
7use pin_project::pin_project;
8use std::{
9    future::Future,
10    pin::Pin,
11    task::{Context, Poll},
12};
13
14/// Stream for the [`Stream::then_concurrent()`] method.
15#[pin_project(project = ThenConcurrentProj)]
16pub struct ThenConcurrent<St, Fut: Future, F> {
17    #[pin]
18    stream: St,
19    #[pin]
20    futures: FuturesUnordered<Fut>,
21    fun: F,
22    limit: Option<usize>,
23}
24
25impl<St, Fut, F, T> Stream for ThenConcurrent<St, Fut, F>
26where
27    St: Stream,
28    Fut: Future<Output = T>,
29    F: FnMut(St::Item) -> Fut,
30{
31    type Item = T;
32
33    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
34        let ThenConcurrentProj {
35            mut stream,
36            mut futures,
37            fun,
38            limit,
39        } = self.project();
40
41        // Eagerly fetch all ready items from the stream if not full already
42        if limit.as_ref().is_none_or(|&l| futures.len() < l) {
43            loop {
44                match stream.as_mut().poll_next(cx) {
45                    Poll::Ready(Some(n)) => {
46                        futures.push(fun(n));
47                        // Go to process existing futures, before pulling new ones from the Stream
48                        // when full
49                        if limit.as_ref().is_some_and(|&l| futures.len() >= l) {
50                            break;
51                        }
52                    }
53                    Poll::Ready(None) => {
54                        if futures.is_empty() {
55                            return Poll::Ready(None);
56                        }
57                        break;
58                    }
59                    Poll::Pending => {
60                        if futures.is_empty() {
61                            return Poll::Pending;
62                        }
63                        break;
64                    }
65                }
66            }
67        }
68
69        futures.as_mut().poll_next(cx)
70    }
71}
72
73/// Extension to `futures::stream::Stream`
74pub trait StreamThenConcurrentExt: Stream {
75    /// Chain a computation when a stream value is ready, passing `Ok` values to the closure `f`.
76    ///
77    /// This function is similar to [`futures::stream::StreamExt::then`], but the
78    /// stream is polled concurrently with the futures returned by `f`. An unbounded number of
79    /// futures corresponding to past stream values is kept via `FuturesUnordered`.
80    fn then_concurrent<Fut, F, L>(self, f: F, limit: L) -> ThenConcurrent<Self, Fut, F>
81    where
82        Self: Sized,
83        Fut: Future,
84        F: FnMut(Self::Item) -> Fut,
85        L: Into<Option<usize>>;
86}
87
88impl<S: Stream> StreamThenConcurrentExt for S {
89    fn then_concurrent<Fut, F, L>(self, f: F, limit: L) -> ThenConcurrent<Self, Fut, F>
90    where
91        Self: Sized,
92        Fut: Future,
93        F: FnMut(Self::Item) -> Fut,
94        L: Into<Option<usize>>,
95    {
96        ThenConcurrent {
97            stream: self,
98            futures: FuturesUnordered::new(),
99            fun: f,
100            limit: limit.into().filter(|&l| l > 0),
101        }
102    }
103}
104
105#[cfg(test)]
106mod tests {
107    use super::*;
108    use futures::{channel::mpsc::unbounded, StreamExt};
109
110    #[tokio::test]
111    async fn no_items() {
112        let stream = futures::stream::iter::<Vec<u64>>(vec![]).then_concurrent(|_| async move {
113            panic!("must not be called");
114        }, None);
115
116        assert_eq!(stream.collect::<Vec<_>>().await, vec![]);
117    }
118
119    #[tokio::test]
120    async fn paused_stream() {
121        let (mut tx, rx) = unbounded::<u64>();
122
123        let mut stream = rx.then_concurrent(|x| async move {
124            if x == 0 {
125                x
126            } else {
127                tokio::time::sleep(std::time::Duration::from_millis(x)).await;
128                x
129            }
130        }, None);
131
132        // we need to poll the stream such that FuturesUnordered gets empty
133        let first_item = stream.next();
134
135        tx.start_send(0).unwrap();
136
137        assert_eq!(first_item.await, Some(0));
138
139        let second_item = stream.next();
140
141        // item produces a delay
142        tx.start_send(5).unwrap();
143
144        assert_eq!(second_item.await, Some(5));
145    }
146
147    #[tokio::test]
148    async fn fast_items() {
149        let item_1 = 0u64;
150        let item_2 = 0u64;
151        let item_3 = 7u64;
152
153        let stream =
154            futures::stream::iter(vec![item_1, item_2, item_3]).then_concurrent(|x| async move {
155                if x == 0 {
156                    x
157                } else {
158                    tokio::time::sleep(std::time::Duration::from_millis(x)).await;
159                    x
160                }
161            }, None);
162        let actual_packets = stream.collect::<Vec<u64>>().await;
163
164        assert_eq!(actual_packets, vec![0, 0, 7]);
165    }
166
167    #[tokio::test]
168    async fn reorder_items() {
169        let item_1 = 10u64; // 3rd in the output
170        let item_2 = 5u64; // 1st in the output
171        let item_3 = 7u64; // 2nd in the output
172
173        let stream =
174            futures::stream::iter(vec![item_1, item_2, item_3]).then_concurrent(|x| async move {
175                tokio::time::sleep(std::time::Duration::from_millis(x)).await;
176                x
177            }, None);
178        let actual_packets = stream.collect::<Vec<u64>>().await;
179
180        assert_eq!(actual_packets, vec![5, 7, 10]);
181    }
182}