1use core::{mem::size_of, time::Duration};
12use hash_hasher::HashHasher;
13use s2n_codec::{DecoderBuffer, DecoderBufferMut};
14use s2n_quic_core::{
15 connection, event::api::SocketAddress, random, time::Timestamp, token::Source,
16};
17use s2n_quic_crypto::{constant_time, digest, hmac};
18use std::hash::{Hash, Hasher};
19use zerocopy::{FromBytes, IntoBytes, Unaligned};
20use zeroize::Zeroizing;
21
22struct BaseKey {
23 active_duration: Duration,
24
25 key: Option<(Timestamp, hmac::Key)>,
27
28 duplicate_filter: Option<cuckoofilter::CuckooFilter<HashHasher>>,
32}
33
34impl BaseKey {
35 pub fn new(active_duration: Duration) -> Self {
36 Self {
37 active_duration,
38 key: None,
39 duplicate_filter: None,
40 }
41 }
42
43 pub fn hasher(&mut self, random: &mut dyn random::Generator) -> Option<hmac::Context> {
44 let key = self.poll_key(random)?;
45 Some(hmac::Context::with_key(&key))
46 }
47
48 fn poll_key(&mut self, random: &mut dyn random::Generator) -> Option<hmac::Key> {
49 let now = s2n_quic_platform::time::now();
50
51 if let Some((expires_at, key)) = self.key.as_ref() {
55 if expires_at > &now {
56 return Some(key.clone());
58 }
59 }
60
61 let expires_at = now.checked_add(self.active_duration)?;
62
63 let mut key_material = Zeroizing::new([0; digest::SHA256_OUTPUT_LEN]);
66 random.private_random_fill(&mut key_material[..]);
67 let key = hmac::Key::new(hmac::HMAC_SHA256, key_material.as_ref());
68
69 self.duplicate_filter = None;
72
73 self.key = Some((expires_at, key));
74
75 self.key.as_ref().map(|key| key.1.clone())
76 }
77}
78
79const DEFAULT_KEY_ROTATION_PERIOD: Duration = Duration::from_millis(1000);
80
81#[derive(Debug)]
82pub struct Provider {
83 key_rotation_period: Duration,
91}
92
93impl Default for Provider {
94 fn default() -> Self {
95 Self {
96 key_rotation_period: DEFAULT_KEY_ROTATION_PERIOD,
97 }
98 }
99}
100
101impl super::Provider for Provider {
102 type Format = Format;
103 type Error = core::convert::Infallible;
104
105 fn start(self) -> Result<Self::Format, Self::Error> {
106 let format = Format {
109 key_rotation_period: self.key_rotation_period,
110 current_key_rotates_at: s2n_quic_platform::time::now(),
111 current_key: 0,
112 keys: [
113 BaseKey::new(self.key_rotation_period * 2),
114 BaseKey::new(self.key_rotation_period * 2),
115 ],
116 };
117
118 Ok(format)
119 }
120}
121
122pub struct Format {
123 key_rotation_period: Duration,
131
132 current_key_rotates_at: s2n_quic_core::time::Timestamp,
137
138 current_key: u8,
140
141 keys: [BaseKey; 2],
143}
144
145impl Format {
146 fn current_key(&mut self) -> u8 {
147 let now = s2n_quic_platform::time::now();
148 if now > self.current_key_rotates_at {
149 self.current_key ^= 1;
150 self.current_key_rotates_at = now + self.key_rotation_period;
151
152 }
155 self.current_key
156 }
157
158 fn tag_retry_token(
161 &mut self,
162 token: &Token,
163 context: &mut super::Context<'_>,
164 ) -> Option<hmac::Tag> {
165 let mut ctx = self.keys[token.header.key_id() as usize].hasher(context.random)?;
166
167 ctx.update(&token.original_destination_connection_id);
173 ctx.update(&token.nonce);
174 ctx.update(context.peer_connection_id);
175 match context.remote_address {
176 SocketAddress::IpV4 { ip, port, .. } => {
177 ctx.update(ip);
178 ctx.update(&port.to_be_bytes());
179 }
180 SocketAddress::IpV6 { ip, port, .. } => {
181 ctx.update(ip);
182 ctx.update(&port.to_be_bytes());
183 }
184 _ => {
185 return None;
187 }
188 };
189
190 Some(ctx.sign())
191 }
192
193 fn validate_retry_token(
195 &mut self,
196 context: &mut super::Context<'_>,
197 token: &Token,
198 ) -> Option<connection::InitialId> {
199 if self.keys[token.header.key_id() as usize]
200 .duplicate_filter
201 .as_ref()
202 .is_some_and(|f| f.contains(token))
203 {
204 return None;
205 }
206
207 let tag = self.tag_retry_token(token, context)?;
208
209 if constant_time::verify_slices_are_equal(&token.hmac, tag.as_ref()).is_ok() {
210 let _ = self.keys[token.header.key_id() as usize]
216 .duplicate_filter
217 .get_or_insert_with(|| {
218 cuckoofilter::CuckooFilter::with_capacity(cuckoofilter::DEFAULT_CAPACITY)
219 })
220 .add(token);
221
222 return token.original_destination_connection_id();
223 }
224
225 None
226 }
227}
228
229impl super::Format for Format {
230 const TOKEN_LEN: usize = size_of::<Token>();
231
232 fn generate_new_token(
234 &mut self,
235 _context: &mut super::Context<'_>,
236 _source_connection_id: &connection::LocalId,
237 _output_buffer: &mut [u8],
238 ) -> Option<()> {
239 None
273 }
274
275 fn generate_retry_token(
283 &mut self,
284 context: &mut super::Context<'_>,
285 original_destination_connection_id: &connection::InitialId,
286 output_buffer: &mut [u8],
287 ) -> Option<()> {
288 let buffer = DecoderBufferMut::new(output_buffer);
289 let (token, _) = buffer
290 .decode::<&mut Token>()
291 .expect("Provided output buffer did not match TOKEN_LEN");
292
293 let header = Header::new(Source::RetryPacket, self.current_key());
294
295 token.header = header;
296 token.original_destination_connection_id[..original_destination_connection_id.len()]
297 .copy_from_slice(original_destination_connection_id.as_bytes());
298 token.odcid_len = original_destination_connection_id.len() as u8;
299
300 for b in token
302 .original_destination_connection_id
303 .iter_mut()
304 .skip(original_destination_connection_id.len())
305 {
306 *b = 0;
307 }
308
309 context.random.public_random_fill(&mut token.nonce[..]);
311
312 let tag = self.tag_retry_token(token, context)?;
313
314 token.hmac.copy_from_slice(tag.as_ref());
315
316 Some(())
317 }
318
319 fn validate_token(
324 &mut self,
325 context: &mut super::Context<'_>,
326 token: &[u8],
327 ) -> Option<connection::InitialId> {
328 let buffer = DecoderBuffer::new(token);
329 let (token, remaining) = buffer.decode::<&Token>().ok()?;
330
331 remaining.ensure_empty().ok()?;
333
334 if token.header.version() != TOKEN_VERSION {
335 return None;
336 }
337
338 let source = token.header.token_source();
339
340 match source {
341 Source::RetryPacket => self.validate_retry_token(context, token),
342 Source::NewTokenFrame => None, }
344 }
357}
358
359#[derive(Clone, Copy, Debug, FromBytes, IntoBytes, Unaligned)]
360#[repr(C)]
361pub(crate) struct Header(u8);
362
363const TOKEN_VERSION: u8 = 0x00;
364
365const VERSION_SHIFT: u8 = 7;
366const VERSION_MASK: u8 = 0x80;
367
368const TOKEN_SOURCE_SHIFT: u8 = 6;
369const TOKEN_SOURCE_MASK: u8 = 0x40;
370
371const KEY_ID_SHIFT: u8 = 5;
372const KEY_ID_MASK: u8 = 0x20;
373
374impl Header {
375 fn new(source: Source, key_id: u8) -> Header {
376 let mut header: u8 = 0;
377 header |= TOKEN_VERSION << VERSION_SHIFT;
378 header |= match source {
383 Source::NewTokenFrame => 0 << TOKEN_SOURCE_SHIFT,
384 Source::RetryPacket => 1 << TOKEN_SOURCE_SHIFT,
385 };
386
387 debug_assert!(key_id <= 1);
389 header |= (key_id & 0x01) << KEY_ID_SHIFT;
390
391 Header(header)
392 }
393
394 fn version(self) -> u8 {
395 (self.0 & VERSION_MASK) >> VERSION_SHIFT
396 }
397
398 fn key_id(self) -> u8 {
399 (self.0 & KEY_ID_MASK) >> KEY_ID_SHIFT
400 }
401
402 fn token_source(self) -> Source {
408 match (self.0 & TOKEN_SOURCE_MASK) >> TOKEN_SOURCE_SHIFT {
409 0 => Source::NewTokenFrame,
410 1 => Source::RetryPacket,
411 _ => Source::NewTokenFrame,
412 }
413 }
414}
415
416#[derive(Copy, Clone, Debug, FromBytes, IntoBytes, Unaligned)]
420#[repr(C)]
421struct Token {
422 header: Header,
423
424 odcid_len: u8,
425 original_destination_connection_id: [u8; 20],
426
427 nonce: [u8; 32],
433
434 hmac: [u8; 32],
443}
444
445s2n_codec::zerocopy_value_codec!(Token);
446
447impl Hash for Token {
448 fn hash<H: Hasher>(&self, state: &mut H) {
450 state.write(&self.hmac);
451 }
452}
453
454impl Token {
455 pub fn original_destination_connection_id(&self) -> Option<connection::InitialId> {
456 let dcid = self
457 .original_destination_connection_id
458 .get(..self.odcid_len as usize)?;
459 connection::InitialId::try_from_bytes(dcid)
460 }
461}
462
463#[cfg(test)]
464mod tests {
465 use super::*;
466 use s2n_quic_core::{
467 inet::SocketAddress,
468 random,
469 token::{Context, Format as FormatTrait, Source},
470 };
471 use s2n_quic_platform::time;
472 use std::{net::SocketAddr, sync::Arc};
473
474 const TEST_KEY_ROTATION_PERIOD: Duration = Duration::from_millis(1000);
475
476 fn get_test_format() -> Format {
477 Format {
478 key_rotation_period: TEST_KEY_ROTATION_PERIOD,
479 keys: [
480 BaseKey::new(TEST_KEY_ROTATION_PERIOD * 2),
481 BaseKey::new(TEST_KEY_ROTATION_PERIOD * 2),
482 ],
483 current_key_rotates_at: time::now(),
484 current_key: 0,
485 }
486 }
487
488 #[test]
489 fn test_header() {
490 for source in &[Source::NewTokenFrame, Source::RetryPacket] {
493 for key_id in [0, 1] {
494 let header = Header::new(*source, key_id);
495 assert_eq!(header.version(), TOKEN_VERSION);
497 assert_eq!(header.token_source(), *source);
504 assert_eq!(header.key_id(), key_id);
505 }
506 }
507 }
508
509 #[test]
510 fn test_valid_retry_tokens() {
511 let clock = Arc::new(time::testing::MockClock::new());
512 time::testing::set_local_clock(clock.clone());
513
514 let mut format = get_test_format();
515 let first_conn_id = connection::PeerId::try_from_bytes(&[2, 4, 6, 8, 10]).unwrap();
516 let second_conn_id = connection::PeerId::try_from_bytes(&[1, 3, 5, 7, 9]).unwrap();
517 let orig_conn_id =
518 connection::InitialId::try_from_bytes(&[0, 1, 2, 3, 4, 5, 6, 7]).unwrap();
519 let addr = SocketAddress::default();
520 let mut first_token = [0; Format::TOKEN_LEN];
521 let mut second_token = [0; Format::TOKEN_LEN];
522 let mut random = random::testing::Generator(5);
523 let mut context = Context::new(&addr, &first_conn_id, &mut random);
524
525 format
527 .generate_retry_token(&mut context, &orig_conn_id, &mut first_token)
528 .unwrap();
529
530 context = Context::new(&addr, &second_conn_id, &mut random);
531 format
532 .generate_retry_token(&mut context, &orig_conn_id, &mut second_token)
533 .unwrap();
534
535 clock.adjust_by(TEST_KEY_ROTATION_PERIOD);
536 context = Context::new(&addr, &first_conn_id, &mut random);
537 assert_eq!(
538 format.validate_token(&mut context, &first_token),
539 Some(orig_conn_id)
540 );
541 context = Context::new(&addr, &second_conn_id, &mut random);
542 assert_eq!(
543 format.validate_token(&mut context, &second_token),
544 Some(orig_conn_id)
545 );
546 context = Context::new(&addr, &first_conn_id, &mut random);
547 assert_eq!(format.validate_token(&mut context, &second_token), None);
548 }
549
550 #[test]
551 fn test_retry_ip_port_validation() {
552 let mut format = get_test_format();
559 let conn_id = connection::PeerId::try_from_bytes(&[2, 4, 6, 8, 10]).unwrap();
560 let orig_conn_id =
561 connection::InitialId::try_from_bytes(&[0, 1, 2, 3, 4, 5, 6, 7]).unwrap();
562
563 let mut token = [0; Format::TOKEN_LEN];
564 let ip_address = "127.0.0.1:443";
565 let addr: SocketAddr = ip_address.parse().unwrap();
566 let correct_address: SocketAddress = addr.into();
567 let mut random = random::testing::Generator(5);
568 let mut context = Context::new(&correct_address, &conn_id, &mut random);
569 format
570 .generate_retry_token(&mut context, &orig_conn_id, &mut token)
571 .unwrap();
572
573 let ip_address = "127.0.0.2:443";
574 let addr: SocketAddr = ip_address.parse().unwrap();
575 let incorrect_address: SocketAddress = addr.into();
576 context = Context::new(&incorrect_address, &conn_id, &mut random);
577 assert_eq!(format.validate_token(&mut context, &token), None);
578
579 let ip_address = "127.0.0.1:444";
580 let addr: SocketAddr = ip_address.parse().unwrap();
581 let incorrect_port: SocketAddress = addr.into();
582 context = Context::new(&incorrect_port, &conn_id, &mut random);
583 assert_eq!(format.validate_token(&mut context, &token), None);
584
585 context = Context::new(&correct_address, &conn_id, &mut random);
587 assert!(format.validate_token(&mut context, &token).is_some());
588 }
589
590 #[test]
591 fn test_key_rotation() {
592 let clock = Arc::new(time::testing::MockClock::new());
599 time::testing::set_local_clock(clock.clone());
600
601 let mut format = get_test_format();
602 let conn_id = connection::PeerId::TEST_ID;
603 let orig_conn_id = connection::InitialId::TEST_ID;
604 let addr = SocketAddress::default();
605 let mut buf = [0; Format::TOKEN_LEN];
606 let mut random = random::testing::Generator(5);
607 let mut context = Context::new(&addr, &conn_id, &mut random);
608 format
609 .generate_retry_token(&mut context, &orig_conn_id, &mut buf)
610 .unwrap();
611
612 clock.adjust_by(TEST_KEY_ROTATION_PERIOD);
615 assert!(format.validate_token(&mut context, &buf).is_some());
616
617 clock.adjust_by(TEST_KEY_ROTATION_PERIOD);
619 assert!(format.validate_token(&mut context, &buf).is_none());
620 }
621
622 #[test]
623 fn test_expired_retry_token() {
624 let clock = Arc::new(time::testing::MockClock::new());
629 time::testing::set_local_clock(clock.clone());
630
631 let mut format = get_test_format();
632 let conn_id = connection::PeerId::TEST_ID;
633 let orig_conn_id = connection::InitialId::TEST_ID;
634 let addr = SocketAddress::default();
635 let mut buf = [0; Format::TOKEN_LEN];
636 let mut random = random::testing::Generator(5);
637 let mut context = Context::new(&addr, &conn_id, &mut random);
638 format
639 .generate_retry_token(&mut context, &orig_conn_id, &mut buf)
640 .unwrap();
641
642 clock.adjust_by(TEST_KEY_ROTATION_PERIOD * 2);
648 assert!(format.validate_token(&mut context, &buf).is_none());
649 }
650
651 #[test]
652 fn test_retry_validation_default_format() {
653 let clock = Arc::new(time::testing::MockClock::new());
654 time::testing::set_local_clock(clock);
655
656 let mut format = get_test_format();
657 let conn_id = connection::PeerId::TEST_ID;
658 let odcid = connection::InitialId::try_from_bytes(&[0, 1, 2, 3, 4, 5, 6, 7]).unwrap();
659 let addr = SocketAddress::default();
660 let mut buf = [0; Format::TOKEN_LEN];
661 let mut random = random::testing::Generator(5);
662 let mut context = Context::new(&addr, &conn_id, &mut random);
663 format
664 .generate_retry_token(&mut context, &odcid, &mut buf)
665 .unwrap();
666
667 assert_eq!(format.validate_token(&mut context, &buf), Some(odcid));
668
669 let wrong_conn_id = connection::PeerId::try_from_bytes(&[0, 1, 2]).unwrap();
670 context = Context::new(&addr, &wrong_conn_id, &mut random);
671 assert!(format.validate_token(&mut context, &buf).is_none());
672 }
673
674 #[test]
675 fn test_duplicate_token_detection() {
676 let mut format = get_test_format();
681 let conn_id = connection::PeerId::TEST_ID;
682 let odcid = connection::InitialId::try_from_bytes(&[0, 1, 2, 3, 4, 5, 6, 7]).unwrap();
683 let addr = SocketAddress::default();
684 let mut buf = [0; Format::TOKEN_LEN];
685 let mut random = random::testing::Generator(5);
686 let mut context = Context::new(&addr, &conn_id, &mut random);
687 format
688 .generate_retry_token(&mut context, &odcid, &mut buf)
689 .unwrap();
690
691 assert_eq!(format.validate_token(&mut context, &buf), Some(odcid));
692
693 assert!(format.validate_token(&mut context, &buf).is_none());
695 }
696
697 #[test]
698 fn test_token_modification_detection() {
699 let mut format = get_test_format();
705 let conn_id = connection::PeerId::try_from_bytes(&[2, 4, 6, 8, 10]).unwrap();
706 let orig_conn_id =
707 connection::InitialId::try_from_bytes(&[0, 1, 2, 3, 4, 5, 6, 7]).unwrap();
708 let addr = SocketAddress::default();
709 let mut token = [0; Format::TOKEN_LEN];
710
711 let mut random = random::testing::Generator(5);
712 let mut context = Context::new(&addr, &conn_id, &mut random);
713 format
715 .generate_retry_token(&mut context, &orig_conn_id, &mut token)
716 .unwrap();
717
718 for i in 0..Format::TOKEN_LEN {
719 random = random::testing::Generator(5);
720 context = Context::new(&addr, &conn_id, &mut random);
721 token[i] = !token[i];
722 assert!(format.validate_token(&mut context, &token).is_none());
723 token[i] = !token[i];
724 }
725 }
726
727 #[test]
728 fn test_token_length_check() {
729 let mut format = get_test_format();
730 let conn_id = connection::PeerId::try_from_bytes(&[2, 4, 6, 8, 10]).unwrap();
731 let addr = SocketAddress::default();
732
733 bolero::check!().for_each(move |token| {
734 let mut random = random::testing::Generator(5);
735 let mut context = Context::new(&addr, &conn_id, &mut random);
736 assert!(format.validate_token(&mut context, token).is_none())
737 });
738 }
739
740 #[test]
741 fn test_token_falsification_detection() {
742 let mut format = get_test_format();
743 let conn_id = connection::PeerId::try_from_bytes(&[2, 4, 6, 8, 10]).unwrap();
744 let addr = SocketAddress::default();
745
746 let generator = bolero::generator::produce::<Vec<u8>>()
752 .with()
753 .len(Format::TOKEN_LEN);
754 bolero::check!()
755 .with_generator(generator)
756 .for_each(move |token| {
757 let mut random = random::testing::Generator(5);
758 let mut context = Context::new(&addr, &conn_id, &mut random);
759 assert!(format.validate_token(&mut context, token).is_none())
760 });
761 }
762}