tokio_par_util/try_stream/
try_parallel_buffer_unordered.rs1use std::fmt;
2use std::pin::Pin;
3use std::task::{Context, Poll};
4
5use futures_util::stream::TryBufferUnordered;
6use futures_util::{Stream, TryFuture, TryStream};
7use tokio_util::sync::CancellationToken;
8
9use crate::try_stream::try_into_tasks::TryIntoTasks;
10use crate::try_stream::try_parallel_buffer::TryParallelBuffer;
11#[cfg(doc)]
12use crate::try_stream::TryStreamParExt;
13
14#[must_use = "streams do nothing unless polled"]
31#[pin_project::pin_project]
32pub struct TryParallelBufferUnordered<St>(
33 #[pin] TryParallelBuffer<St, TryBufferUnordered<TryIntoTasks<St>>>,
34)
35where
36 St: TryStream,
37 St::Ok: TryFuture<Error = St::Error> + Send + 'static,
38 St::Error: Send,
39 <St::Ok as TryFuture>::Ok: Send,
40 <St::Ok as TryFuture>::Error: Send;
41
42impl<St> TryParallelBufferUnordered<St>
43where
44 St: TryStream,
45 St::Ok: TryFuture<Error = St::Error> + Send + 'static,
46 St::Error: Send,
47 <St::Ok as TryFuture>::Ok: Send,
48 <St::Ok as TryFuture>::Error: Send,
49{
50 pub fn awaiting_completion(self, value: bool) -> Self {
55 Self(self.0.awaiting_completion(value))
56 }
57
58 pub(crate) fn new(
59 stream: St,
60 cancellation_token: CancellationToken,
61 limit: usize,
62 ) -> TryParallelBufferUnordered<St> {
63 Self(TryParallelBuffer::new(stream, cancellation_token, limit))
64 }
65}
66
67impl<St> Stream for TryParallelBufferUnordered<St>
68where
69 St: TryStream,
70 St::Ok: TryFuture<Error = St::Error> + Send + 'static,
71 St::Error: Send,
72 <St::Ok as TryFuture>::Ok: Send,
73 <St::Ok as TryFuture>::Error: Send,
74{
75 type Item = Result<<<St as TryStream>::Ok as TryFuture>::Ok, <St as TryStream>::Error>;
76
77 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
78 self.project().0.poll_next(cx)
79 }
80
81 fn size_hint(&self) -> (usize, Option<usize>) {
82 self.0.size_hint()
83 }
84}
85
86impl<St> fmt::Debug for TryParallelBufferUnordered<St>
87where
88 St: TryStream + fmt::Debug,
89 St::Ok: fmt::Debug + TryFuture<Error = St::Error> + Send,
90 <St::Ok as TryFuture>::Ok: fmt::Debug + Send,
91 <St::Ok as TryFuture>::Error: fmt::Debug + Send,
92{
93 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
94 f.debug_tuple("TryParallelBufferUnordered")
95 .field(&self.0)
96 .finish()
97 }
98}
99
100#[cfg(test)]
101mod tests {
102 use std::collections::HashSet;
103 use std::future;
104 use std::future::poll_fn;
105 use std::pin::pin;
106 use std::sync::Arc;
107
108 use futures_util::{stream, Stream, StreamExt};
109 use scopeguard::defer;
110 use tokio::sync::Semaphore;
111 use tokio::task;
112 use tokio_util::sync::CancellationToken;
113
114 use crate::stream::StreamParExt;
115 use crate::try_stream::TryStreamParExt;
116
117 #[tokio::test]
118 async fn test_parallel_buffer_unordered() -> anyhow::Result<()> {
119 let result_set: HashSet<u32> = stream::iter([1, 2, 3, 4])
120 .map(move |elem| async move { elem + 1 })
121 .parallel_buffer_unordered(4)
122 .collect()
123 .await;
124
125 assert!(result_set.contains(&2));
126 assert!(result_set.contains(&3));
127 assert!(result_set.contains(&4));
128 assert!(result_set.contains(&5));
129
130 Ok(())
131 }
132
133 #[tokio::test]
134 async fn test_parallel_buffer_unordered_await_cancellation() -> anyhow::Result<()> {
135 let drop_set = Arc::new(dashmap::DashSet::new());
136 let semaphore = Arc::new(Semaphore::new(0));
137
138 let future = stream::iter([1, 2, 3, 4])
139 .map({
140 let drop_set = Arc::clone(&drop_set);
141 let semaphore = Arc::clone(&semaphore);
142
143 move |elem| {
144 let drop_set = Arc::clone(&drop_set);
145 let semaphore = Arc::clone(&semaphore);
146 async move {
147 defer! { drop_set.insert(elem); }
148 semaphore.add_permits(1);
149 future::pending::<u32>().await;
151 }
152 }
153 })
154 .parallel_buffer_unordered(4)
155 .collect::<HashSet<_>>();
156 let task = task::spawn(future);
157
158 drop(semaphore.acquire_many(4).await?);
160
161 task.abort();
162
163 if let Err(err) = task.await {
164 assert!(err.is_cancelled());
165 } else {
166 panic!("expected task to be cancelled")
167 }
168
169 assert!(drop_set.contains(&1));
171 assert!(drop_set.contains(&2));
172 assert!(drop_set.contains(&3));
173 assert!(drop_set.contains(&4));
174
175 Ok(())
176 }
177
178 #[tokio::test]
179 async fn test_parallel_buffer_unordered_cancel_via_token() -> anyhow::Result<()> {
180 let drop_set = Arc::new(dashmap::DashSet::new());
181 let semaphore = Arc::new(Semaphore::new(0));
182 let cancellation_token = CancellationToken::new();
183
184 let future = stream::iter([1, 2, 3, 4])
185 .map({
186 let drop_set = Arc::clone(&drop_set);
187 let semaphore = Arc::clone(&semaphore);
188
189 move |elem| {
190 let drop_set = Arc::clone(&drop_set);
191 let semaphore = Arc::clone(&semaphore);
192 async move {
193 defer! { drop_set.insert(elem); }
194 semaphore.add_permits(1);
195 future::pending::<u32>().await;
197 }
198 }
199 })
200 .parallel_buffer_unordered_with_token(4, cancellation_token.clone())
201 .collect::<HashSet<_>>();
202 let task = task::spawn(future);
203
204 drop(semaphore.acquire_many(4).await?);
206
207 cancellation_token.cancel();
208
209 let returned_set = task.await?;
211 assert!(returned_set.is_empty());
212
213 assert!(drop_set.contains(&1));
215 assert!(drop_set.contains(&2));
216 assert!(drop_set.contains(&3));
217 assert!(drop_set.contains(&4));
218
219 Ok(())
220 }
221
222 #[tokio::test]
223 async fn test_parallel_buffer_unordered_panic() -> anyhow::Result<()> {
224 let drop_set = Arc::new(dashmap::DashSet::new());
225 let semaphore = Arc::new(Semaphore::new(0));
226
227 let future = stream::iter([1, 2, 3, 4])
228 .map({
229 let drop_set = Arc::clone(&drop_set);
230 let semaphore = Arc::clone(&semaphore);
231
232 move |elem| {
233 let drop_set = Arc::clone(&drop_set);
234 let semaphore = Arc::clone(&semaphore);
235 async move {
236 defer! { drop_set.insert(elem); }
237 semaphore.add_permits(1);
238 if elem == 2 {
239 panic!("allergic to the number 2")
240 }
241 future::pending::<u32>().await;
243 }
244 }
245 })
246 .parallel_buffer_unordered(4)
247 .collect::<HashSet<_>>();
248 let task = task::spawn(future);
249
250 drop(semaphore.acquire_many(4).await?);
252
253 let res = task.await;
255
256 assert!(drop_set.contains(&1));
258 assert!(drop_set.contains(&2));
259 assert!(drop_set.contains(&3));
260 assert!(drop_set.contains(&4));
261
262 let err = res.err().unwrap();
263 let panic_msg = *err.into_panic().downcast_ref::<&'static str>().unwrap();
264 assert_eq!(panic_msg, "allergic to the number 2");
265
266 Ok(())
267 }
268
269 #[tokio::test]
274 async fn test_try_parallel_buffer_unordered_no_panic_after_stream_error() {
275 let input_stream = stream::iter([Err::<std::future::Ready<Result<u32, &str>>, _>(
277 "stream error",
278 )]);
279
280 let mut buffered = pin!(input_stream.try_parallel_buffer_unordered(4));
281
282 let item1 = poll_fn(|cx| buffered.as_mut().poll_next(cx)).await;
284 assert!(
285 matches!(item1, Some(Err("stream error"))),
286 "expected Some(Err(\"stream error\")), got {:?}",
287 item1
288 );
289
290 let item2 = poll_fn(|cx| buffered.as_mut().poll_next(cx)).await;
292 assert!(
293 item2.is_none(),
294 "expected None after error, got {:?}",
295 item2
296 );
297 }
298}