1#![allow(clippy::missing_safety_doc)] #[cfg(feature = "unstable-cert_authorities")]
7use crate::cert_authorities::CertRequestState;
8#[cfg(feature = "unstable-renegotiate")]
9use crate::renegotiate::RenegotiateState;
10use crate::{
11 callbacks::*,
12 cert_chain::{CertificateChain, CertificateChainHandle},
13 config::Config,
14 enums::*,
15 error::{Error, Fallible, Pollable},
16 psk::Psk,
17 security,
18 utilities::cstr_to_str,
19};
20
21use core::{
22 convert::TryInto,
23 fmt,
24 mem::{self, ManuallyDrop, MaybeUninit},
25 pin::Pin,
26 ptr::NonNull,
27 task::{Poll, Waker},
28 time::Duration,
29};
30use libc::c_void;
31use s2n_tls_sys::*;
32use std::{
33 any::{Any, TypeId},
34 collections::HashMap,
35 ffi::CStr,
36};
37
38mod builder;
39pub use builder::*;
40
41macro_rules! const_str {
49 ($c_chars:expr) => {
50 CStr::from_ptr($c_chars)
51 .to_str()
52 .map_err(|_| Error::INVALID_INPUT)
53 };
54}
55
56#[non_exhaustive]
57#[derive(Debug, PartialEq)]
58pub struct KeyUpdateCount {
61 pub send_key_updates: u8,
62 pub recv_key_updates: u8,
63}
64
65pub struct Connection {
67 connection: NonNull<s2n_connection>,
68}
69
70impl fmt::Debug for Connection {
71 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
72 let mut debug = f.debug_struct("Connection");
73 if let Ok(handshake) = self.handshake_type() {
74 debug.field("handshake_type", &handshake);
75 }
76 if let Ok(cipher) = self.cipher_suite() {
77 debug.field("cipher_suite", &cipher);
78 }
79 if let Ok(version) = self.actual_protocol_version() {
80 debug.field("actual_protocol_version", &version);
81 }
82 if let Some(group_name) = self.selected_key_exchange_group() {
83 debug.field("selected_key_exchange_group", &group_name);
84 }
85 debug.finish_non_exhaustive()
86 }
87}
88
89unsafe impl Send for Connection {}
93
94unsafe impl Sync for Connection {}
108
109impl Connection {
110 pub fn new(mode: Mode) -> Self {
122 crate::init::init();
123
124 let connection = unsafe { s2n_connection_new(mode.into()).into_result() }.unwrap();
125
126 unsafe {
127 debug_assert! {
128 s2n_connection_get_config(connection.as_ptr(), &mut core::ptr::null_mut())
129 .into_result()
130 .is_err()
131 }
132 }
133
134 let mut connection = Self { connection };
135 connection.init_context(mode);
136 connection
137 }
138
139 fn init_context(&mut self, mode: Mode) {
140 let context = Box::new(Context::new(mode));
141 let context = Box::into_raw(context) as *mut c_void;
142 unsafe {
144 debug_assert!(s2n_connection_get_ctx(self.connection.as_ptr())
146 .into_result()
147 .is_err());
148
149 s2n_connection_set_ctx(self.connection.as_ptr(), context)
150 .into_result()
151 .unwrap();
152 }
153 }
154
155 pub fn new_client() -> Self {
156 Self::new(Mode::Client)
157 }
158
159 pub fn new_server() -> Self {
160 Self::new(Mode::Server)
161 }
162
163 pub(crate) fn as_ptr(&mut self) -> *mut s2n_connection {
164 self.connection.as_ptr()
165 }
166
167 #[cfg(s2n_tls_external_build)]
175 pub fn unstable_as_ptr(&mut self) -> *mut s2n_connection {
176 self.as_ptr()
177 }
178
179 pub(crate) unsafe fn from_raw(connection: NonNull<s2n_connection>) -> Self {
183 Self { connection }
184 }
185
186 pub(crate) fn mode(&self) -> Mode {
187 self.context().mode
188 }
189
190 pub fn set_blinding(&mut self, blinding: Blinding) -> Result<&mut Self, Error> {
196 unsafe {
197 s2n_connection_set_blinding(self.connection.as_ptr(), blinding.into()).into_result()
198 }?;
199 Ok(self)
200 }
201
202 pub fn remaining_blinding_delay(&self) -> Result<Duration, Error> {
211 let nanos = unsafe { s2n_connection_get_delay(self.connection.as_ptr()).into_result() }?;
212 Ok(Duration::from_nanos(nanos))
213 }
214
215 pub fn set_client_auth_type(
223 &mut self,
224 client_auth_type: ClientAuthType,
225 ) -> Result<&mut Self, Error> {
226 unsafe {
227 s2n_connection_set_client_auth_type(self.connection.as_ptr(), client_auth_type.into())
228 .into_result()
229 }?;
230 Ok(self)
231 }
232
233 unsafe fn drop_config(&mut self) -> Result<(), Error> {
240 let mut prev_config = core::ptr::null_mut();
241
242 if s2n_connection_get_config(self.connection.as_ptr(), &mut prev_config)
245 .into_result()
246 .is_ok()
247 {
248 let prev_config = NonNull::new(prev_config).expect(
249 "config should exist since the call to s2n_connection_get_config was successful",
250 );
251 drop(Config::from_raw(prev_config));
252 }
253
254 Ok(())
255 }
256
257 pub fn set_config(&mut self, mut config: Config) -> Result<&mut Self, Error> {
261 unsafe {
262 self.drop_config()?;
264
265 s2n_connection_set_config(self.connection.as_ptr(), config.as_mut_ptr())
266 .into_result()?;
267
268 debug_assert! {
269 s2n_connection_get_config(self.connection.as_ptr(), &mut core::ptr::null_mut()).into_result().is_ok(),
270 "s2n_connection_set_config was successful"
271 };
272
273 mem::forget(config);
276 }
277
278 Ok(self)
279 }
280
281 pub(crate) fn config(&self) -> Option<Config> {
282 let mut raw = core::ptr::null_mut();
283 let config = unsafe {
284 s2n_connection_get_config(self.connection.as_ptr(), &mut raw)
285 .into_result()
286 .ok()?;
287 let raw = NonNull::new(raw)?;
288 Config::from_raw(raw)
289 };
290 let _ = ManuallyDrop::new(config.clone());
293 Some(config)
294 }
295
296 pub fn set_security_policy(&mut self, policy: &security::Policy) -> Result<&mut Self, Error> {
298 unsafe {
299 s2n_connection_set_cipher_preferences(
300 self.connection.as_ptr(),
301 policy.as_cstr().as_ptr(),
302 )
303 .into_result()
304 }?;
305 Ok(self)
306 }
307
308 pub fn set_dynamic_record_threshold(
316 &mut self,
317 resize_threshold: u32,
318 timeout_threshold: u16,
319 ) -> Result<&mut Self, Error> {
320 unsafe {
321 s2n_connection_set_dynamic_record_threshold(
322 self.connection.as_ptr(),
323 resize_threshold,
324 timeout_threshold,
325 )
326 .into_result()
327 }?;
328 Ok(self)
329 }
330
331 pub fn request_key_update(&mut self, peer_request: PeerKeyUpdate) -> Result<&mut Self, Error> {
345 unsafe {
346 s2n_connection_request_key_update(self.connection.as_ptr(), peer_request.into())
347 .into_result()
348 }?;
349 Ok(self)
350 }
351
352 #[cfg(feature = "unstable-ktls")]
358 pub fn key_update_counts(&self) -> Result<KeyUpdateCount, Error> {
359 let mut send_key_updates = 0;
360 let mut recv_key_updates = 0;
361 unsafe {
362 s2n_connection_get_key_update_counts(
363 self.connection.as_ptr(),
364 &mut send_key_updates,
365 &mut recv_key_updates,
366 )
367 .into_result()?;
368 }
369 Ok(KeyUpdateCount {
370 send_key_updates,
371 recv_key_updates,
372 })
373 }
374
375 pub fn set_application_protocol_preference<P: IntoIterator<Item = I>, I: AsRef<[u8]>>(
385 &mut self,
386 protocols: P,
387 ) -> Result<&mut Self, Error> {
388 unsafe {
390 s2n_connection_set_protocol_preferences(self.connection.as_ptr(), core::ptr::null(), 0)
391 .into_result()
392 }?;
393
394 for protocol in protocols {
395 self.append_application_protocol_preference(protocol.as_ref())?;
396 }
397
398 Ok(self)
399 }
400
401 pub fn append_application_protocol_preference(
403 &mut self,
404 protocol: &[u8],
405 ) -> Result<&mut Self, Error> {
406 unsafe {
407 s2n_connection_append_protocol_preference(
408 self.connection.as_ptr(),
409 protocol.as_ptr(),
410 protocol
411 .len()
412 .try_into()
413 .map_err(|_| Error::INVALID_INPUT)?,
414 )
415 .into_result()
416 }?;
417 Ok(self)
418 }
419
420 pub fn set_receive_callback(&mut self, callback: s2n_recv_fn) -> Result<&mut Self, Error> {
424 unsafe { s2n_connection_set_recv_cb(self.connection.as_ptr(), callback).into_result() }?;
425 Ok(self)
426 }
427
428 pub unsafe fn set_receive_context(&mut self, context: *mut c_void) -> Result<&mut Self, Error> {
434 s2n_connection_set_recv_ctx(self.connection.as_ptr(), context).into_result()?;
435 Ok(self)
436 }
437
438 pub fn set_send_callback(&mut self, callback: s2n_send_fn) -> Result<&mut Self, Error> {
442 unsafe { s2n_connection_set_send_cb(self.connection.as_ptr(), callback).into_result() }?;
443 Ok(self)
444 }
445
446 pub unsafe fn set_send_context(&mut self, context: *mut c_void) -> Result<&mut Self, Error> {
452 s2n_connection_set_send_ctx(self.connection.as_ptr(), context).into_result()?;
453 Ok(self)
454 }
455
456 pub fn set_verify_host_callback<T: 'static + VerifyHostNameCallback>(
464 &mut self,
465 handler: T,
466 ) -> Result<&mut Self, Error> {
467 unsafe extern "C" fn verify_host_cb_fn(
468 host_name: *const ::libc::c_char,
469 host_name_len: usize,
470 context: *mut ::libc::c_void,
471 ) -> u8 {
472 let context = &mut *(context as *mut Context);
473 let handler = context.verify_host_callback.as_mut().unwrap();
474 verify_host(host_name, host_name_len, handler)
475 }
476
477 self.context_mut().verify_host_callback = Some(Box::new(handler));
478 unsafe {
479 s2n_connection_set_verify_host_callback(
480 self.connection.as_ptr(),
481 Some(verify_host_cb_fn),
482 self.context_mut() as *mut Context as *mut c_void,
483 )
484 .into_result()
485 }?;
486 Ok(self)
487 }
488
489 pub fn prefer_low_latency(&mut self) -> Result<&mut Self, Error> {
494 unsafe { s2n_connection_prefer_low_latency(self.connection.as_ptr()).into_result() }?;
495 Ok(self)
496 }
497
498 pub fn prefer_throughput(&mut self) -> Result<&mut Self, Error> {
502 unsafe { s2n_connection_prefer_throughput(self.connection.as_ptr()).into_result() }?;
503 Ok(self)
504 }
505
506 pub fn set_receive_buffering(&mut self, enabled: bool) -> Result<&mut Self, Error> {
510 unsafe {
511 s2n_connection_set_recv_buffering(self.connection.as_ptr(), enabled).into_result()
512 }?;
513 Ok(self)
514 }
515
516 pub fn release_buffers(&mut self) -> Result<&mut Self, Error> {
523 unsafe { s2n_connection_release_buffers(self.connection.as_ptr()).into_result() }?;
524 Ok(self)
525 }
526
527 pub fn use_corked_io(&mut self) -> Result<&mut Self, Error> {
529 unsafe { s2n_connection_use_corked_io(self.connection.as_ptr()).into_result() }?;
530 Ok(self)
531 }
532
533 pub(crate) fn wipe_method<F, T>(&mut self, wipe: F) -> Result<(), Error>
534 where
535 F: FnOnce(&mut Self) -> Result<T, Error>,
536 {
537 let mode = self.mode();
538
539 unsafe { self.drop_context()? };
542
543 let result = wipe(self);
544 self.init_context(mode);
547 result?;
548
549 Ok(())
550 }
551
552 pub fn wipe(&mut self) -> Result<&mut Self, Error> {
561 self.wipe_method(|conn| unsafe { s2n_connection_wipe(conn.as_ptr()).into_result() })?;
562 Ok(self)
563 }
564
565 fn trigger_initializer(&mut self) {
566 if !core::mem::replace(&mut self.context_mut().connection_initialized, true) {
567 if let Some(config) = self.config() {
568 if let Some(callback) = config.context().connection_initializer.as_ref() {
569 let future = callback.initialize_connection(self);
570 AsyncCallback::trigger(future, self);
571 }
572 }
573 }
574 }
575
576 fn poll_async_task(&mut self) -> Option<Poll<Result<(), Error>>> {
580 self.take_async_callback().map(|mut callback| {
581 let waker = self.waker().ok_or(Error::MISSING_WAKER)?.clone();
582 let mut ctx = core::task::Context::from_waker(&waker);
583 match Pin::new(&mut callback).poll(self, &mut ctx) {
584 Poll::Ready(result) => Poll::Ready(result),
585 Poll::Pending => {
586 self.set_async_callback(callback);
588 Poll::Pending
589 }
590 }
591 })
592 }
593
594 pub(crate) fn poll_negotiate_method<F, T>(
595 &mut self,
596 mut negotiate: F,
597 ) -> Poll<Result<(), Error>>
598 where
599 F: FnMut(&mut Connection) -> Poll<Result<T, Error>>,
600 {
601 self.trigger_initializer();
602
603 loop {
604 match self.poll_async_task().unwrap_or(Poll::Ready(Ok(()))) {
606 Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
607 Poll::Pending => return Poll::Pending,
608 Poll::Ready(Ok(_)) => {}
609 };
610
611 match negotiate(self) {
612 Poll::Ready(res) => return Poll::Ready(res.map(|_| ())),
613 Poll::Pending => {
614 if self.context_mut().async_callback.is_some() {
620 continue;
622 }
623
624 return Poll::Pending;
626 }
627 }
628 }
629 }
630
631 pub fn poll_negotiate(&mut self) -> Poll<Result<&mut Self, Error>> {
643 let mut blocked = s2n_blocked_status::NOT_BLOCKED;
644 self.poll_negotiate_method(|conn| unsafe {
645 s2n_negotiate(conn.as_ptr(), &mut blocked).into_poll()
646 })
647 .map_ok(|_| self)
648 }
649
650 #[cfg(not(feature = "unstable-renegotiate"))]
657 pub fn poll_send(&mut self, buf: &[u8]) -> Poll<Result<usize, Error>> {
658 let mut blocked = s2n_blocked_status::NOT_BLOCKED;
659 let buf_len: isize = buf.len().try_into().map_err(|_| Error::INVALID_INPUT)?;
660 let buf_ptr = buf.as_ptr() as *const ::libc::c_void;
661 unsafe { s2n_send(self.connection.as_ptr(), buf_ptr, buf_len, &mut blocked).into_poll() }
662 }
663
664 #[cfg(not(feature = "unstable-renegotiate"))]
665 pub(crate) fn poll_recv_raw(
666 &mut self,
667 buf_ptr: *mut ::libc::c_void,
668 buf_len: isize,
669 ) -> Poll<Result<usize, Error>> {
670 let mut blocked = s2n_blocked_status::NOT_BLOCKED;
671 unsafe { s2n_recv(self.connection.as_ptr(), buf_ptr, buf_len, &mut blocked).into_poll() }
672 }
673
674 pub fn poll_recv(&mut self, buf: &mut [u8]) -> Poll<Result<usize, Error>> {
682 let buf_len: isize = buf.len().try_into().map_err(|_| Error::INVALID_INPUT)?;
683 let buf_ptr = buf.as_ptr() as *mut ::libc::c_void;
684 self.poll_recv_raw(buf_ptr, buf_len)
685 }
686
687 pub fn poll_recv_uninitialized(
701 &mut self,
702 buf: &mut [MaybeUninit<u8>],
703 ) -> Poll<Result<usize, Error>> {
704 let buf_len: isize = buf.len().try_into().map_err(|_| Error::INVALID_INPUT)?;
705 let buf_ptr = buf.as_ptr() as *mut ::libc::c_void;
706
707 self.poll_recv_raw(buf_ptr, buf_len)
713 }
714
715 pub fn poll_flush(&mut self) -> Poll<Result<&mut Self, Error>> {
724 let mut blocked = s2n_blocked_status::NOT_BLOCKED;
725 unsafe {
726 s2n_flush(self.connection.as_ptr(), &mut blocked)
727 .into_poll()
728 .map_ok(|_| self)
729 }
730 }
731
732 pub fn peek_len(&self) -> usize {
736 unsafe { s2n_peek(self.connection.as_ptr()) as usize }
737 }
738
739 pub fn peek_buffered_len(&self) -> usize {
750 unsafe { s2n_peek_buffered(self.connection.as_ptr()) as usize }
751 }
752
753 pub fn poll_shutdown(&mut self) -> Poll<Result<&mut Self, Error>> {
761 if !self.remaining_blinding_delay()?.is_zero() {
762 return Poll::Pending;
763 }
764 let mut blocked = s2n_blocked_status::NOT_BLOCKED;
765 unsafe {
766 s2n_shutdown(self.connection.as_ptr(), &mut blocked)
767 .into_poll()
768 .map_ok(|_| self)
769 }
770 }
771
772 pub fn poll_shutdown_send(&mut self) -> Poll<Result<&mut Self, Error>> {
779 if !self.remaining_blinding_delay()?.is_zero() {
780 return Poll::Pending;
781 }
782 let mut blocked = s2n_blocked_status::NOT_BLOCKED;
783 unsafe {
784 s2n_shutdown_send(self.connection.as_ptr(), &mut blocked)
785 .into_poll()
786 .map_ok(|_| self)
787 }
788 }
789
790 pub fn alert(&self) -> Option<u8> {
794 let alert =
795 unsafe { s2n_connection_get_alert(self.connection.as_ptr()).into_result() }.ok()?;
796 Some(alert as u8)
797 }
798
799 pub fn set_server_name(&mut self, server_name: &str) -> Result<&mut Self, Error> {
803 let server_name = std::ffi::CString::new(server_name).map_err(|_| Error::INVALID_INPUT)?;
804 unsafe {
805 s2n_set_server_name(self.connection.as_ptr(), server_name.as_ptr()).into_result()
806 }?;
807 Ok(self)
808 }
809
810 pub fn server_name(&self) -> Option<&str> {
814 unsafe {
815 let server_name = s2n_get_server_name(self.connection.as_ptr());
816 match server_name.into_result() {
817 Ok(server_name) => CStr::from_ptr(server_name).to_str().ok(),
818 Err(_) => None,
819 }
820 }
821 }
822
823 pub fn set_session_ticket(&mut self, session: &[u8]) -> Result<&mut Self, Error> {
827 unsafe {
828 s2n_connection_set_session(self.connection.as_ptr(), session.as_ptr(), session.len())
829 .into_result()
830 }?;
831 Ok(self)
832 }
833
834 pub fn session_ticket_length(&self) -> Result<usize, Error> {
838 let len =
839 unsafe { s2n_connection_get_session_length(self.connection.as_ptr()).into_result()? };
840 Ok(len.try_into().unwrap())
841 }
842
843 pub fn session_ticket(&self, output: &mut [u8]) -> Result<usize, Error> {
855 if output.len() < self.session_ticket_length()? {
856 return Err(Error::INVALID_INPUT);
857 }
858 let written = unsafe {
859 s2n_connection_get_session(self.connection.as_ptr(), output.as_mut_ptr(), output.len())
860 .into_result()?
861 };
862 Ok(written.try_into().unwrap())
863 }
864
865 pub fn set_waker(&mut self, waker: Option<&Waker>) -> Result<&mut Self, Error> {
867 let ctx = self.context_mut();
868
869 if let Some(waker) = waker {
870 if let Some(prev_waker) = ctx.waker.as_mut() {
871 if !prev_waker.will_wake(waker) {
873 prev_waker.clone_from(waker);
874 }
875 } else {
876 ctx.waker = Some(waker.clone());
877 }
878 } else {
879 ctx.waker = None;
880 }
881 Ok(self)
882 }
883
884 pub fn waker(&self) -> Option<&Waker> {
886 let ctx = self.context();
887 ctx.waker.as_ref()
888 }
889
890 fn take_async_callback(&mut self) -> Option<AsyncCallback> {
896 let ctx = self.context_mut();
897 ctx.async_callback.take()
898 }
899
900 pub(crate) fn set_async_callback(&mut self, callback: AsyncCallback) {
902 let ctx = self.context_mut();
903 debug_assert!(ctx.async_callback.is_none());
904 ctx.async_callback = Some(callback);
905 }
906
907 fn context_mut(&mut self) -> &mut Context {
909 unsafe {
910 let ctx = s2n_connection_get_ctx(self.connection.as_ptr())
911 .into_result()
912 .unwrap();
913 &mut *(ctx.as_ptr() as *mut Context)
914 }
915 }
916
917 fn context(&self) -> &Context {
919 unsafe {
920 let ctx = s2n_connection_get_ctx(self.connection.as_ptr())
921 .into_result()
922 .unwrap();
923 &*(ctx.as_ptr() as *mut Context)
924 }
925 }
926
927 unsafe fn drop_context(&mut self) -> Result<(), Error> {
933 let ctx = s2n_connection_get_ctx(self.connection.as_ptr()).into_result();
934 if let Ok(ctx) = ctx {
935 drop(Box::from_raw(ctx.as_ptr() as *mut Context));
936 }
937 s2n_connection_set_ctx(self.connection.as_ptr(), core::ptr::null_mut()).into_result()?;
941 Ok(())
942 }
943
944 pub fn server_name_extension_used(&mut self) {
948 unsafe {
951 s2n_connection_server_name_extension_used(self.connection.as_ptr())
952 .into_result()
953 .unwrap();
954 }
955 }
956
957 pub fn client_cert_used(&self) -> bool {
963 unsafe { s2n_connection_client_cert_used(self.connection.as_ptr()) == 1 }
964 }
965
966 pub fn client_cert_chain_bytes(&self) -> Result<Option<&[u8]>, Error> {
970 if !self.client_cert_used() {
971 return Ok(None);
972 }
973
974 let mut chain = std::ptr::null_mut();
975 let mut len = 0;
976 unsafe {
977 s2n_connection_get_client_cert_chain(self.connection.as_ptr(), &mut chain, &mut len)
978 .into_result()?;
979 }
980
981 if chain.is_null() || len == 0 {
982 return Ok(None);
983 }
984
985 unsafe { Ok(Some(std::slice::from_raw_parts(chain, len as usize))) }
986 }
987
988 pub fn client_hello(&self) -> Result<&crate::client_hello::ClientHello, Error> {
1020 let mut handle =
1021 unsafe { s2n_connection_get_client_hello(self.connection.as_ptr()).into_result()? };
1022 Ok(crate::client_hello::ClientHello::from_ptr(unsafe {
1023 handle.as_mut()
1024 }))
1025 }
1026
1027 pub(crate) fn mark_client_hello_cb_done(&mut self) -> Result<(), Error> {
1029 unsafe {
1030 s2n_client_hello_cb_done(self.connection.as_ptr()).into_result()?;
1031 }
1032 Ok(())
1033 }
1034
1035 pub fn actual_protocol_version(&self) -> Result<Version, Error> {
1039 let version = unsafe {
1040 s2n_connection_get_actual_protocol_version(self.connection.as_ptr()).into_result()?
1041 };
1042 version.try_into()
1043 }
1044
1045 pub fn client_hello_is_sslv2(&self) -> Result<bool, Error> {
1055 let version = unsafe {
1056 s2n_connection_get_client_hello_version(self.connection.as_ptr()).into_result()?
1057 };
1058 let version: Version = version.try_into()?;
1059 Ok(version == Version::SSLV2)
1060 }
1061
1062 pub fn handshake_type(&self) -> Result<&str, Error> {
1064 let handshake = unsafe {
1065 s2n_connection_get_handshake_type_name(self.connection.as_ptr()).into_result()?
1066 };
1067 unsafe {
1068 const_str!(handshake)
1073 }
1074 }
1075
1076 pub fn cipher_suite(&self) -> Result<&str, Error> {
1078 let cipher = unsafe { s2n_connection_get_cipher(self.connection.as_ptr()).into_result()? };
1079 unsafe {
1080 const_str!(cipher)
1085 }
1086 }
1087
1088 #[deprecated = "PQ TLS 1.2 KEM Names are no longer supported. Use kem_group_name() to retrieve PQ TLS 1.3 Group name."]
1090 pub fn kem_name(&self) -> Option<&str> {
1091 let name_bytes = {
1092 let name = unsafe { s2n_connection_get_kem_name(self.connection.as_ptr()) };
1093 if name.is_null() {
1094 return None;
1095 }
1096 name
1097 };
1098
1099 let name_str = unsafe {
1100 const_str!(name_bytes)
1105 };
1106
1107 match name_str {
1108 Ok("NONE") => None,
1109 Ok(name) => Some(name),
1110 Err(_) => {
1111 None
1114 }
1115 }
1116 }
1117
1118 pub fn kem_group_name(&self) -> Option<&str> {
1120 let name_bytes = {
1121 let name = unsafe { s2n_connection_get_kem_group_name(self.connection.as_ptr()) };
1122 if name.is_null() {
1123 return None;
1124 }
1125 name
1126 };
1127
1128 let name_str = unsafe {
1129 const_str!(name_bytes)
1134 };
1135
1136 match name_str {
1137 Ok("NONE") => None,
1138 Ok(name) => Some(name),
1139 Err(_) => {
1140 None
1143 }
1144 }
1145 }
1146
1147 #[deprecated = "Use selected_key_exchange_group instead"]
1149 pub fn selected_curve(&self) -> Result<&str, Error> {
1150 let curve = unsafe { s2n_connection_get_curve(self.connection.as_ptr()).into_result()? };
1151 unsafe {
1152 const_str!(curve)
1157 }
1158 }
1159
1160 pub fn selected_key_exchange_group(&self) -> Option<&str> {
1162 let mut group_name = core::ptr::null();
1163 unsafe {
1164 s2n_connection_get_key_exchange_group(self.connection.as_ptr(), &mut group_name)
1165 .into_result()
1166 .ok()
1167 }?;
1168
1169 unsafe {
1170 const_str!(group_name).ok()
1176 }
1177 }
1178
1179 pub fn selected_signature_algorithm(&self) -> Result<SignatureAlgorithm, Error> {
1181 let mut sig_alg = s2n_tls_signature_algorithm::ANONYMOUS;
1182 unsafe {
1183 s2n_connection_get_selected_signature_algorithm(self.connection.as_ptr(), &mut sig_alg)
1184 .into_result()?;
1185 }
1186 sig_alg.try_into()
1187 }
1188
1189 pub fn signature_scheme(&self) -> Option<&'static str> {
1198 let mut sig_alg: *const std::ffi::c_char = std::ptr::null();
1199 unsafe {
1200 s2n_connection_get_signature_scheme(self.connection.as_ptr(), &mut sig_alg)
1201 .into_result()
1202 .ok()?;
1203 }
1204 let result = unsafe {
1205 cstr_to_str(sig_alg)
1208 };
1209 if result == "none" {
1210 None
1211 } else {
1212 Some(result)
1213 }
1214 }
1215
1216 pub fn selected_hash_algorithm(&self) -> Result<HashAlgorithm, Error> {
1218 let mut hash_alg = s2n_tls_hash_algorithm::NONE;
1219 unsafe {
1220 s2n_connection_get_selected_digest_algorithm(self.connection.as_ptr(), &mut hash_alg)
1221 .into_result()?;
1222 }
1223 hash_alg.try_into()
1224 }
1225
1226 pub fn certificate_match(&self) -> Result<CertSNIMatch, Error> {
1228 let mut cert_match = s2n_cert_sni_match::SNI_NO_MATCH;
1229 unsafe {
1230 s2n_connection_get_certificate_match(self.connection.as_ptr(), &mut cert_match)
1231 .into_result()?;
1232 }
1233 cert_match.try_into()
1234 }
1235
1236 pub fn selected_client_signature_algorithm(&self) -> Result<Option<SignatureAlgorithm>, Error> {
1238 let mut sig_alg = s2n_tls_signature_algorithm::ANONYMOUS;
1239 unsafe {
1240 s2n_connection_get_selected_client_cert_signature_algorithm(
1241 self.connection.as_ptr(),
1242 &mut sig_alg,
1243 )
1244 .into_result()?;
1245 }
1246 Ok(match sig_alg {
1247 s2n_tls_signature_algorithm::ANONYMOUS => None,
1248 sig_alg => Some(sig_alg.try_into()?),
1249 })
1250 }
1251
1252 pub fn selected_client_hash_algorithm(&self) -> Result<Option<HashAlgorithm>, Error> {
1254 let mut hash_alg = s2n_tls_hash_algorithm::NONE;
1255 unsafe {
1256 s2n_connection_get_selected_client_cert_digest_algorithm(
1257 self.connection.as_ptr(),
1258 &mut hash_alg,
1259 )
1260 .into_result()?;
1261 }
1262 Ok(match hash_alg {
1263 s2n_tls_hash_algorithm::NONE => None,
1264 hash_alg => Some(hash_alg.try_into()?),
1265 })
1266 }
1267
1268 pub fn application_protocol(&self) -> Option<&[u8]> {
1270 let protocol = unsafe { s2n_get_application_protocol(self.connection.as_ptr()) };
1271 if protocol.is_null() {
1272 return None;
1273 }
1274 Some(unsafe { CStr::from_ptr(protocol).to_bytes() })
1275 }
1276
1277 pub fn tls_exporter(
1285 &self,
1286 label: &[u8],
1287 context: &[u8],
1288 output: &mut [u8],
1289 ) -> Result<(), Error> {
1290 unsafe {
1291 s2n_connection_tls_exporter(
1292 self.connection.as_ptr(),
1293 label.as_ptr(),
1294 label.len().try_into().map_err(|_| Error::INVALID_INPUT)?,
1295 context.as_ptr(),
1296 context.len().try_into().map_err(|_| Error::INVALID_INPUT)?,
1297 output.as_mut_ptr(),
1298 output.len().try_into().map_err(|_| Error::INVALID_INPUT)?,
1299 )
1300 .into_result()
1301 .map(|_| ())
1302 }
1303 }
1304
1305 pub fn peer_cert_chain(&self) -> Result<CertificateChain<'static>, Error> {
1311 unsafe {
1312 let chain_handle = CertificateChainHandle::allocate()?;
1313 s2n_connection_get_peer_cert_chain(
1314 self.connection.as_ptr(),
1315 chain_handle.cert.as_ptr(),
1316 )
1317 .into_result()
1318 .map(|_| ())?;
1319 Ok(CertificateChain::from_allocated(chain_handle))
1320 }
1321 }
1322
1323 pub fn selected_cert(&self) -> Option<CertificateChain<'_>> {
1333 unsafe {
1334 #[allow(clippy::manual_map)]
1337 if let Some(ptr) =
1338 NonNull::new(s2n_connection_get_selected_cert(self.connection.as_ptr()))
1339 {
1340 Some(CertificateChain::from_ptr_reference(ptr))
1341 } else {
1342 None
1343 }
1344 }
1345 }
1346
1347 pub fn master_secret(&self) -> Result<Vec<u8>, Error> {
1349 let mut secret = vec![0; 48];
1351 unsafe {
1352 s2n_connection_get_master_secret(
1353 self.connection.as_ptr(),
1354 secret.as_mut_ptr(),
1355 secret.len(),
1356 )
1357 .into_result()?;
1358 }
1359 Ok(secret)
1360 }
1361
1362 pub fn serialization_length(&self) -> Result<usize, Error> {
1366 unsafe {
1367 let mut length = 0;
1368 s2n_connection_serialization_length(self.connection.as_ptr(), &mut length)
1369 .into_result()?;
1370 Ok(length.try_into().unwrap())
1371 }
1372 }
1373
1374 pub fn serialize(&self, output: &mut [u8]) -> Result<(), Error> {
1378 unsafe {
1379 s2n_connection_serialize(
1380 self.connection.as_ptr(),
1381 output.as_mut_ptr(),
1382 output.len().try_into().map_err(|_| Error::INVALID_INPUT)?,
1383 )
1384 .into_result()?;
1385 Ok(())
1386 }
1387 }
1388
1389 pub fn deserialize(&mut self, input: &[u8]) -> Result<(), Error> {
1394 let size = input.len();
1395 let input = input.as_ptr() as *mut u8;
1398 unsafe {
1399 s2n_connection_deserialize(
1400 self.as_ptr(),
1401 input,
1402 size.try_into().map_err(|_| Error::INVALID_INPUT)?,
1403 )
1404 .into_result()?;
1405 Ok(())
1406 }
1407 }
1408
1409 pub fn resumed(&self) -> bool {
1413 unsafe { s2n_connection_is_session_resumed(self.connection.as_ptr()) == 1 }
1414 }
1415
1416 pub fn append_psk(&mut self, psk: &Psk) -> Result<(), Error> {
1422 unsafe {
1423 s2n_connection_append_psk(self.as_ptr(), psk.ptr.as_ptr()).into_result()?
1427 };
1428 Ok(())
1429 }
1430
1431 pub fn negotiated_psk_identity_length(&self) -> Result<usize, Error> {
1433 let mut length = 0;
1434 unsafe {
1435 s2n_connection_get_negotiated_psk_identity_length(self.connection.as_ptr(), &mut length)
1436 .into_result()?
1437 };
1438 Ok(length as usize)
1439 }
1440
1441 pub fn negotiated_psk_identity(&self, destination: &mut [u8]) -> Result<(), Error> {
1446 unsafe {
1447 s2n_connection_get_negotiated_psk_identity(
1448 self.connection.as_ptr(),
1449 destination.as_mut_ptr(),
1450 destination.len().min(u16::MAX as usize) as u16,
1451 )
1452 .into_result()?;
1453 }
1454 Ok(())
1455 }
1456
1457 pub fn set_application_context<T: Send + Sync + 'static>(&mut self, app_context: T) {
1464 let context_type_id = TypeId::of::<T>();
1465 self.context_mut()
1466 .app_context
1467 .insert(context_type_id, Box::new(app_context));
1468 }
1469
1470 pub fn remove_application_context<T: Send + Sync + 'static>(
1475 &mut self,
1476 ) -> Option<Box<dyn Any + Send + Sync>> {
1477 let context_type_id = TypeId::of::<T>();
1478 self.context_mut().app_context.remove(&context_type_id)
1479 }
1480
1481 pub fn application_context<T: Send + Sync + 'static>(&self) -> Option<&T> {
1490 let context_type_id = TypeId::of::<T>();
1491 self.context()
1495 .app_context
1496 .get(&context_type_id)
1497 .and_then(|app_context| app_context.downcast_ref::<T>())
1498 }
1499
1500 pub fn application_context_mut<T: Send + Sync + 'static>(&mut self) -> Option<&mut T> {
1509 let context_type_id = TypeId::of::<T>();
1510 self.context_mut()
1511 .app_context
1512 .get_mut(&context_type_id)
1513 .and_then(|app_context| app_context.downcast_mut::<T>())
1514 }
1515
1516 #[cfg(feature = "unstable-cert_authorities")]
1517 pub(crate) fn cert_request_state(&mut self) -> &mut CertRequestState {
1518 &mut self.context_mut().cert_request_state
1519 }
1520
1521 #[cfg(feature = "unstable-renegotiate")]
1522 pub(crate) fn renegotiate_state_mut(&mut self) -> &mut RenegotiateState {
1523 &mut self.context_mut().renegotiate_state
1524 }
1525
1526 #[cfg(feature = "unstable-renegotiate")]
1527 pub(crate) fn renegotiate_state(&self) -> &RenegotiateState {
1528 &self.context().renegotiate_state
1529 }
1530}
1531
1532struct Context {
1533 mode: Mode,
1534 waker: Option<Waker>,
1535 async_callback: Option<AsyncCallback>,
1536 verify_host_callback: Option<Box<dyn VerifyHostNameCallback>>,
1537 connection_initialized: bool,
1538 app_context: HashMap<TypeId, Box<dyn Any + Send + Sync>>,
1539 #[cfg(feature = "unstable-renegotiate")]
1540 pub(crate) renegotiate_state: RenegotiateState,
1541 #[cfg(feature = "unstable-cert_authorities")]
1542 pub(crate) cert_request_state: CertRequestState,
1543}
1544
1545impl Context {
1546 fn new(mode: Mode) -> Self {
1547 Context {
1548 mode,
1549 waker: None,
1550 async_callback: None,
1551 verify_host_callback: None,
1552 connection_initialized: false,
1553 app_context: HashMap::new(),
1554 #[cfg(feature = "unstable-renegotiate")]
1555 renegotiate_state: RenegotiateState::default(),
1556 #[cfg(feature = "unstable-cert_authorities")]
1557 cert_request_state: CertRequestState::default(),
1558 }
1559 }
1560}
1561
1562#[cfg(feature = "quic")]
1563impl Connection {
1564 pub fn enable_quic(&mut self) -> Result<&mut Self, Error> {
1566 unsafe { s2n_connection_enable_quic(self.connection.as_ptr()).into_result() }?;
1567 Ok(self)
1568 }
1569
1570 pub fn set_quic_transport_parameters(&mut self, buffer: &[u8]) -> Result<&mut Self, Error> {
1572 unsafe {
1573 s2n_connection_set_quic_transport_parameters(
1574 self.connection.as_ptr(),
1575 buffer.as_ptr(),
1576 buffer.len().try_into().map_err(|_| Error::INVALID_INPUT)?,
1577 )
1578 .into_result()
1579 }?;
1580 Ok(self)
1581 }
1582
1583 pub fn quic_transport_parameters(&mut self) -> Result<&[u8], Error> {
1585 let mut ptr = core::ptr::null();
1586 let mut len = 0;
1587 unsafe {
1588 s2n_connection_get_quic_transport_parameters(
1589 self.connection.as_ptr(),
1590 &mut ptr,
1591 &mut len,
1592 )
1593 .into_result()
1594 }?;
1595 let buffer = unsafe { core::slice::from_raw_parts(ptr, len as _) };
1596 Ok(buffer)
1597 }
1598
1599 pub unsafe fn set_secret_callback(
1605 &mut self,
1606 callback: s2n_secret_cb,
1607 context: *mut c_void,
1608 ) -> Result<&mut Self, Error> {
1609 s2n_connection_set_secret_callback(self.connection.as_ptr(), callback, context)
1610 .into_result()?;
1611 Ok(self)
1612 }
1613
1614 pub fn quic_process_post_handshake_message(&mut self) -> Result<&mut Self, Error> {
1616 let mut blocked = s2n_blocked_status::NOT_BLOCKED;
1617 unsafe {
1618 s2n_recv_quic_post_handshake_message(self.connection.as_ptr(), &mut blocked)
1619 .into_result()
1620 }?;
1621 Ok(self)
1622 }
1623
1624 pub fn are_session_tickets_enabled(&self) -> bool {
1628 unsafe { s2n_connection_are_session_tickets_enabled(self.connection.as_ptr()) }
1629 }
1630}
1631
1632impl AsRef<Connection> for Connection {
1633 fn as_ref(&self) -> &Connection {
1634 self
1635 }
1636}
1637
1638impl AsMut<Connection> for Connection {
1639 fn as_mut(&mut self) -> &mut Connection {
1640 self
1641 }
1642}
1643
1644impl Drop for Connection {
1645 fn drop(&mut self) {
1647 unsafe {
1649 let _ = self.drop_context();
1651
1652 let _ = self.drop_config();
1654
1655 let _ = s2n_connection_free(self.connection.as_ptr()).into_result();
1657 }
1658 }
1659}
1660
1661#[cfg(test)]
1662mod tests {
1663 use futures_test::task::noop_waker;
1664
1665 use super::*;
1666 use crate::{
1667 security::Policy,
1668 testing::{build_config, config_builder, LIFOSessionResumption, SniTestCerts, TestPair},
1669 };
1670 use std::{
1671 net::{IpAddr, Ipv4Addr, SocketAddr},
1672 time::SystemTime,
1673 };
1674
1675 #[test]
1677 fn context_send_test() {
1678 fn assert_send<T: 'static + Send>() {}
1679 assert_send::<Context>();
1680 }
1681
1682 #[test]
1684 fn context_sync_test() {
1685 fn assert_sync<T: 'static + Sync>() {}
1686 assert_sync::<Context>();
1687 }
1688
1689 #[test]
1691 fn test_app_context_set_and_retrieve() {
1692 let mut connection = Connection::new_server();
1693
1694 let test_value: u32 = 1142;
1695
1696 assert!(connection.application_context::<u32>().is_none());
1698
1699 connection.set_application_context(test_value);
1700
1701 assert_eq!(*connection.application_context::<u32>().unwrap(), 1142);
1703 }
1704
1705 #[test]
1707 fn test_app_context_modify() {
1708 let test_value: u64 = 0;
1709
1710 let mut connection = Connection::new_server();
1711 connection.set_application_context(test_value);
1712
1713 let context_value = connection.application_context_mut::<u64>().unwrap();
1714 *context_value += 1;
1715
1716 assert_eq!(*connection.application_context::<u64>().unwrap(), 1);
1717 }
1718
1719 #[test]
1721 fn test_app_context_override() {
1722 let mut connection = Connection::new_server();
1723
1724 let test_value: u16 = 1142;
1725 connection.set_application_context(test_value);
1726
1727 assert_eq!(*connection.application_context::<u16>().unwrap(), 1142);
1728
1729 let test_value: u16 = 10;
1731 connection.set_application_context(test_value);
1732
1733 assert_eq!(*connection.application_context::<u16>().unwrap(), 10);
1734
1735 let test_value: i16 = -20;
1737 connection.set_application_context(test_value);
1738
1739 assert_eq!(*connection.application_context::<i16>().unwrap(), -20);
1740 }
1741
1742 #[test]
1744 fn test_multiple_app_contexts() {
1745 let mut connection = Connection::new_server();
1746
1747 let first_test_value: u16 = 1142;
1748 connection.set_application_context(first_test_value);
1749
1750 assert_eq!(*connection.application_context::<u16>().unwrap(), 1142);
1751
1752 let second_test_value: SocketAddr =
1754 SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 8080);
1755 connection.set_application_context(second_test_value);
1756
1757 assert_eq!(
1758 *connection.application_context::<SocketAddr>().unwrap(),
1759 second_test_value
1760 );
1761
1762 assert_eq!(
1764 second_test_value,
1765 *connection
1766 .remove_application_context::<SocketAddr>()
1767 .unwrap()
1768 .downcast::<SocketAddr>()
1769 .unwrap()
1770 );
1771
1772 assert!(connection.application_context::<SocketAddr>().is_none());
1773 }
1774
1775 #[test]
1777 fn test_app_context_invalid_type() {
1778 let mut connection = Connection::new_server();
1779
1780 let test_value: u32 = 0;
1781 connection.set_application_context(test_value);
1782
1783 assert!(connection.application_context::<i16>().is_none());
1785
1786 assert!(connection.application_context::<u32>().is_some());
1788 }
1789
1790 #[test]
1793 fn test_certificate_match_variants() -> Result<(), Box<dyn std::error::Error>> {
1794 let scenarios = vec![
1795 (None, CertSNIMatch::NoSNI),
1796 (Some("nonmatching_sni"), CertSNIMatch::NoMatch),
1797 (Some("127.0.0.1"), CertSNIMatch::ExactMatch),
1798 ];
1799
1800 for (sni_opt, expected) in scenarios {
1801 let config = build_config(&security::DEFAULT_TLS13)?;
1802 let mut pair = TestPair::from_config(&config);
1803
1804 if let Some(sni) = sni_opt {
1805 pair.client.set_server_name(sni)?;
1806 }
1807
1808 pair.handshake()?;
1809 let cert_match = pair.server.certificate_match()?;
1810
1811 assert_eq!(cert_match, expected,);
1812 }
1813
1814 Ok(())
1815 }
1816
1817 #[test]
1819 fn test_certificate_match_returns_wildcard_match() -> Result<(), Box<dyn std::error::Error>> {
1820 let wildcard_cert = SniTestCerts::WildcardInsectRsa.get();
1821
1822 let mut builder = crate::config::Builder::new();
1823 builder.load_pem(wildcard_cert.cert(), wildcard_cert.key())?;
1824 let server_config = builder.build()?;
1825
1826 let mut client_builder = crate::config::Builder::new();
1827 client_builder.trust_pem(wildcard_cert.cert())?;
1828 let client_config = client_builder.build()?;
1829
1830 let mut pair = TestPair::from_configs(&client_config, &server_config);
1831
1832 pair.client.set_server_name("anything.insect.hexapod")?;
1833 pair.handshake()?;
1834
1835 let cert_match = pair.server.certificate_match()?;
1836 assert_eq!(cert_match, CertSNIMatch::WildcardMatch);
1837
1838 Ok(())
1839 }
1840
1841 #[cfg(s2n_tls_external_build)]
1843 #[test]
1844 fn test_unstable_as_ptr() -> Result<(), Error> {
1845 let mut connection = Connection::new_client();
1846
1847 let test_server_name = "test-server-name";
1848 connection.set_server_name(test_server_name)?;
1849
1850 let server_name = unsafe {
1851 let server_name = s2n_get_server_name(connection.unstable_as_ptr());
1852 CStr::from_ptr(server_name).to_str().unwrap()
1853 };
1854
1855 assert_eq!(server_name, test_server_name);
1856 Ok(())
1857 }
1858
1859 #[test]
1860 fn signature_scheme_before_handshake() {
1861 let connection = Connection::new_server();
1862 assert_eq!(connection.signature_scheme(), None);
1863 }
1864
1865 #[test]
1866 fn signature_scheme_after_handshake() -> Result<(), Box<dyn std::error::Error>> {
1867 let server_config = {
1868 let policy = Policy::from_version("20240503")?;
1870 let mut builder = config_builder(&policy).unwrap();
1871 builder.add_session_ticket_key(
1872 b"a key name",
1873 b"good enough bytes for test",
1874 SystemTime::UNIX_EPOCH,
1875 )?;
1876 builder.build()?
1877 };
1878
1879 let client_config = {
1880 let session_tickets = LIFOSessionResumption::default();
1881 let mut builder = config_builder(&security::DEFAULT).unwrap();
1882 builder.enable_session_tickets(true)?;
1883 builder.set_session_ticket_callback(session_tickets.clone())?;
1884 builder.set_connection_initializer(session_tickets.clone())?;
1885 builder.build()?
1886 };
1887
1888 {
1890 let mut test_pair = TestPair::from_configs(&client_config, &server_config);
1891 test_pair.client.set_waker(Some(&noop_waker()))?;
1892 test_pair.handshake().unwrap();
1893 assert!(test_pair.client.poll_recv(&mut [0]).is_pending());
1895
1896 for conn in [&test_pair.client, &test_pair.server] {
1897 assert_eq!(conn.signature_scheme(), Some("rsa_pss_rsae_sha256"))
1898 }
1899 }
1900
1901 {
1903 let mut test_pair = TestPair::from_configs(&client_config, &server_config);
1904 test_pair.client.set_waker(Some(&noop_waker()))?;
1905 test_pair.handshake().unwrap();
1906 assert!(test_pair.client.resumed());
1907 assert_eq!(test_pair.client.signature_scheme(), None);
1908 assert_eq!(test_pair.server.signature_scheme(), None);
1909 }
1910 Ok(())
1911 }
1912}