s2n_quic_core/packet/
decoding.rs1use crate::{
5 connection,
6 connection::id::ConnectionInfo,
7 crypto::ProtectedPayload,
8 packet::{
9 long::{
10 validate_destination_connection_id_range, validate_source_connection_id_range,
11 DestinationConnectionIdLen, SourceConnectionIdLen, Version,
12 },
13 number::ProtectedPacketNumber,
14 Tag,
15 },
16 varint::VarInt,
17};
18use core::mem::size_of;
19use s2n_codec::{CheckedRange, DecoderBuffer, DecoderBufferMut, DecoderError, DecoderValue};
20
21pub struct HeaderDecoder<'a> {
22 initial_buffer_len: usize,
23 peek: DecoderBuffer<'a>,
24}
25
26impl<'a> HeaderDecoder<'a> {
27 pub fn new_long<'b>(buffer: &'a DecoderBufferMut<'b>) -> Self {
28 let initial_buffer_len = buffer.len();
29 let peek = buffer.peek();
30 let peek = peek
31 .skip(size_of::<Tag>() + size_of::<Version>())
32 .expect("tag and version already verified");
33 Self {
34 initial_buffer_len,
35 peek,
36 }
37 }
38
39 pub fn new_short<'b>(buffer: &'a DecoderBufferMut<'b>) -> Self {
40 let initial_buffer_len = buffer.len();
41 let peek = buffer.peek();
42 let peek = peek.skip(size_of::<Tag>()).expect("tag already verified");
43 Self {
44 initial_buffer_len,
45 peek,
46 }
47 }
48
49 pub fn decode_destination_connection_id(
50 &mut self,
51 buffer: &DecoderBufferMut<'_>,
52 ) -> Result<CheckedRange, DecoderError> {
53 let destination_connection_id =
54 self.decode_checked_range::<DestinationConnectionIdLen>(buffer)?;
55 validate_destination_connection_id_range(&destination_connection_id)?;
56 Ok(destination_connection_id)
57 }
58
59 pub fn decode_short_destination_connection_id<Validator: connection::id::Validator>(
60 &mut self,
61 buffer: &DecoderBufferMut<'_>,
62 connection_info: &ConnectionInfo,
63 connection_id_validator: &Validator,
64 ) -> Result<CheckedRange, DecoderError> {
65 let destination_connection_id_len = if let Some(len) = connection_id_validator
66 .validate(connection_info, self.peek.peek().into_less_safe_slice())
67 {
68 len
69 } else {
70 return Err(DecoderError::InvariantViolation("invalid connection id"));
71 };
72
73 let (destination_connection_id, peek) = self
74 .peek
75 .skip_into_range(destination_connection_id_len, buffer)?;
76 self.peek = peek;
77 validate_destination_connection_id_range(&destination_connection_id)?;
78 Ok(destination_connection_id)
79 }
80
81 pub fn decode_source_connection_id(
82 &mut self,
83 buffer: &DecoderBufferMut<'_>,
84 ) -> Result<CheckedRange, DecoderError> {
85 let source_connection_id = self.decode_checked_range::<SourceConnectionIdLen>(buffer)?;
86 validate_source_connection_id_range(&source_connection_id)?;
87 Ok(source_connection_id)
88 }
89
90 pub fn decode_checked_range<Len: DecoderValue<'a> + TryInto<usize>>(
91 &mut self,
92 buffer: &DecoderBufferMut<'_>,
93 ) -> Result<CheckedRange, DecoderError> {
94 let (value, peek) = self.peek.skip_into_range_with_len_prefix::<Len>(buffer)?;
95 self.peek = peek;
96 Ok(value)
97 }
98
99 pub fn finish_long(mut self) -> Result<HeaderDecoderResult, DecoderError> {
100 let (payload_len, peek) = self.peek.decode::<VarInt>()?;
101 self.peek = peek;
102 let header_len = self.decoded_len();
103
104 self.peek = peek.skip(*payload_len as usize)?;
105 let packet_len = self.decoded_len();
106
107 Ok(HeaderDecoderResult {
108 packet_len,
109 header_len,
110 })
111 }
112
113 pub fn finish_short(self) -> Result<HeaderDecoderResult, DecoderError> {
114 let header_len = self.decoded_len();
115 let packet_len = self.initial_buffer_len;
116
117 Ok(HeaderDecoderResult {
118 packet_len,
119 header_len,
120 })
121 }
122
123 pub fn decoded_len(&self) -> usize {
124 self.initial_buffer_len - self.peek.len()
125 }
126}
127
128#[derive(Debug)]
129pub struct HeaderDecoderResult {
130 pub packet_len: usize,
131 pub header_len: usize,
132}
133
134impl HeaderDecoderResult {
135 pub fn split_off_packet<'a>(
136 &self,
137 buffer: DecoderBufferMut<'a>,
138 ) -> Result<
139 (
140 ProtectedPayload<'a>,
141 ProtectedPacketNumber,
142 DecoderBufferMut<'a>,
143 ),
144 DecoderError,
145 > {
146 let (payload, remaining) = buffer.decode_slice(self.packet_len)?;
147 let packet_number = ProtectedPacketNumber;
148 let payload = ProtectedPayload::new(self.header_len, payload.into_less_safe_slice());
149
150 Ok((payload, packet_number, remaining))
151 }
152}