1use crate::{aws_lc_aead as aead, constant_time};
5use s2n_quic_core::crypto::{
6 self, packet_protection,
7 retry::{IntegrityTag, NONCE_BYTES, SECRET_KEY_BYTES},
8};
9
10lazy_static::lazy_static! {
11 static ref SECRET_KEY: aead::LessSafeKey = aead::LessSafeKey::new(
13 aead::UnboundKey::new(&aead::AES_128_GCM, &SECRET_KEY_BYTES).unwrap(),
14 );
15}
16
17#[derive(Debug)]
18pub struct RetryKey;
19
20impl crypto::RetryKey for RetryKey {
21 fn generate_tag(pseudo_packet: &[u8]) -> IntegrityTag {
22 let nonce = aead::Nonce::assume_unique_for_key(NONCE_BYTES);
23 let tag = SECRET_KEY
24 .seal_in_place_separate_tag(nonce, aead::Aad::from(pseudo_packet), &mut [])
25 .expect("in_out len is 0 and should always be less than the nonce max bytes");
26
27 tag.as_ref()
28 .try_into()
29 .expect("AES_128_GCM tag len should always be 128 bits")
30 }
31
32 fn validate(pseudo_packet: &[u8], tag: IntegrityTag) -> Result<(), packet_protection::Error> {
33 let expected = Self::generate_tag(pseudo_packet);
34
35 constant_time::verify_slices_are_equal(&expected, &tag)
36 .map_err(|_| packet_protection::Error::DECRYPT_ERROR)
37 }
38}
39
40#[cfg(test)]
41mod tests {
42 use super::*;
43 use hex_literal::hex;
44 use s2n_codec::{DecoderBufferMut, Encoder, EncoderBuffer};
45 use s2n_quic_core::{
46 connection,
47 connection::id::ConnectionInfo,
48 crypto::{retry, RetryKey as _},
49 inet, packet,
50 packet::number::{PacketNumberSpace, TruncatedPacketNumber},
51 random, token,
52 varint::VarInt,
53 };
54
55 #[test]
56 fn test_tag_validation() {
57 let invalid_tag: [u8; 16] = hex!("00112233445566778899aabbccddeeff");
58
59 assert!(
60 RetryKey::validate(&retry::example::PSEUDO_PACKET, retry::example::EXPECTED_TAG)
61 .is_ok()
62 );
63 assert!(RetryKey::validate(&retry::example::PSEUDO_PACKET, invalid_tag).is_err());
64 }
65
66 fn pn(space: PacketNumberSpace) -> TruncatedPacketNumber {
67 let pn = space.new_packet_number(VarInt::new(0x1).unwrap());
68 pn.truncate(pn).unwrap()
69 }
70
71 #[test]
72 fn test_packet_encode() {
73 let remote_address = inet::ip::SocketAddress::default();
74 let mut token_format = token::testing::Format::new();
75 let packet = packet::initial::Initial {
79 version: 0x01,
80 destination_connection_id: &retry::example::ODCID[..],
81 source_connection_id: &retry::example::DCID[..],
82 token: &retry::example::TOKEN[..],
83 packet_number: pn(PacketNumberSpace::Initial),
84 payload: &[1u8, 2, 3, 4, 5][..],
85 };
86
87 let mut buf = vec![0u8; 1200];
88 let mut encoder = EncoderBuffer::new(&mut buf);
89 encoder.encode(&packet);
90 let len = encoder.len();
91 let decoder = DecoderBufferMut::new(&mut buf[..len]);
92 let connection_info = ConnectionInfo::new(&remote_address);
93 let mut output_buf = vec![0u8; 1200];
94
95 if let Some(packet) =
96 match packet::ProtectedPacket::decode(decoder, &connection_info, &3).unwrap() {
97 (packet::ProtectedPacket::Initial(packet), _) => Some(packet),
98 _ => None,
99 }
100 {
101 let local_conn_id = connection::LocalId::try_from_bytes(&retry::example::SCID).unwrap();
102 if let Some(range) = packet::retry::Retry::encode_packet::<_, RetryKey>(
103 &remote_address,
104 &packet,
105 &local_conn_id,
106 &mut random::testing::Generator(5),
107 &mut token_format,
108 &mut output_buf,
109 ) {
110 assert_eq!(&output_buf[range], &retry::example::PACKET[..]);
111 }
112 }
113 }
114
115 #[test]
116 #[should_panic]
117 fn test_odcid_different_from_local_cid() {
118 let remote_address = inet::ip::SocketAddress::default();
119 let mut token_format = token::testing::Format::new();
120 let packet = packet::initial::Initial {
124 version: 0xff00_0020,
125 destination_connection_id: &retry::example::ODCID[..],
126 source_connection_id: &retry::example::DCID[..],
127 token: &retry::example::TOKEN[..],
128 packet_number: pn(PacketNumberSpace::Initial),
129 payload: &[1u8, 2, 3, 4, 5][..],
130 };
131
132 let mut buf = vec![0u8; 1200];
133 let mut encoder = EncoderBuffer::new(&mut buf);
134 encoder.encode(&packet);
135 let len = encoder.len();
136 let decoder = DecoderBufferMut::new(&mut buf[..len]);
138 let connection_info = ConnectionInfo::new(&remote_address);
139 let mut output_buf = vec![0u8; 1200];
140 if let Some(packet) =
141 match packet::ProtectedPacket::decode(decoder, &connection_info, &3).unwrap() {
142 (packet::ProtectedPacket::Initial(packet), _) => Some(packet),
143 _ => None,
144 }
145 {
146 let local_conn_id =
151 connection::LocalId::try_from_bytes(&retry::example::ODCID).unwrap();
152 assert!(packet::retry::Retry::encode_packet::<_, RetryKey>(
153 &remote_address,
154 &packet,
155 &local_conn_id,
156 &mut random::testing::Generator(5),
157 &mut token_format,
158 &mut output_buf,
159 )
160 .is_none());
161 }
162 }
163}