tokio_par_util/stream/
parallel_buffer_unordered.rs

1use std::fmt;
2use std::future::Future;
3use std::pin::Pin;
4use std::task::{Context, Poll};
5
6use futures_util::stream::BufferUnordered;
7use futures_util::Stream;
8use tokio_util::sync::CancellationToken;
9
10use crate::stream::into_tasks::IntoTasks;
11use crate::stream::parallel_buffer::ParallelBuffer;
12#[cfg(doc)]
13use crate::stream::StreamParExt;
14
15/// Given a [`Stream`] where every item is a [`Future`], this stream buffers up
16/// to a certain number of futures, runs them in parallel on separate tasks, and
17/// returns the results of the futures in completion order.
18///
19/// If any of the futures panics, all other futures and tasks are immediately
20/// cancelled and the panic gets returned immediately.
21///
22/// This stream is **cancellation safe** if the inner stream and generated
23/// futures are also cancellation safe.  This means that dropping this stream
24/// will also cancel any outstanding tasks and drop the relevant
25/// futures/streams.
26///
27/// You can use [`ParallelBufferUnordered::awaiting_completion`] to control
28/// whether we wait for all tasks to fully terminate before this stream is
29/// considered to have ended.
30#[must_use = "streams do nothing unless polled"]
31#[pin_project::pin_project]
32pub struct ParallelBufferUnordered<St>(#[pin] ParallelBuffer<St, BufferUnordered<IntoTasks<St>>>)
33where
34    St: Stream,
35    St::Item: Future + Send + 'static,
36    <St::Item as Future>::Output: Send;
37
38impl<St> ParallelBufferUnordered<St>
39where
40    St: Stream,
41    St::Item: Future + Send + 'static,
42    <St::Item as Future>::Output: Send,
43{
44    /// Whether to always await completion for the tasks from the
45    /// [`StreamParExt::parallel_buffer_unordered`] or
46    /// [`StreamParExt::parallel_buffer_unordered_with_token`] call, even when
47    /// one of the futures has panicked or been cancelled.
48    pub fn awaiting_completion(self, value: bool) -> Self {
49        Self(self.0.awaiting_completion(value))
50    }
51
52    pub(crate) fn new(
53        stream: St,
54        cancellation_token: CancellationToken,
55        limit: usize,
56    ) -> ParallelBufferUnordered<St> {
57        Self(ParallelBuffer::new(stream, cancellation_token, limit))
58    }
59}
60
61impl<St> Stream for ParallelBufferUnordered<St>
62where
63    St: Stream,
64    St::Item: Future + Send + 'static,
65    <St::Item as Future>::Output: Send,
66{
67    type Item = <<St as Stream>::Item as Future>::Output;
68
69    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
70        self.project().0.poll_next(cx)
71    }
72
73    fn size_hint(&self) -> (usize, Option<usize>) {
74        self.0.size_hint()
75    }
76}
77
78impl<St> fmt::Debug for ParallelBufferUnordered<St>
79where
80    St: Stream + fmt::Debug,
81    St::Item: fmt::Debug + Future + Send,
82    <St::Item as Future>::Output: fmt::Debug + Send,
83{
84    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
85        f.debug_tuple("ParallelBufferUnordered")
86            .field(&self.0)
87            .finish()
88    }
89}
90
91#[cfg(test)]
92mod tests {
93    use std::collections::HashSet;
94    use std::future;
95    use std::sync::Arc;
96
97    use futures_util::{stream, StreamExt};
98    use scopeguard::defer;
99    use tokio::sync::Semaphore;
100    use tokio::task;
101    use tokio_util::sync::CancellationToken;
102
103    use crate::stream::StreamParExt;
104
105    #[tokio::test]
106    async fn test_parallel_buffer_unordered() -> anyhow::Result<()> {
107        let result_set: HashSet<u32> = stream::iter([1, 2, 3, 4])
108            .map(move |elem| async move { elem + 1 })
109            .parallel_buffer_unordered(4)
110            .collect()
111            .await;
112
113        assert!(result_set.contains(&2));
114        assert!(result_set.contains(&3));
115        assert!(result_set.contains(&4));
116        assert!(result_set.contains(&5));
117
118        Ok(())
119    }
120
121    #[tokio::test]
122    async fn test_parallel_buffer_unordered_await_cancellation() -> anyhow::Result<()> {
123        let drop_set = Arc::new(dashmap::DashSet::new());
124        let semaphore = Arc::new(Semaphore::new(0));
125
126        let future = stream::iter([1, 2, 3, 4])
127            .map({
128                let drop_set = Arc::clone(&drop_set);
129                let semaphore = Arc::clone(&semaphore);
130
131                move |elem| {
132                    let drop_set = Arc::clone(&drop_set);
133                    let semaphore = Arc::clone(&semaphore);
134                    async move {
135                        defer! { drop_set.insert(elem); }
136                        semaphore.add_permits(1);
137                        // Block forever here
138                        future::pending::<u32>().await;
139                    }
140                }
141            })
142            .parallel_buffer_unordered(4)
143            .collect::<HashSet<_>>();
144        let task = task::spawn(future);
145
146        // Ensure all futures have made progress past `defer!`
147        drop(semaphore.acquire_many(4).await?);
148
149        task.abort();
150
151        if let Err(err) = task.await {
152            assert!(err.is_cancelled());
153        } else {
154            panic!("expected task to be cancelled")
155        }
156
157        // Check that `defer!` scope guards ran
158        assert!(drop_set.contains(&1));
159        assert!(drop_set.contains(&2));
160        assert!(drop_set.contains(&3));
161        assert!(drop_set.contains(&4));
162
163        Ok(())
164    }
165
166    #[tokio::test]
167    async fn test_parallel_buffer_unordered_cancel_via_token() -> anyhow::Result<()> {
168        let drop_set = Arc::new(dashmap::DashSet::new());
169        let semaphore = Arc::new(Semaphore::new(0));
170        let cancellation_token = CancellationToken::new();
171
172        let future = stream::iter([1, 2, 3, 4])
173            .map({
174                let drop_set = Arc::clone(&drop_set);
175                let semaphore = Arc::clone(&semaphore);
176
177                move |elem| {
178                    let drop_set = Arc::clone(&drop_set);
179                    let semaphore = Arc::clone(&semaphore);
180                    async move {
181                        defer! { drop_set.insert(elem); }
182                        semaphore.add_permits(1);
183                        // Block forever here
184                        future::pending::<u32>().await;
185                    }
186                }
187            })
188            .parallel_buffer_unordered_with_token(4, cancellation_token.clone())
189            .collect::<HashSet<_>>();
190        let task = task::spawn(future);
191
192        // Ensure all futures have made progress past `defer!`
193        drop(semaphore.acquire_many(4).await?);
194
195        cancellation_token.cancel();
196
197        // The result from the spawned task is `Ok(HashSet::new())`
198        let returned_set = task.await?;
199        assert!(returned_set.is_empty());
200
201        // Check that `defer!` scope guards ran
202        assert!(drop_set.contains(&1));
203        assert!(drop_set.contains(&2));
204        assert!(drop_set.contains(&3));
205        assert!(drop_set.contains(&4));
206
207        Ok(())
208    }
209
210    #[tokio::test]
211    async fn test_parallel_buffer_unordered_panic() -> anyhow::Result<()> {
212        let drop_set = Arc::new(dashmap::DashSet::new());
213        let semaphore = Arc::new(Semaphore::new(0));
214
215        let future = stream::iter([1, 2, 3, 4])
216            .map({
217                let drop_set = Arc::clone(&drop_set);
218                let semaphore = Arc::clone(&semaphore);
219
220                move |elem| {
221                    let drop_set = Arc::clone(&drop_set);
222                    let semaphore = Arc::clone(&semaphore);
223                    async move {
224                        defer! { drop_set.insert(elem); }
225                        semaphore.add_permits(1);
226                        if elem == 2 {
227                            panic!("allergic to the number 2")
228                        }
229                        // Block forever here
230                        future::pending::<u32>().await;
231                    }
232                }
233            })
234            .parallel_buffer_unordered(4)
235            .collect::<HashSet<_>>();
236        let task = task::spawn(future);
237
238        // Ensure all futures have made progress past `defer!`
239        drop(semaphore.acquire_many(4).await?);
240
241        // Expect a panic to be caught here
242        let res = task.await;
243
244        // Check that `defer!` scope guards ran
245        assert!(drop_set.contains(&1));
246        assert!(drop_set.contains(&2));
247        assert!(drop_set.contains(&3));
248        assert!(drop_set.contains(&4));
249
250        let err = res.err().unwrap();
251        let panic_msg = *err.into_panic().downcast_ref::<&'static str>().unwrap();
252        assert_eq!(panic_msg, "allergic to the number 2");
253
254        Ok(())
255    }
256}