tokio_par_util/stream/
parallel_buffered.rs1use std::fmt;
2use std::future::Future;
3use std::pin::Pin;
4use std::task::{Context, Poll};
5
6use futures_util::stream::Buffered;
7use futures_util::Stream;
8use tokio_util::sync::CancellationToken;
9
10use crate::stream::into_tasks::IntoTasks;
11use crate::stream::parallel_buffer::ParallelBuffer;
12#[cfg(doc)]
13use crate::stream::StreamParExt;
14
15#[must_use = "streams do nothing unless polled"]
32#[pin_project::pin_project]
33pub struct ParallelBuffered<St>(#[pin] ParallelBuffer<St, Buffered<IntoTasks<St>>>)
34where
35 St: Stream,
36 St::Item: Future + Send + 'static,
37 <St::Item as Future>::Output: Send;
38
39impl<St> ParallelBuffered<St>
40where
41 St: Stream,
42 St::Item: Future + Send + 'static,
43 <St::Item as Future>::Output: Send,
44{
45 pub fn awaiting_completion(self, value: bool) -> Self {
50 Self(self.0.awaiting_completion(value))
51 }
52
53 pub(crate) fn new(
54 stream: St,
55 cancellation_token: CancellationToken,
56 limit: usize,
57 ) -> ParallelBuffered<St> {
58 Self(ParallelBuffer::new(stream, cancellation_token, limit))
59 }
60}
61
62impl<St> Stream for ParallelBuffered<St>
63where
64 St: Stream,
65 St::Item: Future + Send,
66 <St::Item as Future>::Output: Send,
67{
68 type Item = <<St as Stream>::Item as Future>::Output;
69
70 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
71 self.project().0.poll_next(cx)
72 }
73
74 fn size_hint(&self) -> (usize, Option<usize>) {
75 self.0.size_hint()
76 }
77}
78
79impl<St> fmt::Debug for ParallelBuffered<St>
80where
81 St: fmt::Debug + Stream,
82 St::Item: fmt::Debug + Future + Send,
83 <St::Item as Future>::Output: Send,
84{
85 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
86 f.debug_tuple("ParallelBuffered").field(&self.0).finish()
87 }
88}
89
90#[cfg(test)]
91mod tests {
92 use std::future;
93 use std::sync::Arc;
94
95 use futures_util::{stream, StreamExt};
96 use scopeguard::defer;
97 use tokio::sync::Semaphore;
98 use tokio::task;
99 use tokio_util::sync::CancellationToken;
100
101 use crate::stream::StreamParExt;
102
103 #[tokio::test]
104 async fn test_parallel_buffered() -> anyhow::Result<()> {
105 let result_vec: Vec<u32> = stream::iter([1, 2, 3, 4])
106 .map(move |elem| async move { elem + 1 })
107 .parallel_buffered(4)
108 .collect()
109 .await;
110
111 assert_eq!(result_vec, &[2, 3, 4, 5]);
112
113 Ok(())
114 }
115
116 #[tokio::test]
117 async fn test_parallel_buffered_await_cancellation() -> anyhow::Result<()> {
118 let drop_set = Arc::new(dashmap::DashSet::new());
119 let semaphore = Arc::new(Semaphore::new(0));
120
121 let future = stream::iter([1, 2, 3, 4])
122 .map({
123 let drop_set = Arc::clone(&drop_set);
124 let semaphore = Arc::clone(&semaphore);
125
126 move |elem| {
127 let drop_set = Arc::clone(&drop_set);
128 let semaphore = Arc::clone(&semaphore);
129 async move {
130 defer! { drop_set.insert(elem); }
131 semaphore.add_permits(1);
132 future::pending::<u32>().await;
134 }
135 }
136 })
137 .parallel_buffered(4)
138 .collect::<Vec<_>>();
139 let task = task::spawn(future);
140
141 drop(semaphore.acquire_many(4).await?);
143
144 task.abort();
145
146 if let Err(err) = task.await {
147 assert!(err.is_cancelled());
148 } else {
149 panic!("expected task to be cancelled")
150 }
151
152 assert!(drop_set.contains(&1));
154 assert!(drop_set.contains(&2));
155 assert!(drop_set.contains(&3));
156 assert!(drop_set.contains(&4));
157
158 Ok(())
159 }
160
161 #[tokio::test]
162 async fn test_parallel_buffered_cancel_via_token() -> anyhow::Result<()> {
163 let drop_set = Arc::new(dashmap::DashSet::new());
164 let semaphore = Arc::new(Semaphore::new(0));
165 let cancellation_token = CancellationToken::new();
166
167 let future = stream::iter([1, 2, 3, 4])
168 .map({
169 let drop_set = Arc::clone(&drop_set);
170 let semaphore = Arc::clone(&semaphore);
171
172 move |elem| {
173 let drop_set = Arc::clone(&drop_set);
174 let semaphore = Arc::clone(&semaphore);
175 async move {
176 defer! { drop_set.insert(elem); }
177 semaphore.add_permits(1);
178 future::pending::<u32>().await;
180 }
181 }
182 })
183 .parallel_buffered_with_token(4, cancellation_token.clone())
184 .collect::<Vec<_>>();
185 let task = task::spawn(future);
186
187 drop(semaphore.acquire_many(4).await?);
189
190 cancellation_token.cancel();
191
192 let returned_vec = task.await?;
194 assert!(returned_vec.is_empty());
195
196 assert!(drop_set.contains(&1));
198 assert!(drop_set.contains(&2));
199 assert!(drop_set.contains(&3));
200 assert!(drop_set.contains(&4));
201
202 Ok(())
203 }
204
205 #[tokio::test]
206 async fn test_parallel_buffered_panic() -> anyhow::Result<()> {
207 let drop_set = Arc::new(dashmap::DashSet::new());
208 let semaphore = Arc::new(Semaphore::new(0));
209
210 let future = stream::iter([1, 2, 3, 4])
211 .map({
212 let drop_set = Arc::clone(&drop_set);
213 let semaphore = Arc::clone(&semaphore);
214
215 move |elem| {
216 let drop_set = Arc::clone(&drop_set);
217 let semaphore = Arc::clone(&semaphore);
218 async move {
219 defer! { drop_set.insert(elem); }
220 semaphore.add_permits(1);
221 if elem == 2 {
222 panic!("allergic to the number 2")
223 }
224 if elem > 2 {
225 future::pending::<u32>().await;
227 }
228 elem + 1
229 }
230 }
231 })
232 .parallel_buffered(4)
233 .collect::<Vec<_>>();
234 let task = task::spawn(future);
235
236 drop(semaphore.acquire_many(4).await?);
238
239 let res = task.await;
241
242 assert!(drop_set.contains(&1));
244 assert!(drop_set.contains(&2));
245 assert!(drop_set.contains(&3));
246 assert!(drop_set.contains(&4));
247
248 let err = res.err().unwrap();
249 let panic_msg = *err.into_panic().downcast_ref::<&'static str>().unwrap();
250 assert_eq!(panic_msg, "allergic to the number 2");
251
252 Ok(())
253 }
254}