wtransport_proto_lightyear_patch/
frame.rs

1use crate::bytes::BufferReader;
2use crate::bytes::BufferWriter;
3use crate::bytes::BytesReader;
4use crate::bytes::BytesWriter;
5use crate::bytes::EndOfBuffer;
6use crate::ids::InvalidSessionId;
7use crate::ids::SessionId;
8use crate::varint::VarInt;
9use std::borrow::Cow;
10
11#[cfg(feature = "async")]
12use crate::bytes::AsyncRead;
13
14#[cfg(feature = "async")]
15use crate::bytes::AsyncWrite;
16
17#[cfg(feature = "async")]
18use crate::bytes;
19
20/// Error frame parsing.
21#[derive(Debug)]
22pub enum ParseError {
23    /// Error for unknown frame ID.
24    UnknownFrame,
25
26    /// Error for invalid session ID.
27    InvalidSessionId,
28
29    /// Payload required too big.
30    PayloadTooBig,
31}
32
33/// An error during frame I/O read operation.
34#[cfg(feature = "async")]
35#[cfg_attr(docsrs, doc(cfg(feature = "async")))]
36#[derive(Debug)]
37pub enum IoReadError {
38    /// Error during parsing a frame.
39    Parse(ParseError),
40
41    /// Error due to I/O operation.
42    IO(bytes::IoReadError),
43}
44
45#[cfg(feature = "async")]
46impl From<bytes::IoReadError> for IoReadError {
47    #[inline(always)]
48    fn from(io_error: bytes::IoReadError) -> Self {
49        IoReadError::IO(io_error)
50    }
51}
52
53/// An error during frame I/O write operation.
54#[cfg(feature = "async")]
55pub type IoWriteError = bytes::IoWriteError;
56
57/// Alias for [`Frame<'static>`](Frame);
58pub type FrameOwned = Frame<'static>;
59
60/// An HTTP3 [`Frame`] type.
61#[derive(Copy, Clone, Debug)]
62pub enum FrameKind {
63    /// DATA frame type.
64    Data,
65
66    /// HEADERS frame type.
67    Headers,
68
69    /// SETTINGS frame type.
70    Settings,
71
72    /// WebTransport frame type.
73    WebTransport,
74
75    /// Exercise frame.
76    Exercise(VarInt),
77}
78
79impl FrameKind {
80    /// Checks whether an `id` is valid for a [`FrameKind::Exercise`].
81    #[inline(always)]
82    pub const fn is_id_exercise(id: VarInt) -> bool {
83        id.into_inner() >= 0x21 && ((id.into_inner() - 0x21) % 0x1f == 0)
84    }
85
86    const fn parse(id: VarInt) -> Option<Self> {
87        match id {
88            frame_kind_ids::DATA => Some(FrameKind::Data),
89            frame_kind_ids::HEADERS => Some(FrameKind::Headers),
90            frame_kind_ids::SETTINGS => Some(FrameKind::Settings),
91            frame_kind_ids::WEBTRANSPORT_STREAM => Some(FrameKind::WebTransport),
92            id if FrameKind::is_id_exercise(id) => Some(FrameKind::Exercise(id)),
93            _ => None,
94        }
95    }
96
97    const fn id(self) -> VarInt {
98        match self {
99            FrameKind::Data => frame_kind_ids::DATA,
100            FrameKind::Headers => frame_kind_ids::HEADERS,
101            FrameKind::Settings => frame_kind_ids::SETTINGS,
102            FrameKind::WebTransport => frame_kind_ids::WEBTRANSPORT_STREAM,
103            FrameKind::Exercise(id) => id,
104        }
105    }
106}
107
108/// An HTTP3 frame.
109pub struct Frame<'a> {
110    kind: FrameKind,
111    payload: Cow<'a, [u8]>,
112    session_id: Option<SessionId>,
113}
114
115impl<'a> Frame<'a> {
116    const MAX_PARSE_PAYLOAD_ALLOWED: usize = 4096;
117
118    /// Creates a new frame of type [`FrameKind::Headers`].
119    ///
120    /// # Panics
121    ///
122    /// Panics if the `payload` size if greater than [`VarInt::MAX`].
123    #[inline(always)]
124    pub fn new_headers(payload: Cow<'a, [u8]>) -> Self {
125        Self::new(FrameKind::Headers, payload, None)
126    }
127
128    /// Creates a new frame of type [`FrameKind::Settings`].
129    ///
130    /// # Panics
131    ///
132    /// Panics if the `payload` size if greater than [`VarInt::MAX`].
133    #[inline(always)]
134    pub fn new_settings(payload: Cow<'a, [u8]>) -> Self {
135        Self::new(FrameKind::Settings, payload, None)
136    }
137
138    /// Creates a new frame of type [`FrameKind::WebTransport`].
139    #[inline(always)]
140    pub fn new_webtransport(session_id: SessionId) -> Self {
141        Self::new(
142            FrameKind::WebTransport,
143            Cow::Owned(Default::default()),
144            Some(session_id),
145        )
146    }
147
148    /// Creates a new frame of type [`FrameKind::Exercise`].
149    ///
150    /// # Panics
151    ///
152    /// * Panics if the `payload` size if greater than [`VarInt::MAX`].
153    /// * Panics if `id` is not a valid exercise (see [`FrameKind::is_id_exercise`]).
154    #[inline(always)]
155    pub fn new_exercise(id: VarInt, payload: Cow<'a, [u8]>) -> Self {
156        assert!(FrameKind::is_id_exercise(id));
157        Self::new(FrameKind::Exercise(id), payload, None)
158    }
159
160    /// Reads a [`Frame`] from a [`BytesReader`].
161    ///
162    /// It returns [`None`] if the `bytes_reader` does not contain enough bytes
163    /// to parse an entire frame.
164    ///
165    /// In case [`None`] or [`Err`], `bytes_reader` might be partially read.
166    pub fn read<R>(bytes_reader: &mut R) -> Result<Option<Self>, ParseError>
167    where
168        R: BytesReader<'a>,
169    {
170        let kind = match bytes_reader.get_varint() {
171            Some(kind_id) => FrameKind::parse(kind_id).ok_or(ParseError::UnknownFrame)?,
172            None => return Ok(None),
173        };
174
175        if matches!(kind, FrameKind::WebTransport) {
176            let session_id = match bytes_reader.get_varint() {
177                Some(session_id) => SessionId::try_from_varint(session_id)
178                    .map_err(|InvalidSessionId| ParseError::InvalidSessionId)?,
179                None => return Ok(None),
180            };
181
182            Ok(Some(Self::new_webtransport(session_id)))
183        } else {
184            let payload_len = match bytes_reader.get_varint() {
185                Some(payload_len) => payload_len.into_inner() as usize,
186                None => return Ok(None),
187            };
188
189            if payload_len > Self::MAX_PARSE_PAYLOAD_ALLOWED {
190                return Err(ParseError::PayloadTooBig);
191            }
192
193            let payload = match bytes_reader.get_bytes(payload_len) {
194                Some(payload) => payload,
195                None => return Ok(None),
196            };
197
198            Ok(Some(Self::new(kind, Cow::Borrowed(payload), None)))
199        }
200    }
201
202    /// Reads a [`Frame`] from a `reader`.
203    #[cfg(feature = "async")]
204    #[cfg_attr(docsrs, doc(cfg(feature = "async")))]
205    pub async fn read_async<R>(reader: &mut R) -> Result<Frame<'a>, IoReadError>
206    where
207        R: AsyncRead + Unpin + ?Sized,
208    {
209        use crate::bytes::BytesReaderAsync;
210
211        let kind_id = reader.get_varint().await?;
212        let kind = FrameKind::parse(kind_id).ok_or(IoReadError::Parse(ParseError::UnknownFrame))?;
213
214        if matches!(kind, FrameKind::WebTransport) {
215            let session_id =
216                SessionId::try_from_varint(reader.get_varint().await.map_err(|e| match e {
217                    bytes::IoReadError::ImmediateFin => bytes::IoReadError::UnexpectedFin,
218                    _ => e,
219                })?)
220                .map_err(|InvalidSessionId| IoReadError::Parse(ParseError::InvalidSessionId))?;
221
222            Ok(Self::new_webtransport(session_id))
223        } else {
224            let payload_len = reader
225                .get_varint()
226                .await
227                .map_err(|e| match e {
228                    bytes::IoReadError::ImmediateFin => bytes::IoReadError::UnexpectedFin,
229                    _ => e,
230                })?
231                .into_inner() as usize;
232
233            if payload_len > Self::MAX_PARSE_PAYLOAD_ALLOWED {
234                return Err(IoReadError::Parse(ParseError::PayloadTooBig));
235            }
236
237            let mut payload = vec![0; payload_len];
238
239            reader.get_buffer(&mut payload).await.map_err(|e| match e {
240                bytes::IoReadError::ImmediateFin => bytes::IoReadError::UnexpectedFin,
241                _ => e,
242            })?;
243
244            payload.shrink_to_fit();
245
246            Ok(Self::new(kind, Cow::Owned(payload), None))
247        }
248    }
249
250    /// Reads a [`Frame`] from a [`BufferReader`].
251    ///
252    /// It returns [`None`] if the `buffer_reader` does not contain enough bytes
253    /// to parse an entire frame.
254    ///
255    /// In case [`None`] or [`Err`], `buffer_reader` offset if not advanced.
256    pub fn read_from_buffer(
257        buffer_reader: &mut BufferReader<'a>,
258    ) -> Result<Option<Self>, ParseError> {
259        let mut buffer_reader_child = buffer_reader.child();
260
261        match Self::read(&mut *buffer_reader_child)? {
262            Some(frame) => {
263                buffer_reader_child.commit();
264                Ok(Some(frame))
265            }
266            None => Ok(None),
267        }
268    }
269
270    /// Writes a [`Frame`] into a [`BytesWriter`].
271    ///
272    /// It returns [`Err`] if the `bytes_writer` does not have enough capacity
273    /// to write the entire frame.
274    /// See [`Self::write_size`] to retrieve the exact amount of required capacity.
275    ///
276    /// In case [`Err`], `bytes_writer` might be partially written.
277    ///
278    /// # Panics
279    ///
280    /// Panics if the payload size if greater than [`VarInt::MAX`].
281    pub fn write<W>(&self, bytes_writer: &mut W) -> Result<(), EndOfBuffer>
282    where
283        W: BytesWriter,
284    {
285        bytes_writer.put_varint(self.kind.id())?;
286
287        if let Some(session_id) = self.session_id() {
288            bytes_writer.put_varint(session_id.into_varint())?;
289        } else {
290            bytes_writer.put_varint(
291                VarInt::try_from(self.payload.len() as u64)
292                    .expect("Payload cannot be larger than varint max"),
293            )?;
294            bytes_writer.put_bytes(&self.payload)?;
295        }
296
297        Ok(())
298    }
299
300    /// Writes a [`Frame`] into a `writer`.
301    ///
302    /// # Panics
303    ///
304    /// Panics if the payload size if greater than [`VarInt::MAX`].
305    #[cfg(feature = "async")]
306    #[cfg_attr(docsrs, doc(cfg(feature = "async")))]
307    pub async fn write_async<W>(&self, writer: &mut W) -> Result<(), IoWriteError>
308    where
309        W: AsyncWrite + Unpin + ?Sized,
310    {
311        use crate::bytes::BytesWriterAsync;
312
313        writer.put_varint(self.kind.id()).await?;
314
315        if let Some(session_id) = self.session_id() {
316            writer.put_varint(session_id.into_varint()).await?;
317        } else {
318            writer
319                .put_varint(
320                    VarInt::try_from(self.payload.len() as u64)
321                        .expect("Payload cannot be larger than varint max"),
322                )
323                .await?;
324            writer.put_buffer(&self.payload).await?;
325        }
326
327        Ok(())
328    }
329
330    /// Writes this [`Frame`] into a buffer via [`BufferWriter`].
331    ///
332    /// In case [`Err`], `buffer_writer` is not advanced.
333    ///
334    /// # Panics
335    ///
336    /// Panics if the payload size if greater than [`VarInt::MAX`].
337    pub fn write_to_buffer(&self, buffer_writer: &mut BufferWriter) -> Result<(), EndOfBuffer> {
338        if buffer_writer.capacity() < self.write_size() {
339            return Err(EndOfBuffer);
340        }
341
342        self.write(buffer_writer)
343            .expect("Enough capacity for frame");
344
345        Ok(())
346    }
347
348    /// Returns the needed capacity to write this frame into a buffer.
349    pub fn write_size(&self) -> usize {
350        if let Some(session_id) = self.session_id() {
351            self.kind.id().size() + session_id.into_varint().size()
352        } else {
353            self.kind.id().size()
354                + VarInt::try_from(self.payload.len() as u64)
355                    .expect("Payload cannot be larger than varint max")
356                    .size()
357                + self.payload.len()
358        }
359    }
360
361    /// Returns the [`FrameKind`] of this [`Frame`].
362    #[inline(always)]
363    pub const fn kind(&self) -> FrameKind {
364        self.kind
365    }
366
367    /// Returns the payload of this [`Frame`].
368    #[inline(always)]
369    pub fn payload(&self) -> &[u8] {
370        &self.payload
371    }
372
373    /// Returns the [`SessionId`] if frame is [`FrameKind::WebTransport`],
374    /// otherwise returns [`None`].
375    #[inline(always)]
376    pub fn session_id(&self) -> Option<SessionId> {
377        matches!(self.kind, FrameKind::WebTransport).then(|| {
378            self.session_id
379                .expect("WebTransport frame contains session id")
380        })
381    }
382
383    /// # Panics
384    ///
385    /// Panics if the `payload` size if greater than [`VarInt::MAX`].
386    fn new(kind: FrameKind, payload: Cow<'a, [u8]>, session_id: Option<SessionId>) -> Self {
387        if let FrameKind::Exercise(id) = kind {
388            debug_assert!(FrameKind::is_id_exercise(id));
389        } else if let FrameKind::WebTransport = kind {
390            debug_assert!(payload.is_empty());
391            debug_assert!(session_id.is_some());
392        }
393
394        assert!(payload.len() <= VarInt::MAX.into_inner() as usize);
395
396        Self {
397            kind,
398            payload,
399            session_id,
400        }
401    }
402
403    #[cfg(test)]
404    pub(crate) fn into_owned<'b>(self) -> Frame<'b> {
405        Frame {
406            kind: self.kind,
407            payload: Cow::Owned(self.payload.into_owned()),
408            session_id: self.session_id,
409        }
410    }
411
412    #[cfg(test)]
413    pub(crate) fn serialize_any(kind: VarInt, payload: &[u8]) -> Vec<u8> {
414        let mut buffer = Vec::new();
415
416        Self {
417            kind: FrameKind::Exercise(kind),
418            payload: Cow::Owned(payload.to_vec()),
419            session_id: None,
420        }
421        .write(&mut buffer)
422        .unwrap();
423
424        buffer
425    }
426
427    #[cfg(test)]
428    pub(crate) fn serialize_webtransport(session_id: SessionId) -> Vec<u8> {
429        let mut buffer = Vec::new();
430
431        Self {
432            kind: FrameKind::WebTransport,
433            payload: Cow::Owned(Default::default()),
434            session_id: Some(session_id),
435        }
436        .write(&mut buffer)
437        .unwrap();
438
439        buffer
440    }
441}
442
443mod frame_kind_ids {
444    use crate::varint::VarInt;
445
446    pub const DATA: VarInt = VarInt::from_u32(0x00);
447    pub const HEADERS: VarInt = VarInt::from_u32(0x01);
448    pub const SETTINGS: VarInt = VarInt::from_u32(0x04);
449    pub const WEBTRANSPORT_STREAM: VarInt = VarInt::from_u32(0x41);
450}
451
452#[cfg(test)]
453mod tests {
454    use super::*;
455    use crate::headers::Headers;
456    use crate::ids::StreamId;
457    use crate::settings::Settings;
458
459    #[test]
460    fn settings() {
461        let settings = Settings::builder()
462            .qpack_blocked_streams(VarInt::from_u32(1))
463            .qpack_max_table_capacity(VarInt::from_u32(2))
464            .enable_h3_datagrams()
465            .enable_webtransport()
466            .webtransport_max_sessions(VarInt::from_u32(3))
467            .build();
468
469        let frame = settings.generate_frame();
470        assert!(frame.session_id().is_none());
471        assert!(matches!(frame.kind(), FrameKind::Settings));
472
473        let frame = utils::assert_serde(frame);
474        Settings::with_frame(&frame).unwrap();
475    }
476
477    #[tokio::test]
478    async fn settings_async() {
479        let settings = Settings::builder()
480            .qpack_blocked_streams(VarInt::from_u32(1))
481            .qpack_max_table_capacity(VarInt::from_u32(2))
482            .enable_h3_datagrams()
483            .enable_webtransport()
484            .webtransport_max_sessions(VarInt::from_u32(3))
485            .build();
486
487        let frame = settings.generate_frame();
488        assert!(frame.session_id().is_none());
489        assert!(matches!(frame.kind(), FrameKind::Settings));
490
491        let frame = utils::assert_serde_async(frame).await;
492        Settings::with_frame(&frame).unwrap();
493    }
494
495    #[test]
496    fn headers() {
497        let stream_id = StreamId::new(VarInt::from_u32(0));
498        let headers = Headers::from_iter([("key1", "value1")]);
499
500        let frame = headers.generate_frame(stream_id);
501        assert!(frame.session_id().is_none());
502        assert!(matches!(frame.kind(), FrameKind::Headers));
503
504        let frame = utils::assert_serde(frame);
505        Headers::with_frame(&frame, stream_id).unwrap();
506    }
507
508    #[tokio::test]
509    async fn headers_async() {
510        let stream_id = StreamId::new(VarInt::from_u32(0));
511        let headers = Headers::from_iter([("key1", "value1")]);
512
513        let frame = headers.generate_frame(stream_id);
514        assert!(frame.session_id().is_none());
515        assert!(matches!(frame.kind(), FrameKind::Headers));
516
517        let frame = utils::assert_serde_async(frame).await;
518        Headers::with_frame(&frame, stream_id).unwrap();
519    }
520
521    #[test]
522    fn webtransport() {
523        let session_id = SessionId::try_from_varint(VarInt::from_u32(0)).unwrap();
524        let frame = Frame::new_webtransport(session_id);
525
526        assert!(frame.payload().is_empty());
527        assert!(matches!(frame.session_id(), Some(x) if x == session_id));
528        assert!(matches!(frame.kind(), FrameKind::WebTransport));
529
530        let frame = utils::assert_serde(frame);
531
532        assert!(frame.payload().is_empty());
533        assert!(matches!(frame.session_id(), Some(x) if x == session_id));
534        assert!(matches!(frame.kind(), FrameKind::WebTransport));
535    }
536
537    #[tokio::test]
538    async fn webtransport_async() {
539        let session_id = SessionId::try_from_varint(VarInt::from_u32(0)).unwrap();
540        let frame = Frame::new_webtransport(session_id);
541
542        assert!(frame.payload().is_empty());
543        assert!(matches!(frame.session_id(), Some(x) if x == session_id));
544        assert!(matches!(frame.kind(), FrameKind::WebTransport));
545
546        let frame = utils::assert_serde_async(frame).await;
547
548        assert!(frame.payload().is_empty());
549        assert!(matches!(frame.session_id(), Some(x) if x == session_id));
550        assert!(matches!(frame.kind(), FrameKind::WebTransport));
551    }
552
553    #[test]
554    fn read_eof() {
555        let buffer = Frame::serialize_any(FrameKind::Data.id(), b"This is a test payload");
556        assert!(Frame::read(&mut &buffer[..buffer.len() - 1])
557            .unwrap()
558            .is_none());
559    }
560
561    #[tokio::test]
562    async fn read_eof_async() {
563        let buffer = Frame::serialize_any(FrameKind::Data.id(), b"This is a test payload");
564
565        for len in 0..buffer.len() {
566            let result = Frame::read_async(&mut &buffer[..len]).await;
567
568            match len {
569                0 => assert!(matches!(
570                    result,
571                    Err(IoReadError::IO(bytes::IoReadError::ImmediateFin))
572                )),
573                _ => assert!(matches!(
574                    result,
575                    Err(IoReadError::IO(bytes::IoReadError::UnexpectedFin))
576                )),
577            }
578        }
579    }
580
581    #[tokio::test]
582    async fn read_eof_webtransport_async() {
583        let session_id = SessionId::try_from_varint(VarInt::from_u32(0)).unwrap();
584        let buffer = Frame::serialize_webtransport(session_id);
585
586        for len in 0..buffer.len() {
587            let result = Frame::read_async(&mut &buffer[..len]).await;
588
589            match len {
590                0 => assert!(matches!(
591                    result,
592                    Err(IoReadError::IO(bytes::IoReadError::ImmediateFin))
593                )),
594                _ => assert!(matches!(
595                    result,
596                    Err(IoReadError::IO(bytes::IoReadError::UnexpectedFin))
597                )),
598            }
599        }
600    }
601
602    #[test]
603    fn unknown_frame() {
604        let buffer = Frame::serialize_any(VarInt::from_u32(0x0042_4242), b"This is a test payload");
605
606        assert!(matches!(
607            Frame::read(&mut buffer.as_slice()),
608            Err(ParseError::UnknownFrame)
609        ));
610    }
611
612    #[tokio::test]
613    async fn unknown_frame_async() {
614        let buffer = Frame::serialize_any(VarInt::from_u32(0x0042_4242), b"This is a test payload");
615
616        assert!(matches!(
617            Frame::read_async(&mut buffer.as_slice()).await,
618            Err(IoReadError::Parse(ParseError::UnknownFrame))
619        ));
620    }
621
622    #[test]
623    fn invalid_session_id() {
624        let invalid_session_id = SessionId::maybe_invalid(VarInt::from_u32(1));
625        let buffer = Frame::serialize_webtransport(invalid_session_id);
626
627        assert!(matches!(
628            Frame::read(&mut buffer.as_slice()),
629            Err(ParseError::InvalidSessionId)
630        ));
631    }
632
633    #[tokio::test]
634    async fn invalid_session_id_async() {
635        let invalid_session_id = SessionId::maybe_invalid(VarInt::from_u32(1));
636        let buffer = Frame::serialize_webtransport(invalid_session_id);
637
638        assert!(matches!(
639            Frame::read_async(&mut buffer.as_slice()).await,
640            Err(IoReadError::Parse(ParseError::InvalidSessionId))
641        ));
642    }
643
644    #[test]
645    fn payload_too_big() {
646        let mut buffer = Vec::new();
647        buffer.put_varint(FrameKind::Data.id()).unwrap();
648        buffer
649            .put_varint(VarInt::from_u32(
650                Frame::MAX_PARSE_PAYLOAD_ALLOWED as u32 + 1,
651            ))
652            .unwrap();
653
654        assert!(matches!(
655            Frame::read_from_buffer(&mut BufferReader::new(&buffer)),
656            Err(ParseError::PayloadTooBig)
657        ));
658    }
659
660    #[tokio::test]
661    async fn payload_too_big_async() {
662        let mut buffer = Vec::new();
663        buffer.put_varint(FrameKind::Data.id()).unwrap();
664        buffer
665            .put_varint(VarInt::from_u32(
666                Frame::MAX_PARSE_PAYLOAD_ALLOWED as u32 + 1,
667            ))
668            .unwrap();
669
670        assert!(matches!(
671            Frame::read_async(&mut &*buffer).await,
672            Err(IoReadError::Parse(ParseError::PayloadTooBig)),
673        ));
674    }
675
676    mod utils {
677        use super::*;
678
679        pub fn assert_serde(frame: Frame) -> Frame {
680            let mut buffer = Vec::new();
681
682            frame.write(&mut buffer).unwrap();
683            assert_eq!(buffer.len(), frame.write_size());
684
685            let mut buffer = buffer.as_slice();
686            let frame = Frame::read(&mut buffer).unwrap().unwrap();
687            assert!(buffer.is_empty());
688
689            frame.into_owned()
690        }
691
692        #[cfg(feature = "async")]
693        pub async fn assert_serde_async(frame: Frame<'_>) -> Frame {
694            let mut buffer = Vec::new();
695
696            frame.write_async(&mut buffer).await.unwrap();
697            assert_eq!(buffer.len(), frame.write_size());
698
699            let mut buffer = buffer.as_slice();
700            let frame = Frame::read_async(&mut buffer).await.unwrap();
701            assert!(buffer.is_empty());
702
703            frame.into_owned()
704        }
705    }
706}