wtransport_proto_lightyear_patch/
stream_header.rs1use crate::bytes::BufferReader;
2use crate::bytes::BufferWriter;
3use crate::bytes::BytesReader;
4use crate::bytes::BytesWriter;
5use crate::bytes::EndOfBuffer;
6use crate::ids::InvalidSessionId;
7use crate::ids::SessionId;
8use crate::varint::VarInt;
9
10#[cfg(feature = "async")]
11use crate::bytes::AsyncRead;
12
13#[cfg(feature = "async")]
14use crate::bytes::AsyncWrite;
15
16#[cfg(feature = "async")]
17use crate::bytes;
18
19#[derive(Debug)]
21pub enum ParseError {
22 UnknownStream,
24
25 InvalidSessionId,
27}
28
29#[cfg(feature = "async")]
31#[cfg_attr(docsrs, doc(cfg(feature = "async")))]
32#[derive(Debug)]
33pub enum IoReadError {
34 Parse(ParseError),
36
37 IO(bytes::IoReadError),
39}
40
41#[cfg(feature = "async")]
42impl From<bytes::IoReadError> for IoReadError {
43 #[inline(always)]
44 fn from(io_error: bytes::IoReadError) -> Self {
45 IoReadError::IO(io_error)
46 }
47}
48
49#[cfg(feature = "async")]
51pub type IoWriteError = bytes::IoWriteError;
52
53#[derive(Copy, Clone, Debug)]
55pub enum StreamKind {
56 Control,
58
59 QPackEncoder,
61
62 QPackDecoder,
64
65 WebTransport,
67
68 Exercise(VarInt),
70}
71
72impl StreamKind {
73 #[inline(always)]
75 pub const fn is_id_exercise(id: VarInt) -> bool {
76 id.into_inner() >= 0x21 && ((id.into_inner() - 0x21) % 0x1f == 0)
77 }
78
79 const fn parse(id: VarInt) -> Option<Self> {
80 match id {
81 stream_type_ids::CONTROL_STREAM => Some(StreamKind::Control),
82 stream_type_ids::QPACK_ENCODER_STREAM => Some(StreamKind::QPackEncoder),
83 stream_type_ids::QPACK_DECODER_STREAM => Some(StreamKind::QPackDecoder),
84 stream_type_ids::WEBTRANSPORT_STREAM => Some(StreamKind::WebTransport),
85 id if StreamKind::is_id_exercise(id) => Some(StreamKind::Exercise(id)),
86 _ => None,
87 }
88 }
89
90 const fn id(self) -> VarInt {
91 match self {
92 StreamKind::Control => stream_type_ids::CONTROL_STREAM,
93 StreamKind::QPackEncoder => stream_type_ids::QPACK_ENCODER_STREAM,
94 StreamKind::QPackDecoder => stream_type_ids::QPACK_DECODER_STREAM,
95 StreamKind::WebTransport => stream_type_ids::WEBTRANSPORT_STREAM,
96 StreamKind::Exercise(id) => id,
97 }
98 }
99}
100
101pub struct StreamHeader {
105 kind: StreamKind,
106 session_id: Option<SessionId>,
107}
108
109impl StreamHeader {
110 pub const MAX_SIZE: usize = 16;
112
113 #[inline(always)]
115 pub fn new_control() -> Self {
116 Self::new(StreamKind::Control, None)
117 }
118
119 #[inline(always)]
121 pub fn new_webtransport(session_id: SessionId) -> Self {
122 Self::new(StreamKind::WebTransport, Some(session_id))
123 }
124
125 pub fn read<'a, R>(bytes_reader: &mut R) -> Result<Option<Self>, ParseError>
132 where
133 R: BytesReader<'a>,
134 {
135 let kind = match bytes_reader.get_varint() {
136 Some(kind_id) => StreamKind::parse(kind_id).ok_or(ParseError::UnknownStream)?,
137 None => return Ok(None),
138 };
139
140 let session_id = if matches!(kind, StreamKind::WebTransport) {
141 let session_id = match bytes_reader.get_varint() {
142 Some(session_id) => SessionId::try_from_varint(session_id)
143 .map_err(|InvalidSessionId| ParseError::InvalidSessionId)?,
144 None => return Ok(None),
145 };
146
147 Some(session_id)
148 } else {
149 None
150 };
151
152 Ok(Some(Self::new(kind, session_id)))
153 }
154
155 #[cfg(feature = "async")]
157 #[cfg_attr(docsrs, doc(cfg(feature = "async")))]
158 pub async fn read_async<R>(reader: &mut R) -> Result<Self, IoReadError>
159 where
160 R: AsyncRead + Unpin + ?Sized,
161 {
162 use crate::bytes::BytesReaderAsync;
163
164 let kind_id = reader.get_varint().await?;
165 let kind =
166 StreamKind::parse(kind_id).ok_or(IoReadError::Parse(ParseError::UnknownStream))?;
167
168 let session_id = if matches!(kind, StreamKind::WebTransport) {
169 let session_id =
170 SessionId::try_from_varint(reader.get_varint().await.map_err(|e| match e {
171 bytes::IoReadError::ImmediateFin => bytes::IoReadError::UnexpectedFin,
172 _ => e,
173 })?)
174 .map_err(|InvalidSessionId| IoReadError::Parse(ParseError::InvalidSessionId))?;
175
176 Some(session_id)
177 } else {
178 None
179 };
180
181 Ok(Self::new(kind, session_id))
182 }
183
184 pub fn read_from_buffer(buffer_reader: &mut BufferReader) -> Result<Option<Self>, ParseError> {
191 let mut buffer_reader_child = buffer_reader.child();
192
193 match Self::read(&mut *buffer_reader_child)? {
194 Some(header) => {
195 buffer_reader_child.commit();
196 Ok(Some(header))
197 }
198 None => Ok(None),
199 }
200 }
201
202 pub fn write<W>(&self, bytes_writer: &mut W) -> Result<(), EndOfBuffer>
210 where
211 W: BytesWriter,
212 {
213 bytes_writer.put_varint(self.kind.id())?;
214
215 if let Some(session_id) = self.session_id() {
216 bytes_writer.put_varint(session_id.into_varint())?;
217 }
218
219 Ok(())
220 }
221
222 #[cfg(feature = "async")]
224 #[cfg_attr(docsrs, doc(cfg(feature = "async")))]
225 pub async fn write_async<W>(&self, writer: &mut W) -> Result<(), IoWriteError>
226 where
227 W: AsyncWrite + Unpin + ?Sized,
228 {
229 use crate::bytes::BytesWriterAsync;
230
231 writer.put_varint(self.kind.id()).await?;
232
233 if let Some(session_id) = self.session_id() {
234 writer.put_varint(session_id.into_varint()).await?;
235 }
236
237 Ok(())
238 }
239
240 pub fn write_to_buffer(&self, buffer_writer: &mut BufferWriter) -> Result<(), EndOfBuffer> {
244 if buffer_writer.capacity() < self.write_size() {
245 return Err(EndOfBuffer);
246 }
247
248 self.write(buffer_writer)
249 .expect("Enough capacity for header");
250
251 Ok(())
252 }
253
254 pub fn write_size(&self) -> usize {
256 if let Some(session_id) = self.session_id() {
257 self.kind.id().size() + session_id.into_varint().size()
258 } else {
259 self.kind.id().size()
260 }
261 }
262
263 #[inline(always)]
265 pub const fn kind(&self) -> StreamKind {
266 self.kind
267 }
268
269 #[inline(always)]
272 pub fn session_id(&self) -> Option<SessionId> {
273 matches!(self.kind, StreamKind::WebTransport).then(|| {
274 self.session_id
275 .expect("WebTransport stream header contains session id")
276 })
277 }
278
279 fn new(kind: StreamKind, session_id: Option<SessionId>) -> Self {
280 if let StreamKind::Exercise(id) = kind {
281 debug_assert!(StreamKind::is_id_exercise(id));
282 debug_assert!(session_id.is_none());
283 } else if let StreamKind::WebTransport = kind {
284 debug_assert!(session_id.is_some());
285 } else {
286 debug_assert!(session_id.is_none());
287 }
288
289 Self { kind, session_id }
290 }
291
292 #[cfg(test)]
293 pub(crate) fn serialize_any(kind: VarInt) -> Vec<u8> {
294 let mut buffer = Vec::new();
295
296 Self {
297 kind: StreamKind::Exercise(kind),
298 session_id: None,
299 }
300 .write(&mut buffer)
301 .unwrap();
302
303 buffer
304 }
305
306 #[cfg(test)]
307 pub(crate) fn serialize_webtransport(session_id: SessionId) -> Vec<u8> {
308 let mut buffer = Vec::new();
309
310 Self {
311 kind: StreamKind::WebTransport,
312 session_id: Some(session_id),
313 }
314 .write(&mut buffer)
315 .unwrap();
316
317 buffer
318 }
319}
320
321mod stream_type_ids {
322 use crate::varint::VarInt;
323
324 pub const CONTROL_STREAM: VarInt = VarInt::from_u32(0x0);
325 pub const QPACK_ENCODER_STREAM: VarInt = VarInt::from_u32(0x02);
326 pub const QPACK_DECODER_STREAM: VarInt = VarInt::from_u32(0x03);
327 pub const WEBTRANSPORT_STREAM: VarInt = VarInt::from_u32(0x54);
328}
329
330#[cfg(test)]
331mod tests {
332 use super::*;
333
334 #[test]
335 fn control() {
336 let stream_header = StreamHeader::new_control();
337 assert!(matches!(stream_header.kind(), StreamKind::Control));
338 assert!(stream_header.session_id().is_none());
339
340 let stream_header = utils::assert_serde(stream_header);
341 assert!(matches!(stream_header.kind(), StreamKind::Control));
342 assert!(stream_header.session_id().is_none());
343 }
344
345 #[tokio::test]
346 async fn control_async() {
347 let stream_header = StreamHeader::new_control();
348 assert!(matches!(stream_header.kind(), StreamKind::Control));
349 assert!(stream_header.session_id().is_none());
350
351 let stream_header = utils::assert_serde_async(stream_header).await;
352 assert!(matches!(stream_header.kind(), StreamKind::Control));
353 assert!(stream_header.session_id().is_none());
354 }
355
356 #[test]
357 fn webtransport() {
358 let session_id = SessionId::try_from_varint(VarInt::from_u32(0)).unwrap();
359
360 let stream_header = StreamHeader::new_webtransport(session_id);
361 assert!(matches!(stream_header.kind(), StreamKind::WebTransport));
362 assert!(matches!(stream_header.session_id(), Some(x) if x == session_id));
363
364 let stream_header = utils::assert_serde(stream_header);
365 assert!(matches!(stream_header.kind(), StreamKind::WebTransport));
366 assert!(matches!(stream_header.session_id(), Some(x) if x == session_id));
367 }
368
369 #[tokio::test]
370 async fn webtransport_async() {
371 let session_id = SessionId::try_from_varint(VarInt::from_u32(0)).unwrap();
372
373 let stream_header = StreamHeader::new_webtransport(session_id);
374 assert!(matches!(stream_header.kind(), StreamKind::WebTransport));
375 assert!(matches!(stream_header.session_id(), Some(x) if x == session_id));
376
377 let stream_header = utils::assert_serde_async(stream_header).await;
378 assert!(matches!(stream_header.kind(), StreamKind::WebTransport));
379 assert!(matches!(stream_header.session_id(), Some(x) if x == session_id));
380 }
381
382 #[test]
383 fn read_eof() {
384 let buffer = StreamHeader::serialize_any(VarInt::from_u32(0x0042_4242));
385 assert!(StreamHeader::read(&mut &buffer[..buffer.len() - 1])
386 .unwrap()
387 .is_none());
388 }
389
390 #[tokio::test]
391 async fn read_eof_async() {
392 let buffer = StreamHeader::serialize_any(VarInt::from_u32(0x0042_4242));
393
394 for len in 0..buffer.len() {
395 let result = StreamHeader::read_async(&mut &buffer[..len]).await;
396
397 match len {
398 0 => assert!(matches!(
399 result,
400 Err(IoReadError::IO(bytes::IoReadError::ImmediateFin))
401 )),
402 _ => assert!(matches!(
403 result,
404 Err(IoReadError::IO(bytes::IoReadError::UnexpectedFin))
405 )),
406 }
407 }
408 }
409
410 #[tokio::test]
411 async fn read_eof_webtransport_async() {
412 let session_id = SessionId::try_from_varint(VarInt::from_u32(0)).unwrap();
413 let buffer = StreamHeader::serialize_webtransport(session_id);
414
415 for len in 0..buffer.len() {
416 let result = StreamHeader::read_async(&mut &buffer[..len]).await;
417
418 match len {
419 0 => assert!(matches!(
420 result,
421 Err(IoReadError::IO(bytes::IoReadError::ImmediateFin))
422 )),
423 _ => assert!(matches!(
424 result,
425 Err(IoReadError::IO(bytes::IoReadError::UnexpectedFin))
426 )),
427 }
428 }
429 }
430
431 #[test]
432 fn unknown_stream() {
433 let buffer = StreamHeader::serialize_any(VarInt::from_u32(0x0042_4242));
434
435 assert!(matches!(
436 StreamHeader::read(&mut buffer.as_slice()),
437 Err(ParseError::UnknownStream)
438 ));
439 }
440
441 #[tokio::test]
442 async fn unknown_stream_async() {
443 let buffer = StreamHeader::serialize_any(VarInt::from_u32(0x0042_4242));
444
445 assert!(matches!(
446 StreamHeader::read_async(&mut buffer.as_slice()).await,
447 Err(IoReadError::Parse(ParseError::UnknownStream))
448 ));
449 }
450
451 #[test]
452 fn invalid_session_id() {
453 let invalid_session_id = SessionId::maybe_invalid(VarInt::from_u32(1));
454 let buffer = StreamHeader::serialize_webtransport(invalid_session_id);
455
456 assert!(matches!(
457 StreamHeader::read(&mut buffer.as_slice()),
458 Err(ParseError::InvalidSessionId)
459 ));
460 }
461
462 #[tokio::test]
463 async fn invalid_session_id_async() {
464 let invalid_session_id = SessionId::maybe_invalid(VarInt::from_u32(1));
465 let buffer = StreamHeader::serialize_webtransport(invalid_session_id);
466
467 assert!(matches!(
468 StreamHeader::read_async(&mut buffer.as_slice()).await,
469 Err(IoReadError::Parse(ParseError::InvalidSessionId))
470 ));
471 }
472
473 mod utils {
474 use super::*;
475
476 pub fn assert_serde(stream_header: StreamHeader) -> StreamHeader {
477 let mut buffer = Vec::new();
478
479 stream_header.write(&mut buffer).unwrap();
480 assert_eq!(buffer.len(), stream_header.write_size());
481 assert!(buffer.len() <= StreamHeader::MAX_SIZE);
482
483 let mut buffer = buffer.as_slice();
484 let stream_header = StreamHeader::read(&mut buffer).unwrap().unwrap();
485 assert!(buffer.is_empty());
486
487 stream_header
488 }
489
490 #[cfg(feature = "async")]
491 pub async fn assert_serde_async(stream_header: StreamHeader) -> StreamHeader {
492 let mut buffer = Vec::new();
493
494 stream_header.write_async(&mut buffer).await.unwrap();
495 assert_eq!(buffer.len(), stream_header.write_size());
496 assert!(buffer.len() <= StreamHeader::MAX_SIZE);
497
498 let mut buffer = buffer.as_slice();
499 let stream_header = StreamHeader::read_async(&mut buffer).await.unwrap();
500 assert!(buffer.is_empty());
501
502 stream_header
503 }
504 }
505}