tokio_stream_util/try_stream/ext/
try_chunks.rs

1use alloc::vec::Vec;
2use core::fmt;
3use core::mem;
4use core::pin::Pin;
5use core::task::{Context, Poll};
6
7#[cfg(feature = "sink")]
8use async_sink::Sink;
9use tokio_stream::Stream;
10
11use super::IntoFuseStream;
12
13use super::{FusedStream, TryStream};
14
15/// Stream for the [`try_chunks`](super::TryStreamExt::try_chunks) method.
16#[must_use = "streams do nothing unless polled"]
17pub struct TryChunks<St: TryStream> {
18    stream: IntoFuseStream<St>,
19    items: Vec<St::Ok>,
20    cap: usize,
21    pending_error: Option<St::Error>,
22}
23
24impl<St> fmt::Debug for TryChunks<St>
25where
26    St: TryStream + fmt::Debug,
27    St::Ok: fmt::Debug,
28{
29    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
30        f.debug_struct("TryChunks")
31            .field("stream", &self.stream)
32            .field("items", &self.items)
33            .field("cap", &self.cap)
34            .finish()
35    }
36}
37
38struct TryChunksProj<'pin, St: TryStream> {
39    stream: Pin<&'pin mut IntoFuseStream<St>>,
40    #[allow(dead_code)]
41    items: &'pin mut Vec<St::Ok>,
42    #[allow(dead_code)]
43    cap: &'pin usize,
44}
45
46impl<St: TryStream + Unpin> Unpin for TryChunks<St> {}
47
48impl<St: TryStream> TryChunks<St> {
49    pub(super) fn new(stream: St, capacity: usize) -> Self {
50        assert!(capacity > 0);
51
52        Self {
53            stream: IntoFuseStream::new(stream),
54            items: Vec::with_capacity(capacity),
55            cap: capacity,
56            pending_error: None,
57        }
58    }
59
60    fn take(self: Pin<&mut Self>) -> Vec<St::Ok> {
61        let this = unsafe { self.get_unchecked_mut() };
62        let cap = this.cap;
63        mem::replace(&mut this.items, Vec::with_capacity(cap))
64    }
65
66    /// Acquires a reference to the underlying stream that this combinator is
67    /// pulling from.
68    pub fn get_ref(&self) -> &St {
69        self.stream.get_ref()
70    }
71
72    /// Acquires a mutable reference to the underlying stream that this
73    /// combinator is pulling from.
74    ///
75    /// Note that care must be taken to avoid tampering with the state of the
76    /// stream which may otherwise confuse this combinator.
77    pub fn get_mut(&mut self) -> &mut St {
78        self.stream.get_mut()
79    }
80
81    /// Acquires a pinned mutable reference to the underlying stream that this
82    /// combinator is pulling from.
83    ///
84    /// Note that care must be taken to avoid tampering with the state of the
85    /// stream which may otherwise confuse this combinator.
86    pub fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut St> {
87        self.project().stream.get_pin_mut()
88    }
89
90    fn project<'pin>(self: Pin<&'pin mut Self>) -> TryChunksProj<'pin, St> {
91        unsafe {
92            let this = self.get_unchecked_mut();
93            TryChunksProj {
94                stream: Pin::new_unchecked(&mut this.stream),
95                items: &mut this.items,
96                cap: &this.cap,
97            }
98        }
99    }
100
101    /// Consumes this combinator, returning the underlying stream.
102    ///
103    /// Note that this may discard intermediate state of this combinator, so
104    /// care should be taken to avoid losing resources when this is called.
105    pub fn into_inner(self) -> St {
106        self.stream.into_inner()
107    }
108}
109
110type TryChunksStreamError<St> = TryChunksError<<St as TryStream>::Ok, <St as TryStream>::Error>;
111
112impl<St: TryStream> Stream for TryChunks<St> {
113    type Item = Result<Vec<St::Ok>, TryChunksStreamError<St>>;
114
115    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
116        if let Some(err) = unsafe { self.as_mut().get_unchecked_mut() }
117            .pending_error
118            .take()
119        {
120            return Poll::Ready(Some(Err(TryChunksError(Vec::new(), err))));
121        }
122        loop {
123            let this = self.as_mut();
124            let stream = this.project().stream;
125            match stream.poll_next(cx) {
126                // Push the item into the buffer and check whether it is full.
127                // If so, replace our buffer with a new and empty one and return
128                // the full one.
129                Poll::Ready(Some(Ok(item))) => {
130                    let this = unsafe { self.as_mut().get_unchecked_mut() };
131                    this.items.push(item);
132                    if this.items.len() >= this.cap {
133                        let items = mem::replace(&mut this.items, Vec::with_capacity(this.cap));
134                        break Poll::Ready(Some(Ok(items)));
135                    }
136                }
137                Poll::Ready(Some(Err(e))) => {
138                    let this = unsafe { self.as_mut().get_unchecked_mut() };
139                    if this.items.is_empty() {
140                        break Poll::Ready(Some(Err(TryChunksError(Vec::new(), e))));
141                    } else {
142                        // stash error and yield the buffered items first
143                        this.pending_error = Some(e);
144                        let items = mem::replace(&mut this.items, Vec::with_capacity(this.cap));
145                        break Poll::Ready(Some(Ok(items)));
146                    }
147                }
148
149                // Since the underlying stream ran out of values, break what we
150                // have buffered, if we have anything.
151                Poll::Ready(None) => {
152                    let this = unsafe { self.as_mut().get_unchecked_mut() };
153                    let last = if this.items.is_empty() {
154                        None
155                    } else {
156                        Some(mem::take(&mut this.items))
157                    };
158
159                    break Poll::Ready(last.map(Ok));
160                }
161                Poll::Pending => {
162                    break if self.items.is_empty() {
163                        Poll::Pending
164                    } else {
165                        let items = self.take();
166                        Poll::Ready(Some(Ok(items)))
167                    }
168                }
169            }
170        }
171    }
172
173    fn size_hint(&self) -> (usize, Option<usize>) {
174        let chunk_len = if !self.items.is_empty() { 1 } else { 0 };
175        let (lower, upper) = self.stream.size_hint();
176        let lower = (lower / self.cap).saturating_add(chunk_len);
177        let upper = match upper {
178            Some(x) => x.checked_add(chunk_len),
179            None => None,
180        };
181        (lower, upper)
182    }
183}
184
185impl<St: TryStream> FusedStream for TryChunks<St> {
186    fn is_terminated(&self) -> bool {
187        self.stream.is_terminated() && self.items.is_empty()
188    }
189}
190
191// Forwarding impl of Sink from the underlying stream
192#[cfg(feature = "sink")]
193impl<St, Item> Sink<Item> for TryChunks<St>
194where
195    St: TryStream + Sink<Item>,
196{
197    type Error = <St as Sink<Item>>::Error;
198
199    fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
200        self.get_pin_mut().poll_ready(cx)
201    }
202
203    fn start_send(self: Pin<&mut Self>, item: Item) -> Result<(), Self::Error> {
204        self.get_pin_mut().start_send(item)
205    }
206
207    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
208        self.get_pin_mut().poll_flush(cx)
209    }
210
211    fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
212        self.get_pin_mut().poll_close(cx)
213    }
214}
215
216/// Error indicating, that while chunk was collected inner stream produced an error.
217///
218/// Contains all items that were collected before an error occurred, and the stream error itself.
219#[derive(PartialEq, Eq)]
220pub struct TryChunksError<T, E>(pub Vec<T>, pub E);
221
222impl<T, E: fmt::Debug> fmt::Debug for TryChunksError<T, E> {
223    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
224        self.1.fmt(f)
225    }
226}
227
228impl<T, E: fmt::Display> fmt::Display for TryChunksError<T, E> {
229    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
230        self.1.fmt(f)
231    }
232}
233
234impl<T, E: fmt::Debug + fmt::Display> core::error::Error for TryChunksError<T, E> {}