1use crate::{
5 connection,
6 crypto::{
7 packet_protection, retry,
8 retry::{IntegrityTag, RetryKey},
9 },
10 inet::SocketAddress,
11 packet::{
12 decoding::HeaderDecoder,
13 initial::ProtectedInitial,
14 long::{DestinationConnectionIdLen, SourceConnectionIdLen, Version},
15 Tag,
16 },
17 random, token,
18};
19use core::{mem::size_of, ops::Range};
20use retry::INTEGRITY_TAG_LEN;
21use s2n_codec::{
22 decoder_invariant, DecoderBufferMut, DecoderBufferMutResult, Encoder, EncoderBuffer,
23 EncoderValue,
24};
25
26macro_rules! retry_tag {
44 () => {
45 0b1111u8
46 };
47}
48
49#[derive(Debug)]
58pub struct Retry<'a> {
59 pub tag: Tag,
60 pub version: Version,
61 pub destination_connection_id: &'a [u8],
62 pub source_connection_id: &'a [u8],
63 pub retry_token: &'a [u8],
64 pub retry_integrity_tag: &'a IntegrityTag,
65}
66
67#[derive(Debug)]
84pub struct PseudoRetry<'a> {
85 pub original_destination_connection_id: &'a [u8],
86 pub tag: Tag,
87 pub version: Version,
88 pub destination_connection_id: &'a [u8],
89 pub source_connection_id: &'a [u8],
90 pub retry_token: &'a [u8],
91}
92
93impl<'a> PseudoRetry<'a> {
94 pub fn new(
95 odcid: &'a [u8],
96 tag: Tag,
97 version: Version,
98 destination_connection_id: &'a [u8],
99 source_connection_id: &'a [u8],
100 retry_token: &'a [u8],
101 ) -> Self {
102 Self {
103 original_destination_connection_id: odcid,
104 tag,
105 version,
106 destination_connection_id,
107 source_connection_id,
108 retry_token,
109 }
110 }
111}
112
113pub type ProtectedRetry<'a> = Retry<'a>;
114pub type EncryptedRetry<'a> = Retry<'a>;
115pub type CleartextRetry<'a> = Retry<'a>;
116
117impl<'a> Retry<'a> {
118 pub fn encode_packet<T: token::Format, C: RetryKey>(
119 remote_address: &SocketAddress,
120 packet: &ProtectedInitial,
121 local_connection_id: &connection::LocalId,
122 random: &mut dyn random::Generator,
123 token_format: &mut T,
124 packet_buf: &mut [u8],
125 ) -> Option<Range<usize>> {
126 debug_assert_ne!(
130 local_connection_id.as_ref(),
131 packet.destination_connection_id()
132 );
133 if local_connection_id.as_ref() == packet.destination_connection_id() {
134 return None;
135 }
136
137 let retry_packet = Retry::from_initial(packet, local_connection_id.as_ref());
138 let pseudo_packet = retry_packet.pseudo_packet(packet.destination_connection_id());
139
140 let mut buffer = EncoderBuffer::new(packet_buf);
141 pseudo_packet.encode(&mut buffer);
142
143 let destination_connection_id =
144 &connection::PeerId::try_from_bytes(retry_packet.destination_connection_id).unwrap();
145 let mut context = token::Context::new(remote_address, destination_connection_id, random);
146
147 let mut outcome = None;
148
149 buffer.write_sized(T::TOKEN_LEN, |token_buf| {
150 outcome = token_format.generate_retry_token(
151 &mut context,
152 &connection::InitialId::try_from_bytes(packet.destination_connection_id()).unwrap(),
153 token_buf,
154 );
155 });
156
157 outcome?;
158
159 let tag = C::generate_tag(buffer.as_mut_slice());
160 buffer.write_slice(&tag);
161 let end = buffer.len();
162 let start =
163 packet.destination_connection_id().len() + size_of::<DestinationConnectionIdLen>();
164
165 Some(start..end)
166 }
167
168 pub fn validate<Crypto, CreateBuf, Buf>(
169 &self,
170 odcid: &connection::InitialId,
171 create_buf: CreateBuf,
172 ) -> Result<(), packet_protection::Error>
173 where
174 Crypto: RetryKey,
175 CreateBuf: FnOnce(usize) -> Buf,
176 Buf: AsMut<[u8]>,
177 {
178 let pseudo_packet = self.pseudo_packet(odcid.as_ref());
179 let len = pseudo_packet.encoding_size();
180 let mut buf = create_buf(len);
181 let buf = buf.as_mut();
182
183 let mut buffer = EncoderBuffer::new(buf);
184 pseudo_packet.encode(&mut buffer);
185
186 Crypto::validate(buf, *self.retry_integrity_tag)?;
193
194 Ok(())
195 }
196
197 pub fn from_initial(
198 initial_packet: &'a ProtectedInitial,
199 local_connection_id: &'a [u8],
200 ) -> Self {
201 Self {
205 tag: (retry_tag!() << 4) | 0x0f,
212 version: initial_packet.version,
213 destination_connection_id: initial_packet.source_connection_id(),
214 source_connection_id: local_connection_id,
215 retry_token: &[][..],
216 retry_integrity_tag: {
217 static EMPTY_TAG: IntegrityTag = [0u8; INTEGRITY_TAG_LEN];
218 &EMPTY_TAG
219 },
220 }
221 }
222
223 #[inline]
224 pub(crate) fn decode(
225 tag: Tag,
226 version: Version,
227 buffer: DecoderBufferMut,
228 ) -> DecoderBufferMutResult<Retry> {
229 let mut decoder = HeaderDecoder::new_long(&buffer);
230
231 let destination_connection_id = decoder.decode_destination_connection_id(&buffer)?;
235 let source_connection_id = decoder.decode_source_connection_id(&buffer)?;
236
237 let header_len = decoder.decoded_len();
239 let (header, buffer) = buffer.decode_slice(header_len)?;
240 let header: &[u8] = header.into_less_safe_slice();
241
242 let destination_connection_id = destination_connection_id.get(header);
244 let source_connection_id = source_connection_id.get(header);
245
246 let buffer_len = buffer.len().saturating_sub(retry::INTEGRITY_TAG_LEN);
247
248 decoder_invariant!(buffer_len > 0, "Token cannot be empty");
252
253 let (retry_token, buffer) = buffer.decode_slice(buffer_len)?;
254 let retry_token: &[u8] = retry_token.into_less_safe_slice();
255
256 let (retry_integrity_tag, buffer) = buffer.decode_slice(retry::INTEGRITY_TAG_LEN)?;
257 let retry_integrity_tag: &[u8] = retry_integrity_tag.into_less_safe_slice();
258 let retry_integrity_tag: &IntegrityTag = retry_integrity_tag
259 .try_into()
260 .expect("tag length already checked");
261
262 let packet = Retry {
263 tag,
264 version,
265 destination_connection_id,
266 source_connection_id,
267 retry_token,
268 retry_integrity_tag,
269 };
270
271 Ok((packet, buffer))
272 }
273
274 #[inline]
275 pub fn destination_connection_id(&self) -> &[u8] {
276 self.destination_connection_id
277 }
278
279 #[inline]
280 pub fn source_connection_id(&self) -> &[u8] {
281 self.source_connection_id
282 }
283
284 #[inline]
285 pub fn retry_token(&self) -> &[u8] {
286 self.retry_token
287 }
288
289 #[inline]
290 fn pseudo_packet(&self, odcid: &'a [u8]) -> PseudoRetry<'_> {
291 PseudoRetry {
292 original_destination_connection_id: odcid,
293 tag: self.tag,
294 version: self.version,
295 destination_connection_id: self.destination_connection_id,
296 source_connection_id: self.source_connection_id,
297 retry_token: self.retry_token,
298 }
299 }
300}
301
302impl EncoderValue for Retry<'_> {
303 fn encode<E: Encoder>(&self, encoder: &mut E) {
304 let tag: u8 = self.tag;
305 tag.encode(encoder);
306
307 self.version.encode(encoder);
308
309 self.destination_connection_id
310 .encode_with_len_prefix::<DestinationConnectionIdLen, E>(encoder);
311 self.source_connection_id
312 .encode_with_len_prefix::<SourceConnectionIdLen, E>(encoder);
313 self.retry_token.encode(encoder);
314 self.retry_integrity_tag.as_ref().encode(encoder);
315 }
316}
317
318impl EncoderValue for PseudoRetry<'_> {
319 fn encode<E: Encoder>(&self, encoder: &mut E) {
320 self.original_destination_connection_id
321 .encode_with_len_prefix::<DestinationConnectionIdLen, E>(encoder);
322
323 self.tag.encode(encoder);
324
325 self.version.encode(encoder);
326
327 self.destination_connection_id
328 .encode_with_len_prefix::<DestinationConnectionIdLen, E>(encoder);
329 self.source_connection_id
330 .encode_with_len_prefix::<SourceConnectionIdLen, E>(encoder);
331 self.retry_token.encode(encoder);
332 }
333}
334
335#[cfg(test)]
336mod tests {
337 use super::*;
338 use crate::{crypto::retry, inet, packet};
339 use s2n_codec::EncoderBuffer;
340
341 #[test]
342 fn test_encode() {
343 let packet = Retry {
344 tag: (retry_tag!() << 4) | 0x0f,
345 destination_connection_id: &retry::example::DCID,
346 source_connection_id: &retry::example::SCID,
347 retry_token: &retry::example::TOKEN,
348 retry_integrity_tag: &retry::example::EXPECTED_TAG,
349 version: retry::example::VERSION,
350 };
351 let mut buf = [0; retry::example::PACKET_LEN];
352 let mut encoder = EncoderBuffer::new(&mut buf);
353 packet.encode(&mut encoder);
354 assert_eq!(retry::example::PACKET[..], buf[..]);
355 }
356
357 #[test]
358 fn test_decode() {
359 let mut buf = retry::example::PACKET;
360 let decoder = DecoderBufferMut::new(&mut buf);
361 let remote_address = inet::ip::SocketAddress::default();
362 let connection_info = connection::id::ConnectionInfo::new(&remote_address);
363 let (packet, _) = packet::ProtectedPacket::decode(decoder, &connection_info, &20).unwrap();
364 let packet = match packet {
365 packet::ProtectedPacket::Retry(retry) => retry,
366 _ => panic!("expected retry packet type"),
367 };
368
369 assert_eq!(packet.retry_integrity_tag, &retry::example::EXPECTED_TAG);
375 assert_eq!(packet.retry_token, retry::example::TOKEN);
376 assert_eq!(packet.source_connection_id, retry::example::SCID);
377 assert_eq!(packet.destination_connection_id, retry::example::DCID);
378 assert_eq!(packet.version, retry::example::VERSION);
379 }
380
381 #[test]
382 fn test_decode_no_token() {
383 let mut buf = retry::example::INVALID_PACKET_NO_TOKEN;
388 let decoder = DecoderBufferMut::new(&mut buf);
389 let remote_address = inet::ip::SocketAddress::default();
390 let connection_info = connection::id::ConnectionInfo::new(&remote_address);
391 assert!(packet::ProtectedPacket::decode(decoder, &connection_info, &20).is_err());
392 }
393
394 #[test]
395 fn test_pseudo_decode() {
396 let mut buf = retry::example::PACKET;
397 let decoder = DecoderBufferMut::new(&mut buf);
398 let remote_address = inet::ip::SocketAddress::default();
399 let connection_info = connection::id::ConnectionInfo::new(&remote_address);
400 let (packet, _) = packet::ProtectedPacket::decode(decoder, &connection_info, &20).unwrap();
401 let packet = match packet {
402 packet::ProtectedPacket::Retry(retry) => retry,
403 _ => panic!("expected retry packet type"),
404 };
405 let pseudo_packet = packet.pseudo_packet(&retry::example::ODCID);
406
407 assert_eq!(pseudo_packet.retry_token, retry::example::TOKEN);
408 assert_eq!(pseudo_packet.source_connection_id, retry::example::SCID);
409 assert_eq!(
410 pseudo_packet.destination_connection_id,
411 retry::example::DCID
412 );
413 assert_eq!(pseudo_packet.version, retry::example::VERSION);
414 assert_eq!(
415 pseudo_packet.original_destination_connection_id,
416 retry::example::ODCID
417 );
418
419 let length = pseudo_packet.encoding_size();
420 let mut pseudo_scratch: Vec<u8> = vec![0; length];
421 let mut encoder = EncoderBuffer::new(&mut pseudo_scratch);
422 pseudo_packet.encode(&mut encoder);
423
424 assert_eq!(pseudo_scratch, retry::example::PSEUDO_PACKET);
425 }
426}