tokio_par_util/try_stream/
try_parallel_buffer_unordered.rs

1use std::fmt;
2use std::pin::Pin;
3use std::task::{Context, Poll};
4
5use futures_util::stream::TryBufferUnordered;
6use futures_util::{Stream, TryFuture, TryStream};
7use tokio_util::sync::CancellationToken;
8
9use crate::try_stream::try_into_tasks::TryIntoTasks;
10use crate::try_stream::try_parallel_buffer::TryParallelBuffer;
11#[cfg(doc)]
12use crate::try_stream::TryStreamParExt;
13
14/// Given a [`TryStream`] where every item is a [`TryFuture`], this stream
15/// buffers up to a certain number of futures, runs them in parallel on separate
16/// tasks, and returns the results of the futures in completion order.
17///
18/// If any of the futures returns an error or panics, all other futures and
19/// tasks are immediately cancelled and the error/panic gets returned
20/// immediately.
21///
22/// This stream is **cancellation safe** if the inner stream and generated
23/// futures are also cancellation safe.  This means that dropping this stream
24/// will also cancel any outstanding tasks and drop the relevant
25/// futures/streams.
26///
27/// You can use [`TryParallelBufferUnordered::awaiting_completion`] to control
28/// whether we wait for all tasks to fully terminate before this stream is
29/// considered to have ended.
30#[must_use = "streams do nothing unless polled"]
31#[pin_project::pin_project]
32pub struct TryParallelBufferUnordered<St>(
33    #[pin] TryParallelBuffer<St, TryBufferUnordered<TryIntoTasks<St>>>,
34)
35where
36    St: TryStream,
37    St::Ok: TryFuture<Error = St::Error> + Send + 'static,
38    St::Error: Send,
39    <St::Ok as TryFuture>::Ok: Send,
40    <St::Ok as TryFuture>::Error: Send;
41
42impl<St> TryParallelBufferUnordered<St>
43where
44    St: TryStream,
45    St::Ok: TryFuture<Error = St::Error> + Send + 'static,
46    St::Error: Send,
47    <St::Ok as TryFuture>::Ok: Send,
48    <St::Ok as TryFuture>::Error: Send,
49{
50    /// Whether to always await completion for the tasks from the
51    /// [`TryStreamParExt::try_parallel_buffer_unordered`] or
52    /// [`TryStreamParExt::try_parallel_buffer_unordered_with_token`] call,
53    /// even when one of the futures has panicked or been cancelled.
54    pub fn awaiting_completion(self, value: bool) -> Self {
55        Self(self.0.awaiting_completion(value))
56    }
57
58    pub(crate) fn new(
59        stream: St,
60        cancellation_token: CancellationToken,
61        limit: usize,
62    ) -> TryParallelBufferUnordered<St> {
63        Self(TryParallelBuffer::new(stream, cancellation_token, limit))
64    }
65}
66
67impl<St> Stream for TryParallelBufferUnordered<St>
68where
69    St: TryStream,
70    St::Ok: TryFuture<Error = St::Error> + Send + 'static,
71    St::Error: Send,
72    <St::Ok as TryFuture>::Ok: Send,
73    <St::Ok as TryFuture>::Error: Send,
74{
75    type Item = Result<<<St as TryStream>::Ok as TryFuture>::Ok, <St as TryStream>::Error>;
76
77    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
78        self.project().0.poll_next(cx)
79    }
80
81    fn size_hint(&self) -> (usize, Option<usize>) {
82        self.0.size_hint()
83    }
84}
85
86impl<St> fmt::Debug for TryParallelBufferUnordered<St>
87where
88    St: TryStream + fmt::Debug,
89    St::Ok: fmt::Debug + TryFuture<Error = St::Error> + Send,
90    <St::Ok as TryFuture>::Ok: fmt::Debug + Send,
91    <St::Ok as TryFuture>::Error: fmt::Debug + Send,
92{
93    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
94        f.debug_tuple("TryParallelBufferUnordered")
95            .field(&self.0)
96            .finish()
97    }
98}
99
100#[cfg(test)]
101mod tests {
102    use std::collections::HashSet;
103    use std::future;
104    use std::future::poll_fn;
105    use std::pin::pin;
106    use std::sync::Arc;
107
108    use futures_util::{stream, Stream, StreamExt};
109    use scopeguard::defer;
110    use tokio::sync::Semaphore;
111    use tokio::task;
112    use tokio_util::sync::CancellationToken;
113
114    use crate::stream::StreamParExt;
115    use crate::try_stream::TryStreamParExt;
116
117    #[tokio::test]
118    async fn test_parallel_buffer_unordered() -> anyhow::Result<()> {
119        let result_set: HashSet<u32> = stream::iter([1, 2, 3, 4])
120            .map(move |elem| async move { elem + 1 })
121            .parallel_buffer_unordered(4)
122            .collect()
123            .await;
124
125        assert!(result_set.contains(&2));
126        assert!(result_set.contains(&3));
127        assert!(result_set.contains(&4));
128        assert!(result_set.contains(&5));
129
130        Ok(())
131    }
132
133    #[tokio::test]
134    async fn test_parallel_buffer_unordered_await_cancellation() -> anyhow::Result<()> {
135        let drop_set = Arc::new(dashmap::DashSet::new());
136        let semaphore = Arc::new(Semaphore::new(0));
137
138        let future = stream::iter([1, 2, 3, 4])
139            .map({
140                let drop_set = Arc::clone(&drop_set);
141                let semaphore = Arc::clone(&semaphore);
142
143                move |elem| {
144                    let drop_set = Arc::clone(&drop_set);
145                    let semaphore = Arc::clone(&semaphore);
146                    async move {
147                        defer! { drop_set.insert(elem); }
148                        semaphore.add_permits(1);
149                        // Block forever here
150                        future::pending::<u32>().await;
151                    }
152                }
153            })
154            .parallel_buffer_unordered(4)
155            .collect::<HashSet<_>>();
156        let task = task::spawn(future);
157
158        // Ensure all futures have made progress past `defer!`
159        drop(semaphore.acquire_many(4).await?);
160
161        task.abort();
162
163        if let Err(err) = task.await {
164            assert!(err.is_cancelled());
165        } else {
166            panic!("expected task to be cancelled")
167        }
168
169        // Check that `defer!` scope guards ran
170        assert!(drop_set.contains(&1));
171        assert!(drop_set.contains(&2));
172        assert!(drop_set.contains(&3));
173        assert!(drop_set.contains(&4));
174
175        Ok(())
176    }
177
178    #[tokio::test]
179    async fn test_parallel_buffer_unordered_cancel_via_token() -> anyhow::Result<()> {
180        let drop_set = Arc::new(dashmap::DashSet::new());
181        let semaphore = Arc::new(Semaphore::new(0));
182        let cancellation_token = CancellationToken::new();
183
184        let future = stream::iter([1, 2, 3, 4])
185            .map({
186                let drop_set = Arc::clone(&drop_set);
187                let semaphore = Arc::clone(&semaphore);
188
189                move |elem| {
190                    let drop_set = Arc::clone(&drop_set);
191                    let semaphore = Arc::clone(&semaphore);
192                    async move {
193                        defer! { drop_set.insert(elem); }
194                        semaphore.add_permits(1);
195                        // Block forever here
196                        future::pending::<u32>().await;
197                    }
198                }
199            })
200            .parallel_buffer_unordered_with_token(4, cancellation_token.clone())
201            .collect::<HashSet<_>>();
202        let task = task::spawn(future);
203
204        // Ensure all futures have made progress past `defer!`
205        drop(semaphore.acquire_many(4).await?);
206
207        cancellation_token.cancel();
208
209        // The result from the spawned task is `Ok(HashSet::new())`
210        let returned_set = task.await?;
211        assert!(returned_set.is_empty());
212
213        // Check that `defer!` scope guards ran
214        assert!(drop_set.contains(&1));
215        assert!(drop_set.contains(&2));
216        assert!(drop_set.contains(&3));
217        assert!(drop_set.contains(&4));
218
219        Ok(())
220    }
221
222    #[tokio::test]
223    async fn test_parallel_buffer_unordered_panic() -> anyhow::Result<()> {
224        let drop_set = Arc::new(dashmap::DashSet::new());
225        let semaphore = Arc::new(Semaphore::new(0));
226
227        let future = stream::iter([1, 2, 3, 4])
228            .map({
229                let drop_set = Arc::clone(&drop_set);
230                let semaphore = Arc::clone(&semaphore);
231
232                move |elem| {
233                    let drop_set = Arc::clone(&drop_set);
234                    let semaphore = Arc::clone(&semaphore);
235                    async move {
236                        defer! { drop_set.insert(elem); }
237                        semaphore.add_permits(1);
238                        if elem == 2 {
239                            panic!("allergic to the number 2")
240                        }
241                        // Block forever here
242                        future::pending::<u32>().await;
243                    }
244                }
245            })
246            .parallel_buffer_unordered(4)
247            .collect::<HashSet<_>>();
248        let task = task::spawn(future);
249
250        // Ensure all futures have made progress past `defer!`
251        drop(semaphore.acquire_many(4).await?);
252
253        // Expect a panic to be caught here
254        let res = task.await;
255
256        // Check that `defer!` scope guards ran
257        assert!(drop_set.contains(&1));
258        assert!(drop_set.contains(&2));
259        assert!(drop_set.contains(&3));
260        assert!(drop_set.contains(&4));
261
262        let err = res.err().unwrap();
263        let panic_msg = *err.into_panic().downcast_ref::<&'static str>().unwrap();
264        assert_eq!(panic_msg, "allergic to the number 2");
265
266        Ok(())
267    }
268
269    /// Regression test: TryParallelBuffer should not panic when poll_next is
270    /// called after it has returned Some(Err(...)). Previously, the internal
271    /// Wait future would be polled after completion, causing a panic with
272    /// "async fn resumed after completion".
273    #[tokio::test]
274    async fn test_try_parallel_buffer_unordered_no_panic_after_stream_error() {
275        // Create a stream that immediately yields an error
276        let input_stream = stream::iter([Err::<std::future::Ready<Result<u32, &str>>, _>(
277            "stream error",
278        )]);
279
280        let mut buffered = pin!(input_stream.try_parallel_buffer_unordered(4));
281
282        // Poll 1: Should get the stream error
283        let item1 = poll_fn(|cx| buffered.as_mut().poll_next(cx)).await;
284        assert!(
285            matches!(item1, Some(Err("stream error"))),
286            "expected Some(Err(\"stream error\")), got {:?}",
287            item1
288        );
289
290        // Poll 2: This previously caused a panic. Should now return None.
291        let item2 = poll_fn(|cx| buffered.as_mut().poll_next(cx)).await;
292        assert!(
293            item2.is_none(),
294            "expected None after error, got {:?}",
295            item2
296        );
297    }
298}