1#[cfg(feature = "alloc")]
5use alloc::vec::Vec;
6#[cfg(feature = "alloc")]
7pub use bytes::{Bytes, BytesMut};
8use core::{any::Any, fmt::Debug};
9use zerocopy::{FromBytes, IntoBytes, Unaligned};
10
11mod error;
12pub use error::Error;
13
14#[cfg(any(test, feature = "testing"))]
15pub mod testing;
16
17#[cfg(all(feature = "alloc", any(test, feature = "testing")))]
18pub mod null;
19
20#[cfg(feature = "alloc")]
21pub mod slow_tls;
22
23#[derive(Debug)]
25pub struct ApplicationParameters<'a> {
26 pub transport_parameters: &'a [u8],
28}
29
30#[derive(Debug, Eq)]
34pub struct NamedGroup {
35 pub group_name: &'static str,
36 pub contains_kem: bool,
37}
38
39impl PartialEq for NamedGroup {
43 fn eq(&self, other: &Self) -> bool {
44 self.group_name.eq_ignore_ascii_case(other.group_name)
45 && self.contains_kem == other.contains_kem
46 }
47}
48
49#[derive(Debug)]
50#[non_exhaustive]
51pub enum TlsExportError {
52 #[non_exhaustive]
53 Failure,
54}
55
56impl TlsExportError {
57 pub fn failure() -> Self {
58 TlsExportError::Failure
59 }
60}
61
62#[derive(Debug)]
63#[non_exhaustive]
64pub enum ChainError {
65 #[non_exhaustive]
66 Failure,
67}
68
69impl ChainError {
70 pub fn failure() -> Self {
71 ChainError::Failure
72 }
73}
74
75pub trait TlsSession: Send {
76 fn tls_exporter(
78 &self,
79 label: &[u8],
80 context: &[u8],
81 output: &mut [u8],
82 ) -> Result<(), TlsExportError>;
83
84 fn cipher_suite(&self) -> CipherSuite;
85
86 #[cfg(feature = "alloc")]
87 fn peer_cert_chain_der(&self) -> Result<Vec<Vec<u8>>, ChainError>;
88}
89
90#[cfg(feature = "alloc")]
91pub trait Context<Crypto: crate::crypto::CryptoSuite> {
92 fn on_client_application_params(
107 &mut self,
108 client_params: ApplicationParameters,
109 server_params: &mut alloc::vec::Vec<u8>,
110 ) -> Result<(), crate::transport::Error>;
111
112 fn on_handshake_keys(
113 &mut self,
114 key: Crypto::HandshakeKey,
115 header_key: Crypto::HandshakeHeaderKey,
116 ) -> Result<(), crate::transport::Error>;
117
118 fn on_zero_rtt_keys(
119 &mut self,
120 key: Crypto::ZeroRttKey,
121 header_key: Crypto::ZeroRttHeaderKey,
122 application_parameters: ApplicationParameters,
123 ) -> Result<(), crate::transport::Error>;
124
125 fn on_one_rtt_keys(
126 &mut self,
127 key: Crypto::OneRttKey,
128 header_key: Crypto::OneRttHeaderKey,
129 application_parameters: ApplicationParameters,
130 ) -> Result<(), crate::transport::Error>;
131
132 fn on_server_name(
133 &mut self,
134 server_name: crate::application::ServerName,
135 ) -> Result<(), crate::transport::Error>;
136
137 fn on_application_protocol(
138 &mut self,
139 application_protocol: Bytes,
140 ) -> Result<(), crate::transport::Error>;
141
142 fn on_key_exchange_group(
143 &mut self,
144 named_group: NamedGroup,
145 ) -> Result<(), crate::transport::Error>;
146
147 fn on_handshake_complete(&mut self) -> Result<(), crate::transport::Error>;
153
154 #[cfg(feature = "alloc")]
156 fn on_tls_context(&mut self, _context: alloc::boxed::Box<dyn Any + Send>);
157
158 fn on_tls_exporter_ready(
159 &mut self,
160 session: &impl TlsSession,
161 ) -> Result<(), crate::transport::Error>;
162
163 fn on_tls_handshake_failed(
164 &mut self,
165 session: &impl TlsSession,
166 ) -> Result<(), crate::transport::Error>;
167
168 fn receive_initial(&mut self, max_len: Option<usize>) -> Option<Bytes>;
173
174 fn receive_handshake(&mut self, max_len: Option<usize>) -> Option<Bytes>;
179
180 fn receive_application(&mut self, max_len: Option<usize>) -> Option<Bytes>;
185
186 fn can_send_initial(&self) -> bool;
187 fn send_initial(&mut self, transmission: Bytes);
188
189 fn can_send_handshake(&self) -> bool;
190 fn send_handshake(&mut self, transmission: Bytes);
191
192 fn can_send_application(&self) -> bool;
193 fn send_application(&mut self, transmission: Bytes);
194
195 fn waker(&self) -> &core::task::Waker;
196}
197
198#[cfg(feature = "alloc")]
199pub trait Endpoint: 'static + Sized + Send {
200 type Session: Session;
201
202 fn new_server_session<Params: s2n_codec::EncoderValue>(
203 &mut self,
204 transport_parameters: &Params,
205 ) -> Self::Session;
206
207 fn new_client_session<Params: s2n_codec::EncoderValue>(
208 &mut self,
209 transport_parameters: &Params,
210 server_name: crate::application::ServerName,
211 ) -> Self::Session;
212
213 fn max_tag_length(&self) -> usize;
215}
216
217#[cfg(feature = "alloc")]
218pub trait Session: crate::crypto::CryptoSuite + Sized + Send + Debug {
219 fn poll<C: Context<Self>>(
220 &mut self,
221 context: &mut C,
222 ) -> core::task::Poll<Result<(), crate::transport::Error>>;
223
224 fn process_post_handshake_message<C: Context<Self>>(
225 &mut self,
226 _context: &mut C,
227 ) -> Result<(), crate::transport::Error> {
228 Ok(())
229 }
230
231 fn should_discard_session(&self) -> bool {
232 true
233 }
234
235 #[inline]
239 fn parse_hello(
240 msg_type: HandshakeType,
241 header_chunk: &[u8],
242 total_received_len: u64,
243 max_hello_size: u64,
244 ) -> Result<Option<HelloOffsets>, crate::transport::Error> {
245 let buffer = s2n_codec::DecoderBuffer::new(header_chunk);
246
247 let header = if let Ok((header, _)) = buffer.decode::<HandshakeHeader>() {
248 header
249 } else {
250 return Ok(None);
252 };
253
254 if header.msg_type() != Some(msg_type) {
255 return Err(crate::transport::Error::PROTOCOL_VIOLATION
256 .with_reason("first TLS message should be a hello message"));
257 }
258
259 let payload_len = header.len() as u64;
260
261 if payload_len > max_hello_size {
262 return Err(crate::transport::Error::CRYPTO_BUFFER_EXCEEDED
263 .with_reason("hello message cannot exceed 16k"));
264 }
265
266 let header_len = core::mem::size_of::<HandshakeHeader>() as u64;
267
268 if total_received_len < payload_len + header_len {
270 return Ok(None);
271 }
272
273 let offsets = HelloOffsets {
274 payload_offset: header_len as _,
275 payload_len: payload_len as _,
276 };
277
278 Ok(Some(offsets))
279 }
280}
281
282#[derive(Copy, Clone, Debug)]
283pub struct HelloOffsets {
284 pub payload_offset: usize,
285 pub payload_len: usize,
286}
287
288impl HelloOffsets {
289 #[inline]
290 pub fn trim_chunks<'a, I: Iterator<Item = &'a [u8]>>(
291 &self,
292 chunks: I,
293 ) -> impl Iterator<Item = &'a [u8]> {
294 let mut offsets = *self;
295
296 chunks.filter_map(move |mut chunk| {
297 if offsets.payload_offset > 0 {
299 let start = offsets.payload_offset.min(chunk.len());
300 chunk = &chunk[start..];
301 offsets.payload_offset -= start;
302 }
303
304 if offsets.payload_offset == 0 && offsets.payload_len > 0 {
306 let end = offsets.payload_len.min(chunk.len());
307 chunk = &chunk[..end];
308 offsets.payload_len -= end;
309 } else {
310 return None;
312 }
313
314 if chunk.is_empty() {
315 None
316 } else {
317 Some(chunk)
318 }
319 })
320 }
321}
322
323#[derive(Copy, Clone, Debug, Default)]
324#[allow(non_camel_case_types)]
325pub enum CipherSuite {
326 TLS_AES_128_GCM_SHA256,
327 TLS_AES_256_GCM_SHA384,
328 TLS_CHACHA20_POLY1305_SHA256,
329 #[default]
330 Unknown,
331}
332
333impl crate::event::IntoEvent<crate::event::builder::CipherSuite> for CipherSuite {
334 #[inline]
335 fn into_event(self) -> crate::event::builder::CipherSuite {
336 use crate::event::builder::CipherSuite::*;
337 match self {
338 Self::TLS_AES_128_GCM_SHA256 => TLS_AES_128_GCM_SHA256 {},
339 Self::TLS_AES_256_GCM_SHA384 => TLS_AES_256_GCM_SHA384 {},
340 Self::TLS_CHACHA20_POLY1305_SHA256 => TLS_CHACHA20_POLY1305_SHA256 {},
341 Self::Unknown => Unknown {},
342 }
343 }
344}
345
346impl crate::event::IntoEvent<crate::event::api::CipherSuite> for CipherSuite {
347 #[inline]
348 fn into_event(self) -> crate::event::api::CipherSuite {
349 let builder: crate::event::builder::CipherSuite = self.into_event();
350 builder.into_event()
351 }
352}
353
354macro_rules! handshake_type {
355 ($($variant:ident($value:literal)),* $(,)?) => {
356 #[derive(Clone, Copy, Debug, PartialEq, Eq, IntoBytes, Unaligned)]
357 #[cfg_attr(any(test, feature = "bolero-generator"), derive(bolero_generator::TypeGenerator))]
358 #[repr(u8)]
359 pub enum HandshakeType {
360 $($variant = $value),*
361 }
362
363 impl TryFrom<u8> for HandshakeType {
364 type Error = ();
365
366 #[inline]
367 fn try_from(value: u8) -> Result<Self, Self::Error> {
368 match value {
369 $($value => Ok(Self::$variant),)*
370 _ => Err(()),
371 }
372 }
373 }
374 };
375}
376
377handshake_type!(
387 HelloRequest(0),
388 ClientHello(1),
389 ServerHello(2),
390 Certificate(11),
391 ServerKeyExchange(12),
392 CertificateRequest(13),
393 ServerHelloDone(14),
394 CertificateVerify(15),
395 ClientKeyExchange(16),
396 Finished(20),
397);
398
399#[derive(Clone, Copy, Debug, IntoBytes, FromBytes, Unaligned)]
417#[repr(C)]
418pub struct HandshakeHeader {
419 msg_type: u8,
420 length: [u8; 3],
421}
422
423impl HandshakeHeader {
424 #[inline]
425 pub fn msg_type(self) -> Option<HandshakeType> {
426 HandshakeType::try_from(self.msg_type).ok()
427 }
428
429 #[inline]
430 pub fn len(self) -> usize {
431 let mut len = [0u8; 4];
432 len[1..].copy_from_slice(&self.length);
433 let len = u32::from_be_bytes(len);
434 len as _
435 }
436
437 #[inline]
438 pub fn is_empty(self) -> bool {
439 self.len() == 0
440 }
441}
442
443s2n_codec::zerocopy_value_codec!(HandshakeHeader);
444
445#[cfg(test)]
446mod tests {
447 use super::*;
448 use bolero::check;
449 use hex_literal::hex;
450
451 const MAX_HELLO_SIZE: u64 = if cfg!(kani) { 32 } else { 255 };
452
453 type Chunk = crate::testing::InlineVec<u8, { MAX_HELLO_SIZE as usize + 2 }>;
454
455 #[test]
457 #[cfg_attr(kani, kani::proof, kani::solver(cadical), kani::unwind(36))]
458 fn parse_hello_test() {
459 check!()
460 .with_type::<(HandshakeType, Chunk, u64)>()
461 .for_each(|(ty, chunk, total_received_len)| {
462 let _ =
463 testing::Session::parse_hello(*ty, chunk, *total_received_len, MAX_HELLO_SIZE);
464 });
465 }
466
467 macro_rules! h {
468 ($($tt:tt)*) => {
469 &hex!($($tt)*)[..]
470 }
471 }
472
473 fn parse_hello<'a>(
474 ty: HandshakeType,
475 input: &'a [&'a [u8]],
476 ) -> Result<Option<Vec<&'a [u8]>>, crate::transport::Error> {
477 let total_received_len: usize = input.iter().map(|chunk| chunk.len()).sum();
478
479 let empty = &[][..];
480 let first = input.iter().copied().next().unwrap_or(empty);
481
482 let outcome =
483 testing::Session::parse_hello(ty, first, total_received_len as _, MAX_HELLO_SIZE)?;
484
485 if let Some(offsets) = outcome {
486 let payload = offsets.trim_chunks(input.iter().copied()).collect();
487 Ok(Some(payload))
488 } else {
489 Ok(None)
490 }
491 }
492
493 #[test]
494 fn client_hello_valid_tests() {
495 let tests = [
496 (&[h!("01 00 00 02 aa bb cc")][..], &[h!("aa bb")][..]),
497 (&[h!("01 00 00 01"), h!("aa bb cc dd")], &[h!("aa")]),
498 (
499 &[h!("01 00 00 02"), h!("aa"), h!("bb"), h!("cc")],
500 &[h!("aa"), h!("bb")],
501 ),
502 ];
503
504 for (input, expected) in tests {
505 let output = parse_hello(HandshakeType::ClientHello, input)
506 .unwrap()
507 .unwrap();
508
509 assert_eq!(&output[..], expected);
510 }
511 }
512
513 #[test]
514 fn server_hello_valid_tests() {
515 let tests = [(&[h!("02 00 00 02 aa bb cc")][..], &[h!("aa bb")][..])];
516
517 for (input, expected) in tests {
518 let output = parse_hello(HandshakeType::ServerHello, input)
519 .unwrap()
520 .unwrap();
521
522 assert_eq!(&output[..], expected);
523 }
524 }
525
526 #[test]
527 fn client_hello_incomplete_tests() {
528 let tests = [
529 &[][..],
530 &[h!("01 00 00")],
532 &[h!("01 00 00 01")],
534 &[h!("01 00 00 04"), h!("aa"), h!("bb")],
536 ];
537
538 for input in tests {
539 assert_eq!(
540 parse_hello(HandshakeType::ClientHello, input).unwrap(),
541 None
542 );
543 }
544 }
545
546 #[test]
547 fn client_hello_invalid_tests() {
548 let tests = [
549 &[h!("02 00 00 01 aa")],
551 &[h!("01 00 01 00 aa")],
553 &[h!("01 ff ff ff aa")],
555 ];
556
557 for input in tests {
558 assert!(parse_hello(HandshakeType::ClientHello, input).is_err());
559 }
560 }
561}