s2n_quic/stream/receive.rs
1// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4use s2n_quic_transport::stream;
5
6/// A QUIC stream that is only allowed to receive data.
7#[derive(Debug)]
8pub struct ReceiveStream(stream::ReceiveStream);
9
10macro_rules! impl_receive_stream_api {
11 (| $stream:ident, $dispatch:ident | $dispatch_body:expr) => {
12 /// Receives a chunk of data from the stream.
13 ///
14 /// # Return value
15 ///
16 /// The function returns:
17 ///
18 /// - `Ok(Some(chunk))` if the stream is open and data was available.
19 /// - `Ok(None)` if the stream was finished and all of the data was consumed.
20 /// - `Err(e)` if the stream encountered a [`stream::Error`](crate::stream::Error).
21 ///
22 /// # Examples
23 ///
24 /// ```rust,no_run
25 /// # async fn test() -> s2n_quic::stream::Result<()> {
26 /// # let mut stream: s2n_quic::stream::ReceiveStream = todo!();
27 /// #
28 /// while let Some(chunk) = stream.receive().await? {
29 /// println!("received: {:?}", chunk);
30 /// }
31 ///
32 /// println!("finished");
33 /// #
34 /// # Ok(())
35 /// # }
36 /// ```
37 #[inline]
38 pub async fn receive(&mut self) -> $crate::stream::Result<Option<bytes::Bytes>> {
39 ::futures::future::poll_fn(|cx| self.poll_receive(cx)).await
40 }
41
42 /// Poll for a chunk of data from the stream.
43 ///
44 /// # Return value
45 ///
46 /// The function returns:
47 ///
48 /// - `Poll::Pending` if the stream is waiting to receive data from the peer. In this case,
49 /// the caller should retry receiving after the [`Waker`](core::task::Waker) on the provided
50 /// [`Context`](core::task::Context) is notified.
51 /// - `Poll::Ready(Ok(Some(chunk)))` if the stream is open and data was available.
52 /// - `Poll::Ready(Ok(None))` if the stream was finished and all of the data was consumed.
53 /// - `Poll::Ready(Err(e))` if the stream encountered a [`stream::Error`](crate::stream::Error).
54 #[inline]
55 pub fn poll_receive(
56 &mut self,
57 cx: &mut core::task::Context,
58 ) -> core::task::Poll<$crate::stream::Result<Option<bytes::Bytes>>> {
59 macro_rules! $dispatch {
60 () => {
61 Err($crate::stream::Error::non_readable()).into()
62 };
63 ($variant: expr) => {
64 s2n_quic_core::task::waker::debug_assert_contract(cx, |cx| {
65 $variant.poll_receive(cx)
66 })
67 };
68 }
69
70 let $stream = self;
71 $dispatch_body
72 }
73
74 /// Receives a slice of chunks of data from the stream.
75 ///
76 /// This can be more efficient than calling [`receive`](Self::receive) for each chunk,
77 /// especially when receiving large amounts of data.
78 ///
79 /// # Return value
80 ///
81 /// The function returns:
82 ///
83 /// - `Ok((count, is_open))` if the stream received data into the slice,
84 /// where `count` was the number of chunks received, and `is_open` indicating if the stream is
85 /// still open. If `is_open == true`, `count` will be at least `1`. If `is_open == false`, future calls to
86 /// [`receive_vectored`](Self::receive_vectored) will always return
87 /// `Ok((0, false))`.
88 /// - `Err(e)` if the stream encountered a [`stream::Error`](crate::stream::Error).
89 ///
90 /// # Examples
91 ///
92 /// ```rust,no_run
93 /// # async fn test() -> s2n_quic::stream::Result<()> {
94 /// # let mut stream: s2n_quic::stream::ReceiveStream = todo!();
95 /// #
96 /// # use bytes::Bytes;
97 /// #
98 /// loop {
99 /// let mut chunks = [Bytes::new(), Bytes::new(), Bytes::new()];
100 /// let (count, is_open) = stream.receive_vectored(&mut chunks).await?;
101 ///
102 /// for chunk in &chunks[..count] {
103 /// println!("received: {:?}", chunk);
104 /// }
105 ///
106 /// if !is_open {
107 /// break;
108 /// }
109 /// }
110 ///
111 /// println!("finished");
112 /// #
113 /// # Ok(())
114 /// # }
115 /// ```
116 #[inline]
117 pub async fn receive_vectored(
118 &mut self,
119 chunks: &mut [bytes::Bytes],
120 ) -> $crate::stream::Result<(usize, bool)> {
121 ::futures::future::poll_fn(|cx| self.poll_receive_vectored(chunks, cx)).await
122 }
123
124 /// Polls for receiving a slice of chunks of data from the stream.
125 ///
126 /// # Return value
127 ///
128 /// The function returns:
129 ///
130 /// - `Poll::Pending` if the stream is waiting to receive data from the peer. In this case,
131 /// the caller should retry receiving after the [`Waker`](core::task::Waker) on the provided
132 /// [`Context`](core::task::Context) is notified.
133 /// - `Poll::Ready(Ok((count, is_open)))` if the stream received data into the slice,
134 /// where `count` was the number of chunks received, and `is_open` indicating if the stream is
135 /// still open. If `is_open == true`, `count` will be at least `1`. If `is_open == false`, future calls to
136 /// [`poll_receive_vectored`](Self::poll_receive_vectored) will always return
137 /// `Poll::Ready(Ok((0, false)))`.
138 /// - `Poll::Ready(Err(e))` if the stream encountered a [`stream::Error`](crate::stream::Error).
139 #[inline]
140 pub fn poll_receive_vectored(
141 &mut self,
142 chunks: &mut [bytes::Bytes],
143 cx: &mut core::task::Context,
144 ) -> core::task::Poll<$crate::stream::Result<(usize, bool)>> {
145 macro_rules! $dispatch {
146 () => {
147 Err($crate::stream::Error::non_readable()).into()
148 };
149 ($variant: expr) => {
150 s2n_quic_core::task::waker::debug_assert_contract(cx, |cx| {
151 $variant.poll_receive_vectored(chunks, cx)
152 })
153 };
154 }
155
156 let $stream = self;
157 $dispatch_body
158 }
159
160 /// Notifies the peer to stop sending data on the stream.
161 ///
162 /// This requests the peer to finish the stream as soon as possible
163 /// by issuing a [`reset`](crate::stream::SendStream::reset) with the
164 /// provided [`error_code`](crate::application::Error).
165 ///
166 /// Since this is merely a request for the peer to reset the stream, the
167 /// stream will not immediately be in a reset state after issuing this
168 /// call.
169 ///
170 /// If the stream has already been reset by the peer or if all data has
171 /// been received, the call will not trigger any action.
172 ///
173 /// # Return value
174 ///
175 /// The function returns:
176 ///
177 /// - `Ok(())` if the stop sending message was enqueued for the peer.
178 /// - `Err(e)` if the stream encountered a [`stream::Error`](crate::stream::Error).
179 ///
180 /// # Examples
181 ///
182 /// ```rust,no_run
183 /// # async fn test() -> s2n_quic::stream::Result<()> {
184 /// # let mut connection: s2n_quic::connection::Connection = todo!();
185 /// #
186 /// while let Some(stream) = connection.accept_receive_stream().await? {
187 /// stream.stop_sending(123u8.into());
188 /// }
189 /// #
190 /// # Ok(())
191 /// # }
192 /// ```
193 #[inline]
194 pub fn stop_sending(
195 &mut self,
196 error_code: $crate::application::Error,
197 ) -> $crate::stream::Result<()> {
198 macro_rules! $dispatch {
199 () => {
200 Err($crate::stream::Error::non_readable())
201 };
202 ($variant: expr) => {
203 $variant.stop_sending(error_code)
204 };
205 }
206
207 let $stream = self;
208 $dispatch_body
209 }
210
211 /// Create a batch request for receiving data
212 #[inline]
213 pub(crate) fn rx_request(
214 &mut self,
215 ) -> $crate::stream::Result<s2n_quic_transport::stream::RxRequest<'_, '_>> {
216 macro_rules! $dispatch {
217 () => {
218 Err($crate::stream::Error::non_readable())
219 };
220 ($variant: expr) => {
221 $variant.rx_request()
222 };
223 }
224
225 let $stream = self;
226 $dispatch_body
227 }
228
229 #[inline]
230 pub(crate) fn receive_chunks(
231 &mut self,
232 cx: &mut core::task::Context,
233 chunks: &mut [bytes::Bytes],
234 high_watermark: usize,
235 ) -> core::task::Poll<$crate::stream::Result<s2n_quic_transport::stream::ops::rx::Response>>
236 {
237 s2n_quic_core::task::waker::debug_assert_contract(cx, |cx| {
238 let response = core::task::ready!(self
239 .rx_request()?
240 .receive(chunks)
241 // don't receive more than we're capable of storing
242 .with_high_watermark(high_watermark)
243 .poll(Some(cx))?
244 .into_poll());
245
246 core::task::Poll::Ready(Ok(response))
247 })
248 }
249 };
250}
251
252macro_rules! impl_receive_stream_trait {
253 ($name:ident, | $stream:ident, $dispatch:ident | $dispatch_body:expr) => {
254 impl futures::stream::Stream for $name {
255 type Item = $crate::stream::Result<bytes::Bytes>;
256
257 #[inline]
258 fn poll_next(
259 mut self: core::pin::Pin<&mut Self>,
260 cx: &mut core::task::Context<'_>,
261 ) -> core::task::Poll<Option<Self::Item>> {
262 match core::task::ready!(self.poll_receive(cx)) {
263 Ok(Some(v)) => Some(Ok(v)),
264 Ok(None) => None,
265 Err(err) => Some(Err(err)),
266 }
267 .into()
268 }
269 }
270
271 impl futures::io::AsyncRead for $name {
272 fn poll_read(
273 mut self: core::pin::Pin<&mut Self>,
274 cx: &mut core::task::Context<'_>,
275 buf: &mut [u8],
276 ) -> core::task::Poll<std::io::Result<usize>> {
277 use bytes::Bytes;
278
279 if buf.is_empty() {
280 return Ok(0).into();
281 }
282
283 // create some chunks on the stack to receive into
284 // TODO investigate a better default number
285 let mut chunks = [
286 Bytes::new(),
287 Bytes::new(),
288 Bytes::new(),
289 Bytes::new(),
290 Bytes::new(),
291 ];
292
293 let high_watermark = buf.len();
294
295 let response =
296 core::task::ready!(self.receive_chunks(cx, &mut chunks, high_watermark))?;
297
298 let chunks = &chunks[..response.chunks.consumed];
299 let mut bufs = [buf];
300 let copied_len = s2n_quic_core::slice::vectored_copy(chunks, &mut bufs);
301
302 debug_assert_eq!(
303 copied_len, response.bytes.consumed,
304 "the consumed bytes should always have enough capacity in bufs"
305 );
306
307 Ok(response.bytes.consumed).into()
308 }
309
310 fn poll_read_vectored(
311 mut self: core::pin::Pin<&mut Self>,
312 cx: &mut core::task::Context<'_>,
313 bufs: &mut [futures::io::IoSliceMut],
314 ) -> core::task::Poll<std::io::Result<usize>> {
315 use bytes::Bytes;
316
317 if bufs.is_empty() {
318 return Ok(0).into();
319 }
320
321 // create some chunks on the stack to receive into
322 // TODO investigate a better default number
323 let mut chunks = [
324 Bytes::new(),
325 Bytes::new(),
326 Bytes::new(),
327 Bytes::new(),
328 Bytes::new(),
329 Bytes::new(),
330 Bytes::new(),
331 Bytes::new(),
332 Bytes::new(),
333 Bytes::new(),
334 ];
335
336 let high_watermark = bufs.iter().map(|buf| buf.len()).sum();
337
338 let response =
339 core::task::ready!(self.receive_chunks(cx, &mut chunks, high_watermark))?;
340
341 let chunks = &chunks[..response.chunks.consumed];
342 let copied_len = s2n_quic_core::slice::vectored_copy(chunks, bufs);
343
344 debug_assert_eq!(
345 copied_len, response.bytes.consumed,
346 "the consumed bytes should always have enough capacity in bufs"
347 );
348
349 Ok(copied_len).into()
350 }
351 }
352
353 impl tokio::io::AsyncRead for $name {
354 fn poll_read(
355 mut self: core::pin::Pin<&mut Self>,
356 cx: &mut core::task::Context<'_>,
357 buf: &mut tokio::io::ReadBuf,
358 ) -> core::task::Poll<std::io::Result<()>> {
359 use bytes::Bytes;
360
361 if buf.remaining() == 0 {
362 return Ok(()).into();
363 }
364
365 // create some chunks on the stack to receive into
366 // TODO investigate a better default number
367 let mut chunks = [
368 Bytes::new(),
369 Bytes::new(),
370 Bytes::new(),
371 Bytes::new(),
372 Bytes::new(),
373 ];
374
375 let high_watermark = buf.remaining();
376
377 let response =
378 core::task::ready!(self.receive_chunks(cx, &mut chunks, high_watermark))?;
379
380 for chunk in &chunks[..response.chunks.consumed] {
381 buf.put_slice(chunk);
382 }
383
384 Ok(()).into()
385 }
386 }
387 };
388}
389
390impl ReceiveStream {
391 #[inline]
392 pub(crate) const fn new(stream: stream::ReceiveStream) -> Self {
393 Self(stream)
394 }
395
396 /// Returns the stream's identifier
397 ///
398 /// This value is unique to a particular connection. The format follows the same as what is
399 /// defined in the
400 /// [QUIC Transport RFC](https://www.rfc-editor.org/rfc/rfc9000.html#name-stream-types-and-identifier).
401 ///
402 /// # Examples
403 ///
404 /// ```rust,no_run
405 /// # async fn test() -> s2n_quic::stream::Result<()> {
406 /// # let connection: s2n_quic::connection::Connection = todo!();
407 /// #
408 /// while let Some(stream) = connection.accept_receive_stream().await? {
409 /// println!("New stream's id: {}", stream.id());
410 /// }
411 /// #
412 /// # Ok(())
413 /// # }
414 /// ```
415 #[inline]
416 pub fn id(&self) -> u64 {
417 self.0.id().into()
418 }
419
420 impl_connection_api!(|stream| crate::connection::Handle(stream.0.connection().clone()));
421
422 impl_receive_stream_api!(|stream, dispatch| dispatch!(stream.0));
423}
424
425impl_splittable_stream_trait!(ReceiveStream, |stream| (Some(stream), None));
426impl_receive_stream_trait!(ReceiveStream, |stream, dispatch| dispatch!(stream.0));