1mod buffered;
2mod buffered_unordered;
3
4mod futures_ordered;
5mod futures_unordered;
6
7pub use self::{
8 futures_ordered::FuturesParallelOrdered, futures_unordered::FuturesParallelUnordered,
9};
10use futures::Stream;
11use std::future::Future;
12
13pub trait TokioParStream: Stream {
14 fn par_buffered(self, n: usize) -> BufferedParallel<Self>
15 where
16 Self: Sized,
17 Self::Item: Future;
18
19 fn par_buffered_unordered(self, n: usize) -> BufferedParallelUnordered<Self>
20 where
21 Self: Sized,
22 Self::Item: Future;
23}
24
25impl<St: Stream> TokioParStream for St {
26 fn par_buffered(self, n: usize) -> BufferedParallel<Self>
27 where
28 Self: Sized,
29 Self::Item: Future,
30 {
31 BufferedParallel::new(self, Some(n))
32 }
33
34 fn par_buffered_unordered(self, n: usize) -> BufferedParallelUnordered<Self>
35 where
36 Self: Sized,
37 Self::Item: Future,
38 {
39 BufferedParallelUnordered::new(self, Some(n))
40 }
41}
42
43pub(crate) mod order {
44 use pin_project_lite::pin_project;
45 use std::cmp::Ordering;
46 use std::future::Future;
47 use std::pin::Pin;
48 use std::task::{Context, Poll};
49 pin_project! {
50 #[must_use = "futures do nothing unless you `.await` or poll them"]
51 #[derive(Debug)]
52 pub(crate) struct OrderWrapper<T> {
53 #[pin]
54 pub(crate) data: T, pub(crate) index: i64,
57 }
58 }
59
60 impl<T> PartialEq for OrderWrapper<T> {
61 fn eq(&self, other: &Self) -> bool {
62 self.index == other.index
63 }
64 }
65
66 impl<T> Eq for OrderWrapper<T> {}
67
68 impl<T> PartialOrd for OrderWrapper<T> {
69 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
70 Some(self.cmp(other))
71 }
72 }
73
74 impl<T> Ord for OrderWrapper<T> {
75 fn cmp(&self, other: &Self) -> Ordering {
76 other.index.cmp(&self.index)
78 }
79 }
80
81 impl<T> Future for OrderWrapper<T>
82 where
83 T: Future,
84 {
85 type Output = OrderWrapper<T::Output>;
86
87 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
88 let index = self.index;
89 self.project().data.poll(cx).map(|output| OrderWrapper {
90 data: output,
91 index,
92 })
93 }
94 }
95}
96
97macro_rules! buffered_stream {
98 ($ty: ident, $backing:ident, $push: ident) => {
99 use futures::stream::{Fuse, FusedStream};
100 use futures::{Stream, StreamExt};
101 use pin_project_lite::pin_project;
102 use std::fmt::{Debug, Formatter};
103 use std::future::Future;
104 use std::num::NonZeroUsize;
105 use std::pin::Pin;
106 use std::task::{Context, Poll};
107
108 pin_project! {
109 #[must_use = "streams do nothing unless polled"]
110 pub struct $ty<St>
111 where
112 St: Stream,
113 St::Item: Future
114 {
115 #[pin]
116 stream: Fuse<St>,
117 in_progress_queue: $backing<St::Item>,
118 limit: Option<NonZeroUsize>,
119 }
120 }
121
122 impl<St> Debug for $ty<St>
123 where
124 St: Stream,
125 St::Item: Future,
126 {
127 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
128 f.debug_struct("BufferedParallelUnordered")
129 .finish_non_exhaustive()
130 }
131 }
132
133 impl<St> $ty<St>
134 where
135 St: Stream,
136 St::Item: Future,
137 {
138 pub(super) fn new(stream: St, limit: Option<usize>) -> Self {
139 Self {
140 stream: stream.fuse(),
141 in_progress_queue: $backing::new(),
142 limit: limit.and_then(NonZeroUsize::new),
144 }
145 }
146 }
147
148 impl<St> Stream for $ty<St>
149 where
150 St: Stream,
151 St::Item: Future + Send + 'static,
152 <St::Item as Future>::Output: Send,
153 {
154 type Item = <St::Item as Future>::Output;
155
156 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
157 let mut this = self.project();
158
159 let limit = *this.limit;
160 while limit.map_or(true, |limit| this.in_progress_queue.len() < limit.get()) {
161 match this.stream.as_mut().poll_next(cx) {
162 Poll::Ready(Some(fut)) => this.in_progress_queue.$push(fut),
163 Poll::Ready(None) | Poll::Pending => break,
164 }
165 }
166
167 if let x @ (Poll::Pending | Poll::Ready(Some(_))) =
169 this.in_progress_queue.poll_next_unpin(cx)
170 {
171 return x;
172 }
173
174 if this.stream.is_done() {
176 Poll::Ready(None)
177 } else {
178 Poll::Pending
179 }
180 }
181
182 fn size_hint(&self) -> (usize, Option<usize>) {
183 let queue_len = self.in_progress_queue.len();
184 let (lower, upper) = self.stream.size_hint();
185 (
186 lower.saturating_add(queue_len),
187 upper.and_then(|x| x.checked_add(queue_len)),
188 )
189 }
190 }
191
192 impl<St> FusedStream for $ty<St>
193 where
194 St: Stream,
195 St::Item: Future + Send + 'static,
196 <St::Item as Future>::Output: Send,
197 {
198 fn is_terminated(&self) -> bool {
199 self.in_progress_queue.is_empty() && self.stream.is_terminated()
200 }
201 }
202 };
203}
204
205use crate::buffered::BufferedParallel;
206use crate::buffered_unordered::BufferedParallelUnordered;
207pub(crate) use buffered_stream;
208
209
210#[cfg(test)]
211mod tests {
212 use std::ops::Range;
213 use super::*;
214 use futures::stream::iter;
215 use futures::StreamExt;
216
217 const TEST_RANGE: Range<u64> = 0..256;
218
219 fn transform(i: u64) -> u64 {
220 i * 2
221 }
222
223 fn test_stream() -> impl Stream<Item: Future<Output=u64> + Send + 'static> {
224 iter(TEST_RANGE).map(|i| async move {
225 for _ in i..TEST_RANGE.end {
226 tokio::task::yield_now().await
227 }
228 transform(i)
229 })
230 }
231
232 #[tokio::test]
233 async fn buffered() {
234 let stream = test_stream().par_buffered(8);
235 assert!(stream.zip(iter(TEST_RANGE).map(transform)).all(|(x, y)| async move { x == y }).await);
236 }
237
238 #[tokio::test]
239 async fn buffered_unordered() {
240 let mut items = test_stream().buffer_unordered(8).collect::<Vec<_>>().await;
241 items.sort();
242 assert_eq!(items, TEST_RANGE.map(transform).collect::<Vec<_>>());
243 }
244}