tokio_par_util/stream/
parallel_buffer_unordered.rs1use std::fmt;
2use std::future::Future;
3use std::pin::Pin;
4use std::task::{Context, Poll};
5
6use futures_util::stream::BufferUnordered;
7use futures_util::Stream;
8use tokio_util::sync::CancellationToken;
9
10use crate::stream::into_tasks::IntoTasks;
11use crate::stream::parallel_buffer::ParallelBuffer;
12#[cfg(doc)]
13use crate::stream::StreamParExt;
14
15#[must_use = "streams do nothing unless polled"]
31#[pin_project::pin_project]
32pub struct ParallelBufferUnordered<St>(#[pin] ParallelBuffer<St, BufferUnordered<IntoTasks<St>>>)
33where
34 St: Stream,
35 St::Item: Future + Send + 'static,
36 <St::Item as Future>::Output: Send;
37
38impl<St> ParallelBufferUnordered<St>
39where
40 St: Stream,
41 St::Item: Future + Send + 'static,
42 <St::Item as Future>::Output: Send,
43{
44 pub fn awaiting_completion(self, value: bool) -> Self {
49 Self(self.0.awaiting_completion(value))
50 }
51
52 pub(crate) fn new(
53 stream: St,
54 cancellation_token: CancellationToken,
55 limit: usize,
56 ) -> ParallelBufferUnordered<St> {
57 Self(ParallelBuffer::new(stream, cancellation_token, limit))
58 }
59}
60
61impl<St> Stream for ParallelBufferUnordered<St>
62where
63 St: Stream,
64 St::Item: Future + Send + 'static,
65 <St::Item as Future>::Output: Send,
66{
67 type Item = <<St as Stream>::Item as Future>::Output;
68
69 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
70 self.project().0.poll_next(cx)
71 }
72
73 fn size_hint(&self) -> (usize, Option<usize>) {
74 self.0.size_hint()
75 }
76}
77
78impl<St> fmt::Debug for ParallelBufferUnordered<St>
79where
80 St: Stream + fmt::Debug,
81 St::Item: fmt::Debug + Future + Send,
82 <St::Item as Future>::Output: fmt::Debug + Send,
83{
84 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
85 f.debug_tuple("ParallelBufferUnordered")
86 .field(&self.0)
87 .finish()
88 }
89}
90
91#[cfg(test)]
92mod tests {
93 use std::collections::HashSet;
94 use std::future;
95 use std::sync::Arc;
96
97 use futures_util::{stream, StreamExt};
98 use scopeguard::defer;
99 use tokio::sync::Semaphore;
100 use tokio::task;
101 use tokio_util::sync::CancellationToken;
102
103 use crate::stream::StreamParExt;
104
105 #[tokio::test]
106 async fn test_parallel_buffer_unordered() -> anyhow::Result<()> {
107 let result_set: HashSet<u32> = stream::iter([1, 2, 3, 4])
108 .map(move |elem| async move { elem + 1 })
109 .parallel_buffer_unordered(4)
110 .collect()
111 .await;
112
113 assert!(result_set.contains(&2));
114 assert!(result_set.contains(&3));
115 assert!(result_set.contains(&4));
116 assert!(result_set.contains(&5));
117
118 Ok(())
119 }
120
121 #[tokio::test]
122 async fn test_parallel_buffer_unordered_await_cancellation() -> anyhow::Result<()> {
123 let drop_set = Arc::new(dashmap::DashSet::new());
124 let semaphore = Arc::new(Semaphore::new(0));
125
126 let future = stream::iter([1, 2, 3, 4])
127 .map({
128 let drop_set = Arc::clone(&drop_set);
129 let semaphore = Arc::clone(&semaphore);
130
131 move |elem| {
132 let drop_set = Arc::clone(&drop_set);
133 let semaphore = Arc::clone(&semaphore);
134 async move {
135 defer! { drop_set.insert(elem); }
136 semaphore.add_permits(1);
137 future::pending::<u32>().await;
139 }
140 }
141 })
142 .parallel_buffer_unordered(4)
143 .collect::<HashSet<_>>();
144 let task = task::spawn(future);
145
146 drop(semaphore.acquire_many(4).await?);
148
149 task.abort();
150
151 if let Err(err) = task.await {
152 assert!(err.is_cancelled());
153 } else {
154 panic!("expected task to be cancelled")
155 }
156
157 assert!(drop_set.contains(&1));
159 assert!(drop_set.contains(&2));
160 assert!(drop_set.contains(&3));
161 assert!(drop_set.contains(&4));
162
163 Ok(())
164 }
165
166 #[tokio::test]
167 async fn test_parallel_buffer_unordered_cancel_via_token() -> anyhow::Result<()> {
168 let drop_set = Arc::new(dashmap::DashSet::new());
169 let semaphore = Arc::new(Semaphore::new(0));
170 let cancellation_token = CancellationToken::new();
171
172 let future = stream::iter([1, 2, 3, 4])
173 .map({
174 let drop_set = Arc::clone(&drop_set);
175 let semaphore = Arc::clone(&semaphore);
176
177 move |elem| {
178 let drop_set = Arc::clone(&drop_set);
179 let semaphore = Arc::clone(&semaphore);
180 async move {
181 defer! { drop_set.insert(elem); }
182 semaphore.add_permits(1);
183 future::pending::<u32>().await;
185 }
186 }
187 })
188 .parallel_buffer_unordered_with_token(4, cancellation_token.clone())
189 .collect::<HashSet<_>>();
190 let task = task::spawn(future);
191
192 drop(semaphore.acquire_many(4).await?);
194
195 cancellation_token.cancel();
196
197 let returned_set = task.await?;
199 assert!(returned_set.is_empty());
200
201 assert!(drop_set.contains(&1));
203 assert!(drop_set.contains(&2));
204 assert!(drop_set.contains(&3));
205 assert!(drop_set.contains(&4));
206
207 Ok(())
208 }
209
210 #[tokio::test]
211 async fn test_parallel_buffer_unordered_panic() -> anyhow::Result<()> {
212 let drop_set = Arc::new(dashmap::DashSet::new());
213 let semaphore = Arc::new(Semaphore::new(0));
214
215 let future = stream::iter([1, 2, 3, 4])
216 .map({
217 let drop_set = Arc::clone(&drop_set);
218 let semaphore = Arc::clone(&semaphore);
219
220 move |elem| {
221 let drop_set = Arc::clone(&drop_set);
222 let semaphore = Arc::clone(&semaphore);
223 async move {
224 defer! { drop_set.insert(elem); }
225 semaphore.add_permits(1);
226 if elem == 2 {
227 panic!("allergic to the number 2")
228 }
229 future::pending::<u32>().await;
231 }
232 }
233 })
234 .parallel_buffer_unordered(4)
235 .collect::<HashSet<_>>();
236 let task = task::spawn(future);
237
238 drop(semaphore.acquire_many(4).await?);
240
241 let res = task.await;
243
244 assert!(drop_set.contains(&1));
246 assert!(drop_set.contains(&2));
247 assert!(drop_set.contains(&3));
248 assert!(drop_set.contains(&4));
249
250 let err = res.err().unwrap();
251 let panic_msg = *err.into_panic().downcast_ref::<&'static str>().unwrap();
252 assert_eq!(panic_msg, "allergic to the number 2");
253
254 Ok(())
255 }
256}