1use crate::{cipher_suite::TLS_AES_128_GCM_SHA256 as CipherSuite, header_key::HeaderKeyPair, hkdf};
5use s2n_quic_core::{
6 crypto::{
7 self,
8 label::{CLIENT_IN, SERVER_IN},
9 packet_protection, scatter, Key, INITIAL_SALT,
10 },
11 endpoint,
12};
13
14header_key!(InitialHeaderKey);
15
16impl crypto::InitialHeaderKey for InitialHeaderKey {}
17
18#[derive(Debug)]
19pub struct InitialKey {
20 sealer: CipherSuite,
21 opener: CipherSuite,
22}
23
24lazy_static::lazy_static! {
25 static ref INITIAL_SIGNING_KEY: hkdf::Salt = hkdf::Salt::new(hkdf::HKDF_SHA256, &INITIAL_SALT);
27}
28
29impl InitialKey {
30 fn new(endpoint: endpoint::Type, connection_id: &[u8]) -> (Self, InitialHeaderKey) {
31 let initial_secret = INITIAL_SIGNING_KEY.extract(connection_id);
32 let digest = INITIAL_SIGNING_KEY.algorithm();
33
34 let client_secret = initial_secret
35 .expand(&[&CLIENT_IN], digest)
36 .expect("label size verified")
37 .into();
38
39 let server_secret = initial_secret
40 .expand(&[&SERVER_IN], digest)
41 .expect("label size verified")
42 .into();
43
44 let (sealer, opener) = match endpoint {
45 endpoint::Type::Client => (
46 CipherSuite::new(client_secret),
47 CipherSuite::new(server_secret),
48 ),
49 endpoint::Type::Server => (
50 CipherSuite::new(server_secret),
51 CipherSuite::new(client_secret),
52 ),
53 };
54
55 let (key_sealer, header_sealer) = sealer;
56 let (key_opener, header_opener) = opener;
57 let key = Self {
58 sealer: key_sealer,
59 opener: key_opener,
60 };
61 let header_key = InitialHeaderKey(HeaderKeyPair {
62 sealer: header_sealer,
63 opener: header_opener,
64 });
65
66 (key, header_key)
67 }
68}
69
70impl crypto::InitialKey for InitialKey {
71 type HeaderKey = InitialHeaderKey;
72
73 fn new_server(connection_id: &[u8]) -> (Self, Self::HeaderKey) {
74 Self::new(endpoint::Type::Server, connection_id)
75 }
76
77 fn new_client(connection_id: &[u8]) -> (Self, Self::HeaderKey) {
78 Self::new(endpoint::Type::Client, connection_id)
79 }
80}
81
82impl Key for InitialKey {
83 #[inline]
84 fn decrypt(
85 &self,
86 packet_number: u64,
87 header: &[u8],
88 payload: &mut [u8],
89 ) -> Result<(), packet_protection::Error> {
90 self.opener.decrypt(packet_number, header, payload)
91 }
92
93 #[inline]
94 fn encrypt(
95 &mut self,
96 packet_number: u64,
97 header: &[u8],
98 payload: &mut scatter::Buffer,
99 ) -> Result<(), packet_protection::Error> {
100 self.sealer.encrypt(packet_number, header, payload)
101 }
102
103 #[inline]
104 fn tag_len(&self) -> usize {
105 self.sealer.tag_len()
106 }
107
108 #[inline]
109 fn aead_confidentiality_limit(&self) -> u64 {
110 self.sealer.aead_confidentiality_limit()
111 }
112
113 #[inline]
114 fn aead_integrity_limit(&self) -> u64 {
115 self.opener.aead_integrity_limit()
116 }
117
118 #[inline]
119 fn cipher_suite(&self) -> s2n_quic_core::crypto::tls::CipherSuite {
120 self.opener.cipher_suite()
121 }
122}
123
124#[cfg(test)]
125mod tests {
126 use super::*;
127 use s2n_codec::{DecoderBufferMut, EncoderBuffer};
128 use s2n_quic_core::{
129 connection::id::ConnectionInfo,
130 crypto::{
131 initial::{
132 EXAMPLE_CLIENT_INITIAL_PAYLOAD, EXAMPLE_CLIENT_INITIAL_PROTECTED_PACKET,
133 EXAMPLE_DCID, EXAMPLE_SERVER_INITIAL_PAYLOAD,
134 EXAMPLE_SERVER_INITIAL_PROTECTED_PACKET,
135 },
136 InitialKey as _,
137 },
138 inet::SocketAddress,
139 packet::{encoding::PacketEncoder, initial::CleartextInitial, ProtectedPacket},
140 };
141
142 #[test]
143 fn rfc_example_server_test() {
144 test_round_trip(
145 &mut InitialKey::new_client(&EXAMPLE_DCID),
146 &InitialKey::new_server(&EXAMPLE_DCID),
147 &EXAMPLE_CLIENT_INITIAL_PROTECTED_PACKET,
148 &EXAMPLE_CLIENT_INITIAL_PAYLOAD,
149 );
150 }
151
152 #[test]
153 fn rfc_example_client_test() {
154 test_round_trip(
155 &mut InitialKey::new_server(&EXAMPLE_DCID),
156 &InitialKey::new_client(&EXAMPLE_DCID),
157 &EXAMPLE_SERVER_INITIAL_PROTECTED_PACKET,
158 &EXAMPLE_SERVER_INITIAL_PAYLOAD,
159 );
160 }
161
162 fn test_round_trip(
163 sealer: &mut (InitialKey, InitialHeaderKey),
164 opener: &(InitialKey, InitialHeaderKey),
165 protected_packet: &[u8],
166 cleartext_payload: &[u8],
167 ) {
168 let (sealer_key, sealer_header_key) = sealer;
169 let (opener_key, opener_header_key) = opener;
170 let (version, dcid, scid, token, sealed_packet) = decrypt(
171 opener_key,
172 opener_header_key,
173 protected_packet.to_vec(),
174 cleartext_payload,
175 |packet| {
176 let version = packet.version;
177 let dcid = packet.destination_connection_id.to_vec();
178 let scid = packet.source_connection_id.to_vec();
179 let token = packet.token.to_vec();
180
181 let mut output_buffer = vec![0; protected_packet.len()];
182 packet
183 .encode_packet(
184 sealer_key,
185 sealer_header_key,
186 Default::default(),
187 None,
188 EncoderBuffer::new(&mut output_buffer),
189 )
190 .unwrap();
191
192 (version, dcid, scid, token, output_buffer)
193 },
194 );
195
196 decrypt(
200 opener_key,
201 opener_header_key,
202 sealed_packet,
203 cleartext_payload,
204 |packet| {
205 assert_eq!(packet.version, version);
206 assert_eq!(packet.destination_connection_id, &dcid[..]);
207 assert_eq!(packet.source_connection_id, &scid[..]);
208 assert_eq!(packet.token, &token[..]);
209 },
210 );
211 }
212
213 fn decrypt<F: FnOnce(CleartextInitial) -> O, O>(
214 opener_key: &InitialKey,
215 opener_header_key: &InitialHeaderKey,
216 mut protected_packet: Vec<u8>,
217 cleartext_payload: &[u8],
218 on_decrypt: F,
219 ) -> O {
220 let decoder = DecoderBufferMut::new(&mut protected_packet);
221 let remote_address = SocketAddress::default();
222 let connection_info = ConnectionInfo::new(&remote_address);
223 let (packet, _) = ProtectedPacket::decode(decoder, &connection_info, &20).unwrap();
224
225 let packet = match packet {
226 ProtectedPacket::Initial(initial) => initial,
227 _ => panic!("expected initial packet type"),
228 };
229
230 let packet = packet
231 .unprotect(opener_header_key, Default::default())
232 .unwrap();
233 let packet = packet.decrypt(opener_key).unwrap();
234
235 let actual_payload = &packet.payload.as_less_safe_slice()[..cleartext_payload.len()];
237 assert_eq!(cleartext_payload, actual_payload);
238
239 on_decrypt(packet)
240 }
241}