tokio_par_util/stream/
parallel_buffered.rs

1use std::fmt;
2use std::future::Future;
3use std::pin::Pin;
4use std::task::{Context, Poll};
5
6use futures_util::stream::Buffered;
7use futures_util::Stream;
8use tokio_util::sync::CancellationToken;
9
10use crate::stream::into_tasks::IntoTasks;
11use crate::stream::parallel_buffer::ParallelBuffer;
12#[cfg(doc)]
13use crate::stream::StreamParExt;
14
15/// Given a [`Stream`] where every item is a [`Future`], this stream buffers up
16/// to a certain number of futures, runs them in parallel on separate tasks, and
17/// returns the results of the futures in the order in which they appeared in
18/// the input stream.
19///
20/// If any of the futures panics, all other futures and tasks are immediately
21/// cancelled and the panic gets returned immediately.
22///
23/// This stream is **cancellation safe** if the inner stream and generated
24/// futures are also cancellation safe.  This means that dropping this stream
25/// will also cancel any outstanding tasks and drop the relevant
26/// futures/streams.
27///
28/// You can use [`ParallelBuffered::awaiting_completion`] to control
29/// whether we wait for all tasks to fully terminate before this stream is
30/// considered to have ended.
31#[must_use = "streams do nothing unless polled"]
32#[pin_project::pin_project]
33pub struct ParallelBuffered<St>(#[pin] ParallelBuffer<St, Buffered<IntoTasks<St>>>)
34where
35    St: Stream,
36    St::Item: Future + Send + 'static,
37    <St::Item as Future>::Output: Send;
38
39impl<St> ParallelBuffered<St>
40where
41    St: Stream,
42    St::Item: Future + Send + 'static,
43    <St::Item as Future>::Output: Send,
44{
45    /// Whether to always await completion for the tasks from the
46    /// [`StreamParExt::parallel_buffered`] or
47    /// [`StreamParExt::parallel_buffered_with_token`] call, even when
48    /// one of the futures has panicked or been cancelled.
49    pub fn awaiting_completion(self, value: bool) -> Self {
50        Self(self.0.awaiting_completion(value))
51    }
52
53    pub(crate) fn new(
54        stream: St,
55        cancellation_token: CancellationToken,
56        limit: usize,
57    ) -> ParallelBuffered<St> {
58        Self(ParallelBuffer::new(stream, cancellation_token, limit))
59    }
60}
61
62impl<St> Stream for ParallelBuffered<St>
63where
64    St: Stream,
65    St::Item: Future + Send,
66    <St::Item as Future>::Output: Send,
67{
68    type Item = <<St as Stream>::Item as Future>::Output;
69
70    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
71        self.project().0.poll_next(cx)
72    }
73
74    fn size_hint(&self) -> (usize, Option<usize>) {
75        self.0.size_hint()
76    }
77}
78
79impl<St> fmt::Debug for ParallelBuffered<St>
80where
81    St: fmt::Debug + Stream,
82    St::Item: fmt::Debug + Future + Send,
83    <St::Item as Future>::Output: Send,
84{
85    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
86        f.debug_tuple("ParallelBuffered").field(&self.0).finish()
87    }
88}
89
90#[cfg(test)]
91mod tests {
92    use std::future;
93    use std::sync::Arc;
94
95    use futures_util::{stream, StreamExt};
96    use scopeguard::defer;
97    use tokio::sync::Semaphore;
98    use tokio::task;
99    use tokio_util::sync::CancellationToken;
100
101    use crate::stream::StreamParExt;
102
103    #[tokio::test]
104    async fn test_parallel_buffered() -> anyhow::Result<()> {
105        let result_vec: Vec<u32> = stream::iter([1, 2, 3, 4])
106            .map(move |elem| async move { elem + 1 })
107            .parallel_buffered(4)
108            .collect()
109            .await;
110
111        assert_eq!(result_vec, &[2, 3, 4, 5]);
112
113        Ok(())
114    }
115
116    #[tokio::test]
117    async fn test_parallel_buffered_await_cancellation() -> anyhow::Result<()> {
118        let drop_set = Arc::new(dashmap::DashSet::new());
119        let semaphore = Arc::new(Semaphore::new(0));
120
121        let future = stream::iter([1, 2, 3, 4])
122            .map({
123                let drop_set = Arc::clone(&drop_set);
124                let semaphore = Arc::clone(&semaphore);
125
126                move |elem| {
127                    let drop_set = Arc::clone(&drop_set);
128                    let semaphore = Arc::clone(&semaphore);
129                    async move {
130                        defer! { drop_set.insert(elem); }
131                        semaphore.add_permits(1);
132                        // Block forever here
133                        future::pending::<u32>().await;
134                    }
135                }
136            })
137            .parallel_buffered(4)
138            .collect::<Vec<_>>();
139        let task = task::spawn(future);
140
141        // Ensure all futures have made progress past `defer!`
142        drop(semaphore.acquire_many(4).await?);
143
144        task.abort();
145
146        if let Err(err) = task.await {
147            assert!(err.is_cancelled());
148        } else {
149            panic!("expected task to be cancelled")
150        }
151
152        // Check that `defer!` scope guards ran
153        assert!(drop_set.contains(&1));
154        assert!(drop_set.contains(&2));
155        assert!(drop_set.contains(&3));
156        assert!(drop_set.contains(&4));
157
158        Ok(())
159    }
160
161    #[tokio::test]
162    async fn test_parallel_buffered_cancel_via_token() -> anyhow::Result<()> {
163        let drop_set = Arc::new(dashmap::DashSet::new());
164        let semaphore = Arc::new(Semaphore::new(0));
165        let cancellation_token = CancellationToken::new();
166
167        let future = stream::iter([1, 2, 3, 4])
168            .map({
169                let drop_set = Arc::clone(&drop_set);
170                let semaphore = Arc::clone(&semaphore);
171
172                move |elem| {
173                    let drop_set = Arc::clone(&drop_set);
174                    let semaphore = Arc::clone(&semaphore);
175                    async move {
176                        defer! { drop_set.insert(elem); }
177                        semaphore.add_permits(1);
178                        // Block forever here
179                        future::pending::<u32>().await;
180                    }
181                }
182            })
183            .parallel_buffered_with_token(4, cancellation_token.clone())
184            .collect::<Vec<_>>();
185        let task = task::spawn(future);
186
187        // Ensure all futures have made progress past `defer!`
188        drop(semaphore.acquire_many(4).await?);
189
190        cancellation_token.cancel();
191
192        // The result from the spawned task is `Ok(Vec::new())`
193        let returned_vec = task.await?;
194        assert!(returned_vec.is_empty());
195
196        // Check that `defer!` scope guards ran
197        assert!(drop_set.contains(&1));
198        assert!(drop_set.contains(&2));
199        assert!(drop_set.contains(&3));
200        assert!(drop_set.contains(&4));
201
202        Ok(())
203    }
204
205    #[tokio::test]
206    async fn test_parallel_buffered_panic() -> anyhow::Result<()> {
207        let drop_set = Arc::new(dashmap::DashSet::new());
208        let semaphore = Arc::new(Semaphore::new(0));
209
210        let future = stream::iter([1, 2, 3, 4])
211            .map({
212                let drop_set = Arc::clone(&drop_set);
213                let semaphore = Arc::clone(&semaphore);
214
215                move |elem| {
216                    let drop_set = Arc::clone(&drop_set);
217                    let semaphore = Arc::clone(&semaphore);
218                    async move {
219                        defer! { drop_set.insert(elem); }
220                        semaphore.add_permits(1);
221                        if elem == 2 {
222                            panic!("allergic to the number 2")
223                        }
224                        if elem > 2 {
225                            // Block forever here
226                            future::pending::<u32>().await;
227                        }
228                        elem + 1
229                    }
230                }
231            })
232            .parallel_buffered(4)
233            .collect::<Vec<_>>();
234        let task = task::spawn(future);
235
236        // Ensure all futures have made progress past `defer!`
237        drop(semaphore.acquire_many(4).await?);
238
239        // Expect a panic to be caught here
240        let res = task.await;
241
242        // Check that `defer!` scope guards ran
243        assert!(drop_set.contains(&1));
244        assert!(drop_set.contains(&2));
245        assert!(drop_set.contains(&3));
246        assert!(drop_set.contains(&4));
247
248        let err = res.err().unwrap();
249        let panic_msg = *err.into_panic().downcast_ref::<&'static str>().unwrap();
250        assert_eq!(panic_msg, "allergic to the number 2");
251
252        Ok(())
253    }
254}