wtransport_proto_lightyear_patch/
stream_header.rs

1use crate::bytes::BufferReader;
2use crate::bytes::BufferWriter;
3use crate::bytes::BytesReader;
4use crate::bytes::BytesWriter;
5use crate::bytes::EndOfBuffer;
6use crate::ids::InvalidSessionId;
7use crate::ids::SessionId;
8use crate::varint::VarInt;
9
10#[cfg(feature = "async")]
11use crate::bytes::AsyncRead;
12
13#[cfg(feature = "async")]
14use crate::bytes::AsyncWrite;
15
16#[cfg(feature = "async")]
17use crate::bytes;
18
19/// Error stream header parsing.
20#[derive(Debug)]
21pub enum ParseError {
22    /// Error for unknown stream type.
23    UnknownStream,
24
25    /// Error for invalid session ID.
26    InvalidSessionId,
27}
28
29/// An error during stream header I/O read operation.
30#[cfg(feature = "async")]
31#[cfg_attr(docsrs, doc(cfg(feature = "async")))]
32#[derive(Debug)]
33pub enum IoReadError {
34    /// Error during parsing stream header.
35    Parse(ParseError),
36
37    /// Error due to I/O operation.
38    IO(bytes::IoReadError),
39}
40
41#[cfg(feature = "async")]
42impl From<bytes::IoReadError> for IoReadError {
43    #[inline(always)]
44    fn from(io_error: bytes::IoReadError) -> Self {
45        IoReadError::IO(io_error)
46    }
47}
48
49/// An error during stream header I/O write operation.
50#[cfg(feature = "async")]
51pub type IoWriteError = bytes::IoWriteError;
52
53/// An HTTP3 stream type.
54#[derive(Copy, Clone, Debug)]
55pub enum StreamKind {
56    /// CONTROL stream type.
57    Control,
58
59    /// QPACK Encoder stream type.
60    QPackEncoder,
61
62    /// QPACK Decoder stream type.
63    QPackDecoder,
64
65    /// WebTransport stream type.
66    WebTransport,
67
68    /// Exercise stream.
69    Exercise(VarInt),
70}
71
72impl StreamKind {
73    /// Checks whether an `id` is valid for a [`StreamKind::Exercise`].
74    #[inline(always)]
75    pub const fn is_id_exercise(id: VarInt) -> bool {
76        id.into_inner() >= 0x21 && ((id.into_inner() - 0x21) % 0x1f == 0)
77    }
78
79    const fn parse(id: VarInt) -> Option<Self> {
80        match id {
81            stream_type_ids::CONTROL_STREAM => Some(StreamKind::Control),
82            stream_type_ids::QPACK_ENCODER_STREAM => Some(StreamKind::QPackEncoder),
83            stream_type_ids::QPACK_DECODER_STREAM => Some(StreamKind::QPackDecoder),
84            stream_type_ids::WEBTRANSPORT_STREAM => Some(StreamKind::WebTransport),
85            id if StreamKind::is_id_exercise(id) => Some(StreamKind::Exercise(id)),
86            _ => None,
87        }
88    }
89
90    const fn id(self) -> VarInt {
91        match self {
92            StreamKind::Control => stream_type_ids::CONTROL_STREAM,
93            StreamKind::QPackEncoder => stream_type_ids::QPACK_ENCODER_STREAM,
94            StreamKind::QPackDecoder => stream_type_ids::QPACK_DECODER_STREAM,
95            StreamKind::WebTransport => stream_type_ids::WEBTRANSPORT_STREAM,
96            StreamKind::Exercise(id) => id,
97        }
98    }
99}
100
101/// HTTP3 stream type.
102///
103/// *Unidirectional* HTTP3 streams have an header encoding the type.
104pub struct StreamHeader {
105    kind: StreamKind,
106    session_id: Option<SessionId>,
107}
108
109impl StreamHeader {
110    /// Maximum number of bytes a [`StreamHeader`] can take over network.
111    pub const MAX_SIZE: usize = 16;
112
113    /// Creates a new stream header of type [`StreamKind::Control`].
114    #[inline(always)]
115    pub fn new_control() -> Self {
116        Self::new(StreamKind::Control, None)
117    }
118
119    /// Creates a new stream header of type [`StreamKind::WebTransport`].
120    #[inline(always)]
121    pub fn new_webtransport(session_id: SessionId) -> Self {
122        Self::new(StreamKind::WebTransport, Some(session_id))
123    }
124
125    /// Reads a [`StreamHeader`] from a [`BytesReader`].
126    ///
127    /// It returns [`None`] if the `bytes_reader` does not contain enough bytes
128    /// to parse an entire header.
129    ///
130    /// In case [`None`] or [`Err`], `bytes_reader` might be partially read.
131    pub fn read<'a, R>(bytes_reader: &mut R) -> Result<Option<Self>, ParseError>
132    where
133        R: BytesReader<'a>,
134    {
135        let kind = match bytes_reader.get_varint() {
136            Some(kind_id) => StreamKind::parse(kind_id).ok_or(ParseError::UnknownStream)?,
137            None => return Ok(None),
138        };
139
140        let session_id = if matches!(kind, StreamKind::WebTransport) {
141            let session_id = match bytes_reader.get_varint() {
142                Some(session_id) => SessionId::try_from_varint(session_id)
143                    .map_err(|InvalidSessionId| ParseError::InvalidSessionId)?,
144                None => return Ok(None),
145            };
146
147            Some(session_id)
148        } else {
149            None
150        };
151
152        Ok(Some(Self::new(kind, session_id)))
153    }
154
155    /// Reads a [`StreamHeader`] from a `reader`.
156    #[cfg(feature = "async")]
157    #[cfg_attr(docsrs, doc(cfg(feature = "async")))]
158    pub async fn read_async<R>(reader: &mut R) -> Result<Self, IoReadError>
159    where
160        R: AsyncRead + Unpin + ?Sized,
161    {
162        use crate::bytes::BytesReaderAsync;
163
164        let kind_id = reader.get_varint().await?;
165        let kind =
166            StreamKind::parse(kind_id).ok_or(IoReadError::Parse(ParseError::UnknownStream))?;
167
168        let session_id = if matches!(kind, StreamKind::WebTransport) {
169            let session_id =
170                SessionId::try_from_varint(reader.get_varint().await.map_err(|e| match e {
171                    bytes::IoReadError::ImmediateFin => bytes::IoReadError::UnexpectedFin,
172                    _ => e,
173                })?)
174                .map_err(|InvalidSessionId| IoReadError::Parse(ParseError::InvalidSessionId))?;
175
176            Some(session_id)
177        } else {
178            None
179        };
180
181        Ok(Self::new(kind, session_id))
182    }
183
184    /// Reads a [`StreamHeader`] from a [`BufferReader`].
185    ///
186    /// It returns [`None`] if the `buffer_reader` does not contain enough bytes
187    /// to parse an entire header.
188    ///
189    /// In case [`None`] or [`Err`], `buffer_reader` offset if not advanced.
190    pub fn read_from_buffer(buffer_reader: &mut BufferReader) -> Result<Option<Self>, ParseError> {
191        let mut buffer_reader_child = buffer_reader.child();
192
193        match Self::read(&mut *buffer_reader_child)? {
194            Some(header) => {
195                buffer_reader_child.commit();
196                Ok(Some(header))
197            }
198            None => Ok(None),
199        }
200    }
201
202    /// Writes a [`StreamHeader`] into a [`BytesWriter`].
203    ///
204    /// It returns [`Err`] if the `bytes_writer` does not have enough capacity
205    /// to write the entire header.
206    /// See [`Self::write_size`] to retrieve the exact amount of required capacity.
207    ///
208    /// In case [`Err`], `bytes_writer` might be partially written.
209    pub fn write<W>(&self, bytes_writer: &mut W) -> Result<(), EndOfBuffer>
210    where
211        W: BytesWriter,
212    {
213        bytes_writer.put_varint(self.kind.id())?;
214
215        if let Some(session_id) = self.session_id() {
216            bytes_writer.put_varint(session_id.into_varint())?;
217        }
218
219        Ok(())
220    }
221
222    /// Writes a [`StreamHeader`] into a `writer`.
223    #[cfg(feature = "async")]
224    #[cfg_attr(docsrs, doc(cfg(feature = "async")))]
225    pub async fn write_async<W>(&self, writer: &mut W) -> Result<(), IoWriteError>
226    where
227        W: AsyncWrite + Unpin + ?Sized,
228    {
229        use crate::bytes::BytesWriterAsync;
230
231        writer.put_varint(self.kind.id()).await?;
232
233        if let Some(session_id) = self.session_id() {
234            writer.put_varint(session_id.into_varint()).await?;
235        }
236
237        Ok(())
238    }
239
240    /// Writes this [`StreamHeader`] into a buffer via [`BufferWriter`].
241    ///
242    /// In case [`Err`], `buffer_writer` is not advanced.
243    pub fn write_to_buffer(&self, buffer_writer: &mut BufferWriter) -> Result<(), EndOfBuffer> {
244        if buffer_writer.capacity() < self.write_size() {
245            return Err(EndOfBuffer);
246        }
247
248        self.write(buffer_writer)
249            .expect("Enough capacity for header");
250
251        Ok(())
252    }
253
254    /// Returns the needed capacity to write this stream header into a buffer.
255    pub fn write_size(&self) -> usize {
256        if let Some(session_id) = self.session_id() {
257            self.kind.id().size() + session_id.into_varint().size()
258        } else {
259            self.kind.id().size()
260        }
261    }
262
263    /// Returns the [`StreamKind`].
264    #[inline(always)]
265    pub const fn kind(&self) -> StreamKind {
266        self.kind
267    }
268
269    /// Returns the [`SessionId`] if stream is [`StreamKind::WebTransport`],
270    /// otherwise returns [`None`].
271    #[inline(always)]
272    pub fn session_id(&self) -> Option<SessionId> {
273        matches!(self.kind, StreamKind::WebTransport).then(|| {
274            self.session_id
275                .expect("WebTransport stream header contains session id")
276        })
277    }
278
279    fn new(kind: StreamKind, session_id: Option<SessionId>) -> Self {
280        if let StreamKind::Exercise(id) = kind {
281            debug_assert!(StreamKind::is_id_exercise(id));
282            debug_assert!(session_id.is_none());
283        } else if let StreamKind::WebTransport = kind {
284            debug_assert!(session_id.is_some());
285        } else {
286            debug_assert!(session_id.is_none());
287        }
288
289        Self { kind, session_id }
290    }
291
292    #[cfg(test)]
293    pub(crate) fn serialize_any(kind: VarInt) -> Vec<u8> {
294        let mut buffer = Vec::new();
295
296        Self {
297            kind: StreamKind::Exercise(kind),
298            session_id: None,
299        }
300        .write(&mut buffer)
301        .unwrap();
302
303        buffer
304    }
305
306    #[cfg(test)]
307    pub(crate) fn serialize_webtransport(session_id: SessionId) -> Vec<u8> {
308        let mut buffer = Vec::new();
309
310        Self {
311            kind: StreamKind::WebTransport,
312            session_id: Some(session_id),
313        }
314        .write(&mut buffer)
315        .unwrap();
316
317        buffer
318    }
319}
320
321mod stream_type_ids {
322    use crate::varint::VarInt;
323
324    pub const CONTROL_STREAM: VarInt = VarInt::from_u32(0x0);
325    pub const QPACK_ENCODER_STREAM: VarInt = VarInt::from_u32(0x02);
326    pub const QPACK_DECODER_STREAM: VarInt = VarInt::from_u32(0x03);
327    pub const WEBTRANSPORT_STREAM: VarInt = VarInt::from_u32(0x54);
328}
329
330#[cfg(test)]
331mod tests {
332    use super::*;
333
334    #[test]
335    fn control() {
336        let stream_header = StreamHeader::new_control();
337        assert!(matches!(stream_header.kind(), StreamKind::Control));
338        assert!(stream_header.session_id().is_none());
339
340        let stream_header = utils::assert_serde(stream_header);
341        assert!(matches!(stream_header.kind(), StreamKind::Control));
342        assert!(stream_header.session_id().is_none());
343    }
344
345    #[tokio::test]
346    async fn control_async() {
347        let stream_header = StreamHeader::new_control();
348        assert!(matches!(stream_header.kind(), StreamKind::Control));
349        assert!(stream_header.session_id().is_none());
350
351        let stream_header = utils::assert_serde_async(stream_header).await;
352        assert!(matches!(stream_header.kind(), StreamKind::Control));
353        assert!(stream_header.session_id().is_none());
354    }
355
356    #[test]
357    fn webtransport() {
358        let session_id = SessionId::try_from_varint(VarInt::from_u32(0)).unwrap();
359
360        let stream_header = StreamHeader::new_webtransport(session_id);
361        assert!(matches!(stream_header.kind(), StreamKind::WebTransport));
362        assert!(matches!(stream_header.session_id(), Some(x) if x == session_id));
363
364        let stream_header = utils::assert_serde(stream_header);
365        assert!(matches!(stream_header.kind(), StreamKind::WebTransport));
366        assert!(matches!(stream_header.session_id(), Some(x) if x == session_id));
367    }
368
369    #[tokio::test]
370    async fn webtransport_async() {
371        let session_id = SessionId::try_from_varint(VarInt::from_u32(0)).unwrap();
372
373        let stream_header = StreamHeader::new_webtransport(session_id);
374        assert!(matches!(stream_header.kind(), StreamKind::WebTransport));
375        assert!(matches!(stream_header.session_id(), Some(x) if x == session_id));
376
377        let stream_header = utils::assert_serde_async(stream_header).await;
378        assert!(matches!(stream_header.kind(), StreamKind::WebTransport));
379        assert!(matches!(stream_header.session_id(), Some(x) if x == session_id));
380    }
381
382    #[test]
383    fn read_eof() {
384        let buffer = StreamHeader::serialize_any(VarInt::from_u32(0x0042_4242));
385        assert!(StreamHeader::read(&mut &buffer[..buffer.len() - 1])
386            .unwrap()
387            .is_none());
388    }
389
390    #[tokio::test]
391    async fn read_eof_async() {
392        let buffer = StreamHeader::serialize_any(VarInt::from_u32(0x0042_4242));
393
394        for len in 0..buffer.len() {
395            let result = StreamHeader::read_async(&mut &buffer[..len]).await;
396
397            match len {
398                0 => assert!(matches!(
399                    result,
400                    Err(IoReadError::IO(bytes::IoReadError::ImmediateFin))
401                )),
402                _ => assert!(matches!(
403                    result,
404                    Err(IoReadError::IO(bytes::IoReadError::UnexpectedFin))
405                )),
406            }
407        }
408    }
409
410    #[tokio::test]
411    async fn read_eof_webtransport_async() {
412        let session_id = SessionId::try_from_varint(VarInt::from_u32(0)).unwrap();
413        let buffer = StreamHeader::serialize_webtransport(session_id);
414
415        for len in 0..buffer.len() {
416            let result = StreamHeader::read_async(&mut &buffer[..len]).await;
417
418            match len {
419                0 => assert!(matches!(
420                    result,
421                    Err(IoReadError::IO(bytes::IoReadError::ImmediateFin))
422                )),
423                _ => assert!(matches!(
424                    result,
425                    Err(IoReadError::IO(bytes::IoReadError::UnexpectedFin))
426                )),
427            }
428        }
429    }
430
431    #[test]
432    fn unknown_stream() {
433        let buffer = StreamHeader::serialize_any(VarInt::from_u32(0x0042_4242));
434
435        assert!(matches!(
436            StreamHeader::read(&mut buffer.as_slice()),
437            Err(ParseError::UnknownStream)
438        ));
439    }
440
441    #[tokio::test]
442    async fn unknown_stream_async() {
443        let buffer = StreamHeader::serialize_any(VarInt::from_u32(0x0042_4242));
444
445        assert!(matches!(
446            StreamHeader::read_async(&mut buffer.as_slice()).await,
447            Err(IoReadError::Parse(ParseError::UnknownStream))
448        ));
449    }
450
451    #[test]
452    fn invalid_session_id() {
453        let invalid_session_id = SessionId::maybe_invalid(VarInt::from_u32(1));
454        let buffer = StreamHeader::serialize_webtransport(invalid_session_id);
455
456        assert!(matches!(
457            StreamHeader::read(&mut buffer.as_slice()),
458            Err(ParseError::InvalidSessionId)
459        ));
460    }
461
462    #[tokio::test]
463    async fn invalid_session_id_async() {
464        let invalid_session_id = SessionId::maybe_invalid(VarInt::from_u32(1));
465        let buffer = StreamHeader::serialize_webtransport(invalid_session_id);
466
467        assert!(matches!(
468            StreamHeader::read_async(&mut buffer.as_slice()).await,
469            Err(IoReadError::Parse(ParseError::InvalidSessionId))
470        ));
471    }
472
473    mod utils {
474        use super::*;
475
476        pub fn assert_serde(stream_header: StreamHeader) -> StreamHeader {
477            let mut buffer = Vec::new();
478
479            stream_header.write(&mut buffer).unwrap();
480            assert_eq!(buffer.len(), stream_header.write_size());
481            assert!(buffer.len() <= StreamHeader::MAX_SIZE);
482
483            let mut buffer = buffer.as_slice();
484            let stream_header = StreamHeader::read(&mut buffer).unwrap().unwrap();
485            assert!(buffer.is_empty());
486
487            stream_header
488        }
489
490        #[cfg(feature = "async")]
491        pub async fn assert_serde_async(stream_header: StreamHeader) -> StreamHeader {
492            let mut buffer = Vec::new();
493
494            stream_header.write_async(&mut buffer).await.unwrap();
495            assert_eq!(buffer.len(), stream_header.write_size());
496            assert!(buffer.len() <= StreamHeader::MAX_SIZE);
497
498            let mut buffer = buffer.as_slice();
499            let stream_header = StreamHeader::read_async(&mut buffer).await.unwrap();
500            assert!(buffer.is_empty());
501
502            stream_header
503        }
504    }
505}