s2n_quic_core/crypto/
tls.rs

1// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4#[cfg(feature = "alloc")]
5use alloc::vec::Vec;
6#[cfg(feature = "alloc")]
7pub use bytes::{Bytes, BytesMut};
8use core::{any::Any, fmt::Debug};
9use zerocopy::{FromBytes, IntoBytes, Unaligned};
10
11mod error;
12pub use error::Error;
13
14#[cfg(any(test, feature = "testing"))]
15pub mod testing;
16
17#[cfg(all(feature = "alloc", any(test, feature = "testing")))]
18pub mod null;
19
20#[cfg(feature = "alloc")]
21pub mod slow_tls;
22
23#[cfg(feature = "std")]
24pub mod offload;
25
26/// Holds all application parameters which are exchanged within the TLS handshake.
27#[derive(Debug)]
28pub struct ApplicationParameters<'a> {
29    /// Encoded transport parameters
30    pub transport_parameters: &'a [u8],
31}
32
33/// Holds the named group used for key exchange in the TLS handshake.
34///
35/// `contains_kem` is `true` if the named group contains a key encapsulation mechanism.
36#[derive(Debug, Eq)]
37pub struct NamedGroup {
38    pub group_name: &'static str,
39    pub contains_kem: bool,
40}
41
42// Some TLS implementations do not follow the capitalization in the
43// IANA specification so we ignore capitalization of the group name
44// when comparing `NamedGroup`s
45impl PartialEq for NamedGroup {
46    fn eq(&self, other: &Self) -> bool {
47        self.group_name.eq_ignore_ascii_case(other.group_name)
48            && self.contains_kem == other.contains_kem
49    }
50}
51
52#[derive(Debug)]
53#[non_exhaustive]
54pub enum TlsExportError {
55    #[non_exhaustive]
56    Failure,
57}
58
59impl TlsExportError {
60    pub fn failure() -> Self {
61        TlsExportError::Failure
62    }
63}
64
65#[derive(Debug)]
66#[non_exhaustive]
67pub enum ChainError {
68    #[non_exhaustive]
69    Failure,
70}
71
72impl ChainError {
73    pub fn failure() -> Self {
74        ChainError::Failure
75    }
76}
77
78pub trait TlsSession: Send {
79    /// See <https://datatracker.ietf.org/doc/html/rfc5705> and <https://www.rfc-editor.org/rfc/rfc8446>.
80    fn tls_exporter(
81        &self,
82        label: &[u8],
83        context: &[u8],
84        output: &mut [u8],
85    ) -> Result<(), TlsExportError>;
86
87    fn cipher_suite(&self) -> CipherSuite;
88
89    #[cfg(feature = "alloc")]
90    fn peer_cert_chain_der(&self) -> Result<Vec<Vec<u8>>, ChainError>;
91}
92
93#[cfg(feature = "alloc")]
94pub trait Context<Crypto: crate::crypto::CryptoSuite> {
95    /// Called when the client's application parameters are available, prior
96    /// to completion of the handshake.
97    ///
98    /// The `server_params` is provided as a mutable `Vec<u8>` of encoded
99    /// server transport parameters to allow for additional parameters
100    /// dependent on the `client_params` to be appended before transmitting
101    /// them to the client.
102    ///
103    /// The value of transport parameters is not authenticated until
104    /// the handshake completes, so any use of these parameters cannot
105    /// depend on their authenticity.
106    ///
107    /// NOTE: This function is not currently supported
108    ///       for the `s2n-quic-rustls` provider
109    fn on_client_application_params(
110        &mut self,
111        client_params: ApplicationParameters,
112        server_params: &mut alloc::vec::Vec<u8>,
113    ) -> Result<(), crate::transport::Error>;
114
115    fn on_handshake_keys(
116        &mut self,
117        key: Crypto::HandshakeKey,
118        header_key: Crypto::HandshakeHeaderKey,
119    ) -> Result<(), crate::transport::Error>;
120
121    fn on_zero_rtt_keys(
122        &mut self,
123        key: Crypto::ZeroRttKey,
124        header_key: Crypto::ZeroRttHeaderKey,
125        application_parameters: ApplicationParameters,
126    ) -> Result<(), crate::transport::Error>;
127
128    fn on_one_rtt_keys(
129        &mut self,
130        key: Crypto::OneRttKey,
131        header_key: Crypto::OneRttHeaderKey,
132        application_parameters: ApplicationParameters,
133    ) -> Result<(), crate::transport::Error>;
134
135    fn on_server_name(
136        &mut self,
137        server_name: crate::application::ServerName,
138    ) -> Result<(), crate::transport::Error>;
139
140    fn on_application_protocol(
141        &mut self,
142        application_protocol: Bytes,
143    ) -> Result<(), crate::transport::Error>;
144
145    fn on_key_exchange_group(
146        &mut self,
147        named_group: NamedGroup,
148    ) -> Result<(), crate::transport::Error>;
149
150    //= https://www.rfc-editor.org/rfc/rfc9001#section-4.1.1
151    //# The TLS handshake is considered complete when the
152    //# TLS stack has reported that the handshake is complete.  This happens
153    //# when the TLS stack has both sent a Finished message and verified the
154    //# peer's Finished message.
155    fn on_handshake_complete(&mut self) -> Result<(), crate::transport::Error>;
156
157    /// Set TLS context and transfer from TLS provider to application layer.
158    #[cfg(feature = "alloc")]
159    fn on_tls_context(&mut self, _context: alloc::boxed::Box<dyn Any + Send>);
160
161    fn on_tls_exporter_ready(
162        &mut self,
163        session: &impl TlsSession,
164    ) -> Result<(), crate::transport::Error>;
165
166    fn on_tls_handshake_failed(
167        &mut self,
168        session: &impl TlsSession,
169    ) -> Result<(), crate::transport::Error>;
170
171    /// Receives data from the initial packet space
172    ///
173    /// A `max_len` may be provided to indicate how many bytes the TLS implementation
174    /// is willing to buffer.
175    fn receive_initial(&mut self, max_len: Option<usize>) -> Option<Bytes>;
176
177    /// Receives data from the handshake packet space
178    ///
179    /// A `max_len` may be provided to indicate how many bytes the TLS implementation
180    /// is willing to buffer.
181    fn receive_handshake(&mut self, max_len: Option<usize>) -> Option<Bytes>;
182
183    /// Receives data from the application packet space
184    ///
185    /// A `max_len` may be provided to indicate how many bytes the TLS implementation
186    /// is willing to buffer.
187    fn receive_application(&mut self, max_len: Option<usize>) -> Option<Bytes>;
188
189    fn can_send_initial(&self) -> bool;
190    fn send_initial(&mut self, transmission: Bytes);
191
192    fn can_send_handshake(&self) -> bool;
193    fn send_handshake(&mut self, transmission: Bytes);
194
195    fn can_send_application(&self) -> bool;
196    fn send_application(&mut self, transmission: Bytes);
197
198    fn waker(&self) -> &core::task::Waker;
199}
200
201#[cfg(feature = "alloc")]
202pub trait Endpoint: 'static + Sized + Send {
203    type Session: Session;
204
205    fn new_server_session<Params: s2n_codec::EncoderValue>(
206        &mut self,
207        transport_parameters: &Params,
208    ) -> Self::Session;
209
210    fn new_client_session<Params: s2n_codec::EncoderValue>(
211        &mut self,
212        transport_parameters: &Params,
213        server_name: crate::application::ServerName,
214    ) -> Self::Session;
215
216    /// The maximum length of a tag for any algorithm that may be negotiated
217    fn max_tag_length(&self) -> usize;
218}
219
220#[cfg(feature = "alloc")]
221pub trait Session: crate::crypto::CryptoSuite + Sized + Send + Debug {
222    fn poll<C: Context<Self>>(
223        &mut self,
224        context: &mut C,
225    ) -> core::task::Poll<Result<(), crate::transport::Error>>;
226
227    fn process_post_handshake_message<C: Context<Self>>(
228        &mut self,
229        _context: &mut C,
230    ) -> Result<(), crate::transport::Error> {
231        Ok(())
232    }
233
234    fn should_discard_session(&self) -> bool {
235        true
236    }
237
238    /// Parses a hello message of the provided type
239    ///
240    /// The default implementation of this function assumes TLS messages are being exchanged.
241    #[inline]
242    fn parse_hello(
243        msg_type: HandshakeType,
244        header_chunk: &[u8],
245        total_received_len: u64,
246        max_hello_size: u64,
247    ) -> Result<Option<HelloOffsets>, crate::transport::Error> {
248        let buffer = s2n_codec::DecoderBuffer::new(header_chunk);
249
250        let header = if let Ok((header, _)) = buffer.decode::<HandshakeHeader>() {
251            header
252        } else {
253            // we don't have enough data to parse the header so wait until later
254            return Ok(None);
255        };
256
257        if header.msg_type() != Some(msg_type) {
258            return Err(crate::transport::Error::PROTOCOL_VIOLATION
259                .with_reason("first TLS message should be a hello message"));
260        }
261
262        let payload_len = header.len() as u64;
263
264        if payload_len > max_hello_size {
265            return Err(crate::transport::Error::CRYPTO_BUFFER_EXCEEDED
266                .with_reason("hello message cannot exceed 16k"));
267        }
268
269        let header_len = core::mem::size_of::<HandshakeHeader>() as u64;
270
271        // wait until we have more chunks
272        if total_received_len < payload_len + header_len {
273            return Ok(None);
274        }
275
276        let offsets = HelloOffsets {
277            payload_offset: header_len as _,
278            payload_len: payload_len as _,
279        };
280
281        Ok(Some(offsets))
282    }
283}
284
285#[derive(Copy, Clone, Debug)]
286pub struct HelloOffsets {
287    pub payload_offset: usize,
288    pub payload_len: usize,
289}
290
291impl HelloOffsets {
292    #[inline]
293    pub fn trim_chunks<'a, I: Iterator<Item = &'a [u8]>>(
294        &self,
295        chunks: I,
296    ) -> impl Iterator<Item = &'a [u8]> {
297        let mut offsets = *self;
298
299        chunks.filter_map(move |mut chunk| {
300            // trim off the header
301            if offsets.payload_offset > 0 {
302                let start = offsets.payload_offset.min(chunk.len());
303                chunk = &chunk[start..];
304                offsets.payload_offset -= start;
305            }
306
307            // trim off any trailing data after we've trimmed the header
308            if offsets.payload_offset == 0 && offsets.payload_len > 0 {
309                let end = offsets.payload_len.min(chunk.len());
310                chunk = &chunk[..end];
311                offsets.payload_len -= end;
312            } else {
313                // if the payload doesn't have any remaining data, return an empty chunk
314                return None;
315            }
316
317            if chunk.is_empty() {
318                None
319            } else {
320                Some(chunk)
321            }
322        })
323    }
324}
325
326#[derive(Copy, Clone, Debug, Default)]
327#[allow(non_camel_case_types)]
328pub enum CipherSuite {
329    TLS_AES_128_GCM_SHA256,
330    TLS_AES_256_GCM_SHA384,
331    TLS_CHACHA20_POLY1305_SHA256,
332    #[default]
333    Unknown,
334}
335
336impl crate::event::IntoEvent<crate::event::builder::CipherSuite> for CipherSuite {
337    #[inline]
338    fn into_event(self) -> crate::event::builder::CipherSuite {
339        use crate::event::builder::CipherSuite::*;
340        match self {
341            Self::TLS_AES_128_GCM_SHA256 => TLS_AES_128_GCM_SHA256 {},
342            Self::TLS_AES_256_GCM_SHA384 => TLS_AES_256_GCM_SHA384 {},
343            Self::TLS_CHACHA20_POLY1305_SHA256 => TLS_CHACHA20_POLY1305_SHA256 {},
344            Self::Unknown => Unknown {},
345        }
346    }
347}
348
349impl crate::event::IntoEvent<crate::event::api::CipherSuite> for CipherSuite {
350    #[inline]
351    fn into_event(self) -> crate::event::api::CipherSuite {
352        let builder: crate::event::builder::CipherSuite = self.into_event();
353        builder.into_event()
354    }
355}
356
357macro_rules! handshake_type {
358    ($($variant:ident($value:literal)),* $(,)?) => {
359        #[derive(Clone, Copy, Debug, PartialEq, Eq, IntoBytes, Unaligned)]
360        #[cfg_attr(any(test, feature = "bolero-generator"), derive(bolero_generator::TypeGenerator))]
361        #[repr(u8)]
362        pub enum HandshakeType {
363            $($variant = $value),*
364        }
365
366        impl TryFrom<u8> for HandshakeType {
367            type Error = ();
368
369            #[inline]
370            fn try_from(value: u8) -> Result<Self, Self::Error> {
371                match value {
372                    $($value => Ok(Self::$variant),)*
373                    _ => Err(()),
374                }
375            }
376        }
377    };
378}
379
380//= https://www.rfc-editor.org/rfc/rfc5246#A.4
381//# enum {
382//#     hello_request(0), client_hello(1), server_hello(2),
383//#     certificate(11), server_key_exchange (12),
384//#     certificate_request(13), server_hello_done(14),
385//#     certificate_verify(15), client_key_exchange(16),
386//#     finished(20)
387//#     (255)
388//# } HandshakeType;
389handshake_type!(
390    HelloRequest(0),
391    ClientHello(1),
392    ServerHello(2),
393    Certificate(11),
394    ServerKeyExchange(12),
395    CertificateRequest(13),
396    ServerHelloDone(14),
397    CertificateVerify(15),
398    ClientKeyExchange(16),
399    Finished(20),
400);
401
402//= https://www.rfc-editor.org/rfc/rfc5246#A.4
403//# struct {
404//#     HandshakeType msg_type;
405//#     uint24 length;
406//#     select (HandshakeType) {
407//#         case hello_request:       HelloRequest;
408//#         case client_hello:        ClientHello;
409//#         case server_hello:        ServerHello;
410//#         case certificate:         Certificate;
411//#         case server_key_exchange: ServerKeyExchange;
412//#         case certificate_request: CertificateRequest;
413//#         case server_hello_done:   ServerHelloDone;
414//#         case certificate_verify:  CertificateVerify;
415//#         case client_key_exchange: ClientKeyExchange;
416//#         case finished:            Finished;
417//#   } body;
418//# } Handshake;
419#[derive(Clone, Copy, Debug, IntoBytes, FromBytes, Unaligned)]
420#[repr(C)]
421pub struct HandshakeHeader {
422    msg_type: u8,
423    length: [u8; 3],
424}
425
426impl HandshakeHeader {
427    #[inline]
428    pub fn msg_type(self) -> Option<HandshakeType> {
429        HandshakeType::try_from(self.msg_type).ok()
430    }
431
432    #[inline]
433    pub fn len(self) -> usize {
434        let mut len = [0u8; 4];
435        len[1..].copy_from_slice(&self.length);
436        let len = u32::from_be_bytes(len);
437        len as _
438    }
439
440    #[inline]
441    pub fn is_empty(self) -> bool {
442        self.len() == 0
443    }
444}
445
446s2n_codec::zerocopy_value_codec!(HandshakeHeader);
447
448#[cfg(test)]
449mod tests {
450    use super::*;
451    use bolero::check;
452    use hex_literal::hex;
453
454    const MAX_HELLO_SIZE: u64 = if cfg!(kani) { 32 } else { 255 };
455
456    type Chunk = crate::testing::InlineVec<u8, { MAX_HELLO_SIZE as usize + 2 }>;
457
458    /// make sure the hello parser doesn't panic on arbitrary inputs
459    #[test]
460    #[cfg_attr(kani, kani::proof, kani::solver(cadical), kani::unwind(36))]
461    fn parse_hello_test() {
462        check!()
463            .with_type::<(HandshakeType, Chunk, u64)>()
464            .for_each(|(ty, chunk, total_received_len)| {
465                let _ =
466                    testing::Session::parse_hello(*ty, chunk, *total_received_len, MAX_HELLO_SIZE);
467            });
468    }
469
470    macro_rules! h {
471        ($($tt:tt)*) => {
472            &hex!($($tt)*)[..]
473        }
474    }
475
476    fn parse_hello<'a>(
477        ty: HandshakeType,
478        input: &'a [&'a [u8]],
479    ) -> Result<Option<Vec<&'a [u8]>>, crate::transport::Error> {
480        let total_received_len: usize = input.iter().map(|chunk| chunk.len()).sum();
481
482        let empty = &[][..];
483        let first = input.iter().copied().next().unwrap_or(empty);
484
485        let outcome =
486            testing::Session::parse_hello(ty, first, total_received_len as _, MAX_HELLO_SIZE)?;
487
488        if let Some(offsets) = outcome {
489            let payload = offsets.trim_chunks(input.iter().copied()).collect();
490            Ok(Some(payload))
491        } else {
492            Ok(None)
493        }
494    }
495
496    #[test]
497    fn client_hello_valid_tests() {
498        let tests = [
499            (&[h!("01 00 00 02 aa bb cc")][..], &[h!("aa bb")][..]),
500            (&[h!("01 00 00 01"), h!("aa bb cc dd")], &[h!("aa")]),
501            (
502                &[h!("01 00 00 02"), h!("aa"), h!("bb"), h!("cc")],
503                &[h!("aa"), h!("bb")],
504            ),
505        ];
506
507        for (input, expected) in tests {
508            let output = parse_hello(HandshakeType::ClientHello, input)
509                .unwrap()
510                .unwrap();
511
512            assert_eq!(&output[..], expected);
513        }
514    }
515
516    #[test]
517    fn server_hello_valid_tests() {
518        let tests = [(&[h!("02 00 00 02 aa bb cc")][..], &[h!("aa bb")][..])];
519
520        for (input, expected) in tests {
521            let output = parse_hello(HandshakeType::ServerHello, input)
522                .unwrap()
523                .unwrap();
524
525            assert_eq!(&output[..], expected);
526        }
527    }
528
529    #[test]
530    fn client_hello_incomplete_tests() {
531        let tests = [
532            &[][..],
533            // missing header
534            &[h!("01 00 00")],
535            // missing entire payload
536            &[h!("01 00 00 01")],
537            // missing partial payload
538            &[h!("01 00 00 04"), h!("aa"), h!("bb")],
539        ];
540
541        for input in tests {
542            assert_eq!(
543                parse_hello(HandshakeType::ClientHello, input).unwrap(),
544                None
545            );
546        }
547    }
548
549    #[test]
550    fn client_hello_invalid_tests() {
551        let tests = [
552            // invalid message
553            &[h!("02 00 00 01 aa")],
554            // invalid size - too big
555            &[h!("01 00 01 00 aa")],
556            // invalid size - too big
557            &[h!("01 ff ff ff aa")],
558        ];
559
560        for input in tests {
561            assert!(parse_hello(HandshakeType::ClientHello, input).is_err());
562        }
563    }
564}