sec_http3/
stream.rs

1use std::{
2    marker::PhantomData,
3    pin::Pin,
4    task::{Context, Poll},
5};
6
7use bytes::{Buf, BufMut, Bytes};
8use futures_util::{future, ready};
9use pin_project_lite::pin_project;
10use tokio::io::ReadBuf;
11
12use crate::{
13    buf::BufList,
14    error::{Code, ErrorLevel},
15    frame::FrameStream,
16    proto::{
17        coding::{Decode as _, Encode},
18        frame::{Frame, Settings},
19        stream::StreamType,
20        varint::VarInt,
21    },
22    quic::{self, BidiStream, RecvStream, SendStream, SendStreamUnframed},
23    webtransport::SessionId,
24    Error,
25};
26
27#[inline]
28/// Transmits data by encoding in wire format.
29pub(crate) async fn write<S, D, B>(stream: &mut S, data: D) -> Result<(), Error>
30where
31    S: SendStream<B>,
32    D: Into<WriteBuf<B>>,
33    B: Buf,
34{
35    stream.send_data(data)?;
36    future::poll_fn(|cx| stream.poll_ready(cx)).await?;
37
38    Ok(())
39}
40
41const WRITE_BUF_ENCODE_SIZE: usize = StreamType::MAX_ENCODED_SIZE + Frame::MAX_ENCODED_SIZE;
42
43/// Wrap frames to encode their header on the stack before sending them on the wire
44///
45/// Implements `Buf` so wire data is seamlessly available for transport layer transmits:
46/// `Buf::chunk()` will yield the encoded header, then the payload. For unidirectional streams,
47/// this type makes it possible to prefix wire data with the `StreamType`.
48///
49/// Conveying frames as `Into<WriteBuf>` makes it possible to encode only when generating wire-format
50/// data is necessary (say, in `quic::SendStream::send_data`). It also has a public API ergonomy
51/// advantage: `WriteBuf` doesn't have to appear in public associated types. On the other hand,
52/// QUIC implementers have to call `into()`, which will encode the header in `Self::buf`.
53pub struct WriteBuf<B> {
54    buf: [u8; WRITE_BUF_ENCODE_SIZE],
55    len: usize,
56    pos: usize,
57    frame: Option<Frame<B>>,
58}
59
60impl<B> WriteBuf<B>
61where
62    B: Buf,
63{
64    fn encode_stream_type(&mut self, ty: StreamType) {
65        let mut buf_mut = &mut self.buf[self.len..];
66
67        ty.encode(&mut buf_mut);
68        self.len = WRITE_BUF_ENCODE_SIZE - buf_mut.remaining_mut();
69    }
70
71    fn encode_value(&mut self, value: impl Encode) {
72        let mut buf_mut = &mut self.buf[self.len..];
73        value.encode(&mut buf_mut);
74        self.len = WRITE_BUF_ENCODE_SIZE - buf_mut.remaining_mut();
75    }
76
77    fn encode_frame_header(&mut self) {
78        if let Some(frame) = self.frame.as_ref() {
79            let mut buf_mut = &mut self.buf[self.len..];
80            frame.encode(&mut buf_mut);
81            self.len = WRITE_BUF_ENCODE_SIZE - buf_mut.remaining_mut();
82        }
83    }
84}
85
86impl<B> From<StreamType> for WriteBuf<B>
87where
88    B: Buf,
89{
90    fn from(ty: StreamType) -> Self {
91        let mut me = Self {
92            buf: [0; WRITE_BUF_ENCODE_SIZE],
93            len: 0,
94            pos: 0,
95            frame: None,
96        };
97        me.encode_stream_type(ty);
98        me
99    }
100}
101
102impl<B> From<UniStreamHeader> for WriteBuf<B>
103where
104    B: Buf,
105{
106    fn from(header: UniStreamHeader) -> Self {
107        let mut this = Self {
108            buf: [0; WRITE_BUF_ENCODE_SIZE],
109            len: 0,
110            pos: 0,
111            frame: None,
112        };
113
114        this.encode_value(header);
115        this
116    }
117}
118
119pub enum UniStreamHeader {
120    Control(Settings),
121    WebTransportUni(SessionId),
122}
123
124impl Encode for UniStreamHeader {
125    fn encode<B: BufMut>(&self, buf: &mut B) {
126        match self {
127            Self::Control(settings) => {
128                StreamType::CONTROL.encode(buf);
129                settings.encode(buf);
130            }
131            Self::WebTransportUni(session_id) => {
132                StreamType::WEBTRANSPORT_UNI.encode(buf);
133                session_id.encode(buf);
134            }
135        }
136    }
137}
138
139impl<B> From<BidiStreamHeader> for WriteBuf<B>
140where
141    B: Buf,
142{
143    fn from(header: BidiStreamHeader) -> Self {
144        let mut this = Self {
145            buf: [0; WRITE_BUF_ENCODE_SIZE],
146            len: 0,
147            pos: 0,
148            frame: None,
149        };
150
151        this.encode_value(header);
152        this
153    }
154}
155
156pub enum BidiStreamHeader {
157    Control(Settings),
158    WebTransportBidi(SessionId),
159}
160
161impl Encode for BidiStreamHeader {
162    fn encode<B: BufMut>(&self, buf: &mut B) {
163        match self {
164            Self::Control(settings) => {
165                StreamType::CONTROL.encode(buf);
166                settings.encode(buf);
167            }
168            Self::WebTransportBidi(session_id) => {
169                StreamType::WEBTRANSPORT_BIDI.encode(buf);
170                session_id.encode(buf);
171            }
172        }
173    }
174}
175
176impl<B> From<Frame<B>> for WriteBuf<B>
177where
178    B: Buf,
179{
180    fn from(frame: Frame<B>) -> Self {
181        let mut me = Self {
182            buf: [0; WRITE_BUF_ENCODE_SIZE],
183            len: 0,
184            pos: 0,
185            frame: Some(frame),
186        };
187        me.encode_frame_header();
188        me
189    }
190}
191
192impl<B> From<(StreamType, Frame<B>)> for WriteBuf<B>
193where
194    B: Buf,
195{
196    fn from(ty_stream: (StreamType, Frame<B>)) -> Self {
197        let (ty, frame) = ty_stream;
198        let mut me = Self {
199            buf: [0; WRITE_BUF_ENCODE_SIZE],
200            len: 0,
201            pos: 0,
202            frame: Some(frame),
203        };
204        me.encode_value(ty);
205        me.encode_frame_header();
206        me
207    }
208}
209
210impl<B> Buf for WriteBuf<B>
211where
212    B: Buf,
213{
214    fn remaining(&self) -> usize {
215        self.len - self.pos
216            + self
217                .frame
218                .as_ref()
219                .and_then(|f| f.payload())
220                .map_or(0, |x| x.remaining())
221    }
222
223    fn chunk(&self) -> &[u8] {
224        if self.len - self.pos > 0 {
225            &self.buf[self.pos..self.len]
226        } else if let Some(payload) = self.frame.as_ref().and_then(|f| f.payload()) {
227            payload.chunk()
228        } else {
229            &[]
230        }
231    }
232
233    fn advance(&mut self, mut cnt: usize) {
234        let remaining_header = self.len - self.pos;
235        if remaining_header > 0 {
236            let advanced = usize::min(cnt, remaining_header);
237            self.pos += advanced;
238            cnt -= advanced;
239        }
240
241        if let Some(payload) = self.frame.as_mut().and_then(|f| f.payload_mut()) {
242            payload.advance(cnt);
243        }
244    }
245}
246
247pub(super) enum AcceptedRecvStream<S, B>
248where
249    S: quic::RecvStream,
250    B: Buf,
251{
252    Control(FrameStream<S, B>),
253    Push(u64, FrameStream<S, B>),
254    Encoder(BufRecvStream<S, B>),
255    Decoder(BufRecvStream<S, B>),
256    WebTransportUni(SessionId, BufRecvStream<S, B>),
257    Reserved,
258}
259
260/// Resolves an incoming streams type as well as `PUSH_ID`s and `SESSION_ID`s
261pub(super) struct AcceptRecvStream<S, B> {
262    stream: BufRecvStream<S, B>,
263    ty: Option<StreamType>,
264    /// push_id or session_id
265    id: Option<VarInt>,
266    expected: Option<usize>,
267}
268
269impl<S, B> AcceptRecvStream<S, B>
270where
271    S: RecvStream,
272    B: Buf,
273{
274    pub fn new(stream: S) -> Self {
275        Self {
276            stream: BufRecvStream::new(stream),
277            ty: None,
278            id: None,
279            expected: None,
280        }
281    }
282
283    pub fn into_stream(self) -> Result<AcceptedRecvStream<S, B>, Error> {
284        Ok(match self.ty.expect("Stream type not resolved yet") {
285            StreamType::CONTROL => AcceptedRecvStream::Control(FrameStream::new(self.stream)),
286            StreamType::PUSH => AcceptedRecvStream::Push(
287                self.id.expect("Push ID not resolved yet").into_inner(),
288                FrameStream::new(self.stream),
289            ),
290            StreamType::ENCODER => AcceptedRecvStream::Encoder(self.stream),
291            StreamType::DECODER => AcceptedRecvStream::Decoder(self.stream),
292            StreamType::WEBTRANSPORT_UNI => AcceptedRecvStream::WebTransportUni(
293                SessionId::from_varint(self.id.expect("Session ID not resolved yet")),
294                self.stream,
295            ),
296            t if t.value() > 0x21 && (t.value() - 0x21) % 0x1f == 0 => AcceptedRecvStream::Reserved,
297
298            //= https://www.rfc-editor.org/rfc/rfc9114#section-6.2
299            //# Recipients of unknown stream types MUST
300            //# either abort reading of the stream or discard incoming data without
301            //# further processing.
302
303            //= https://www.rfc-editor.org/rfc/rfc9114#section-6.2
304            //# If reading is aborted, the recipient SHOULD use
305            //# the H3_STREAM_CREATION_ERROR error code or a reserved error code
306            //# (Section 8.1).
307
308            //= https://www.rfc-editor.org/rfc/rfc9114#section-6.2
309            //= type=implication
310            //# The recipient MUST NOT consider unknown stream types
311            //# to be a connection error of any kind.
312            t => {
313                return Err(Code::H3_STREAM_CREATION_ERROR.with_reason(
314                    format!("unknown stream type 0x{:x}", t.value()),
315                    crate::error::ErrorLevel::ConnectionError,
316                ))
317            }
318        })
319    }
320
321    pub fn poll_type(&mut self, cx: &mut Context) -> Poll<Result<(), Error>> {
322        loop {
323            // Return if all identification data is met
324            match self.ty {
325                Some(StreamType::PUSH | StreamType::WEBTRANSPORT_UNI) => {
326                    if self.id.is_some() {
327                        return Poll::Ready(Ok(()));
328                    }
329                }
330                Some(_) => return Poll::Ready(Ok(())),
331                None => (),
332            };
333
334            if ready!(self.stream.poll_read(cx))? {
335                return Poll::Ready(Err(Code::H3_STREAM_CREATION_ERROR.with_reason(
336                    "Stream closed before type received",
337                    ErrorLevel::ConnectionError,
338                )));
339            };
340
341            let mut buf = self.stream.buf_mut();
342            if self.expected.is_none() && buf.remaining() >= 1 {
343                self.expected = Some(VarInt::encoded_size(buf.chunk()[0]));
344            }
345
346            if let Some(expected) = self.expected {
347                // Poll for more data
348                if buf.remaining() < expected {
349                    continue;
350                }
351            } else {
352                continue;
353            }
354
355            // Parse ty and then id
356            if self.ty.is_none() {
357                // Parse StreamType
358                self.ty = Some(StreamType::decode(&mut buf).map_err(|_| {
359                    Code::H3_INTERNAL_ERROR.with_reason(
360                        "Unexpected end parsing stream type",
361                        ErrorLevel::ConnectionError,
362                    )
363                })?);
364                // Get the next VarInt for PUSH_ID on the next iteration
365                self.expected = None;
366            } else {
367                // Parse PUSH_ID
368                self.id = Some(VarInt::decode(&mut buf).map_err(|_| {
369                    Code::H3_INTERNAL_ERROR.with_reason(
370                        "Unexpected end parsing push or session id",
371                        ErrorLevel::ConnectionError,
372                    )
373                })?);
374            }
375        }
376    }
377}
378
379pin_project! {
380    /// A stream which allows partial reading of the data without data loss.
381    ///
382    /// This fixes the problem where `poll_data` returns more than the needed amount of bytes,
383    /// requiring correct implementations to hold on to that extra data and return it later.
384    ///
385    /// # Usage
386    ///
387    /// Implements `quic::RecvStream` which will first return buffered data, and then read from the
388    /// stream
389    pub struct BufRecvStream<S, B> {
390        buf: BufList<Bytes>,
391        // Indicates that the end of the stream has been reached
392        //
393        // Data may still be available as buffered
394        eos: bool,
395        stream: S,
396        _marker: PhantomData<B>,
397    }
398}
399
400impl<S, B> std::fmt::Debug for BufRecvStream<S, B> {
401    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
402        f.debug_struct("BufRecvStream")
403            .field("buf", &self.buf)
404            .field("eos", &self.eos)
405            .field("stream", &"...")
406            .finish()
407    }
408}
409
410impl<S, B> BufRecvStream<S, B> {
411    pub fn new(stream: S) -> Self {
412        Self {
413            buf: BufList::new(),
414            eos: false,
415            stream,
416            _marker: PhantomData,
417        }
418    }
419}
420
421impl<B, S: RecvStream> BufRecvStream<S, B> {
422    /// Reads more data into the buffer, returning the number of bytes read.
423    ///
424    /// Returns `true` if the end of the stream is reached.
425    pub fn poll_read(&mut self, cx: &mut Context<'_>) -> Poll<Result<bool, S::Error>> {
426        let data = ready!(self.stream.poll_data(cx))?;
427
428        if let Some(mut data) = data {
429            self.buf.push_bytes(&mut data);
430            Poll::Ready(Ok(false))
431        } else {
432            self.eos = true;
433            Poll::Ready(Ok(true))
434        }
435    }
436
437    /// Returns the currently buffered data, allowing it to be partially read
438    #[inline]
439    pub(crate) fn buf_mut(&mut self) -> &mut BufList<Bytes> {
440        &mut self.buf
441    }
442
443    /// Returns the next chunk of data from the stream
444    ///
445    /// Return `None` when there is no more buffered data; use [`Self::poll_read`].
446    pub fn take_chunk(&mut self, limit: usize) -> Option<Bytes> {
447        self.buf.take_chunk(limit)
448    }
449
450    /// Returns true if there is remaining buffered data
451    pub fn has_remaining(&mut self) -> bool {
452        self.buf.has_remaining()
453    }
454
455    #[inline]
456    pub(crate) fn buf(&self) -> &BufList<Bytes> {
457        &self.buf
458    }
459
460    pub fn is_eos(&self) -> bool {
461        self.eos
462    }
463}
464
465impl<S: RecvStream, B> RecvStream for BufRecvStream<S, B> {
466    type Buf = Bytes;
467
468    type Error = S::Error;
469
470    fn poll_data(
471        &mut self,
472        cx: &mut std::task::Context<'_>,
473    ) -> Poll<Result<Option<Self::Buf>, Self::Error>> {
474        // There is data buffered, return that immediately
475        if let Some(chunk) = self.buf.take_first_chunk() {
476            return Poll::Ready(Ok(Some(chunk)));
477        }
478
479        if let Some(mut data) = ready!(self.stream.poll_data(cx))? {
480            Poll::Ready(Ok(Some(data.copy_to_bytes(data.remaining()))))
481        } else {
482            self.eos = true;
483            Poll::Ready(Ok(None))
484        }
485    }
486
487    fn stop_sending(&mut self, error_code: u64) {
488        self.stream.stop_sending(error_code)
489    }
490
491    fn recv_id(&self) -> quic::StreamId {
492        self.stream.recv_id()
493    }
494}
495
496impl<S, B> SendStream<B> for BufRecvStream<S, B>
497where
498    B: Buf,
499    S: SendStream<B>,
500{
501    type Error = S::Error;
502
503    fn poll_finish(&mut self, cx: &mut std::task::Context<'_>) -> Poll<Result<(), Self::Error>> {
504        self.stream.poll_finish(cx)
505    }
506
507    fn reset(&mut self, reset_code: u64) {
508        self.stream.reset(reset_code)
509    }
510
511    fn send_id(&self) -> quic::StreamId {
512        self.stream.send_id()
513    }
514
515    fn poll_ready(&mut self, cx: &mut std::task::Context<'_>) -> Poll<Result<(), Self::Error>> {
516        self.stream.poll_ready(cx)
517    }
518
519    fn send_data<T: Into<WriteBuf<B>>>(&mut self, data: T) -> Result<(), Self::Error> {
520        self.stream.send_data(data)
521    }
522}
523
524impl<S, B> SendStreamUnframed<B> for BufRecvStream<S, B>
525where
526    B: Buf,
527    S: SendStreamUnframed<B>,
528{
529    #[inline]
530    fn poll_send<D: Buf>(
531        &mut self,
532        cx: &mut std::task::Context<'_>,
533        buf: &mut D,
534    ) -> Poll<Result<usize, Self::Error>> {
535        self.stream.poll_send(cx, buf)
536    }
537}
538
539impl<S, B> BidiStream<B> for BufRecvStream<S, B>
540where
541    B: Buf,
542    S: BidiStream<B>,
543{
544    type SendStream = BufRecvStream<S::SendStream, B>;
545
546    type RecvStream = BufRecvStream<S::RecvStream, B>;
547
548    fn split(self) -> (Self::SendStream, Self::RecvStream) {
549        let (send, recv) = self.stream.split();
550        (
551            BufRecvStream {
552                // Sending is not buffered
553                buf: BufList::new(),
554                eos: self.eos,
555                stream: send,
556                _marker: PhantomData,
557            },
558            BufRecvStream {
559                buf: self.buf,
560                eos: self.eos,
561                stream: recv,
562                _marker: PhantomData,
563            },
564        )
565    }
566}
567
568impl<S, B> futures_util::io::AsyncRead for BufRecvStream<S, B>
569where
570    B: Buf,
571    S: RecvStream,
572    S::Error: Into<std::io::Error>,
573{
574    fn poll_read(
575        mut self: Pin<&mut Self>,
576        cx: &mut Context<'_>,
577        buf: &mut [u8],
578    ) -> Poll<futures_util::io::Result<usize>> {
579        let p = &mut *self;
580        // Poll for data if the buffer is empty
581        //
582        // If there is data available *do not* poll for more data, as that may suspend indefinitely
583        // if no more data is sent, causing data loss.
584        if !p.has_remaining() {
585            let eos = ready!(p.poll_read(cx).map_err(Into::into))?;
586            if eos {
587                return Poll::Ready(Ok(0));
588            }
589        }
590
591        let chunk = p.buf_mut().take_chunk(buf.len());
592        if let Some(chunk) = chunk {
593            assert!(chunk.len() <= buf.len());
594            let len = chunk.len().min(buf.len());
595            // Write the subset into the destination
596            buf[..len].copy_from_slice(&chunk);
597            Poll::Ready(Ok(len))
598        } else {
599            Poll::Ready(Ok(0))
600        }
601    }
602}
603
604impl<S, B> tokio::io::AsyncRead for BufRecvStream<S, B>
605where
606    B: Buf,
607    S: RecvStream,
608    S::Error: Into<std::io::Error>,
609{
610    fn poll_read(
611        mut self: Pin<&mut Self>,
612        cx: &mut Context<'_>,
613        buf: &mut ReadBuf<'_>,
614    ) -> Poll<futures_util::io::Result<()>> {
615        let p = &mut *self;
616        // Poll for data if the buffer is empty
617        //
618        // If there is data available *do not* poll for more data, as that may suspend indefinitely
619        // if no more data is sent, causing data loss.
620        if !p.has_remaining() {
621            let eos = ready!(p.poll_read(cx).map_err(Into::into))?;
622            if eos {
623                return Poll::Ready(Ok(()));
624            }
625        }
626
627        let chunk = p.buf_mut().take_chunk(buf.remaining());
628        if let Some(chunk) = chunk {
629            assert!(chunk.len() <= buf.remaining());
630            // Write the subset into the destination
631            buf.put_slice(&chunk);
632            Poll::Ready(Ok(()))
633        } else {
634            Poll::Ready(Ok(()))
635        }
636    }
637}
638
639impl<S, B> futures_util::io::AsyncWrite for BufRecvStream<S, B>
640where
641    B: Buf,
642    S: SendStreamUnframed<B>,
643    S::Error: Into<std::io::Error>,
644{
645    fn poll_write(
646        mut self: Pin<&mut Self>,
647        cx: &mut Context<'_>,
648        mut buf: &[u8],
649    ) -> Poll<std::io::Result<usize>> {
650        let p = &mut *self;
651        p.poll_send(cx, &mut buf).map_err(Into::into)
652    }
653
654    fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<std::io::Result<()>> {
655        Poll::Ready(Ok(()))
656    }
657
658    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
659        let p = &mut *self;
660        p.poll_finish(cx).map_err(Into::into)
661    }
662}
663
664impl<S, B> tokio::io::AsyncWrite for BufRecvStream<S, B>
665where
666    B: Buf,
667    S: SendStreamUnframed<B>,
668    S::Error: Into<std::io::Error>,
669{
670    fn poll_write(
671        mut self: Pin<&mut Self>,
672        cx: &mut Context<'_>,
673        mut buf: &[u8],
674    ) -> Poll<std::io::Result<usize>> {
675        let p = &mut *self;
676        p.poll_send(cx, &mut buf).map_err(Into::into)
677    }
678
679    fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<std::io::Result<()>> {
680        Poll::Ready(Ok(()))
681    }
682
683    fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
684        let p = &mut *self;
685        p.poll_finish(cx).map_err(Into::into)
686    }
687}
688
689#[cfg(test)]
690mod tests {
691    use quinn_proto::coding::BufExt;
692
693    use super::*;
694
695    #[test]
696    fn write_wt_uni_header() {
697        let mut w = WriteBuf::<Bytes>::from(UniStreamHeader::WebTransportUni(
698            SessionId::from_varint(VarInt(5)),
699        ));
700
701        let ty = w.get_var().unwrap();
702        println!("Got type: {ty} {ty:#x}");
703        assert_eq!(ty, 0x54);
704
705        let id = w.get_var().unwrap();
706        println!("Got id: {id}");
707    }
708
709    #[test]
710    fn write_buf_encode_streamtype() {
711        let wbuf = WriteBuf::<Bytes>::from(StreamType::ENCODER);
712
713        assert_eq!(wbuf.chunk(), b"\x02");
714        assert_eq!(wbuf.len, 1);
715    }
716
717    #[test]
718    fn write_buf_encode_frame() {
719        let wbuf = WriteBuf::<Bytes>::from(Frame::Goaway(VarInt(2)));
720
721        assert_eq!(wbuf.chunk(), b"\x07\x01\x02");
722        assert_eq!(wbuf.len, 3);
723    }
724
725    #[test]
726    fn write_buf_encode_streamtype_then_frame() {
727        let wbuf = WriteBuf::<Bytes>::from((StreamType::ENCODER, Frame::Goaway(VarInt(2))));
728
729        assert_eq!(wbuf.chunk(), b"\x02\x07\x01\x02");
730    }
731
732    #[test]
733    fn write_buf_advances() {
734        let mut wbuf =
735            WriteBuf::<Bytes>::from((StreamType::ENCODER, Frame::Data(Bytes::from("hey"))));
736
737        assert_eq!(wbuf.chunk(), b"\x02\x00\x03");
738        wbuf.advance(3);
739        assert_eq!(wbuf.remaining(), 3);
740        assert_eq!(wbuf.chunk(), b"hey");
741        wbuf.advance(2);
742        assert_eq!(wbuf.chunk(), b"y");
743        wbuf.advance(1);
744        assert_eq!(wbuf.remaining(), 0);
745    }
746
747    #[test]
748    fn write_buf_advance_jumps_header_and_payload_start() {
749        let mut wbuf =
750            WriteBuf::<Bytes>::from((StreamType::ENCODER, Frame::Data(Bytes::from("hey"))));
751
752        wbuf.advance(4);
753        assert_eq!(wbuf.chunk(), b"ey");
754    }
755}