1use crate::{
5 crypto::{packet_protection, EncryptedPayload, InitialHeaderKey, InitialKey, ProtectedPayload},
6 packet::{
7 decoding::HeaderDecoder,
8 encoding::{PacketEncoder, PacketPayloadEncoder},
9 long::{
10 DestinationConnectionIdLen, LongPayloadEncoder, LongPayloadLenCursor,
11 SourceConnectionIdLen, Version,
12 },
13 number::{
14 PacketNumber, PacketNumberLen, PacketNumberSpace, ProtectedPacketNumber,
15 TruncatedPacketNumber,
16 },
17 KeyPhase, Tag,
18 },
19 varint::VarInt,
20};
21use s2n_codec::{CheckedRange, DecoderBufferMut, DecoderBufferMutResult, Encoder, EncoderValue};
22
23macro_rules! initial_tag {
45 () => {
46 0b1100u8
47 };
48}
49
50#[derive(Debug)]
59pub struct Initial<DCID, SCID, Token, PacketNumber, Payload> {
60 pub version: Version,
61 pub destination_connection_id: DCID,
62 pub source_connection_id: SCID,
63 pub token: Token,
64 pub packet_number: PacketNumber,
65 pub payload: Payload,
66}
67
68pub type ProtectedInitial<'a> =
69 Initial<CheckedRange, CheckedRange, CheckedRange, ProtectedPacketNumber, ProtectedPayload<'a>>;
70pub type EncryptedInitial<'a> =
71 Initial<CheckedRange, CheckedRange, CheckedRange, PacketNumber, EncryptedPayload<'a>>;
72pub type CleartextInitial<'a> =
73 Initial<&'a [u8], &'a [u8], &'a [u8], PacketNumber, DecoderBufferMut<'a>>;
74
75impl<'a> ProtectedInitial<'a> {
76 #[inline]
77 pub(crate) fn decode(
78 _tag: Tag,
79 version: Version,
80 buffer: DecoderBufferMut,
81 ) -> DecoderBufferMutResult<ProtectedInitial> {
82 let mut decoder = HeaderDecoder::new_long(&buffer);
83
84 let destination_connection_id =
91 decoder.decode_checked_range::<DestinationConnectionIdLen>(&buffer)?;
92 let source_connection_id =
93 decoder.decode_checked_range::<SourceConnectionIdLen>(&buffer)?;
94 let token = decoder.decode_checked_range::<VarInt>(&buffer)?;
95
96 let (payload, packet_number, remaining) =
97 decoder.finish_long()?.split_off_packet(buffer)?;
98
99 let packet = Initial {
100 version,
101 destination_connection_id,
102 source_connection_id,
103 token,
104 packet_number,
105 payload,
106 };
107
108 Ok((packet, remaining))
109 }
110
111 pub fn unprotect<H: InitialHeaderKey>(
112 self,
113 header_key: &H,
114 largest_acknowledged_packet_number: PacketNumber,
115 ) -> Result<EncryptedInitial<'a>, packet_protection::Error> {
116 let Initial {
117 version,
118 destination_connection_id,
119 source_connection_id,
120 token,
121 payload,
122 ..
123 } = self;
124
125 let (truncated_packet_number, payload) =
126 crate::crypto::unprotect(header_key, PacketNumberSpace::Initial, payload)?;
127
128 let packet_number = truncated_packet_number.expand(largest_acknowledged_packet_number);
129
130 Ok(Initial {
131 version,
132 destination_connection_id,
133 source_connection_id,
134 token,
135 packet_number,
136 payload,
137 })
138 }
139
140 #[inline]
141 pub fn destination_connection_id(&self) -> &[u8] {
142 self.payload
143 .get_checked_range(&self.destination_connection_id)
144 .into_less_safe_slice()
145 }
146
147 #[inline]
148 pub fn source_connection_id(&self) -> &[u8] {
149 self.payload
150 .get_checked_range(&self.source_connection_id)
151 .into_less_safe_slice()
152 }
153
154 #[inline]
155 pub fn token(&self) -> &[u8] {
156 self.payload
157 .get_checked_range(&self.token)
158 .into_less_safe_slice()
159 }
160}
161
162impl<'a> EncryptedInitial<'a> {
163 pub fn decrypt<C: InitialKey>(
164 self,
165 crypto: &C,
166 ) -> Result<CleartextInitial<'a>, packet_protection::Error> {
167 let Initial {
168 version,
169 destination_connection_id,
170 source_connection_id,
171 token,
172 packet_number,
173 payload,
174 } = self;
175
176 let (header, payload) = crate::crypto::decrypt(crypto, packet_number, payload)?;
177
178 let header = header.into_less_safe_slice();
179
180 let destination_connection_id = destination_connection_id.get(header);
181 let source_connection_id = source_connection_id.get(header);
182 let token = token.get(header);
183
184 Ok(Initial {
185 version,
186 destination_connection_id,
187 source_connection_id,
188 token,
189 packet_number,
190 payload,
191 })
192 }
193
194 #[inline]
195 pub fn destination_connection_id(&self) -> &[u8] {
196 self.payload
197 .get_checked_range(&self.destination_connection_id)
198 .into_less_safe_slice()
199 }
200
201 #[inline]
202 pub fn source_connection_id(&self) -> &[u8] {
203 self.payload
204 .get_checked_range(&self.source_connection_id)
205 .into_less_safe_slice()
206 }
207
208 #[inline]
209 pub fn token(&self) -> &[u8] {
210 self.payload
211 .get_checked_range(&self.token)
212 .into_less_safe_slice()
213 }
214
215 #[inline]
217 pub fn key_phase(&self) -> KeyPhase {
218 KeyPhase::Zero
219 }
220}
221
222impl CleartextInitial<'_> {
223 #[inline]
224 pub fn destination_connection_id(&self) -> &[u8] {
225 self.destination_connection_id
226 }
227
228 #[inline]
229 pub fn source_connection_id(&self) -> &[u8] {
230 self.source_connection_id
231 }
232
233 #[inline]
234 pub fn token(&self) -> &[u8] {
235 self.token
236 }
237}
238
239impl<DCID: EncoderValue, SCID: EncoderValue, Token: EncoderValue, Payload: EncoderValue>
240 EncoderValue for Initial<DCID, SCID, Token, TruncatedPacketNumber, Payload>
241{
242 fn encode<E: Encoder>(&self, encoder: &mut E) {
243 self.encode_header(self.packet_number.len(), encoder);
244 LongPayloadEncoder {
245 packet_number: self.packet_number,
246 payload: &self.payload,
247 }
248 .encode_with_len_prefix::<VarInt, E>(encoder)
249 }
250}
251
252impl<DCID: EncoderValue, SCID: EncoderValue, Token: EncoderValue, PacketNumber, Payload>
253 Initial<DCID, SCID, Token, PacketNumber, Payload>
254{
255 fn encode_header<E: Encoder>(&self, packet_number_len: PacketNumberLen, encoder: &mut E) {
256 let mut tag: u8 = initial_tag!() << 4;
257 tag |= packet_number_len.into_packet_tag_mask();
258 tag.encode(encoder);
259
260 self.version.encode(encoder);
261
262 self.destination_connection_id
263 .encode_with_len_prefix::<DestinationConnectionIdLen, E>(encoder);
264 self.source_connection_id
265 .encode_with_len_prefix::<SourceConnectionIdLen, E>(encoder);
266 self.token.encode_with_len_prefix::<VarInt, E>(encoder);
267 }
268}
269
270impl<
271 DCID: EncoderValue,
272 SCID: EncoderValue,
273 Token: EncoderValue,
274 Payload: PacketPayloadEncoder,
275 K: InitialKey,
276 H: InitialHeaderKey,
277 > PacketEncoder<K, H, Payload> for Initial<DCID, SCID, Token, PacketNumber, Payload>
278{
279 type PayloadLenCursor = LongPayloadLenCursor;
280
281 fn packet_number(&self) -> PacketNumber {
282 self.packet_number
283 }
284
285 fn encode_header<E: Encoder>(&self, packet_number_len: PacketNumberLen, encoder: &mut E) {
286 Initial::encode_header(self, packet_number_len, encoder);
287 }
288
289 fn payload(&mut self) -> &mut Payload {
290 &mut self.payload
291 }
292}