tokio_par_stream/
lib.rs

1mod buffered;
2mod buffered_unordered;
3
4mod futures_ordered;
5mod futures_unordered;
6
7pub use self::{
8    futures_ordered::FuturesParallelOrdered, futures_unordered::FuturesParallelUnordered,
9};
10use futures::Stream;
11use std::future::Future;
12
13pub trait TokioParStream: Stream {
14    fn par_buffered(self, n: usize) -> BufferedParallel<Self>
15    where
16        Self: Sized,
17        Self::Item: Future;
18
19    fn par_buffered_unordered(self, n: usize) -> BufferedParallelUnordered<Self>
20    where
21        Self: Sized,
22        Self::Item: Future;
23}
24
25impl<St: Stream> TokioParStream for St {
26    fn par_buffered(self, n: usize) -> BufferedParallel<Self>
27    where
28        Self: Sized,
29        Self::Item: Future,
30    {
31        BufferedParallel::new(self, Some(n))
32    }
33
34    fn par_buffered_unordered(self, n: usize) -> BufferedParallelUnordered<Self>
35    where
36        Self: Sized,
37        Self::Item: Future,
38    {
39        BufferedParallelUnordered::new(self, Some(n))
40    }
41}
42
43pub(crate) mod order {
44    use pin_project_lite::pin_project;
45    use std::cmp::Ordering;
46    use std::future::Future;
47    use std::pin::Pin;
48    use std::task::{Context, Poll};
49    pin_project! {
50        #[must_use = "futures do nothing unless you `.await` or poll them"]
51        #[derive(Debug)]
52        pub(crate) struct OrderWrapper<T> {
53            #[pin]
54            pub(crate) data: T, // A future or a future's output
55            // Use i64 for index since isize may overflow in 32-bit targets.
56            pub(crate) index: i64,
57        }
58    }
59
60    impl<T> PartialEq for OrderWrapper<T> {
61        fn eq(&self, other: &Self) -> bool {
62            self.index == other.index
63        }
64    }
65
66    impl<T> Eq for OrderWrapper<T> {}
67
68    impl<T> PartialOrd for OrderWrapper<T> {
69        fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
70            Some(self.cmp(other))
71        }
72    }
73
74    impl<T> Ord for OrderWrapper<T> {
75        fn cmp(&self, other: &Self) -> Ordering {
76            // BinaryHeap is a max heap, so compare backwards here.
77            other.index.cmp(&self.index)
78        }
79    }
80
81    impl<T> Future for OrderWrapper<T>
82    where
83        T: Future,
84    {
85        type Output = OrderWrapper<T::Output>;
86
87        fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
88            let index = self.index;
89            self.project().data.poll(cx).map(|output| OrderWrapper {
90                data: output,
91                index,
92            })
93        }
94    }
95}
96
97macro_rules! buffered_stream {
98    ($ty: ident, $backing:ident, $push: ident) => {
99        use futures::stream::{Fuse, FusedStream};
100        use futures::{Stream, StreamExt};
101        use pin_project_lite::pin_project;
102        use std::fmt::{Debug, Formatter};
103        use std::future::Future;
104        use std::num::NonZeroUsize;
105        use std::pin::Pin;
106        use std::task::{Context, Poll};
107
108        pin_project! {
109        #[must_use = "streams do nothing unless polled"]
110        pub struct $ty<St>
111        where
112            St: Stream,
113            St::Item: Future
114        {
115                #[pin]
116                stream: Fuse<St>,
117                in_progress_queue: $backing<St::Item>,
118                limit: Option<NonZeroUsize>,
119            }
120        }
121
122        impl<St> Debug for $ty<St>
123        where
124            St: Stream,
125            St::Item: Future,
126        {
127            fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
128                f.debug_struct("BufferedParallelUnordered")
129                    .finish_non_exhaustive()
130            }
131        }
132
133        impl<St> $ty<St>
134        where
135            St: Stream,
136            St::Item: Future,
137        {
138            pub(super) fn new(stream: St, limit: Option<usize>) -> Self {
139                Self {
140                    stream: stream.fuse(),
141                    in_progress_queue: $backing::new(),
142                    // limit = 0 => no limit
143                    limit: limit.and_then(NonZeroUsize::new),
144                }
145            }
146        }
147
148        impl<St> Stream for $ty<St>
149        where
150            St: Stream,
151            St::Item: Future + Send + 'static,
152            <St::Item as Future>::Output: Send,
153        {
154            type Item = <St::Item as Future>::Output;
155
156            fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
157                let mut this = self.project();
158
159                let limit = *this.limit;
160                while limit.map_or(true, |limit| this.in_progress_queue.len() < limit.get()) {
161                    match this.stream.as_mut().poll_next(cx) {
162                        Poll::Ready(Some(fut)) => this.in_progress_queue.$push(fut),
163                        Poll::Ready(None) | Poll::Pending => break,
164                    }
165                }
166
167                // attempt to pull from the queue in progress
168                if let x @ (Poll::Pending | Poll::Ready(Some(_))) =
169                    this.in_progress_queue.poll_next_unpin(cx)
170                {
171                    return x;
172                }
173
174                // If more values are still coming from the stream, we're not done yet
175                if this.stream.is_done() {
176                    Poll::Ready(None)
177                } else {
178                    Poll::Pending
179                }
180            }
181
182            fn size_hint(&self) -> (usize, Option<usize>) {
183                let queue_len = self.in_progress_queue.len();
184                let (lower, upper) = self.stream.size_hint();
185                (
186                    lower.saturating_add(queue_len),
187                    upper.and_then(|x| x.checked_add(queue_len)),
188                )
189            }
190        }
191
192        impl<St> FusedStream for $ty<St>
193        where
194            St: Stream,
195            St::Item: Future + Send + 'static,
196            <St::Item as Future>::Output: Send,
197        {
198            fn is_terminated(&self) -> bool {
199                self.in_progress_queue.is_empty() && self.stream.is_terminated()
200            }
201        }
202    };
203}
204
205use crate::buffered::BufferedParallel;
206use crate::buffered_unordered::BufferedParallelUnordered;
207pub(crate) use buffered_stream;
208
209
210#[cfg(test)]
211mod tests {
212    use std::ops::Range;
213    use super::*;
214    use futures::stream::iter;
215    use futures::StreamExt;
216
217    const TEST_RANGE: Range<u64> = 0..256;
218
219    fn transform(i: u64) -> u64 {
220        i * 2
221    }
222    
223    fn test_stream() -> impl Stream<Item: Future<Output=u64> + Send + 'static> {
224        iter(TEST_RANGE).map(|i| async move {
225            for _ in i..TEST_RANGE.end {
226                tokio::task::yield_now().await
227            }
228            transform(i)
229        })
230    }
231    
232    #[tokio::test]
233    async fn buffered() {
234        let stream = test_stream().par_buffered(8);
235        assert!(stream.zip(iter(TEST_RANGE).map(transform)).all(|(x, y)| async move { x == y }).await);
236    }
237
238    #[tokio::test]
239    async fn buffered_unordered() {
240        let mut items = test_stream().buffer_unordered(8).collect::<Vec<_>>().await;
241        items.sort();
242        assert_eq!(items, TEST_RANGE.map(transform).collect::<Vec<_>>());
243    }
244}