rust_stream_ext_concurrent/
then_concurrent.rs

1//! Highly inspired by [and_then_concurrent](https://docs.rs/and-then-concurrent/latest/and_then_concurrent/)
2//!
3//! `ThenConcurrent` extension adds the `then_concurrent` method to any `futures::stream::Stream`
4//! object allowing a concurrent execution of futures over the stream items.
5
6use futures::stream::{FuturesUnordered, Stream};
7use pin_project::pin_project;
8use std::{
9    future::Future,
10    pin::Pin,
11    task::{Context, Poll},
12};
13
14/// Stream for the [`Stream::then_concurrent()`] method.
15#[pin_project(project = ThenConcurrentProj)]
16pub struct ThenConcurrent<St, Fut: Future, F> {
17    #[pin]
18    stream: St,
19    #[pin]
20    futures: FuturesUnordered<Fut>,
21    fun: F,
22}
23
24impl<St, Fut, F, T> Stream for ThenConcurrent<St, Fut, F>
25where
26    St: Stream,
27    Fut: Future<Output = T>,
28    F: FnMut(St::Item) -> Fut,
29{
30    type Item = T;
31
32    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
33        let ThenConcurrentProj {
34            mut stream,
35            mut futures,
36            fun,
37        } = self.project();
38
39        // Eagerly fetch all ready items from the stream
40        loop {
41            match stream.as_mut().poll_next(cx) {
42                Poll::Ready(Some(n)) => {
43                    futures.push(fun(n));
44                }
45                Poll::Ready(None) => {
46                    if futures.is_empty() {
47                        return Poll::Ready(None);
48                    }
49                    break;
50                }
51                Poll::Pending => {
52                    if futures.is_empty() {
53                        return Poll::Pending;
54                    }
55                    break;
56                }
57            }
58        }
59
60        futures.as_mut().poll_next(cx)
61    }
62}
63
64/// Extension to `futures::stream::Stream`
65pub trait StreamThenConcurrentExt: Stream {
66    /// Chain a computation when a stream value is ready, passing `Ok` values to the closure `f`.
67    ///
68    /// This function is similar to [`futures::stream::StreamExt::then`], but the
69    /// stream is polled concurrently with the futures returned by `f`. An unbounded number of
70    /// futures corresponding to past stream values is kept via `FuturesUnordered`.
71    fn then_concurrent<Fut, F>(self, f: F) -> ThenConcurrent<Self, Fut, F>
72    where
73        Self: Sized,
74        Fut: Future,
75        F: FnMut(Self::Item) -> Fut;
76}
77
78impl<S: Stream> StreamThenConcurrentExt for S {
79    fn then_concurrent<Fut, F>(self, f: F) -> ThenConcurrent<Self, Fut, F>
80    where
81        Self: Sized,
82        Fut: Future,
83        F: FnMut(Self::Item) -> Fut,
84    {
85        ThenConcurrent {
86            stream: self,
87            futures: FuturesUnordered::new(),
88            fun: f,
89        }
90    }
91}
92
93#[cfg(test)]
94mod tests {
95    use super::*;
96    use futures::{channel::mpsc::unbounded, StreamExt};
97
98    #[async_std::test]
99    async fn no_items() {
100        let stream = futures::stream::iter::<Vec<u64>>(vec![]).then_concurrent(|_| async move {
101            panic!("must not be called");
102        });
103
104        assert_eq!(stream.collect::<Vec<_>>().await, vec![]);
105    }
106
107    #[async_std::test]
108    async fn paused_stream() {
109        let (mut tx, rx) = unbounded::<u64>();
110
111        let mut stream = rx.then_concurrent(|x| async move {
112            if x == 0 {
113                x
114            } else {
115                async_std::task::sleep(std::time::Duration::from_millis(x)).await;
116                x
117            }
118        });
119
120        // we need to poll the stream such that FuturesUnordered gets empty
121        let first_item = stream.next();
122
123        tx.start_send(0).unwrap();
124
125        assert_eq!(first_item.await, Some(0));
126
127        let second_item = stream.next();
128
129        // item produces a delay
130        tx.start_send(5).unwrap();
131
132        assert_eq!(second_item.await, Some(5));
133    }
134
135    #[async_std::test]
136    async fn fast_items() {
137        let item_1 = 0u64;
138        let item_2 = 0u64;
139        let item_3 = 7u64;
140
141        let stream =
142            futures::stream::iter(vec![item_1, item_2, item_3]).then_concurrent(|x| async move {
143                if x == 0 {
144                    x
145                } else {
146                    async_std::task::sleep(std::time::Duration::from_millis(x)).await;
147                    x
148                }
149            });
150        let actual_packets = stream.collect::<Vec<u64>>().await;
151
152        assert_eq!(actual_packets, vec![0, 0, 7]);
153    }
154
155    #[async_std::test]
156    async fn reorder_items() {
157        let item_1 = 10u64; // 3rd in the output
158        let item_2 = 5u64; // 1st in the output
159        let item_3 = 7u64; // 2nd in the output
160
161        let stream =
162            futures::stream::iter(vec![item_1, item_2, item_3]).then_concurrent(|x| async move {
163                async_std::task::sleep(std::time::Duration::from_millis(x)).await;
164                x
165            });
166        let actual_packets = stream.collect::<Vec<u64>>().await;
167
168        assert_eq!(actual_packets, vec![5, 7, 10]);
169    }
170}