1#[cfg(feature = "alloc")]
5use alloc::vec::Vec;
6#[cfg(feature = "alloc")]
7pub use bytes::{Bytes, BytesMut};
8use core::{any::Any, fmt::Debug};
9use zerocopy::{FromBytes, IntoBytes, Unaligned};
10
11mod error;
12pub use error::Error;
13
14#[cfg(any(test, feature = "testing"))]
15pub mod testing;
16
17#[cfg(all(feature = "alloc", any(test, feature = "testing")))]
18pub mod null;
19
20#[cfg(feature = "alloc")]
21pub mod slow_tls;
22
23#[cfg(feature = "std")]
24pub mod offload;
25
26#[derive(Debug)]
28pub struct ApplicationParameters<'a> {
29 pub transport_parameters: &'a [u8],
31}
32
33#[derive(Debug, Eq)]
37pub struct NamedGroup {
38 pub group_name: &'static str,
39 pub contains_kem: bool,
40}
41
42impl PartialEq for NamedGroup {
46 fn eq(&self, other: &Self) -> bool {
47 self.group_name.eq_ignore_ascii_case(other.group_name)
48 && self.contains_kem == other.contains_kem
49 }
50}
51
52#[derive(Debug)]
53#[non_exhaustive]
54pub enum TlsExportError {
55 #[non_exhaustive]
56 Failure,
57}
58
59impl TlsExportError {
60 pub fn failure() -> Self {
61 TlsExportError::Failure
62 }
63}
64
65#[derive(Debug)]
66#[non_exhaustive]
67pub enum ChainError {
68 #[non_exhaustive]
69 Failure,
70}
71
72impl ChainError {
73 pub fn failure() -> Self {
74 ChainError::Failure
75 }
76}
77
78pub trait TlsSession: Send {
79 fn tls_exporter(
81 &self,
82 label: &[u8],
83 context: &[u8],
84 output: &mut [u8],
85 ) -> Result<(), TlsExportError>;
86
87 fn cipher_suite(&self) -> CipherSuite;
88
89 #[cfg(feature = "alloc")]
90 fn peer_cert_chain_der(&self) -> Result<Vec<Vec<u8>>, ChainError>;
91}
92
93#[cfg(feature = "alloc")]
94pub trait Context<Crypto: crate::crypto::CryptoSuite> {
95 fn on_client_application_params(
110 &mut self,
111 client_params: ApplicationParameters,
112 server_params: &mut alloc::vec::Vec<u8>,
113 ) -> Result<(), crate::transport::Error>;
114
115 fn on_handshake_keys(
116 &mut self,
117 key: Crypto::HandshakeKey,
118 header_key: Crypto::HandshakeHeaderKey,
119 ) -> Result<(), crate::transport::Error>;
120
121 fn on_zero_rtt_keys(
122 &mut self,
123 key: Crypto::ZeroRttKey,
124 header_key: Crypto::ZeroRttHeaderKey,
125 application_parameters: ApplicationParameters,
126 ) -> Result<(), crate::transport::Error>;
127
128 fn on_one_rtt_keys(
129 &mut self,
130 key: Crypto::OneRttKey,
131 header_key: Crypto::OneRttHeaderKey,
132 application_parameters: ApplicationParameters,
133 ) -> Result<(), crate::transport::Error>;
134
135 fn on_server_name(
136 &mut self,
137 server_name: crate::application::ServerName,
138 ) -> Result<(), crate::transport::Error>;
139
140 fn on_application_protocol(
141 &mut self,
142 application_protocol: Bytes,
143 ) -> Result<(), crate::transport::Error>;
144
145 fn on_key_exchange_group(
146 &mut self,
147 named_group: NamedGroup,
148 ) -> Result<(), crate::transport::Error>;
149
150 fn on_handshake_complete(&mut self) -> Result<(), crate::transport::Error>;
156
157 #[cfg(feature = "alloc")]
159 fn on_tls_context(&mut self, _context: alloc::boxed::Box<dyn Any + Send>);
160
161 fn on_tls_exporter_ready(
162 &mut self,
163 session: &impl TlsSession,
164 ) -> Result<(), crate::transport::Error>;
165
166 fn on_tls_handshake_failed(
167 &mut self,
168 session: &impl TlsSession,
169 ) -> Result<(), crate::transport::Error>;
170
171 fn receive_initial(&mut self, max_len: Option<usize>) -> Option<Bytes>;
176
177 fn receive_handshake(&mut self, max_len: Option<usize>) -> Option<Bytes>;
182
183 fn receive_application(&mut self, max_len: Option<usize>) -> Option<Bytes>;
188
189 fn can_send_initial(&self) -> bool;
190 fn send_initial(&mut self, transmission: Bytes);
191
192 fn can_send_handshake(&self) -> bool;
193 fn send_handshake(&mut self, transmission: Bytes);
194
195 fn can_send_application(&self) -> bool;
196 fn send_application(&mut self, transmission: Bytes);
197
198 fn waker(&self) -> &core::task::Waker;
199}
200
201#[cfg(feature = "alloc")]
202pub trait Endpoint: 'static + Sized + Send {
203 type Session: Session;
204
205 fn new_server_session<Params: s2n_codec::EncoderValue>(
206 &mut self,
207 transport_parameters: &Params,
208 ) -> Self::Session;
209
210 fn new_client_session<Params: s2n_codec::EncoderValue>(
211 &mut self,
212 transport_parameters: &Params,
213 server_name: crate::application::ServerName,
214 ) -> Self::Session;
215
216 fn max_tag_length(&self) -> usize;
218}
219
220#[cfg(feature = "alloc")]
221pub trait Session: crate::crypto::CryptoSuite + Sized + Send + Debug {
222 fn poll<C: Context<Self>>(
223 &mut self,
224 context: &mut C,
225 ) -> core::task::Poll<Result<(), crate::transport::Error>>;
226
227 fn process_post_handshake_message<C: Context<Self>>(
228 &mut self,
229 _context: &mut C,
230 ) -> Result<(), crate::transport::Error> {
231 Ok(())
232 }
233
234 fn should_discard_session(&self) -> bool {
235 true
236 }
237
238 #[inline]
242 fn parse_hello(
243 msg_type: HandshakeType,
244 header_chunk: &[u8],
245 total_received_len: u64,
246 max_hello_size: u64,
247 ) -> Result<Option<HelloOffsets>, crate::transport::Error> {
248 let buffer = s2n_codec::DecoderBuffer::new(header_chunk);
249
250 let header = if let Ok((header, _)) = buffer.decode::<HandshakeHeader>() {
251 header
252 } else {
253 return Ok(None);
255 };
256
257 if header.msg_type() != Some(msg_type) {
258 return Err(crate::transport::Error::PROTOCOL_VIOLATION
259 .with_reason("first TLS message should be a hello message"));
260 }
261
262 let payload_len = header.len() as u64;
263
264 if payload_len > max_hello_size {
265 return Err(crate::transport::Error::CRYPTO_BUFFER_EXCEEDED
266 .with_reason("hello message cannot exceed 16k"));
267 }
268
269 let header_len = core::mem::size_of::<HandshakeHeader>() as u64;
270
271 if total_received_len < payload_len + header_len {
273 return Ok(None);
274 }
275
276 let offsets = HelloOffsets {
277 payload_offset: header_len as _,
278 payload_len: payload_len as _,
279 };
280
281 Ok(Some(offsets))
282 }
283}
284
285#[derive(Copy, Clone, Debug)]
286pub struct HelloOffsets {
287 pub payload_offset: usize,
288 pub payload_len: usize,
289}
290
291impl HelloOffsets {
292 #[inline]
293 pub fn trim_chunks<'a, I: Iterator<Item = &'a [u8]>>(
294 &self,
295 chunks: I,
296 ) -> impl Iterator<Item = &'a [u8]> {
297 let mut offsets = *self;
298
299 chunks.filter_map(move |mut chunk| {
300 if offsets.payload_offset > 0 {
302 let start = offsets.payload_offset.min(chunk.len());
303 chunk = &chunk[start..];
304 offsets.payload_offset -= start;
305 }
306
307 if offsets.payload_offset == 0 && offsets.payload_len > 0 {
309 let end = offsets.payload_len.min(chunk.len());
310 chunk = &chunk[..end];
311 offsets.payload_len -= end;
312 } else {
313 return None;
315 }
316
317 if chunk.is_empty() {
318 None
319 } else {
320 Some(chunk)
321 }
322 })
323 }
324}
325
326#[derive(Copy, Clone, Debug, Default)]
327#[allow(non_camel_case_types)]
328pub enum CipherSuite {
329 TLS_AES_128_GCM_SHA256,
330 TLS_AES_256_GCM_SHA384,
331 TLS_CHACHA20_POLY1305_SHA256,
332 #[default]
333 Unknown,
334}
335
336impl crate::event::IntoEvent<crate::event::builder::CipherSuite> for CipherSuite {
337 #[inline]
338 fn into_event(self) -> crate::event::builder::CipherSuite {
339 use crate::event::builder::CipherSuite::*;
340 match self {
341 Self::TLS_AES_128_GCM_SHA256 => TLS_AES_128_GCM_SHA256 {},
342 Self::TLS_AES_256_GCM_SHA384 => TLS_AES_256_GCM_SHA384 {},
343 Self::TLS_CHACHA20_POLY1305_SHA256 => TLS_CHACHA20_POLY1305_SHA256 {},
344 Self::Unknown => Unknown {},
345 }
346 }
347}
348
349impl crate::event::IntoEvent<crate::event::api::CipherSuite> for CipherSuite {
350 #[inline]
351 fn into_event(self) -> crate::event::api::CipherSuite {
352 let builder: crate::event::builder::CipherSuite = self.into_event();
353 builder.into_event()
354 }
355}
356
357macro_rules! handshake_type {
358 ($($variant:ident($value:literal)),* $(,)?) => {
359 #[derive(Clone, Copy, Debug, PartialEq, Eq, IntoBytes, Unaligned)]
360 #[cfg_attr(any(test, feature = "bolero-generator"), derive(bolero_generator::TypeGenerator))]
361 #[repr(u8)]
362 pub enum HandshakeType {
363 $($variant = $value),*
364 }
365
366 impl TryFrom<u8> for HandshakeType {
367 type Error = ();
368
369 #[inline]
370 fn try_from(value: u8) -> Result<Self, Self::Error> {
371 match value {
372 $($value => Ok(Self::$variant),)*
373 _ => Err(()),
374 }
375 }
376 }
377 };
378}
379
380handshake_type!(
390 HelloRequest(0),
391 ClientHello(1),
392 ServerHello(2),
393 Certificate(11),
394 ServerKeyExchange(12),
395 CertificateRequest(13),
396 ServerHelloDone(14),
397 CertificateVerify(15),
398 ClientKeyExchange(16),
399 Finished(20),
400);
401
402#[derive(Clone, Copy, Debug, IntoBytes, FromBytes, Unaligned)]
420#[repr(C)]
421pub struct HandshakeHeader {
422 msg_type: u8,
423 length: [u8; 3],
424}
425
426impl HandshakeHeader {
427 #[inline]
428 pub fn msg_type(self) -> Option<HandshakeType> {
429 HandshakeType::try_from(self.msg_type).ok()
430 }
431
432 #[inline]
433 pub fn len(self) -> usize {
434 let mut len = [0u8; 4];
435 len[1..].copy_from_slice(&self.length);
436 let len = u32::from_be_bytes(len);
437 len as _
438 }
439
440 #[inline]
441 pub fn is_empty(self) -> bool {
442 self.len() == 0
443 }
444}
445
446s2n_codec::zerocopy_value_codec!(HandshakeHeader);
447
448#[cfg(test)]
449mod tests {
450 use super::*;
451 use bolero::check;
452 use hex_literal::hex;
453
454 const MAX_HELLO_SIZE: u64 = if cfg!(kani) { 32 } else { 255 };
455
456 type Chunk = crate::testing::InlineVec<u8, { MAX_HELLO_SIZE as usize + 2 }>;
457
458 #[test]
460 #[cfg_attr(kani, kani::proof, kani::solver(cadical), kani::unwind(36))]
461 fn parse_hello_test() {
462 check!()
463 .with_type::<(HandshakeType, Chunk, u64)>()
464 .for_each(|(ty, chunk, total_received_len)| {
465 let _ =
466 testing::Session::parse_hello(*ty, chunk, *total_received_len, MAX_HELLO_SIZE);
467 });
468 }
469
470 macro_rules! h {
471 ($($tt:tt)*) => {
472 &hex!($($tt)*)[..]
473 }
474 }
475
476 fn parse_hello<'a>(
477 ty: HandshakeType,
478 input: &'a [&'a [u8]],
479 ) -> Result<Option<Vec<&'a [u8]>>, crate::transport::Error> {
480 let total_received_len: usize = input.iter().map(|chunk| chunk.len()).sum();
481
482 let empty = &[][..];
483 let first = input.iter().copied().next().unwrap_or(empty);
484
485 let outcome =
486 testing::Session::parse_hello(ty, first, total_received_len as _, MAX_HELLO_SIZE)?;
487
488 if let Some(offsets) = outcome {
489 let payload = offsets.trim_chunks(input.iter().copied()).collect();
490 Ok(Some(payload))
491 } else {
492 Ok(None)
493 }
494 }
495
496 #[test]
497 fn client_hello_valid_tests() {
498 let tests = [
499 (&[h!("01 00 00 02 aa bb cc")][..], &[h!("aa bb")][..]),
500 (&[h!("01 00 00 01"), h!("aa bb cc dd")], &[h!("aa")]),
501 (
502 &[h!("01 00 00 02"), h!("aa"), h!("bb"), h!("cc")],
503 &[h!("aa"), h!("bb")],
504 ),
505 ];
506
507 for (input, expected) in tests {
508 let output = parse_hello(HandshakeType::ClientHello, input)
509 .unwrap()
510 .unwrap();
511
512 assert_eq!(&output[..], expected);
513 }
514 }
515
516 #[test]
517 fn server_hello_valid_tests() {
518 let tests = [(&[h!("02 00 00 02 aa bb cc")][..], &[h!("aa bb")][..])];
519
520 for (input, expected) in tests {
521 let output = parse_hello(HandshakeType::ServerHello, input)
522 .unwrap()
523 .unwrap();
524
525 assert_eq!(&output[..], expected);
526 }
527 }
528
529 #[test]
530 fn client_hello_incomplete_tests() {
531 let tests = [
532 &[][..],
533 &[h!("01 00 00")],
535 &[h!("01 00 00 01")],
537 &[h!("01 00 00 04"), h!("aa"), h!("bb")],
539 ];
540
541 for input in tests {
542 assert_eq!(
543 parse_hello(HandshakeType::ClientHello, input).unwrap(),
544 None
545 );
546 }
547 }
548
549 #[test]
550 fn client_hello_invalid_tests() {
551 let tests = [
552 &[h!("02 00 00 01 aa")],
554 &[h!("01 00 01 00 aa")],
556 &[h!("01 ff ff ff aa")],
558 ];
559
560 for input in tests {
561 assert!(parse_hello(HandshakeType::ClientHello, input).is_err());
562 }
563 }
564}