1use crate::cipher::{Iv, IvLen};
3pub use crate::client::ClientQuicExt;
4use crate::conn::CommonState;
5use crate::error::Error;
6use crate::msgs::enums::AlertDescription;
7pub use crate::server::ServerQuicExt;
8use crate::suites::BulkAlgorithm;
9use crate::tls13::key_schedule::hkdf_expand;
10use crate::tls13::{Tls13CipherSuite, TLS13_AES_128_GCM_SHA256_INTERNAL};
11use std::fmt::Debug;
12
13use ring::{aead, hkdf};
14
15#[derive(Clone, Debug)]
17pub struct Secrets {
18 client: hkdf::Prk,
20 server: hkdf::Prk,
22 suite: &'static Tls13CipherSuite,
24 is_client: bool,
25}
26
27impl Secrets {
28 pub(crate) fn new(
29 client: hkdf::Prk,
30 server: hkdf::Prk,
31 suite: &'static Tls13CipherSuite,
32 is_client: bool,
33 ) -> Self {
34 Self {
35 client,
36 server,
37 suite,
38 is_client,
39 }
40 }
41
42 pub fn next_packet_keys(&mut self) -> PacketKeySet {
44 let keys = PacketKeySet::new(self);
45 self.update();
46 keys
47 }
48
49 fn update(&mut self) {
50 let hkdf_alg = self.suite.hkdf_algorithm;
51 self.client = hkdf_expand(&self.client, hkdf_alg, b"quic ku", &[]);
52 self.server = hkdf_expand(&self.server, hkdf_alg, b"quic ku", &[]);
53 }
54
55 fn local_remote(&self) -> (&hkdf::Prk, &hkdf::Prk) {
56 if self.is_client {
57 (&self.client, &self.server)
58 } else {
59 (&self.server, &self.client)
60 }
61 }
62}
63
64pub trait QuicExt {
66 fn quic_transport_parameters(&self) -> Option<&[u8]>;
74
75 fn zero_rtt_keys(&self) -> Option<DirectionalKeys>;
77
78 fn read_hs(&mut self, plaintext: &[u8]) -> Result<(), Error>;
82
83 fn write_hs(&mut self, buf: &mut Vec<u8>) -> Option<KeyChange>;
87
88 fn alert(&self) -> Option<AlertDescription>;
92}
93
94pub struct DirectionalKeys {
96 pub header: HeaderProtectionKey,
98 pub packet: PacketKey,
100}
101
102impl DirectionalKeys {
103 pub(crate) fn new(suite: &'static Tls13CipherSuite, secret: &hkdf::Prk) -> Self {
104 Self {
105 header: HeaderProtectionKey::new(suite, secret),
106 packet: PacketKey::new(suite, secret),
107 }
108 }
109}
110
111pub struct HeaderProtectionKey(aead::quic::HeaderProtectionKey);
113
114impl HeaderProtectionKey {
115 fn new(suite: &'static Tls13CipherSuite, secret: &hkdf::Prk) -> Self {
116 let alg = match suite.common.bulk {
117 BulkAlgorithm::Aes128Gcm => &aead::quic::AES_128,
118 BulkAlgorithm::Aes256Gcm => &aead::quic::AES_256,
119 BulkAlgorithm::Chacha20Poly1305 => &aead::quic::CHACHA20,
120 };
121
122 Self(hkdf_expand(secret, alg, b"quic hp", &[]))
123 }
124
125 #[inline]
146 pub fn encrypt_in_place(
147 &self,
148 sample: &[u8],
149 first: &mut u8,
150 packet_number: &mut [u8],
151 ) -> Result<(), Error> {
152 self.xor_in_place(sample, first, packet_number, false)
153 }
154
155 #[inline]
177 pub fn decrypt_in_place(
178 &self,
179 sample: &[u8],
180 first: &mut u8,
181 packet_number: &mut [u8],
182 ) -> Result<(), Error> {
183 self.xor_in_place(sample, first, packet_number, true)
184 }
185
186 fn xor_in_place(
187 &self,
188 sample: &[u8],
189 first: &mut u8,
190 packet_number: &mut [u8],
191 masked: bool,
192 ) -> Result<(), Error> {
193 let mask = self
196 .0
197 .new_mask(sample)
198 .map_err(|_| Error::General("sample of invalid length".into()))?;
199
200 let (first_mask, pn_mask) = mask.split_first().unwrap();
203
204 if packet_number.len() > pn_mask.len() {
207 return Err(Error::General("packet number too long".into()));
208 }
209
210 const LONG_HEADER_FORM: u8 = 0x80;
214 let bits = match *first & LONG_HEADER_FORM == LONG_HEADER_FORM {
215 true => 0x0f, false => 0x1f, };
218
219 let first_plain = match masked {
220 true => *first ^ (first_mask & bits),
222 false => *first,
224 };
225 let pn_len = (first_plain & 0x03) as usize + 1;
226
227 *first ^= first_mask & bits;
228 for (dst, m) in packet_number
229 .iter_mut()
230 .zip(pn_mask)
231 .take(pn_len)
232 {
233 *dst ^= m;
234 }
235
236 Ok(())
237 }
238
239 #[inline]
241 pub fn sample_len(&self) -> usize {
242 self.0.algorithm().sample_len()
243 }
244}
245
246pub struct PacketKey {
248 key: aead::LessSafeKey,
250 iv: Iv,
252 suite: &'static Tls13CipherSuite,
254}
255
256impl PacketKey {
257 fn new(suite: &'static Tls13CipherSuite, secret: &hkdf::Prk) -> Self {
258 Self {
259 key: aead::LessSafeKey::new(hkdf_expand(
260 secret,
261 suite.common.aead_algorithm,
262 b"quic key",
263 &[],
264 )),
265 iv: hkdf_expand(secret, IvLen, b"quic iv", &[]),
266 suite,
267 }
268 }
269
270 pub fn encrypt_in_place(
278 &self,
279 packet_number: u64,
280 header: &[u8],
281 payload: &mut [u8],
282 ) -> Result<Tag, Error> {
283 let aad = aead::Aad::from(header);
284 let nonce = nonce_for(packet_number, &self.iv);
285 let tag = self
286 .key
287 .seal_in_place_separate_tag(nonce, aad, payload)
288 .map_err(|_| Error::EncryptError)?;
289 Ok(Tag(tag))
290 }
291
292 pub fn decrypt_in_place<'a>(
300 &self,
301 packet_number: u64,
302 header: &[u8],
303 payload: &'a mut [u8],
304 ) -> Result<&'a [u8], Error> {
305 let payload_len = payload.len();
306 let aad = aead::Aad::from(header);
307 let nonce = nonce_for(packet_number, &self.iv);
308 self.key
309 .open_in_place(nonce, aad, payload)
310 .map_err(|_| Error::DecryptError)?;
311
312 let plain_len = payload_len - self.key.algorithm().tag_len();
313 Ok(&payload[..plain_len])
314 }
315
316 #[inline]
320 pub fn confidentiality_limit(&self) -> u64 {
321 self.suite.confidentiality_limit
322 }
323
324 #[inline]
328 pub fn integrity_limit(&self) -> u64 {
329 self.suite.integrity_limit
330 }
331
332 #[inline]
334 pub fn tag_len(&self) -> usize {
335 self.key.algorithm().tag_len()
336 }
337}
338
339pub struct Tag(aead::Tag);
341
342impl AsRef<[u8]> for Tag {
343 #[inline]
344 fn as_ref(&self) -> &[u8] {
345 self.0.as_ref()
346 }
347}
348
349pub struct PacketKeySet {
351 pub local: PacketKey,
353 pub remote: PacketKey,
355}
356
357impl PacketKeySet {
358 fn new(secrets: &Secrets) -> Self {
359 let (local, remote) = secrets.local_remote();
360 Self {
361 local: PacketKey::new(secrets.suite, local),
362 remote: PacketKey::new(secrets.suite, remote),
363 }
364 }
365}
366
367pub struct Keys {
369 pub local: DirectionalKeys,
371 pub remote: DirectionalKeys,
373}
374
375impl Keys {
376 pub fn initial(version: Version, client_dst_connection_id: &[u8], is_client: bool) -> Self {
378 const CLIENT_LABEL: &[u8] = b"client in";
379 const SERVER_LABEL: &[u8] = b"server in";
380 let salt = version.initial_salt();
381 let hs_secret = hkdf::Salt::new(hkdf::HKDF_SHA256, salt).extract(client_dst_connection_id);
382
383 let secrets = Secrets {
384 client: hkdf_expand(&hs_secret, hkdf::HKDF_SHA256, CLIENT_LABEL, &[]),
385 server: hkdf_expand(&hs_secret, hkdf::HKDF_SHA256, SERVER_LABEL, &[]),
386 suite: TLS13_AES_128_GCM_SHA256_INTERNAL,
387 is_client,
388 };
389 Self::new(&secrets)
390 }
391
392 fn new(secrets: &Secrets) -> Self {
393 let (local, remote) = secrets.local_remote();
394 Self {
395 local: DirectionalKeys::new(secrets.suite, local),
396 remote: DirectionalKeys::new(secrets.suite, remote),
397 }
398 }
399}
400
401pub(crate) fn write_hs(this: &mut CommonState, buf: &mut Vec<u8>) -> Option<KeyChange> {
402 while let Some((_, msg)) = this.quic.hs_queue.pop_front() {
403 buf.extend_from_slice(&msg);
404 if let Some(&(true, _)) = this.quic.hs_queue.front() {
405 if this.quic.hs_secrets.is_some() {
406 break;
408 }
409 }
410 }
411
412 if let Some(secrets) = this.quic.hs_secrets.take() {
413 return Some(KeyChange::Handshake {
414 keys: Keys::new(&secrets),
415 });
416 }
417
418 if let Some(mut secrets) = this.quic.traffic_secrets.take() {
419 if !this.quic.returned_traffic_keys {
420 this.quic.returned_traffic_keys = true;
421 let keys = Keys::new(&secrets);
422 secrets.update();
423 return Some(KeyChange::OneRtt {
424 keys,
425 next: secrets,
426 });
427 }
428 }
429
430 None
431}
432
433#[allow(clippy::large_enum_variant)]
447pub enum KeyChange {
448 Handshake {
450 keys: Keys,
452 },
453 OneRtt {
455 keys: Keys,
457 next: Secrets,
459 },
460}
461
462fn nonce_for(packet_number: u64, iv: &Iv) -> ring::aead::Nonce {
464 let mut out = [0; aead::NONCE_LEN];
465 out[4..].copy_from_slice(&packet_number.to_be_bytes());
466 for (out, inp) in out.iter_mut().zip(iv.0.iter()) {
467 *out ^= inp;
468 }
469 aead::Nonce::assume_unique_for_key(out)
470}
471
472#[non_exhaustive]
476#[derive(Clone, Copy, Debug)]
477pub enum Version {
478 V1Draft,
480 V1,
482}
483
484impl Version {
485 fn initial_salt(self) -> &'static [u8; 20] {
486 match self {
487 Self::V1Draft => &[
488 0xaf, 0xbf, 0xec, 0x28, 0x99, 0x93, 0xd2, 0x4c, 0x9e, 0x97, 0x86, 0xf1, 0x9c, 0x61,
490 0x11, 0xe0, 0x43, 0x90, 0xa8, 0x99,
491 ],
492 Self::V1 => &[
493 0x38, 0x76, 0x2c, 0xf7, 0xf5, 0x59, 0x34, 0xb3, 0x4d, 0x17, 0x9a, 0xe6, 0xa4, 0xc8,
495 0x0c, 0xad, 0xcc, 0xbb, 0x7f, 0x0a,
496 ],
497 }
498 }
499}
500
501#[cfg(test)]
502mod test {
503 use super::*;
504
505 #[test]
506 fn short_packet_header_protection() {
507 const PN: u64 = 654360564;
510 const SECRET: &[u8] = &[
511 0x9a, 0xc3, 0x12, 0xa7, 0xf8, 0x77, 0x46, 0x8e, 0xbe, 0x69, 0x42, 0x27, 0x48, 0xad,
512 0x00, 0xa1, 0x54, 0x43, 0xf1, 0x82, 0x03, 0xa0, 0x7d, 0x60, 0x60, 0xf6, 0x88, 0xf3,
513 0x0f, 0x21, 0x63, 0x2b,
514 ];
515
516 let secret = hkdf::Prk::new_less_safe(hkdf::HKDF_SHA256, SECRET);
517 use crate::tls13::TLS13_CHACHA20_POLY1305_SHA256_INTERNAL;
518 let hpk = HeaderProtectionKey::new(TLS13_CHACHA20_POLY1305_SHA256_INTERNAL, &secret);
519 let packet = PacketKey::new(TLS13_CHACHA20_POLY1305_SHA256_INTERNAL, &secret);
520
521 const PLAIN: &[u8] = &[0x42, 0x00, 0xbf, 0xf4, 0x01];
522
523 let mut buf = PLAIN.to_vec();
524 let (header, payload) = buf.split_at_mut(4);
525 let tag = packet
526 .encrypt_in_place(PN, &*header, payload)
527 .unwrap();
528 buf.extend(tag.as_ref());
529
530 let pn_offset = 1;
531 let (header, sample) = buf.split_at_mut(pn_offset + 4);
532 let (first, rest) = header.split_at_mut(1);
533 let sample = &sample[..hpk.sample_len()];
534 hpk.encrypt_in_place(sample, &mut first[0], dbg!(rest))
535 .unwrap();
536
537 const PROTECTED: &[u8] = &[
538 0x4c, 0xfe, 0x41, 0x89, 0x65, 0x5e, 0x5c, 0xd5, 0x5c, 0x41, 0xf6, 0x90, 0x80, 0x57,
539 0x5d, 0x79, 0x99, 0xc2, 0x5a, 0x5b, 0xfb,
540 ];
541
542 assert_eq!(&buf, PROTECTED);
543
544 let (header, sample) = buf.split_at_mut(pn_offset + 4);
545 let (first, rest) = header.split_at_mut(1);
546 let sample = &sample[..hpk.sample_len()];
547 hpk.decrypt_in_place(sample, &mut first[0], rest)
548 .unwrap();
549
550 let (header, payload_tag) = buf.split_at_mut(4);
551 let plain = packet
552 .decrypt_in_place(PN, &*header, payload_tag)
553 .unwrap();
554
555 assert_eq!(plain, &PLAIN[4..]);
556 }
557
558 #[test]
559 fn key_update_test_vector() {
560 fn equal_prk(x: &hkdf::Prk, y: &hkdf::Prk) -> bool {
561 let mut x_data = [0; 16];
562 let mut y_data = [0; 16];
563 let x_okm = x
564 .expand(&[b"info"], &aead::quic::AES_128)
565 .unwrap();
566 x_okm.fill(&mut x_data[..]).unwrap();
567 let y_okm = y
568 .expand(&[b"info"], &aead::quic::AES_128)
569 .unwrap();
570 y_okm.fill(&mut y_data[..]).unwrap();
571 x_data == y_data
572 }
573
574 let mut secrets = Secrets {
575 client: hkdf::Prk::new_less_safe(
577 hkdf::HKDF_SHA256,
578 &[
579 0xb8, 0x76, 0x77, 0x08, 0xf8, 0x77, 0x23, 0x58, 0xa6, 0xea, 0x9f, 0xc4, 0x3e,
580 0x4a, 0xdd, 0x2c, 0x96, 0x1b, 0x3f, 0x52, 0x87, 0xa6, 0xd1, 0x46, 0x7e, 0xe0,
581 0xae, 0xab, 0x33, 0x72, 0x4d, 0xbf,
582 ],
583 ),
584 server: hkdf::Prk::new_less_safe(
585 hkdf::HKDF_SHA256,
586 &[
587 0x42, 0xdc, 0x97, 0x21, 0x40, 0xe0, 0xf2, 0xe3, 0x98, 0x45, 0xb7, 0x67, 0x61,
588 0x34, 0x39, 0xdc, 0x67, 0x58, 0xca, 0x43, 0x25, 0x9b, 0x87, 0x85, 0x06, 0x82,
589 0x4e, 0xb1, 0xe4, 0x38, 0xd8, 0x55,
590 ],
591 ),
592 suite: TLS13_AES_128_GCM_SHA256_INTERNAL,
593 is_client: true,
594 };
595 secrets.update();
596
597 assert!(equal_prk(
598 &secrets.client,
599 &hkdf::Prk::new_less_safe(
600 hkdf::HKDF_SHA256,
601 &[
602 0x42, 0xca, 0xc8, 0xc9, 0x1c, 0xd5, 0xeb, 0x40, 0x68, 0x2e, 0x43, 0x2e, 0xdf,
603 0x2d, 0x2b, 0xe9, 0xf4, 0x1a, 0x52, 0xca, 0x6b, 0x22, 0xd8, 0xe6, 0xcd, 0xb1,
604 0xe8, 0xac, 0xa9, 0x6, 0x1f, 0xce
605 ]
606 )
607 ));
608 assert!(equal_prk(
609 &secrets.server,
610 &hkdf::Prk::new_less_safe(
611 hkdf::HKDF_SHA256,
612 &[
613 0xeb, 0x7f, 0x5e, 0x2a, 0x12, 0x3f, 0x40, 0x7d, 0xb4, 0x99, 0xe3, 0x61, 0xca,
614 0xe5, 0x90, 0xd4, 0xd9, 0x92, 0xe1, 0x4b, 0x7a, 0xce, 0x3, 0xc2, 0x44, 0xe0,
615 0x42, 0x21, 0x15, 0xb6, 0xd3, 0x8a
616 ]
617 )
618 ));
619 }
620}