tokio_stream_util/try_stream/ext/
try_buffer_unordered.rs

1use super::{IntoFuseStream, TryStream};
2use crate::{FusedStream, FuturesUnordered};
3use core::{
4    fmt,
5    num::NonZeroUsize,
6    pin::Pin,
7    task::{Context, Poll},
8};
9use futures_core::future::TryFuture;
10use futures_util::future::{IntoFuture, TryFutureExt};
11use tokio_stream::Stream;
12
13/// Stream for the
14/// [`try_buffer_unordered`](super::TryStreamExt::try_buffer_unordered) method.
15#[must_use = "streams do nothing unless polled"]
16pub struct TryBufferUnordered<St>
17where
18    St: TryStream,
19{
20    stream: IntoFuseStream<St>,
21    in_progress_queue: FuturesUnordered<IntoFuture<St::Ok>>,
22    max: Option<NonZeroUsize>,
23}
24
25impl<St> fmt::Debug for TryBufferUnordered<St>
26where
27    St: TryStream + fmt::Debug,
28{
29    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
30        f.debug_struct("TryBufferUnordered")
31            .field("stream", &self.stream)
32            .field("in_progress_queue", &self.in_progress_queue)
33            .field("max", &self.max)
34            .finish()
35    }
36}
37
38impl<St> TryBufferUnordered<St>
39where
40    St: TryStream,
41{
42    pub(super) fn new(stream: St, n: Option<usize>) -> Self {
43        Self {
44            stream: IntoFuseStream::new(stream),
45            in_progress_queue: FuturesUnordered::new(),
46            max: n.and_then(NonZeroUsize::new),
47        }
48    }
49}
50
51impl<St> Stream for TryBufferUnordered<St>
52where
53    St: TryStream,
54    St::Ok: TryFuture<Error = <St as crate::try_stream::TryStream>::Error>,
55{
56    type Item = Result<<St::Ok as TryFuture>::Ok, St::Error>;
57
58    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
59        let this = unsafe { self.get_unchecked_mut() };
60        let in_progress_queue_len = this.in_progress_queue.len();
61        let mut stream = unsafe { Pin::new_unchecked(&mut this.stream) };
62
63        // First up, try to spawn off as many futures as possible by filling up
64        // our queue of futures. Propagate errors from the stream immediately.
65        while this
66            .max
67            .map(|max| in_progress_queue_len < max.get())
68            .unwrap_or(true)
69        {
70            match stream.as_mut().poll_next(cx) {
71                Poll::Ready(Some(Ok(fut))) => {
72                    this.in_progress_queue.push(fut.into_future());
73                }
74                Poll::Ready(Some(Err(e))) => {
75                    return Poll::Ready(Some(Err(e)));
76                }
77                Poll::Ready(None) | Poll::Pending => break,
78            }
79        }
80
81        // Attempt to pull the next value from the in_progress_queue
82        match unsafe { Pin::new_unchecked(&mut this.in_progress_queue) }.poll_next(cx) {
83            Poll::Pending => return Poll::Pending,
84            Poll::Ready(Some(item)) => return Poll::Ready(Some(item)),
85            Poll::Ready(None) => {}
86        }
87
88        // If more values are still coming from the stream, we're not done yet
89        if stream.is_terminated() {
90            Poll::Ready(None)
91        } else {
92            Poll::Pending
93        }
94    }
95}
96
97#[cfg(feature = "sink")]
98use async_sink::Sink;
99#[cfg(feature = "sink")]
100// Forwarding impl of Sink from the underlying stream
101impl<St, Item> Sink<Item> for TryBufferUnordered<St>
102where
103    St: TryStream + Sink<Item>,
104    St::Ok: TryFuture<Error = <St as crate::try_stream::TryStream>::Error>,
105{
106    type Error = <St as async_sink::Sink<Item>>::Error;
107
108    fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
109        // Forward Sink to the underlying St via IntoFuseStream -> IntoStream -> St
110        let into_fuse = unsafe { self.map_unchecked_mut(|s| &mut s.stream) };
111        let st = into_fuse.get_pin_mut();
112        st.poll_ready(cx)
113    }
114
115    fn start_send(self: Pin<&mut Self>, item: Item) -> Result<(), Self::Error> {
116        let into_fuse = unsafe { self.map_unchecked_mut(|s| &mut s.stream) };
117        let st = into_fuse.get_pin_mut();
118        st.start_send(item)
119    }
120
121    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
122        let into_fuse = unsafe { self.map_unchecked_mut(|s| &mut s.stream) };
123        let st = into_fuse.get_pin_mut();
124        st.poll_flush(cx)
125    }
126
127    fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
128        let into_fuse = unsafe { self.map_unchecked_mut(|s| &mut s.stream) };
129        let st = into_fuse.get_pin_mut();
130        st.poll_close(cx)
131    }
132}