s2n_quic_core/crypto/application/
keyset.rs

1// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4use crate::{
5    connection::ProcessingError,
6    crypto::{application::limited, OneRttKey, ProtectedPayload},
7    packet::{
8        encoding::PacketEncodingError,
9        number::PacketNumber,
10        short::{CleartextShort, EncryptedShort},
11        KeyPhase,
12    },
13    time::{timer, Timer, Timestamp},
14    transport,
15};
16use core::ops;
17use s2n_codec::EncoderBuffer;
18
19pub struct KeySet<K> {
20    /// The current [`KeyPhase`]
21    key_phase: KeyPhase,
22
23    key_derivation_timer: Timer,
24
25    //= https://www.rfc-editor.org/rfc/rfc9001#section-6.6
26    //# In addition to counting packets sent, endpoints MUST count the number
27    //# of received packets that fail authentication during the lifetime of a
28    //# connection.
29    packet_decryption_failures: u64,
30    aead_integrity_limit: u64,
31    /// The number of times the key has been rotated
32    generation: u16,
33
34    /// Set of keys for the current and next phase
35    crypto: KeyArray<K>,
36
37    limits: limited::Limits,
38}
39
40impl<K: OneRttKey> KeySet<K> {
41    pub fn new(crypto: K, limits: limited::Limits) -> Self {
42        //= https://www.rfc-editor.org/rfc/rfc9001#section-6
43        //# The Key Phase bit is initially set to 0 for the
44        //# first set of 1-RTT packets and toggled to signal each subsequent key
45        //# update.
46
47        //= https://www.rfc-editor.org/rfc/rfc9001#section-6.3
48        //# Endpoints responding to an apparent key update MUST NOT generate a
49        //# timing side-channel signal that might indicate that the Key Phase bit
50        //# was invalid (see Section 9.4).
51
52        //= https://www.rfc-editor.org/rfc/rfc9001#section-5.4
53        //# The same header protection key is used for the duration of the
54        //# connection, with the value not changing after a key update (see
55        //# Section 6).  This allows header protection to be used to protect the
56        //# key phase.
57        // By pre-generating the next key, we can respond to a KeyUpdate without exposing a timing
58        // side channel.
59        let aead_integrity_limit = crypto.aead_integrity_limit();
60        let next_key = limited::Key::new(crypto.derive_next_key());
61        let active_key = limited::Key::new(crypto);
62
63        Self {
64            key_phase: KeyPhase::Zero,
65            key_derivation_timer: Default::default(),
66            packet_decryption_failures: 0,
67            aead_integrity_limit,
68            generation: 0,
69            crypto: KeyArray([active_key, next_key]),
70            limits,
71        }
72    }
73
74    /// Rotating the phase will switch the active key
75    fn rotate_phase(&mut self) {
76        self.generation += 1;
77        self.key_phase = KeyPhase::next_phase(self.key_phase);
78    }
79
80    /// Derive a new key based on the active key, and store it in the non-active slot
81    fn derive_and_store_next_key(&mut self) {
82        //= https://www.rfc-editor.org/rfc/rfc9001#section-6.3
83        //# Once generated, the next set of packet protection keys SHOULD be
84        //# retained, even if the packet that was received was subsequently
85        //# discarded.
86
87        //= https://www.rfc-editor.org/rfc/rfc9001#section-6.5
88        //# After this period, old read keys and their corresponding secrets
89        //# SHOULD be discarded.
90
91        //= https://www.rfc-editor.org/rfc/rfc9001#section-6.5
92        //# These updated keys MAY replace the previous keys at that time.
93
94        let next_key = self.active_key().derive_next_key();
95        let next_phase = KeyPhase::next_phase(self.key_phase);
96        self.crypto[next_phase] = limited::Key::new(next_key);
97    }
98
99    /// Set the timer to derive a new key after timestamp
100    pub fn set_derivation_timer(&mut self, timestamp: Timestamp) {
101        self.key_derivation_timer.set(timestamp)
102    }
103
104    /// Returns whether there is a key update in progress.
105    pub fn key_update_in_progress(&self) -> bool {
106        self.key_derivation_timer.is_armed()
107    }
108
109    /// Passes the key for the the requested phase to a callback function. Integrity limits are
110    /// enforced.
111    ///
112    /// Returns the decrypted packet and generation if the key phase was rotated.
113    pub fn decrypt_packet<'a>(
114        &mut self,
115        packet: EncryptedShort<'a>,
116        largest_acknowledged_packet_number: PacketNumber,
117        pto: Timestamp,
118    ) -> Result<(CleartextShort<'a>, Option<u16>), ProcessingError> {
119        let mut phase_to_use = self.key_phase() as u8;
120        let packet_phase = packet.key_phase();
121        let phase_switch = phase_to_use != (packet_phase as u8);
122        phase_to_use ^= phase_switch as u8;
123
124        if self.key_update_in_progress() && phase_switch {
125            //= https://www.rfc-editor.org/rfc/rfc9001#section-6.5
126            //# An endpoint MAY allow a period of approximately the Probe Timeout
127            //# (PTO; see [QUIC-RECOVERY]) after promoting the next set of receive
128            //# keys to be current before it creates the subsequent set of packet
129            //# protection keys.
130
131            //= https://www.rfc-editor.org/rfc/rfc9001#section-6.4
132            //# Packets with higher packet numbers MUST be protected with either the
133            //# same or newer packet protection keys than packets with lower packet
134            //# numbers.
135            // During this PTO we can still process delayed packets, reducing retransmits
136            // required from the peer. We know the packets are delayed because they have a
137            // lower packet number than expected and the old key phase.
138            if packet.packet_number < largest_acknowledged_packet_number {
139                phase_to_use = packet.key_phase() as u8;
140            }
141        }
142
143        let key = &mut self.crypto[phase_to_use.into()];
144
145        let result = packet.decrypt(key.key_mut());
146
147        key.on_packet_decryption();
148
149        match result {
150            Ok(packet) => {
151                let generation = if packet_phase != self.key_phase() {
152                    //= https://www.rfc-editor.org/rfc/rfc9001#section-6.2
153                    //# Sending keys MUST be updated before sending an
154                    //# acknowledgement for the packet that was received with updated keys.
155
156                    //= https://www.rfc-editor.org/rfc/rfc9001#section-6.2
157                    //# The endpoint MUST update its
158                    //# send keys to the corresponding key phase in response, as described in
159                    //# Section 6.1.
160                    self.rotate_phase();
161
162                    //= https://www.rfc-editor.org/rfc/rfc9001#section-6.3
163                    //# Endpoints responding to an apparent key update MUST NOT generate a
164                    //# timing side-channel signal that might indicate that the Key Phase bit
165                    //# was invalid (see Section 9.4).
166
167                    //= https://www.rfc-editor.org/rfc/rfc9001#section-6.5
168                    //# An endpoint SHOULD retain old read keys for no more than three times
169                    //# the PTO after having received a packet protected using the new keys.
170
171                    //= https://www.rfc-editor.org/rfc/rfc9001#section-6.1
172                    //# An endpoint SHOULD
173                    //# retain old keys for some time after unprotecting a packet sent using
174                    //# the new keys.
175                    self.set_derivation_timer(pto);
176                    Some(self.generation)
177                } else {
178                    None
179                };
180
181                Ok((packet, generation))
182            }
183            Err(err) => {
184                //= https://www.rfc-editor.org/rfc/rfc9001#section-6.6
185                //# In addition to counting packets sent, endpoints MUST count the number
186                //# of received packets that fail authentication during the lifetime of a
187                //# connection.
188                self.packet_decryption_failures += 1;
189
190                //= https://www.rfc-editor.org/rfc/rfc9001#section-6.6
191                //# If a key update is not possible or
192                //# integrity limits are reached, the endpoint MUST stop using the
193                //# connection and only send stateless resets in response to receiving
194                //# packets.
195
196                //= https://www.rfc-editor.org/rfc/rfc9001#section-6.6
197                //# If the total number of received packets that fail
198                //# authentication within the connection, across all keys, exceeds the
199                //# integrity limit for the selected AEAD, the endpoint MUST immediately
200                //# close the connection with a connection error of type
201                //# AEAD_LIMIT_REACHED and not process any more packets.
202                if self.decryption_error_count() > self.aead_integrity_limit {
203                    return Err(transport::Error::AEAD_LIMIT_REACHED.into());
204                }
205
206                Err(err)
207            }
208        }
209    }
210
211    /// This is the KeyPhase that should be used to encrypt a given packet.
212    pub fn encryption_phase(&self) -> KeyPhase {
213        //= https://www.rfc-editor.org/rfc/rfc9001#section-6.6
214        //# Endpoints MUST initiate a key update
215        //# before sending more protected packets than the confidentiality limit
216        //# for the selected AEAD permits.
217        if self.active_key().needs_update(&self.limits) {
218            return KeyPhase::next_phase(self.key_phase());
219        }
220
221        self.key_phase()
222    }
223
224    pub fn encrypt_packet<'a, F>(
225        &mut self,
226        buffer: EncoderBuffer<'a>,
227        f: F,
228    ) -> Result<(ProtectedPayload<'a>, EncoderBuffer<'a>), PacketEncodingError<'a>>
229    where
230        F: FnOnce(
231            EncoderBuffer<'a>,
232            &mut K,
233            KeyPhase,
234        )
235            -> Result<(ProtectedPayload<'a>, EncoderBuffer<'a>), PacketEncodingError<'a>>,
236    {
237        let phase = self.encryption_phase();
238        if self.crypto[phase].expired() {
239            //= https://www.rfc-editor.org/rfc/rfc9001#section-6.6
240            //# If the total number of encrypted packets with the same key
241            //# exceeds the confidentiality limit for the selected AEAD, the endpoint
242            //# MUST stop using those keys.
243
244            //= https://www.rfc-editor.org/rfc/rfc9001#section-6.6
245            //# If a key update is not possible or
246            //# integrity limits are reached, the endpoint MUST stop using the
247            //# connection and only send stateless resets in response to receiving
248            //# packets.
249            return Err(PacketEncodingError::AeadLimitReached(buffer));
250        }
251
252        let r = f(buffer, self.crypto[phase].key_mut(), phase)?;
253
254        //= https://www.rfc-editor.org/rfc/rfc9001#section-6.6
255        //# Endpoints MUST count the number of encrypted packets for each set of
256        //# keys.
257        self.crypto[phase].on_packet_encryption();
258
259        Ok(r)
260    }
261
262    pub fn on_timeout(&mut self, timestamp: Timestamp) {
263        if self
264            .key_derivation_timer
265            .poll_expiration(timestamp)
266            .is_ready()
267        {
268            //= https://www.rfc-editor.org/rfc/rfc9001#section-6.5
269            //# An endpoint SHOULD retain old read keys for no more than three times
270            //# the PTO after having received a packet protected using the new keys.
271            self.derive_and_store_next_key();
272        }
273    }
274
275    pub fn key_phase(&self) -> KeyPhase {
276        self.key_phase
277    }
278
279    pub fn active_key(&self) -> &limited::Key<K> {
280        &self.crypto[self.key_phase]
281    }
282
283    pub fn active_key_mut(&mut self) -> &mut limited::Key<K> {
284        &mut self.crypto[self.key_phase]
285    }
286
287    fn decryption_error_count(&self) -> u64 {
288        self.packet_decryption_failures
289    }
290
291    pub fn cipher_suite(&mut self) -> crate::crypto::tls::CipherSuite {
292        self.crypto.0[0].key_mut().cipher_suite()
293    }
294}
295
296impl<K> timer::Provider for KeySet<K> {
297    #[inline]
298    fn timers<Q: timer::Query>(&self, query: &mut Q) -> timer::Result {
299        self.key_derivation_timer.timers(query)?;
300        Ok(())
301    }
302}
303
304struct KeyArray<K>([limited::Key<K>; 2]);
305
306impl<K> ops::Index<KeyPhase> for KeyArray<K> {
307    type Output = limited::Key<K>;
308
309    #[inline]
310    fn index(&self, key_phase: KeyPhase) -> &Self::Output {
311        &self.0[(key_phase as u8) as usize]
312    }
313}
314
315impl<K> ops::IndexMut<KeyPhase> for KeyArray<K> {
316    #[inline]
317    fn index_mut(&mut self, key_phase: KeyPhase) -> &mut Self::Output {
318        &mut self.0[(key_phase as u8) as usize]
319    }
320}
321
322#[cfg(test)]
323mod tests {
324    use super::*;
325    use crate::{
326        connection::id::ConnectionInfo,
327        crypto::{
328            testing::{HeaderKey as TestHeaderKey, Key as TestKey},
329            ProtectedPayload,
330        },
331        inet::SocketAddress,
332        packet::{
333            encoding::PacketEncodingError, number::PacketNumberSpace, short::ProtectedShort,
334            KeyPhase,
335        },
336        time::{testing::Clock, Clock as _},
337        varint::VarInt,
338    };
339    use core::time::Duration;
340    use s2n_codec::{DecoderBufferMut, EncoderBuffer};
341
342    #[test]
343    fn test_key_derivation_timer() {
344        let mut clock = Clock::default();
345        let now = clock.get_time();
346        let mut keyset = KeySet::new(TestKey::default(), Default::default());
347        keyset.rotate_phase();
348
349        keyset.set_derivation_timer(now + Duration::from_millis(10));
350
351        clock.inc_by(Duration::from_millis(8));
352        keyset.on_timeout(clock.get_time());
353        //= https://www.rfc-editor.org/rfc/rfc9001#section-6.1
354        //= type=test
355        //# An endpoint SHOULD
356        //# retain old keys for some time after unprotecting a packet sent using
357        //# the new keys.
358        assert_eq!(keyset.crypto[KeyPhase::Zero].key_mut().derivations, 0);
359
360        clock.inc_by(Duration::from_millis(8));
361        keyset.on_timeout(clock.get_time());
362
363        //= https://www.rfc-editor.org/rfc/rfc9001#section-6.5
364        //= type=test
365        //# After this period, old read keys and their corresponding secrets
366        //# SHOULD be discarded.
367        assert_eq!(keyset.crypto[KeyPhase::Zero].key_mut().derivations, 2);
368    }
369
370    #[test]
371    fn test_key_set() {
372        //= https://www.rfc-editor.org/rfc/rfc9001#section-6.3
373        //= type=test
374        //# For this reason, endpoints MUST be able to retain two sets of packet
375        //# protection keys for receiving packets: the current and the next.
376
377        let mut keyset = KeySet::new(TestKey::default(), Default::default());
378
379        assert_eq!(keyset.crypto[KeyPhase::Zero].key_mut().derivations, 0);
380        assert_eq!(keyset.crypto[KeyPhase::One].key_mut().derivations, 1);
381    }
382
383    #[test]
384    fn test_phase_rotation() {
385        let mut keyset = KeySet::new(TestKey::default(), Default::default());
386
387        assert_eq!(keyset.active_key_mut().key_mut().derivations, 0);
388        keyset.rotate_phase();
389        assert_eq!(keyset.active_key_mut().key_mut().derivations, 1);
390    }
391
392    #[test]
393    fn test_key_derivation() {
394        let mut keyset = KeySet::new(TestKey::default(), Default::default());
395
396        keyset.rotate_phase();
397        keyset.derive_and_store_next_key();
398        keyset.rotate_phase();
399        assert_eq!(keyset.active_key_mut().key_mut().derivations, 2);
400    }
401
402    //= https://www.rfc-editor.org/rfc/rfc9001#section-6.6
403    //= type=test
404    //# In addition to counting packets sent, endpoints MUST count the number
405    //# of received packets that fail authentication during the lifetime of a
406    //# connection.
407    #[test]
408    fn test_decryption_failure_counter() {
409        let clock = Clock::default();
410        let key = TestKey {
411            fail_on_decrypt: true,
412            ..Default::default()
413        };
414        let mut keyset = KeySet::new(key, Default::default());
415        let mut data = [0; 128];
416        let remote_address = SocketAddress::default();
417        let connection_info = ConnectionInfo::new(&remote_address);
418        let decoder_buffer = DecoderBufferMut::new(&mut data);
419
420        let (encoded_packet, _remaining) =
421            ProtectedShort::decode(0, decoder_buffer, &connection_info, &20).unwrap();
422
423        let encrypted_packet = encoded_packet
424            .unprotect(
425                &TestHeaderKey::default(),
426                PacketNumberSpace::ApplicationData.new_packet_number(VarInt::from_u8(0)),
427            )
428            .unwrap();
429
430        assert_eq!(keyset.decryption_error_count(), 0);
431        assert!(keyset
432            .decrypt_packet(
433                encrypted_packet,
434                PacketNumberSpace::ApplicationData.new_packet_number(VarInt::from_u8(0)),
435                clock.get_time(),
436            )
437            .is_err());
438        assert_eq!(keyset.decryption_error_count(), 1);
439    }
440
441    //= https://www.rfc-editor.org/rfc/rfc9001#section-6.6
442    //= type=test
443    //# If the total number of received packets that fail
444    //# authentication within the connection, across all keys, exceeds the
445    //# integrity limit for the selected AEAD, the endpoint MUST immediately
446    //# close the connection with a connection error of type
447    //# AEAD_LIMIT_REACHED and not process any more packets.
448    #[test]
449    fn test_decryption_failure_enforced_aead_limit() {
450        let clock = Clock::default();
451        let key = TestKey {
452            integrity_limit: 0,
453            fail_on_decrypt: true,
454            ..Default::default()
455        };
456        let mut keyset = KeySet::new(key, Default::default());
457        let mut data = [0; 128];
458        let remote_address = SocketAddress::default();
459        let connection_info = ConnectionInfo::new(&remote_address);
460        let decoder_buffer = DecoderBufferMut::new(&mut data);
461
462        let (encoded_packet, _remaining) =
463            ProtectedShort::decode(0, decoder_buffer, &connection_info, &20).unwrap();
464
465        let encrypted_packet = encoded_packet
466            .unprotect(
467                &TestHeaderKey::default(),
468                PacketNumberSpace::ApplicationData.new_packet_number(VarInt::from_u8(0)),
469            )
470            .unwrap();
471
472        assert_eq!(keyset.decryption_error_count(), 0);
473        assert_eq!(
474            keyset
475                .decrypt_packet(
476                    encrypted_packet,
477                    PacketNumberSpace::ApplicationData.new_packet_number(VarInt::from_u8(0)),
478                    clock.get_time(),
479                )
480                .err(),
481            Some(ProcessingError::ConnectionError(
482                (transport::Error::AEAD_LIMIT_REACHED).into()
483            ))
484        );
485    }
486
487    //= https://www.rfc-editor.org/rfc/rfc9001#section-6.6
488    //= type=test
489    //# Endpoints MUST count the number of encrypted packets for each set of
490    //# keys.
491    #[test]
492    fn test_encrypted_packet_count_increased() {
493        let key = TestKey::default();
494        let mut keyset = KeySet::new(key, Default::default());
495        let mut encoder_bytes = [0; 512];
496        let buffer = EncoderBuffer::new(&mut encoder_bytes);
497        let mut decoder_bytes = [0; 512];
498
499        assert_eq!(keyset.active_key().encrypted_packets(), 0);
500        assert!(keyset
501            .encrypt_packet(buffer, |buffer, _key, _phase| {
502                let payload = ProtectedPayload::new(0, &mut decoder_bytes);
503
504                Ok((payload, buffer))
505            })
506            .is_ok());
507
508        assert_eq!(keyset.active_key().encrypted_packets(), 1);
509    }
510
511    #[test]
512    fn test_encrypted_packet_key_update_window() {
513        let key = TestKey {
514            confidentiality_limit: 10000,
515            ..Default::default()
516        };
517        let mut keyset = KeySet::new(key, Default::default());
518        let mut encoder_bytes = [0; 512];
519        let buffer = EncoderBuffer::new(&mut encoder_bytes);
520        let mut decoder_bytes = [0; 512];
521
522        // The first encryption should use the expected keyphase, and put us into the
523        // KEY_UPDATE_WINDOW.
524        assert_eq!(keyset.active_key().encrypted_packets(), 0);
525        assert!(!keyset.active_key().needs_update(&keyset.limits));
526        assert!(keyset
527            .encrypt_packet(buffer, |buffer, _key, _phase| {
528                let payload = ProtectedPayload::new(0, &mut decoder_bytes);
529
530                Ok((payload, buffer))
531            })
532            .is_ok());
533
534        //= https://www.rfc-editor.org/rfc/rfc9001#section-6.6
535        //= type=test
536        //# Endpoints MUST initiate a key update
537        //# before sending more protected packets than the confidentiality limit
538        //# for the selected AEAD permits.
539
540        // Subsequent encryptions should be in the next phase and our key should need an update.
541        assert_eq!(keyset.encryption_phase(), KeyPhase::One);
542        assert!(keyset.active_key().needs_update(&keyset.limits));
543    }
544
545    //= https://www.rfc-editor.org/rfc/rfc9001#section-6.6
546    //= type=test
547    //# If the total number of encrypted packets with the same key
548    //# exceeds the confidentiality limit for the selected AEAD, the endpoint
549    //# MUST stop using those keys.
550    #[test]
551    fn test_encrypted_packet_aead_limit() {
552        let limit = 10_000;
553        let key = TestKey {
554            confidentiality_limit: limit,
555            ..Default::default()
556        };
557        let mut keyset = KeySet::new(key, Default::default());
558        let mut encoder_bytes = [0; 512];
559
560        // The KeySet chooses the appropriate key phase. Trying to encrypt one more than the limit
561        // will attempt a key update after the first encryption, and fill the update window of the
562        // next key (because the key update never completes).
563        for _ in 0..limit + 1 {
564            let buffer = EncoderBuffer::new(&mut encoder_bytes);
565            let mut decoder_bytes = [0; 512];
566            assert!(keyset
567                .encrypt_packet(buffer, |buffer, _key, _phase| {
568                    let payload = ProtectedPayload::new(0, &mut decoder_bytes);
569
570                    Ok((payload, buffer))
571                })
572                .is_ok());
573
574            // As long as the keyphase is constant, we have not initiated any KeyUpdate, and we
575            // have not derived any new keys.
576            assert_eq!(keyset.key_phase(), KeyPhase::Zero);
577        }
578
579        // The key in KeyPhase::Zero will have encrypted a single packet.
580        // Each additional request will be within the KEY_UPDATE_WINDOW, so the next key phase is
581        // used.
582        assert_eq!(keyset.crypto[KeyPhase::Zero].encrypted_packets(), 1);
583
584        // The next key phase should have limit encryptions
585        assert_eq!(keyset.crypto[KeyPhase::One].encrypted_packets(), limit);
586
587        // The final encryption should push us over the AEAD limit and we should fail.
588        let buffer = EncoderBuffer::new(&mut encoder_bytes);
589        let mut decoder_bytes = [0; 512];
590        assert!(matches!(
591            keyset.encrypt_packet(buffer, |buffer, _key, _phase| {
592                let payload = ProtectedPayload::new(0, &mut decoder_bytes);
593
594                Ok((payload, buffer))
595            }),
596            Err(PacketEncodingError::AeadLimitReached(_))
597        ));
598    }
599}