1use crate::path::{LocalAddress, RemoteAddress};
5#[cfg(feature = "alloc")]
6use alloc::vec::Vec;
7#[cfg(feature = "alloc")]
8pub use bytes::{Bytes, BytesMut};
9use core::{any::Any, fmt::Debug, net::SocketAddr};
10use zerocopy::{FromBytes, IntoBytes, Unaligned};
11
12mod error;
13pub use error::Error;
14
15#[cfg(any(test, feature = "testing"))]
16pub mod testing;
17
18#[cfg(all(feature = "alloc", any(test, feature = "testing")))]
19pub mod null;
20
21#[cfg(feature = "alloc")]
22pub mod slow_tls;
23
24#[cfg(feature = "std")]
25pub mod offload;
26
27#[derive(Debug, Clone, Copy)]
30#[non_exhaustive]
31pub struct ConnectionInfo {
32 pub local_address: SocketAddr,
33 pub remote_address: SocketAddr,
34}
35
36impl ConnectionInfo {
37 #[doc(hidden)]
38 pub fn new(local_address: LocalAddress, remote_address: RemoteAddress) -> Self {
39 Self {
40 local_address: local_address.into(),
41 remote_address: remote_address.into(),
42 }
43 }
44}
45
46#[derive(Debug)]
48pub struct ApplicationParameters<'a> {
49 pub transport_parameters: &'a [u8],
51}
52
53#[derive(Debug, Eq)]
57pub struct NamedGroup {
58 pub group_name: &'static str,
59 pub contains_kem: bool,
60}
61
62impl PartialEq for NamedGroup {
66 fn eq(&self, other: &Self) -> bool {
67 self.group_name.eq_ignore_ascii_case(other.group_name)
68 && self.contains_kem == other.contains_kem
69 }
70}
71
72#[derive(Debug)]
73#[non_exhaustive]
74pub enum TlsExportError {
75 #[non_exhaustive]
76 Failure,
77}
78
79impl TlsExportError {
80 pub fn failure() -> Self {
81 TlsExportError::Failure
82 }
83}
84
85#[derive(Debug)]
86#[non_exhaustive]
87pub enum ChainError {
88 #[non_exhaustive]
89 Failure,
90}
91
92impl ChainError {
93 pub fn failure() -> Self {
94 ChainError::Failure
95 }
96}
97
98pub trait TlsSession: Send {
99 fn tls_exporter(
101 &self,
102 label: &[u8],
103 context: &[u8],
104 output: &mut [u8],
105 ) -> Result<(), TlsExportError>;
106
107 fn cipher_suite(&self) -> CipherSuite;
108
109 #[cfg(feature = "alloc")]
110 fn peer_cert_chain_der(&self) -> Result<Vec<Vec<u8>>, ChainError>;
111}
112
113#[cfg(feature = "alloc")]
114pub trait Context<Crypto: crate::crypto::CryptoSuite> {
115 fn on_client_application_params(
130 &mut self,
131 client_params: ApplicationParameters,
132 server_params: &mut alloc::vec::Vec<u8>,
133 ) -> Result<(), crate::transport::Error>;
134
135 fn on_handshake_keys(
136 &mut self,
137 key: Crypto::HandshakeKey,
138 header_key: Crypto::HandshakeHeaderKey,
139 ) -> Result<(), crate::transport::Error>;
140
141 fn on_zero_rtt_keys(
142 &mut self,
143 key: Crypto::ZeroRttKey,
144 header_key: Crypto::ZeroRttHeaderKey,
145 application_parameters: ApplicationParameters,
146 ) -> Result<(), crate::transport::Error>;
147
148 fn on_one_rtt_keys(
149 &mut self,
150 key: Crypto::OneRttKey,
151 header_key: Crypto::OneRttHeaderKey,
152 application_parameters: ApplicationParameters,
153 ) -> Result<(), crate::transport::Error>;
154
155 fn on_server_name(
156 &mut self,
157 server_name: crate::application::ServerName,
158 ) -> Result<(), crate::transport::Error>;
159
160 fn on_application_protocol(
161 &mut self,
162 application_protocol: Bytes,
163 ) -> Result<(), crate::transport::Error>;
164
165 fn on_key_exchange_group(
166 &mut self,
167 named_group: NamedGroup,
168 ) -> Result<(), crate::transport::Error>;
169
170 fn on_handshake_complete(&mut self) -> Result<(), crate::transport::Error>;
176
177 #[cfg(feature = "alloc")]
179 fn on_tls_context(&mut self, _context: alloc::boxed::Box<dyn Any + Send>);
180
181 fn on_tls_exporter_ready(
182 &mut self,
183 session: &impl TlsSession,
184 ) -> Result<(), crate::transport::Error>;
185
186 fn on_tls_handshake_failed(
187 &mut self,
188 session: &impl TlsSession,
189 error: &(dyn core::error::Error + Send + Sync + 'static),
190 ) -> Result<(), crate::transport::Error>;
191
192 fn receive_initial(&mut self, max_len: Option<usize>) -> Option<Bytes>;
197
198 fn receive_handshake(&mut self, max_len: Option<usize>) -> Option<Bytes>;
203
204 fn receive_application(&mut self, max_len: Option<usize>) -> Option<Bytes>;
209
210 fn can_send_initial(&self) -> bool;
211 fn send_initial(&mut self, transmission: Bytes);
212
213 fn can_send_handshake(&self) -> bool;
214 fn send_handshake(&mut self, transmission: Bytes);
215
216 fn can_send_application(&self) -> bool;
217 fn send_application(&mut self, transmission: Bytes);
218
219 fn waker(&self) -> &core::task::Waker;
220}
221
222#[cfg(feature = "alloc")]
223pub trait Endpoint: 'static + Sized + Send {
224 type Session: Session;
225
226 fn new_server_session<Params: s2n_codec::EncoderValue>(
227 &mut self,
228 transport_parameters: &Params,
229 connection_info: ConnectionInfo,
230 ) -> Self::Session;
231
232 fn new_client_session<Params: s2n_codec::EncoderValue>(
233 &mut self,
234 transport_parameters: &Params,
235 server_name: crate::application::ServerName,
236 ) -> Self::Session;
237
238 fn max_tag_length(&self) -> usize;
240}
241
242#[cfg(feature = "alloc")]
243pub trait Session: crate::crypto::CryptoSuite + Sized + Send + Debug {
244 fn poll<C: Context<Self>>(
245 &mut self,
246 context: &mut C,
247 ) -> core::task::Poll<Result<(), crate::transport::Error>>;
248
249 fn process_post_handshake_message<C: Context<Self>>(
250 &mut self,
251 _context: &mut C,
252 ) -> Result<(), crate::transport::Error> {
253 Ok(())
254 }
255
256 fn should_discard_session(&self) -> bool {
257 true
258 }
259
260 #[inline]
264 fn parse_hello(
265 msg_type: HandshakeType,
266 header_chunk: &[u8],
267 total_received_len: u64,
268 max_hello_size: u64,
269 ) -> Result<Option<HelloOffsets>, crate::transport::Error> {
270 let buffer = s2n_codec::DecoderBuffer::new(header_chunk);
271
272 let header = if let Ok((header, _)) = buffer.decode::<HandshakeHeader>() {
273 header
274 } else {
275 return Ok(None);
277 };
278
279 if header.msg_type() != Some(msg_type) {
280 return Err(crate::transport::Error::PROTOCOL_VIOLATION
281 .with_reason("first TLS message should be a hello message"));
282 }
283
284 let payload_len = header.len() as u64;
285
286 if payload_len > max_hello_size {
287 return Err(crate::transport::Error::CRYPTO_BUFFER_EXCEEDED
288 .with_reason("hello message cannot exceed 16k"));
289 }
290
291 let header_len = core::mem::size_of::<HandshakeHeader>() as u64;
292
293 if total_received_len < payload_len + header_len {
295 return Ok(None);
296 }
297
298 let offsets = HelloOffsets {
299 payload_offset: header_len as _,
300 payload_len: payload_len as _,
301 };
302
303 Ok(Some(offsets))
304 }
305}
306
307#[derive(Copy, Clone, Debug)]
308pub struct HelloOffsets {
309 pub payload_offset: usize,
310 pub payload_len: usize,
311}
312
313impl HelloOffsets {
314 #[inline]
315 pub fn trim_chunks<'a, I: Iterator<Item = &'a [u8]>>(
316 &self,
317 chunks: I,
318 ) -> impl Iterator<Item = &'a [u8]> {
319 let mut offsets = *self;
320
321 chunks.filter_map(move |mut chunk| {
322 if offsets.payload_offset > 0 {
324 let start = offsets.payload_offset.min(chunk.len());
325 chunk = &chunk[start..];
326 offsets.payload_offset -= start;
327 }
328
329 if offsets.payload_offset == 0 && offsets.payload_len > 0 {
331 let end = offsets.payload_len.min(chunk.len());
332 chunk = &chunk[..end];
333 offsets.payload_len -= end;
334 } else {
335 return None;
337 }
338
339 if chunk.is_empty() {
340 None
341 } else {
342 Some(chunk)
343 }
344 })
345 }
346}
347
348#[derive(Copy, Clone, Debug, Default)]
349#[allow(non_camel_case_types)]
350pub enum CipherSuite {
351 TLS_AES_128_GCM_SHA256,
352 TLS_AES_256_GCM_SHA384,
353 TLS_CHACHA20_POLY1305_SHA256,
354 #[default]
355 Unknown,
356}
357
358impl crate::event::IntoEvent<crate::event::builder::CipherSuite> for CipherSuite {
359 #[inline]
360 fn into_event(self) -> crate::event::builder::CipherSuite {
361 use crate::event::builder::CipherSuite::*;
362 match self {
363 Self::TLS_AES_128_GCM_SHA256 => TLS_AES_128_GCM_SHA256 {},
364 Self::TLS_AES_256_GCM_SHA384 => TLS_AES_256_GCM_SHA384 {},
365 Self::TLS_CHACHA20_POLY1305_SHA256 => TLS_CHACHA20_POLY1305_SHA256 {},
366 Self::Unknown => Unknown {},
367 }
368 }
369}
370
371impl crate::event::IntoEvent<crate::event::api::CipherSuite> for CipherSuite {
372 #[inline]
373 fn into_event(self) -> crate::event::api::CipherSuite {
374 let builder: crate::event::builder::CipherSuite = self.into_event();
375 builder.into_event()
376 }
377}
378
379macro_rules! handshake_type {
380 ($($variant:ident($value:literal)),* $(,)?) => {
381 #[derive(Clone, Copy, Debug, PartialEq, Eq, IntoBytes, Unaligned)]
382 #[cfg_attr(any(test, feature = "bolero-generator"), derive(bolero_generator::TypeGenerator))]
383 #[repr(u8)]
384 pub enum HandshakeType {
385 $($variant = $value),*
386 }
387
388 impl TryFrom<u8> for HandshakeType {
389 type Error = ();
390
391 #[inline]
392 fn try_from(value: u8) -> Result<Self, Self::Error> {
393 match value {
394 $($value => Ok(Self::$variant),)*
395 _ => Err(()),
396 }
397 }
398 }
399 };
400}
401
402handshake_type!(
412 HelloRequest(0),
413 ClientHello(1),
414 ServerHello(2),
415 Certificate(11),
416 ServerKeyExchange(12),
417 CertificateRequest(13),
418 ServerHelloDone(14),
419 CertificateVerify(15),
420 ClientKeyExchange(16),
421 Finished(20),
422);
423
424#[derive(Clone, Copy, Debug, IntoBytes, FromBytes, Unaligned)]
442#[repr(C)]
443pub struct HandshakeHeader {
444 msg_type: u8,
445 length: [u8; 3],
446}
447
448impl HandshakeHeader {
449 #[inline]
450 pub fn msg_type(self) -> Option<HandshakeType> {
451 HandshakeType::try_from(self.msg_type).ok()
452 }
453
454 #[inline]
455 pub fn len(self) -> usize {
456 let mut len = [0u8; 4];
457 len[1..].copy_from_slice(&self.length);
458 let len = u32::from_be_bytes(len);
459 len as _
460 }
461
462 #[inline]
463 pub fn is_empty(self) -> bool {
464 self.len() == 0
465 }
466}
467
468s2n_codec::zerocopy_value_codec!(HandshakeHeader);
469
470#[cfg(test)]
471mod tests {
472 use super::*;
473 use bolero::check;
474 use hex_literal::hex;
475
476 const MAX_HELLO_SIZE: u64 = if cfg!(kani) { 32 } else { 255 };
477
478 type Chunk = crate::testing::InlineVec<u8, { MAX_HELLO_SIZE as usize + 2 }>;
479
480 #[test]
482 #[cfg_attr(kani, kani::proof, kani::solver(cadical), kani::unwind(36))]
483 fn parse_hello_test() {
484 check!()
485 .with_type::<(HandshakeType, Chunk, u64)>()
486 .for_each(|(ty, chunk, total_received_len)| {
487 let _ =
488 testing::Session::parse_hello(*ty, chunk, *total_received_len, MAX_HELLO_SIZE);
489 });
490 }
491
492 macro_rules! h {
493 ($($tt:tt)*) => {
494 &hex!($($tt)*)[..]
495 }
496 }
497
498 fn parse_hello<'a>(
499 ty: HandshakeType,
500 input: &'a [&'a [u8]],
501 ) -> Result<Option<Vec<&'a [u8]>>, crate::transport::Error> {
502 let total_received_len: usize = input.iter().map(|chunk| chunk.len()).sum();
503
504 let empty = &[][..];
505 let first = input.iter().copied().next().unwrap_or(empty);
506
507 let outcome =
508 testing::Session::parse_hello(ty, first, total_received_len as _, MAX_HELLO_SIZE)?;
509
510 if let Some(offsets) = outcome {
511 let payload = offsets.trim_chunks(input.iter().copied()).collect();
512 Ok(Some(payload))
513 } else {
514 Ok(None)
515 }
516 }
517
518 #[test]
519 fn client_hello_valid_tests() {
520 let tests = [
521 (&[h!("01 00 00 02 aa bb cc")][..], &[h!("aa bb")][..]),
522 (&[h!("01 00 00 01"), h!("aa bb cc dd")], &[h!("aa")]),
523 (
524 &[h!("01 00 00 02"), h!("aa"), h!("bb"), h!("cc")],
525 &[h!("aa"), h!("bb")],
526 ),
527 ];
528
529 for (input, expected) in tests {
530 let output = parse_hello(HandshakeType::ClientHello, input)
531 .unwrap()
532 .unwrap();
533
534 assert_eq!(&output[..], expected);
535 }
536 }
537
538 #[test]
539 fn server_hello_valid_tests() {
540 let tests = [(&[h!("02 00 00 02 aa bb cc")][..], &[h!("aa bb")][..])];
541
542 for (input, expected) in tests {
543 let output = parse_hello(HandshakeType::ServerHello, input)
544 .unwrap()
545 .unwrap();
546
547 assert_eq!(&output[..], expected);
548 }
549 }
550
551 #[test]
552 fn client_hello_incomplete_tests() {
553 let tests = [
554 &[][..],
555 &[h!("01 00 00")],
557 &[h!("01 00 00 01")],
559 &[h!("01 00 00 04"), h!("aa"), h!("bb")],
561 ];
562
563 for input in tests {
564 assert_eq!(
565 parse_hello(HandshakeType::ClientHello, input).unwrap(),
566 None
567 );
568 }
569 }
570
571 #[test]
572 fn client_hello_invalid_tests() {
573 let tests = [
574 &[h!("02 00 00 01 aa")],
576 &[h!("01 00 01 00 aa")],
578 &[h!("01 ff ff ff aa")],
580 ];
581
582 for input in tests {
583 assert!(parse_hello(HandshakeType::ClientHello, input).is_err());
584 }
585 }
586}