s2n_quic/stream/
receive.rs

1// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4use s2n_quic_transport::stream;
5
6/// A QUIC stream that is only allowed to receive data.
7#[derive(Debug)]
8pub struct ReceiveStream(stream::ReceiveStream);
9
10macro_rules! impl_receive_stream_api {
11    (| $stream:ident, $dispatch:ident | $dispatch_body:expr) => {
12        /// Receives a chunk of data from the stream.
13        ///
14        /// # Return value
15        ///
16        /// The function returns:
17        ///
18        /// - `Ok(Some(chunk))` if the stream is open and data was available.
19        /// - `Ok(None)` if the stream was finished and all of the data was consumed.
20        /// - `Err(e)` if the stream encountered a [`stream::Error`](crate::stream::Error).
21        ///
22        /// # Examples
23        ///
24        /// ```rust,no_run
25        /// # async fn test() -> s2n_quic::stream::Result<()> {
26        /// #   let mut stream: s2n_quic::stream::ReceiveStream = todo!();
27        /// #
28        /// while let Some(chunk) = stream.receive().await? {
29        ///     println!("received: {:?}", chunk);
30        /// }
31        ///
32        /// println!("finished");
33        /// #
34        /// #   Ok(())
35        /// # }
36        /// ```
37        #[inline]
38        pub async fn receive(&mut self) -> $crate::stream::Result<Option<bytes::Bytes>> {
39            ::futures::future::poll_fn(|cx| self.poll_receive(cx)).await
40        }
41
42        /// Poll for a chunk of data from the stream.
43        ///
44        /// # Return value
45        ///
46        /// The function returns:
47        ///
48        /// - `Poll::Pending` if the stream is waiting to receive data from the peer. In this case,
49        ///   the caller should retry receiving after the [`Waker`](core::task::Waker) on the provided
50        ///   [`Context`](core::task::Context) is notified.
51        /// - `Poll::Ready(Ok(Some(chunk)))` if the stream is open and data was available.
52        /// - `Poll::Ready(Ok(None))` if the stream was finished and all of the data was consumed.
53        /// - `Poll::Ready(Err(e))` if the stream encountered a [`stream::Error`](crate::stream::Error).
54        #[inline]
55        pub fn poll_receive(
56            &mut self,
57            cx: &mut core::task::Context,
58        ) -> core::task::Poll<$crate::stream::Result<Option<bytes::Bytes>>> {
59            macro_rules! $dispatch {
60                () => {
61                    Err($crate::stream::Error::non_readable()).into()
62                };
63                ($variant: expr) => {
64                    s2n_quic_core::task::waker::debug_assert_contract(cx, |cx| {
65                        $variant.poll_receive(cx)
66                    })
67                };
68            }
69
70            let $stream = self;
71            $dispatch_body
72        }
73
74        /// Receives a slice of chunks of data from the stream.
75        ///
76        /// This can be more efficient than calling [`receive`](Self::receive) for each chunk,
77        /// especially when receiving large amounts of data.
78        ///
79        /// # Return value
80        ///
81        /// The function returns:
82        ///
83        /// - `Ok((count, is_open))` if the stream received data into the slice,
84        ///   where `count` was the number of chunks received, and `is_open` indicating if the stream is
85        ///   still open. If `is_open == true`, `count` will be at least `1`. If `is_open == false`, future calls to
86        ///   [`receive_vectored`](Self::receive_vectored) will always return
87        ///   `Ok((0, false))`.
88        /// - `Err(e)` if the stream encountered a [`stream::Error`](crate::stream::Error).
89        ///
90        /// # Examples
91        ///
92        /// ```rust,no_run
93        /// # async fn test() -> s2n_quic::stream::Result<()> {
94        /// #   let mut stream: s2n_quic::stream::ReceiveStream = todo!();
95        /// #
96        /// # use bytes::Bytes;
97        /// #
98        /// loop {
99        ///     let mut chunks = [Bytes::new(), Bytes::new(), Bytes::new()];
100        ///     let (count, is_open) = stream.receive_vectored(&mut chunks).await?;
101        ///
102        ///     for chunk in &chunks[..count] {
103        ///         println!("received: {:?}", chunk);
104        ///     }
105        ///
106        ///     if !is_open {
107        ///         break;
108        ///     }
109        /// }
110        ///
111        /// println!("finished");
112        /// #
113        /// #   Ok(())
114        /// # }
115        /// ```
116        #[inline]
117        pub async fn receive_vectored(
118            &mut self,
119            chunks: &mut [bytes::Bytes],
120        ) -> $crate::stream::Result<(usize, bool)> {
121            ::futures::future::poll_fn(|cx| self.poll_receive_vectored(chunks, cx)).await
122        }
123
124        /// Polls for receiving a slice of chunks of data from the stream.
125        ///
126        /// # Return value
127        ///
128        /// The function returns:
129        ///
130        /// - `Poll::Pending` if the stream is waiting to receive data from the peer. In this case,
131        ///   the caller should retry receiving after the [`Waker`](core::task::Waker) on the provided
132        ///   [`Context`](core::task::Context) is notified.
133        /// - `Poll::Ready(Ok((count, is_open)))` if the stream received data into the slice,
134        ///   where `count` was the number of chunks received, and `is_open` indicating if the stream is
135        ///   still open. If `is_open == true`, `count` will be at least `1`. If `is_open == false`, future calls to
136        ///   [`poll_receive_vectored`](Self::poll_receive_vectored) will always return
137        ///   `Poll::Ready(Ok((0, false)))`.
138        /// - `Poll::Ready(Err(e))` if the stream encountered a [`stream::Error`](crate::stream::Error).
139        #[inline]
140        pub fn poll_receive_vectored(
141            &mut self,
142            chunks: &mut [bytes::Bytes],
143            cx: &mut core::task::Context,
144        ) -> core::task::Poll<$crate::stream::Result<(usize, bool)>> {
145            macro_rules! $dispatch {
146                () => {
147                    Err($crate::stream::Error::non_readable()).into()
148                };
149                ($variant: expr) => {
150                    s2n_quic_core::task::waker::debug_assert_contract(cx, |cx| {
151                        $variant.poll_receive_vectored(chunks, cx)
152                    })
153                };
154            }
155
156            let $stream = self;
157            $dispatch_body
158        }
159
160        /// Notifies the peer to stop sending data on the stream.
161        ///
162        /// This requests the peer to finish the stream as soon as possible
163        /// by issuing a [`reset`](crate::stream::SendStream::reset) with the
164        /// provided [`error_code`](crate::application::Error).
165        ///
166        /// Since this is merely a request for the peer to reset the stream, the
167        /// stream will not immediately be in a reset state after issuing this
168        /// call.
169        ///
170        /// If the stream has already been reset by the peer or if all data has
171        /// been received, the call will not trigger any action.
172        ///
173        /// # Return value
174        ///
175        /// The function returns:
176        ///
177        /// - `Ok(())` if the stop sending message was enqueued for the peer.
178        /// - `Err(e)` if the stream encountered a [`stream::Error`](crate::stream::Error).
179        ///
180        /// # Examples
181        ///
182        /// ```rust,no_run
183        /// # async fn test() -> s2n_quic::stream::Result<()> {
184        /// #   let mut connection: s2n_quic::connection::Connection = todo!();
185        /// #
186        /// while let Some(stream) = connection.accept_receive_stream().await? {
187        ///     stream.stop_sending(123u8.into());
188        /// }
189        /// #
190        /// #   Ok(())
191        /// # }
192        /// ```
193        #[inline]
194        pub fn stop_sending(
195            &mut self,
196            error_code: $crate::application::Error,
197        ) -> $crate::stream::Result<()> {
198            macro_rules! $dispatch {
199                () => {
200                    Err($crate::stream::Error::non_readable())
201                };
202                ($variant: expr) => {
203                    $variant.stop_sending(error_code)
204                };
205            }
206
207            let $stream = self;
208            $dispatch_body
209        }
210
211        /// Create a batch request for receiving data
212        #[inline]
213        pub(crate) fn rx_request(
214            &mut self,
215        ) -> $crate::stream::Result<s2n_quic_transport::stream::RxRequest<'_, '_>> {
216            macro_rules! $dispatch {
217                () => {
218                    Err($crate::stream::Error::non_readable())
219                };
220                ($variant: expr) => {
221                    $variant.rx_request()
222                };
223            }
224
225            let $stream = self;
226            $dispatch_body
227        }
228
229        #[inline]
230        pub(crate) fn receive_chunks(
231            &mut self,
232            cx: &mut core::task::Context,
233            chunks: &mut [bytes::Bytes],
234            high_watermark: usize,
235        ) -> core::task::Poll<$crate::stream::Result<s2n_quic_transport::stream::ops::rx::Response>>
236        {
237            s2n_quic_core::task::waker::debug_assert_contract(cx, |cx| {
238                let response = core::task::ready!(self
239                    .rx_request()?
240                    .receive(chunks)
241                    // don't receive more than we're capable of storing
242                    .with_high_watermark(high_watermark)
243                    .poll(Some(cx))?
244                    .into_poll());
245
246                core::task::Poll::Ready(Ok(response))
247            })
248        }
249    };
250}
251
252macro_rules! impl_receive_stream_trait {
253    ($name:ident, | $stream:ident, $dispatch:ident | $dispatch_body:expr) => {
254        impl futures::stream::Stream for $name {
255            type Item = $crate::stream::Result<bytes::Bytes>;
256
257            #[inline]
258            fn poll_next(
259                mut self: core::pin::Pin<&mut Self>,
260                cx: &mut core::task::Context<'_>,
261            ) -> core::task::Poll<Option<Self::Item>> {
262                match core::task::ready!(self.poll_receive(cx)) {
263                    Ok(Some(v)) => Some(Ok(v)),
264                    Ok(None) => None,
265                    Err(err) => Some(Err(err)),
266                }
267                .into()
268            }
269        }
270
271        impl futures::io::AsyncRead for $name {
272            fn poll_read(
273                mut self: core::pin::Pin<&mut Self>,
274                cx: &mut core::task::Context<'_>,
275                buf: &mut [u8],
276            ) -> core::task::Poll<std::io::Result<usize>> {
277                use bytes::Bytes;
278
279                if buf.is_empty() {
280                    return Ok(0).into();
281                }
282
283                // create some chunks on the stack to receive into
284                // TODO investigate a better default number
285                let mut chunks = [
286                    Bytes::new(),
287                    Bytes::new(),
288                    Bytes::new(),
289                    Bytes::new(),
290                    Bytes::new(),
291                ];
292
293                let high_watermark = buf.len();
294
295                let response =
296                    core::task::ready!(self.receive_chunks(cx, &mut chunks, high_watermark))?;
297
298                let chunks = &chunks[..response.chunks.consumed];
299                let mut bufs = [buf];
300                let copied_len = s2n_quic_core::slice::vectored_copy(chunks, &mut bufs);
301
302                debug_assert_eq!(
303                    copied_len, response.bytes.consumed,
304                    "the consumed bytes should always have enough capacity in bufs"
305                );
306
307                Ok(response.bytes.consumed).into()
308            }
309
310            fn poll_read_vectored(
311                mut self: core::pin::Pin<&mut Self>,
312                cx: &mut core::task::Context<'_>,
313                bufs: &mut [futures::io::IoSliceMut],
314            ) -> core::task::Poll<std::io::Result<usize>> {
315                use bytes::Bytes;
316
317                if bufs.is_empty() {
318                    return Ok(0).into();
319                }
320
321                // create some chunks on the stack to receive into
322                // TODO investigate a better default number
323                let mut chunks = [
324                    Bytes::new(),
325                    Bytes::new(),
326                    Bytes::new(),
327                    Bytes::new(),
328                    Bytes::new(),
329                    Bytes::new(),
330                    Bytes::new(),
331                    Bytes::new(),
332                    Bytes::new(),
333                    Bytes::new(),
334                ];
335
336                let high_watermark = bufs.iter().map(|buf| buf.len()).sum();
337
338                let response =
339                    core::task::ready!(self.receive_chunks(cx, &mut chunks, high_watermark))?;
340
341                let chunks = &chunks[..response.chunks.consumed];
342                let copied_len = s2n_quic_core::slice::vectored_copy(chunks, bufs);
343
344                debug_assert_eq!(
345                    copied_len, response.bytes.consumed,
346                    "the consumed bytes should always have enough capacity in bufs"
347                );
348
349                Ok(copied_len).into()
350            }
351        }
352
353        impl tokio::io::AsyncRead for $name {
354            fn poll_read(
355                mut self: core::pin::Pin<&mut Self>,
356                cx: &mut core::task::Context<'_>,
357                buf: &mut tokio::io::ReadBuf,
358            ) -> core::task::Poll<std::io::Result<()>> {
359                use bytes::Bytes;
360
361                if buf.remaining() == 0 {
362                    return Ok(()).into();
363                }
364
365                // create some chunks on the stack to receive into
366                // TODO investigate a better default number
367                let mut chunks = [
368                    Bytes::new(),
369                    Bytes::new(),
370                    Bytes::new(),
371                    Bytes::new(),
372                    Bytes::new(),
373                ];
374
375                let high_watermark = buf.remaining();
376
377                let response =
378                    core::task::ready!(self.receive_chunks(cx, &mut chunks, high_watermark))?;
379
380                for chunk in &chunks[..response.chunks.consumed] {
381                    buf.put_slice(chunk);
382                }
383
384                Ok(()).into()
385            }
386        }
387    };
388}
389
390impl ReceiveStream {
391    #[inline]
392    pub(crate) const fn new(stream: stream::ReceiveStream) -> Self {
393        Self(stream)
394    }
395
396    /// Returns the stream's identifier
397    ///
398    /// This value is unique to a particular connection. The format follows the same as what is
399    /// defined in the
400    /// [QUIC Transport RFC](https://www.rfc-editor.org/rfc/rfc9000.html#name-stream-types-and-identifier).
401    ///
402    /// # Examples
403    ///
404    /// ```rust,no_run
405    /// # async fn test() -> s2n_quic::stream::Result<()> {
406    /// #   let connection: s2n_quic::connection::Connection = todo!();
407    /// #
408    /// while let Some(stream) = connection.accept_receive_stream().await? {
409    ///     println!("New stream's id: {}", stream.id());
410    /// }
411    /// #
412    /// #   Ok(())
413    /// # }
414    /// ```
415    #[inline]
416    pub fn id(&self) -> u64 {
417        self.0.id().into()
418    }
419
420    impl_connection_api!(|stream| crate::connection::Handle(stream.0.connection().clone()));
421
422    impl_receive_stream_api!(|stream, dispatch| dispatch!(stream.0));
423}
424
425impl_splittable_stream_trait!(ReceiveStream, |stream| (Some(stream), None));
426impl_receive_stream_trait!(ReceiveStream, |stream, dispatch| dispatch!(stream.0));