rust_stream_ext_concurrent/
then_concurrent.rs1use futures::stream::{FuturesUnordered, Stream};
7use pin_project::pin_project;
8use std::{
9 future::Future,
10 pin::Pin,
11 task::{Context, Poll},
12};
13
14#[pin_project(project = ThenConcurrentProj)]
16pub struct ThenConcurrent<St, Fut: Future, F> {
17 #[pin]
18 stream: St,
19 #[pin]
20 futures: FuturesUnordered<Fut>,
21 fun: F,
22 limit: Option<usize>,
23}
24
25impl<St, Fut, F, T> Stream for ThenConcurrent<St, Fut, F>
26where
27 St: Stream,
28 Fut: Future<Output = T>,
29 F: FnMut(St::Item) -> Fut,
30{
31 type Item = T;
32
33 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
34 let ThenConcurrentProj {
35 mut stream,
36 mut futures,
37 fun,
38 limit,
39 } = self.project();
40
41 if limit.as_ref().is_none_or(|&l| futures.len() < l) {
43 loop {
44 match stream.as_mut().poll_next(cx) {
45 Poll::Ready(Some(n)) => {
46 futures.push(fun(n));
47 if limit.as_ref().is_some_and(|&l| futures.len() >= l) {
50 break;
51 }
52 }
53 Poll::Ready(None) => {
54 if futures.is_empty() {
55 return Poll::Ready(None);
56 }
57 break;
58 }
59 Poll::Pending => {
60 if futures.is_empty() {
61 return Poll::Pending;
62 }
63 break;
64 }
65 }
66 }
67 }
68
69 futures.as_mut().poll_next(cx)
70 }
71}
72
73pub trait StreamThenConcurrentExt: Stream {
75 fn then_concurrent<Fut, F, L>(self, f: F, limit: L) -> ThenConcurrent<Self, Fut, F>
81 where
82 Self: Sized,
83 Fut: Future,
84 F: FnMut(Self::Item) -> Fut,
85 L: Into<Option<usize>>;
86}
87
88impl<S: Stream> StreamThenConcurrentExt for S {
89 fn then_concurrent<Fut, F, L>(self, f: F, limit: L) -> ThenConcurrent<Self, Fut, F>
90 where
91 Self: Sized,
92 Fut: Future,
93 F: FnMut(Self::Item) -> Fut,
94 L: Into<Option<usize>>,
95 {
96 ThenConcurrent {
97 stream: self,
98 futures: FuturesUnordered::new(),
99 fun: f,
100 limit: limit.into().filter(|&l| l > 0),
101 }
102 }
103}
104
105#[cfg(test)]
106mod tests {
107 use super::*;
108 use futures::{channel::mpsc::unbounded, StreamExt};
109
110 #[tokio::test]
111 async fn no_items() {
112 let stream = futures::stream::iter::<Vec<u64>>(vec![]).then_concurrent(|_| async move {
113 panic!("must not be called");
114 }, None);
115
116 assert_eq!(stream.collect::<Vec<_>>().await, vec![]);
117 }
118
119 #[tokio::test]
120 async fn paused_stream() {
121 let (mut tx, rx) = unbounded::<u64>();
122
123 let mut stream = rx.then_concurrent(|x| async move {
124 if x == 0 {
125 x
126 } else {
127 tokio::time::sleep(std::time::Duration::from_millis(x)).await;
128 x
129 }
130 }, None);
131
132 let first_item = stream.next();
134
135 tx.start_send(0).unwrap();
136
137 assert_eq!(first_item.await, Some(0));
138
139 let second_item = stream.next();
140
141 tx.start_send(5).unwrap();
143
144 assert_eq!(second_item.await, Some(5));
145 }
146
147 #[tokio::test]
148 async fn fast_items() {
149 let item_1 = 0u64;
150 let item_2 = 0u64;
151 let item_3 = 7u64;
152
153 let stream =
154 futures::stream::iter(vec![item_1, item_2, item_3]).then_concurrent(|x| async move {
155 if x == 0 {
156 x
157 } else {
158 tokio::time::sleep(std::time::Duration::from_millis(x)).await;
159 x
160 }
161 }, None);
162 let actual_packets = stream.collect::<Vec<u64>>().await;
163
164 assert_eq!(actual_packets, vec![0, 0, 7]);
165 }
166
167 #[tokio::test]
168 async fn reorder_items() {
169 let item_1 = 10u64; let item_2 = 5u64; let item_3 = 7u64; let stream =
174 futures::stream::iter(vec![item_1, item_2, item_3]).then_concurrent(|x| async move {
175 tokio::time::sleep(std::time::Duration::from_millis(x)).await;
176 x
177 }, None);
178 let actual_packets = stream.collect::<Vec<u64>>().await;
179
180 assert_eq!(actual_packets, vec![5, 7, 10]);
181 }
182}