wtransport_proto_lightyear_patch/
frame.rs1use crate::bytes::BufferReader;
2use crate::bytes::BufferWriter;
3use crate::bytes::BytesReader;
4use crate::bytes::BytesWriter;
5use crate::bytes::EndOfBuffer;
6use crate::ids::InvalidSessionId;
7use crate::ids::SessionId;
8use crate::varint::VarInt;
9use std::borrow::Cow;
10
11#[cfg(feature = "async")]
12use crate::bytes::AsyncRead;
13
14#[cfg(feature = "async")]
15use crate::bytes::AsyncWrite;
16
17#[cfg(feature = "async")]
18use crate::bytes;
19
20#[derive(Debug)]
22pub enum ParseError {
23 UnknownFrame,
25
26 InvalidSessionId,
28
29 PayloadTooBig,
31}
32
33#[cfg(feature = "async")]
35#[cfg_attr(docsrs, doc(cfg(feature = "async")))]
36#[derive(Debug)]
37pub enum IoReadError {
38 Parse(ParseError),
40
41 IO(bytes::IoReadError),
43}
44
45#[cfg(feature = "async")]
46impl From<bytes::IoReadError> for IoReadError {
47 #[inline(always)]
48 fn from(io_error: bytes::IoReadError) -> Self {
49 IoReadError::IO(io_error)
50 }
51}
52
53#[cfg(feature = "async")]
55pub type IoWriteError = bytes::IoWriteError;
56
57pub type FrameOwned = Frame<'static>;
59
60#[derive(Copy, Clone, Debug)]
62pub enum FrameKind {
63 Data,
65
66 Headers,
68
69 Settings,
71
72 WebTransport,
74
75 Exercise(VarInt),
77}
78
79impl FrameKind {
80 #[inline(always)]
82 pub const fn is_id_exercise(id: VarInt) -> bool {
83 id.into_inner() >= 0x21 && ((id.into_inner() - 0x21) % 0x1f == 0)
84 }
85
86 const fn parse(id: VarInt) -> Option<Self> {
87 match id {
88 frame_kind_ids::DATA => Some(FrameKind::Data),
89 frame_kind_ids::HEADERS => Some(FrameKind::Headers),
90 frame_kind_ids::SETTINGS => Some(FrameKind::Settings),
91 frame_kind_ids::WEBTRANSPORT_STREAM => Some(FrameKind::WebTransport),
92 id if FrameKind::is_id_exercise(id) => Some(FrameKind::Exercise(id)),
93 _ => None,
94 }
95 }
96
97 const fn id(self) -> VarInt {
98 match self {
99 FrameKind::Data => frame_kind_ids::DATA,
100 FrameKind::Headers => frame_kind_ids::HEADERS,
101 FrameKind::Settings => frame_kind_ids::SETTINGS,
102 FrameKind::WebTransport => frame_kind_ids::WEBTRANSPORT_STREAM,
103 FrameKind::Exercise(id) => id,
104 }
105 }
106}
107
108pub struct Frame<'a> {
110 kind: FrameKind,
111 payload: Cow<'a, [u8]>,
112 session_id: Option<SessionId>,
113}
114
115impl<'a> Frame<'a> {
116 const MAX_PARSE_PAYLOAD_ALLOWED: usize = 4096;
117
118 #[inline(always)]
124 pub fn new_headers(payload: Cow<'a, [u8]>) -> Self {
125 Self::new(FrameKind::Headers, payload, None)
126 }
127
128 #[inline(always)]
134 pub fn new_settings(payload: Cow<'a, [u8]>) -> Self {
135 Self::new(FrameKind::Settings, payload, None)
136 }
137
138 #[inline(always)]
140 pub fn new_webtransport(session_id: SessionId) -> Self {
141 Self::new(
142 FrameKind::WebTransport,
143 Cow::Owned(Default::default()),
144 Some(session_id),
145 )
146 }
147
148 #[inline(always)]
155 pub fn new_exercise(id: VarInt, payload: Cow<'a, [u8]>) -> Self {
156 assert!(FrameKind::is_id_exercise(id));
157 Self::new(FrameKind::Exercise(id), payload, None)
158 }
159
160 pub fn read<R>(bytes_reader: &mut R) -> Result<Option<Self>, ParseError>
167 where
168 R: BytesReader<'a>,
169 {
170 let kind = match bytes_reader.get_varint() {
171 Some(kind_id) => FrameKind::parse(kind_id).ok_or(ParseError::UnknownFrame)?,
172 None => return Ok(None),
173 };
174
175 if matches!(kind, FrameKind::WebTransport) {
176 let session_id = match bytes_reader.get_varint() {
177 Some(session_id) => SessionId::try_from_varint(session_id)
178 .map_err(|InvalidSessionId| ParseError::InvalidSessionId)?,
179 None => return Ok(None),
180 };
181
182 Ok(Some(Self::new_webtransport(session_id)))
183 } else {
184 let payload_len = match bytes_reader.get_varint() {
185 Some(payload_len) => payload_len.into_inner() as usize,
186 None => return Ok(None),
187 };
188
189 if payload_len > Self::MAX_PARSE_PAYLOAD_ALLOWED {
190 return Err(ParseError::PayloadTooBig);
191 }
192
193 let payload = match bytes_reader.get_bytes(payload_len) {
194 Some(payload) => payload,
195 None => return Ok(None),
196 };
197
198 Ok(Some(Self::new(kind, Cow::Borrowed(payload), None)))
199 }
200 }
201
202 #[cfg(feature = "async")]
204 #[cfg_attr(docsrs, doc(cfg(feature = "async")))]
205 pub async fn read_async<R>(reader: &mut R) -> Result<Frame<'a>, IoReadError>
206 where
207 R: AsyncRead + Unpin + ?Sized,
208 {
209 use crate::bytes::BytesReaderAsync;
210
211 let kind_id = reader.get_varint().await?;
212 let kind = FrameKind::parse(kind_id).ok_or(IoReadError::Parse(ParseError::UnknownFrame))?;
213
214 if matches!(kind, FrameKind::WebTransport) {
215 let session_id =
216 SessionId::try_from_varint(reader.get_varint().await.map_err(|e| match e {
217 bytes::IoReadError::ImmediateFin => bytes::IoReadError::UnexpectedFin,
218 _ => e,
219 })?)
220 .map_err(|InvalidSessionId| IoReadError::Parse(ParseError::InvalidSessionId))?;
221
222 Ok(Self::new_webtransport(session_id))
223 } else {
224 let payload_len = reader
225 .get_varint()
226 .await
227 .map_err(|e| match e {
228 bytes::IoReadError::ImmediateFin => bytes::IoReadError::UnexpectedFin,
229 _ => e,
230 })?
231 .into_inner() as usize;
232
233 if payload_len > Self::MAX_PARSE_PAYLOAD_ALLOWED {
234 return Err(IoReadError::Parse(ParseError::PayloadTooBig));
235 }
236
237 let mut payload = vec![0; payload_len];
238
239 reader.get_buffer(&mut payload).await.map_err(|e| match e {
240 bytes::IoReadError::ImmediateFin => bytes::IoReadError::UnexpectedFin,
241 _ => e,
242 })?;
243
244 payload.shrink_to_fit();
245
246 Ok(Self::new(kind, Cow::Owned(payload), None))
247 }
248 }
249
250 pub fn read_from_buffer(
257 buffer_reader: &mut BufferReader<'a>,
258 ) -> Result<Option<Self>, ParseError> {
259 let mut buffer_reader_child = buffer_reader.child();
260
261 match Self::read(&mut *buffer_reader_child)? {
262 Some(frame) => {
263 buffer_reader_child.commit();
264 Ok(Some(frame))
265 }
266 None => Ok(None),
267 }
268 }
269
270 pub fn write<W>(&self, bytes_writer: &mut W) -> Result<(), EndOfBuffer>
282 where
283 W: BytesWriter,
284 {
285 bytes_writer.put_varint(self.kind.id())?;
286
287 if let Some(session_id) = self.session_id() {
288 bytes_writer.put_varint(session_id.into_varint())?;
289 } else {
290 bytes_writer.put_varint(
291 VarInt::try_from(self.payload.len() as u64)
292 .expect("Payload cannot be larger than varint max"),
293 )?;
294 bytes_writer.put_bytes(&self.payload)?;
295 }
296
297 Ok(())
298 }
299
300 #[cfg(feature = "async")]
306 #[cfg_attr(docsrs, doc(cfg(feature = "async")))]
307 pub async fn write_async<W>(&self, writer: &mut W) -> Result<(), IoWriteError>
308 where
309 W: AsyncWrite + Unpin + ?Sized,
310 {
311 use crate::bytes::BytesWriterAsync;
312
313 writer.put_varint(self.kind.id()).await?;
314
315 if let Some(session_id) = self.session_id() {
316 writer.put_varint(session_id.into_varint()).await?;
317 } else {
318 writer
319 .put_varint(
320 VarInt::try_from(self.payload.len() as u64)
321 .expect("Payload cannot be larger than varint max"),
322 )
323 .await?;
324 writer.put_buffer(&self.payload).await?;
325 }
326
327 Ok(())
328 }
329
330 pub fn write_to_buffer(&self, buffer_writer: &mut BufferWriter) -> Result<(), EndOfBuffer> {
338 if buffer_writer.capacity() < self.write_size() {
339 return Err(EndOfBuffer);
340 }
341
342 self.write(buffer_writer)
343 .expect("Enough capacity for frame");
344
345 Ok(())
346 }
347
348 pub fn write_size(&self) -> usize {
350 if let Some(session_id) = self.session_id() {
351 self.kind.id().size() + session_id.into_varint().size()
352 } else {
353 self.kind.id().size()
354 + VarInt::try_from(self.payload.len() as u64)
355 .expect("Payload cannot be larger than varint max")
356 .size()
357 + self.payload.len()
358 }
359 }
360
361 #[inline(always)]
363 pub const fn kind(&self) -> FrameKind {
364 self.kind
365 }
366
367 #[inline(always)]
369 pub fn payload(&self) -> &[u8] {
370 &self.payload
371 }
372
373 #[inline(always)]
376 pub fn session_id(&self) -> Option<SessionId> {
377 matches!(self.kind, FrameKind::WebTransport).then(|| {
378 self.session_id
379 .expect("WebTransport frame contains session id")
380 })
381 }
382
383 fn new(kind: FrameKind, payload: Cow<'a, [u8]>, session_id: Option<SessionId>) -> Self {
387 if let FrameKind::Exercise(id) = kind {
388 debug_assert!(FrameKind::is_id_exercise(id));
389 } else if let FrameKind::WebTransport = kind {
390 debug_assert!(payload.is_empty());
391 debug_assert!(session_id.is_some());
392 }
393
394 assert!(payload.len() <= VarInt::MAX.into_inner() as usize);
395
396 Self {
397 kind,
398 payload,
399 session_id,
400 }
401 }
402
403 #[cfg(test)]
404 pub(crate) fn into_owned<'b>(self) -> Frame<'b> {
405 Frame {
406 kind: self.kind,
407 payload: Cow::Owned(self.payload.into_owned()),
408 session_id: self.session_id,
409 }
410 }
411
412 #[cfg(test)]
413 pub(crate) fn serialize_any(kind: VarInt, payload: &[u8]) -> Vec<u8> {
414 let mut buffer = Vec::new();
415
416 Self {
417 kind: FrameKind::Exercise(kind),
418 payload: Cow::Owned(payload.to_vec()),
419 session_id: None,
420 }
421 .write(&mut buffer)
422 .unwrap();
423
424 buffer
425 }
426
427 #[cfg(test)]
428 pub(crate) fn serialize_webtransport(session_id: SessionId) -> Vec<u8> {
429 let mut buffer = Vec::new();
430
431 Self {
432 kind: FrameKind::WebTransport,
433 payload: Cow::Owned(Default::default()),
434 session_id: Some(session_id),
435 }
436 .write(&mut buffer)
437 .unwrap();
438
439 buffer
440 }
441}
442
443mod frame_kind_ids {
444 use crate::varint::VarInt;
445
446 pub const DATA: VarInt = VarInt::from_u32(0x00);
447 pub const HEADERS: VarInt = VarInt::from_u32(0x01);
448 pub const SETTINGS: VarInt = VarInt::from_u32(0x04);
449 pub const WEBTRANSPORT_STREAM: VarInt = VarInt::from_u32(0x41);
450}
451
452#[cfg(test)]
453mod tests {
454 use super::*;
455 use crate::headers::Headers;
456 use crate::ids::StreamId;
457 use crate::settings::Settings;
458
459 #[test]
460 fn settings() {
461 let settings = Settings::builder()
462 .qpack_blocked_streams(VarInt::from_u32(1))
463 .qpack_max_table_capacity(VarInt::from_u32(2))
464 .enable_h3_datagrams()
465 .enable_webtransport()
466 .webtransport_max_sessions(VarInt::from_u32(3))
467 .build();
468
469 let frame = settings.generate_frame();
470 assert!(frame.session_id().is_none());
471 assert!(matches!(frame.kind(), FrameKind::Settings));
472
473 let frame = utils::assert_serde(frame);
474 Settings::with_frame(&frame).unwrap();
475 }
476
477 #[tokio::test]
478 async fn settings_async() {
479 let settings = Settings::builder()
480 .qpack_blocked_streams(VarInt::from_u32(1))
481 .qpack_max_table_capacity(VarInt::from_u32(2))
482 .enable_h3_datagrams()
483 .enable_webtransport()
484 .webtransport_max_sessions(VarInt::from_u32(3))
485 .build();
486
487 let frame = settings.generate_frame();
488 assert!(frame.session_id().is_none());
489 assert!(matches!(frame.kind(), FrameKind::Settings));
490
491 let frame = utils::assert_serde_async(frame).await;
492 Settings::with_frame(&frame).unwrap();
493 }
494
495 #[test]
496 fn headers() {
497 let stream_id = StreamId::new(VarInt::from_u32(0));
498 let headers = Headers::from_iter([("key1", "value1")]);
499
500 let frame = headers.generate_frame(stream_id);
501 assert!(frame.session_id().is_none());
502 assert!(matches!(frame.kind(), FrameKind::Headers));
503
504 let frame = utils::assert_serde(frame);
505 Headers::with_frame(&frame, stream_id).unwrap();
506 }
507
508 #[tokio::test]
509 async fn headers_async() {
510 let stream_id = StreamId::new(VarInt::from_u32(0));
511 let headers = Headers::from_iter([("key1", "value1")]);
512
513 let frame = headers.generate_frame(stream_id);
514 assert!(frame.session_id().is_none());
515 assert!(matches!(frame.kind(), FrameKind::Headers));
516
517 let frame = utils::assert_serde_async(frame).await;
518 Headers::with_frame(&frame, stream_id).unwrap();
519 }
520
521 #[test]
522 fn webtransport() {
523 let session_id = SessionId::try_from_varint(VarInt::from_u32(0)).unwrap();
524 let frame = Frame::new_webtransport(session_id);
525
526 assert!(frame.payload().is_empty());
527 assert!(matches!(frame.session_id(), Some(x) if x == session_id));
528 assert!(matches!(frame.kind(), FrameKind::WebTransport));
529
530 let frame = utils::assert_serde(frame);
531
532 assert!(frame.payload().is_empty());
533 assert!(matches!(frame.session_id(), Some(x) if x == session_id));
534 assert!(matches!(frame.kind(), FrameKind::WebTransport));
535 }
536
537 #[tokio::test]
538 async fn webtransport_async() {
539 let session_id = SessionId::try_from_varint(VarInt::from_u32(0)).unwrap();
540 let frame = Frame::new_webtransport(session_id);
541
542 assert!(frame.payload().is_empty());
543 assert!(matches!(frame.session_id(), Some(x) if x == session_id));
544 assert!(matches!(frame.kind(), FrameKind::WebTransport));
545
546 let frame = utils::assert_serde_async(frame).await;
547
548 assert!(frame.payload().is_empty());
549 assert!(matches!(frame.session_id(), Some(x) if x == session_id));
550 assert!(matches!(frame.kind(), FrameKind::WebTransport));
551 }
552
553 #[test]
554 fn read_eof() {
555 let buffer = Frame::serialize_any(FrameKind::Data.id(), b"This is a test payload");
556 assert!(Frame::read(&mut &buffer[..buffer.len() - 1])
557 .unwrap()
558 .is_none());
559 }
560
561 #[tokio::test]
562 async fn read_eof_async() {
563 let buffer = Frame::serialize_any(FrameKind::Data.id(), b"This is a test payload");
564
565 for len in 0..buffer.len() {
566 let result = Frame::read_async(&mut &buffer[..len]).await;
567
568 match len {
569 0 => assert!(matches!(
570 result,
571 Err(IoReadError::IO(bytes::IoReadError::ImmediateFin))
572 )),
573 _ => assert!(matches!(
574 result,
575 Err(IoReadError::IO(bytes::IoReadError::UnexpectedFin))
576 )),
577 }
578 }
579 }
580
581 #[tokio::test]
582 async fn read_eof_webtransport_async() {
583 let session_id = SessionId::try_from_varint(VarInt::from_u32(0)).unwrap();
584 let buffer = Frame::serialize_webtransport(session_id);
585
586 for len in 0..buffer.len() {
587 let result = Frame::read_async(&mut &buffer[..len]).await;
588
589 match len {
590 0 => assert!(matches!(
591 result,
592 Err(IoReadError::IO(bytes::IoReadError::ImmediateFin))
593 )),
594 _ => assert!(matches!(
595 result,
596 Err(IoReadError::IO(bytes::IoReadError::UnexpectedFin))
597 )),
598 }
599 }
600 }
601
602 #[test]
603 fn unknown_frame() {
604 let buffer = Frame::serialize_any(VarInt::from_u32(0x0042_4242), b"This is a test payload");
605
606 assert!(matches!(
607 Frame::read(&mut buffer.as_slice()),
608 Err(ParseError::UnknownFrame)
609 ));
610 }
611
612 #[tokio::test]
613 async fn unknown_frame_async() {
614 let buffer = Frame::serialize_any(VarInt::from_u32(0x0042_4242), b"This is a test payload");
615
616 assert!(matches!(
617 Frame::read_async(&mut buffer.as_slice()).await,
618 Err(IoReadError::Parse(ParseError::UnknownFrame))
619 ));
620 }
621
622 #[test]
623 fn invalid_session_id() {
624 let invalid_session_id = SessionId::maybe_invalid(VarInt::from_u32(1));
625 let buffer = Frame::serialize_webtransport(invalid_session_id);
626
627 assert!(matches!(
628 Frame::read(&mut buffer.as_slice()),
629 Err(ParseError::InvalidSessionId)
630 ));
631 }
632
633 #[tokio::test]
634 async fn invalid_session_id_async() {
635 let invalid_session_id = SessionId::maybe_invalid(VarInt::from_u32(1));
636 let buffer = Frame::serialize_webtransport(invalid_session_id);
637
638 assert!(matches!(
639 Frame::read_async(&mut buffer.as_slice()).await,
640 Err(IoReadError::Parse(ParseError::InvalidSessionId))
641 ));
642 }
643
644 #[test]
645 fn payload_too_big() {
646 let mut buffer = Vec::new();
647 buffer.put_varint(FrameKind::Data.id()).unwrap();
648 buffer
649 .put_varint(VarInt::from_u32(
650 Frame::MAX_PARSE_PAYLOAD_ALLOWED as u32 + 1,
651 ))
652 .unwrap();
653
654 assert!(matches!(
655 Frame::read_from_buffer(&mut BufferReader::new(&buffer)),
656 Err(ParseError::PayloadTooBig)
657 ));
658 }
659
660 #[tokio::test]
661 async fn payload_too_big_async() {
662 let mut buffer = Vec::new();
663 buffer.put_varint(FrameKind::Data.id()).unwrap();
664 buffer
665 .put_varint(VarInt::from_u32(
666 Frame::MAX_PARSE_PAYLOAD_ALLOWED as u32 + 1,
667 ))
668 .unwrap();
669
670 assert!(matches!(
671 Frame::read_async(&mut &*buffer).await,
672 Err(IoReadError::Parse(ParseError::PayloadTooBig)),
673 ));
674 }
675
676 mod utils {
677 use super::*;
678
679 pub fn assert_serde(frame: Frame) -> Frame {
680 let mut buffer = Vec::new();
681
682 frame.write(&mut buffer).unwrap();
683 assert_eq!(buffer.len(), frame.write_size());
684
685 let mut buffer = buffer.as_slice();
686 let frame = Frame::read(&mut buffer).unwrap().unwrap();
687 assert!(buffer.is_empty());
688
689 frame.into_owned()
690 }
691
692 #[cfg(feature = "async")]
693 pub async fn assert_serde_async(frame: Frame<'_>) -> Frame {
694 let mut buffer = Vec::new();
695
696 frame.write_async(&mut buffer).await.unwrap();
697 assert_eq!(buffer.len(), frame.write_size());
698
699 let mut buffer = buffer.as_slice();
700 let frame = Frame::read_async(&mut buffer).await.unwrap();
701 assert!(buffer.is_empty());
702
703 frame.into_owned()
704 }
705 }
706}