1#[cfg(feature = "unstable-renegotiate")]
5use crate::renegotiate::RenegotiateCallback;
6use crate::{
7 callbacks::*,
8 cert_chain::CertificateChain,
9 enums::*,
10 error::{Error, ErrorType, Fallible},
11 security,
12};
13use core::{convert::TryInto, ptr::NonNull};
14use s2n_tls_sys::*;
15use std::{
16 ffi::{c_void, CString},
17 path::Path,
18 pin::Pin,
19 sync::atomic::{AtomicUsize, Ordering},
20 task::Poll,
21 time::{Duration, SystemTime},
22};
23
24#[derive(Debug, PartialEq)]
26pub struct Config(NonNull<s2n_config>);
27
28unsafe impl Send for Config {}
32
33unsafe impl Sync for Config {}
38
39impl Config {
40 pub fn new() -> Self {
52 Self::default()
53 }
54
55 pub fn builder() -> Builder {
57 Builder::default()
58 }
59
60 pub(crate) unsafe fn from_raw(config: NonNull<s2n_config>) -> Self {
67 let config = Self(config);
68
69 config.context();
72
73 config
74 }
75
76 pub(crate) fn as_mut_ptr(&mut self) -> *mut s2n_config {
77 self.0.as_ptr()
78 }
79
80 pub(crate) fn context(&self) -> &Context {
84 let mut ctx = core::ptr::null_mut();
85 unsafe {
86 s2n_config_get_ctx(self.0.as_ptr(), &mut ctx)
87 .into_result()
88 .unwrap();
89 &*(ctx as *const Context)
90 }
91 }
92
93 pub(crate) unsafe fn context_mut(&mut self) -> &mut Context {
101 let mut ctx = core::ptr::null_mut();
102 s2n_config_get_ctx(self.as_mut_ptr(), &mut ctx)
103 .into_result()
104 .unwrap();
105 &mut *(ctx as *mut Context)
106 }
107
108 #[cfg(test)]
109 pub fn test_get_refcount(&self) -> Result<usize, Error> {
111 let context = self.context();
112 Ok(context.refcount.load(Ordering::SeqCst))
113 }
114}
115
116impl Default for Config {
117 fn default() -> Self {
118 Builder::new().build().unwrap()
119 }
120}
121
122impl Clone for Config {
123 fn clone(&self) -> Self {
124 let context = self.context();
125
126 let _count = context.refcount.fetch_add(1, Ordering::Relaxed);
133 Self(self.0)
134 }
135}
136
137impl Drop for Config {
138 fn drop(&mut self) {
140 let context = self.context();
141 let count = context.refcount.fetch_sub(1, Ordering::Release);
142 debug_assert!(count > 0, "refcount should not drop below 1 instance");
143
144 if count != 1 {
146 return;
147 }
148
149 std::sync::atomic::fence(Ordering::Acquire);
162
163 let context = unsafe {
165 Box::from_raw(self.context_mut())
168 };
169 drop(context);
170
171 let _ = unsafe { s2n_config_free(self.0.as_ptr()).into_result() };
172 }
173}
174
175pub struct Builder {
176 pub(crate) config: Config,
177 load_system_certs: bool,
178 enable_ocsp: bool,
179}
180
181impl Builder {
182 pub fn new() -> Self {
194 crate::init::init();
195 let config = unsafe { s2n_config_new_minimal().into_result() }.unwrap();
196
197 let context = Box::<Context>::default();
198 let context = Box::into_raw(context) as *mut c_void;
199
200 unsafe {
201 s2n_config_set_ctx(config.as_ptr(), context)
202 .into_result()
203 .unwrap();
204
205 s2n_config_set_client_hello_cb_mode(
209 config.as_ptr(),
210 s2n_client_hello_cb_mode::NONBLOCKING,
211 )
212 .into_result()
213 .unwrap();
214 }
215
216 Self {
217 config: Config(config),
218 load_system_certs: true,
219 enable_ocsp: false,
220 }
221 }
222
223 pub fn set_alert_behavior(&mut self, value: AlertBehavior) -> Result<&mut Self, Error> {
225 unsafe { s2n_config_set_alert_behavior(self.as_mut_ptr(), value.into()).into_result() }?;
226 Ok(self)
227 }
228
229 pub fn set_security_policy(&mut self, policy: &security::Policy) -> Result<&mut Self, Error> {
231 unsafe {
232 s2n_config_set_cipher_preferences(self.as_mut_ptr(), policy.as_cstr().as_ptr())
233 .into_result()
234 }?;
235 Ok(self)
236 }
237
238 pub fn set_application_protocol_preference<P: IntoIterator<Item = I>, I: AsRef<[u8]>>(
249 &mut self,
250 protocols: P,
251 ) -> Result<&mut Self, Error> {
252 unsafe {
254 s2n_config_set_protocol_preferences(self.as_mut_ptr(), core::ptr::null(), 0)
255 .into_result()
256 }?;
257
258 for protocol in protocols {
259 self.append_application_protocol_preference(protocol.as_ref())?;
260 }
261
262 Ok(self)
263 }
264
265 pub fn append_application_protocol_preference(
267 &mut self,
268 protocol: &[u8],
269 ) -> Result<&mut Self, Error> {
270 unsafe {
271 s2n_config_append_protocol_preference(
272 self.as_mut_ptr(),
273 protocol.as_ptr(),
274 protocol
275 .len()
276 .try_into()
277 .map_err(|_| Error::INVALID_INPUT)?,
278 )
279 .into_result()
280 }?;
281 Ok(self)
282 }
283
284 pub unsafe fn disable_x509_verification(&mut self) -> Result<&mut Self, Error> {
292 s2n_config_disable_x509_verification(self.as_mut_ptr()).into_result()?;
293 Ok(self)
294 }
295
296 pub fn add_dhparams(&mut self, pem: &[u8]) -> Result<&mut Self, Error> {
298 let cstring = CString::new(pem).map_err(|_| Error::INVALID_INPUT)?;
299 unsafe { s2n_config_add_dhparams(self.as_mut_ptr(), cstring.as_ptr()).into_result() }?;
300 Ok(self)
301 }
302
303 pub fn load_pem(&mut self, certificate: &[u8], private_key: &[u8]) -> Result<&mut Self, Error> {
312 let certificate = CString::new(certificate).map_err(|_| Error::INVALID_INPUT)?;
313 let private_key = CString::new(private_key).map_err(|_| Error::INVALID_INPUT)?;
314 unsafe {
315 s2n_config_add_cert_chain_and_key(
316 self.as_mut_ptr(),
317 certificate.as_ptr(),
318 private_key.as_ptr(),
319 )
320 .into_result()
321 }?;
322 Ok(self)
323 }
324
325 pub fn load_chain(&mut self, chain: CertificateChain<'static>) -> Result<&mut Self, Error> {
327 let result = unsafe {
333 s2n_config_add_cert_chain_and_key_to_store(
334 self.as_mut_ptr(),
335 chain.as_ptr() as *mut _,
338 )
339 .into_result()
340 };
341 let context = unsafe {
342 self.config.context_mut()
345 };
346 context.application_owned_certs.push(chain);
347 result?;
348
349 Ok(self)
350 }
351
352 pub fn set_default_chains<T: IntoIterator<Item = CertificateChain<'static>>>(
354 &mut self,
355 chains: T,
356 ) -> Result<&mut Self, Error> {
357 const CHAINS_MAX_COUNT: usize = 3;
359
360 let mut chain_arrays: [Option<CertificateChain<'static>>; CHAINS_MAX_COUNT] =
361 [None, None, None];
362 let mut pointer_array = [std::ptr::null_mut(); CHAINS_MAX_COUNT];
363 let mut cert_chain_count = 0;
364
365 for chain in chains.into_iter() {
366 if cert_chain_count >= CHAINS_MAX_COUNT {
367 return Err(Error::bindings(
368 ErrorType::UsageError,
369 "InvalidInput",
370 "A single default can be specified for RSA, ECDSA,
371 and RSA-PSS auth types, but more than 3 certs were supplied",
372 ));
373 }
374
375 pointer_array[cert_chain_count] = chain.as_ptr() as *mut _;
378 chain_arrays[cert_chain_count] = Some(chain);
379
380 cert_chain_count += 1;
381 }
382
383 let collected_chains = chain_arrays.into_iter().take(cert_chain_count).flatten();
384
385 let context = unsafe {
386 self.config.context_mut()
389 };
390 context.application_owned_certs.extend(collected_chains);
391
392 unsafe {
393 s2n_config_set_cert_chain_and_key_defaults(
394 self.as_mut_ptr(),
395 pointer_array.as_mut_ptr(),
396 cert_chain_count as u32,
397 )
398 .into_result()
399 }?;
400
401 Ok(self)
402 }
403
404 pub fn load_public_pem(&mut self, certificate: &[u8]) -> Result<&mut Self, Error> {
406 let size: u32 = certificate
407 .len()
408 .try_into()
409 .map_err(|_| Error::INVALID_INPUT)?;
410 let certificate = certificate.as_ptr() as *mut u8;
411 unsafe { s2n_config_add_cert_chain(self.as_mut_ptr(), certificate, size) }.into_result()?;
412 Ok(self)
413 }
414
415 pub fn trust_pem(&mut self, certificate: &[u8]) -> Result<&mut Self, Error> {
417 let certificate = CString::new(certificate).map_err(|_| Error::INVALID_INPUT)?;
418 unsafe {
419 s2n_config_add_pem_to_trust_store(self.as_mut_ptr(), certificate.as_ptr()).into_result()
420 }?;
421 Ok(self)
422 }
423
424 pub fn trust_location(
430 &mut self,
431 file: Option<&Path>,
432 dir: Option<&Path>,
433 ) -> Result<&mut Self, Error> {
434 fn to_cstr(input: Option<&Path>) -> Result<Option<CString>, Error> {
435 Ok(match input {
436 Some(input) => {
437 let string = input.to_str().ok_or(Error::INVALID_INPUT)?;
438 let cstring = CString::new(string).map_err(|_| Error::INVALID_INPUT)?;
439 Some(cstring)
440 }
441 None => None,
442 })
443 }
444
445 let file_cstr = to_cstr(file)?;
446 let file_ptr = file_cstr
447 .as_ref()
448 .map(|f| f.as_ptr())
449 .unwrap_or(core::ptr::null());
450
451 let dir_cstr = to_cstr(dir)?;
452 let dir_ptr = dir_cstr
453 .as_ref()
454 .map(|f| f.as_ptr())
455 .unwrap_or(core::ptr::null());
456
457 unsafe {
458 s2n_config_set_verification_ca_location(self.as_mut_ptr(), file_ptr, dir_ptr)
459 .into_result()
460 }?;
461
462 if !self.enable_ocsp {
465 unsafe {
466 s2n_config_set_status_request_type(self.as_mut_ptr(), s2n_status_request_type::NONE)
467 .into_result()?
468 };
469 }
470
471 Ok(self)
472 }
473
474 pub fn with_system_certs(&mut self, load_system_certs: bool) -> Result<&mut Self, Error> {
481 self.load_system_certs = load_system_certs;
482 Ok(self)
483 }
484
485 pub fn wipe_trust_store(&mut self) -> Result<&mut Self, Error> {
487 unsafe { s2n_config_wipe_trust_store(self.as_mut_ptr()).into_result()? };
488 Ok(self)
489 }
490
491 pub fn set_client_auth_type(&mut self, auth_type: ClientAuthType) -> Result<&mut Self, Error> {
497 unsafe {
498 s2n_config_set_client_auth_type(self.as_mut_ptr(), auth_type.into()).into_result()
499 }?;
500 Ok(self)
501 }
502
503 pub fn enable_ocsp(&mut self) -> Result<&mut Self, Error> {
507 unsafe {
508 s2n_config_set_status_request_type(self.as_mut_ptr(), s2n_status_request_type::OCSP)
509 .into_result()
510 }?;
511 self.enable_ocsp = true;
512 Ok(self)
513 }
514
515 pub fn set_ocsp_data(&mut self, data: &[u8]) -> Result<&mut Self, Error> {
528 let size: u32 = data.len().try_into().map_err(|_| Error::INVALID_INPUT)?;
529 unsafe {
530 s2n_config_set_extension_data(
531 self.as_mut_ptr(),
532 s2n_tls_extension_type::OCSP_STAPLING,
533 data.as_ptr(),
534 size,
535 )
536 .into_result()
537 }?;
538 self.enable_ocsp()
539 }
540
541 pub fn set_verify_host_callback<T: 'static + VerifyHostNameCallback>(
549 &mut self,
550 handler: T,
551 ) -> Result<&mut Self, Error> {
552 unsafe extern "C" fn verify_host_cb_fn(
553 host_name: *const ::libc::c_char,
554 host_name_len: usize,
555 context: *mut ::libc::c_void,
556 ) -> u8 {
557 let context = &mut *(context as *mut Context);
558 let handler = context.verify_host_callback.as_mut().unwrap();
559 verify_host(host_name, host_name_len, handler)
560 }
561
562 let context = unsafe {
563 self.config.context_mut()
566 };
567 context.verify_host_callback = Some(Box::new(handler));
568 unsafe {
569 s2n_config_set_verify_host_callback(
570 self.as_mut_ptr(),
571 Some(verify_host_cb_fn),
572 self.config.context() as *const Context as *mut c_void,
573 )
574 .into_result()?;
575 }
576 Ok(self)
577 }
578
579 pub unsafe fn set_key_log_callback(
585 &mut self,
586 callback: s2n_key_log_fn,
587 context: *mut core::ffi::c_void,
588 ) -> Result<&mut Self, Error> {
589 s2n_config_set_key_log_cb(self.as_mut_ptr(), callback, context).into_result()?;
590 Ok(self)
591 }
592
593 pub fn set_max_cert_chain_depth(&mut self, depth: u16) -> Result<&mut Self, Error> {
595 unsafe { s2n_config_set_max_cert_chain_depth(self.as_mut_ptr(), depth).into_result() }?;
596 Ok(self)
597 }
598
599 pub fn set_send_buffer_size(&mut self, size: u32) -> Result<&mut Self, Error> {
601 unsafe { s2n_config_set_send_buffer_size(self.as_mut_ptr(), size).into_result() }?;
602 Ok(self)
603 }
604
605 pub fn set_client_hello_callback<T: 'static + ClientHelloCallback>(
609 &mut self,
610 handler: T,
611 ) -> Result<&mut Self, Error> {
612 unsafe extern "C" fn client_hello_cb(
613 connection_ptr: *mut s2n_connection,
614 _context: *mut core::ffi::c_void,
615 ) -> libc::c_int {
616 with_context(connection_ptr, |conn, context| {
617 let callback = context.client_hello_callback.as_ref();
618 let future = callback
619 .map(|c| c.on_client_hello(conn))
620 .unwrap_or(Ok(None));
621 AsyncCallback::trigger_client_hello_cb(future, conn)
622 })
623 .into()
624 }
625
626 let handler = Box::new(handler);
627 let context = unsafe {
628 self.config.context_mut()
631 };
632 context.client_hello_callback = Some(handler);
633
634 unsafe {
635 s2n_config_set_client_hello_cb(
636 self.as_mut_ptr(),
637 Some(client_hello_cb),
638 core::ptr::null_mut(),
639 )
640 .into_result()?;
641 }
642
643 Ok(self)
644 }
645
646 pub fn set_connection_initializer<T: 'static + ConnectionInitializer>(
647 &mut self,
648 handler: T,
649 ) -> Result<&mut Self, Error> {
650 let handler = Box::new(handler);
652 let context = unsafe {
653 self.config.context_mut()
656 };
657 context.connection_initializer = Some(handler);
658 Ok(self)
659 }
660
661 pub fn set_session_ticket_callback<T: 'static + SessionTicketCallback>(
665 &mut self,
666 handler: T,
667 ) -> Result<&mut Self, Error> {
668 self.enable_session_tickets(true)?;
670
671 unsafe extern "C" fn session_ticket_cb(
673 conn_ptr: *mut s2n_connection,
674 _context: *mut ::libc::c_void,
675 session_ticket: *mut s2n_session_ticket,
676 ) -> libc::c_int {
677 let session_ticket = SessionTicket::from_ptr(&*session_ticket);
678 with_context(conn_ptr, |conn, context| {
679 let callback = context.session_ticket_callback.as_ref();
680 callback.map(|c| c.on_session_ticket(conn, session_ticket))
681 });
682 CallbackResult::Success.into()
683 }
684
685 let handler = Box::new(handler);
687 let context = unsafe {
688 self.config.context_mut()
691 };
692 context.session_ticket_callback = Some(handler);
693
694 unsafe {
695 s2n_config_set_session_ticket_cb(
696 self.as_mut_ptr(),
697 Some(session_ticket_cb),
698 self.config.context() as *const Context as *mut c_void,
699 )
700 .into_result()
701 }?;
702 Ok(self)
703 }
704
705 pub fn set_private_key_callback<T: 'static + PrivateKeyCallback>(
711 &mut self,
712 handler: T,
713 ) -> Result<&mut Self, Error> {
714 unsafe extern "C" fn private_key_cb(
715 conn_ptr: *mut s2n_connection,
716 op_ptr: *mut s2n_async_pkey_op,
717 ) -> libc::c_int {
718 with_context(conn_ptr, |conn, context| {
719 let state = PrivateKeyOperation::try_from_cb(conn, op_ptr);
720 let callback = context.private_key_callback.as_ref();
721 let future_result = state.and_then(|state| {
722 callback.map_or(Ok(None), |callback| callback.handle_operation(conn, state))
723 });
724 AsyncCallback::trigger(future_result, conn)
725 })
726 .into()
727 }
728
729 let handler = Box::new(handler);
730 let context = unsafe {
731 self.config.context_mut()
734 };
735 context.private_key_callback = Some(handler);
736
737 unsafe {
738 s2n_config_set_async_pkey_callback(self.as_mut_ptr(), Some(private_key_cb))
739 .into_result()?;
740 }
741 Ok(self)
742 }
743
744 pub fn set_wall_clock<T: 'static + WallClock>(
752 &mut self,
753 handler: T,
754 ) -> Result<&mut Self, Error> {
755 unsafe extern "C" fn clock_cb(
756 context: *mut ::libc::c_void,
757 time_in_nanos: *mut u64,
758 ) -> libc::c_int {
759 let context = &mut *(context as *mut Context);
760 if let Some(handler) = context.wall_clock.as_mut() {
761 if let Ok(nanos) = handler.get_time_since_epoch().as_nanos().try_into() {
762 *time_in_nanos = nanos;
763 return CallbackResult::Success.into();
764 }
765 }
766 CallbackResult::Failure.into()
767 }
768
769 let handler = Box::new(handler);
770 let context = unsafe {
771 self.config.context_mut()
774 };
775 context.wall_clock = Some(handler);
776 unsafe {
777 s2n_config_set_wall_clock(
778 self.as_mut_ptr(),
779 Some(clock_cb),
780 self.config.context() as *const _ as *mut c_void,
781 )
782 .into_result()?;
783 }
784 Ok(self)
785 }
786
787 pub fn set_monotonic_clock<T: 'static + MonotonicClock>(
795 &mut self,
796 handler: T,
797 ) -> Result<&mut Self, Error> {
798 unsafe extern "C" fn clock_cb(
799 context: *mut ::libc::c_void,
800 time_in_nanos: *mut u64,
801 ) -> libc::c_int {
802 let context = &mut *(context as *mut Context);
803 if let Some(handler) = context.monotonic_clock.as_mut() {
804 if let Ok(nanos) = handler.get_time().as_nanos().try_into() {
805 *time_in_nanos = nanos;
806 return CallbackResult::Success.into();
807 }
808 }
809 CallbackResult::Failure.into()
810 }
811
812 let handler = Box::new(handler);
813 let context = unsafe {
814 self.config.context_mut()
817 };
818 context.monotonic_clock = Some(handler);
819 unsafe {
820 s2n_config_set_monotonic_clock(
821 self.as_mut_ptr(),
822 Some(clock_cb),
823 self.config.context() as *const _ as *mut c_void,
824 )
825 .into_result()?;
826 }
827 Ok(self)
828 }
829
830 pub fn enable_session_tickets(&mut self, enable: bool) -> Result<&mut Self, Error> {
834 unsafe {
835 s2n_config_set_session_tickets_onoff(self.as_mut_ptr(), enable.into()).into_result()
836 }?;
837 Ok(self)
838 }
839
840 pub fn add_session_ticket_key(
846 &mut self,
847 key_name: &[u8],
848 key: &[u8],
849 intro_time: SystemTime,
850 ) -> Result<&mut Self, Error> {
851 let key_name_len: u32 = key_name
852 .len()
853 .try_into()
854 .map_err(|_| Error::INVALID_INPUT)?;
855 let key_len: u32 = key.len().try_into().map_err(|_| Error::INVALID_INPUT)?;
856 let intro_time = intro_time
857 .duration_since(std::time::UNIX_EPOCH)
858 .map_err(|_| Error::INVALID_INPUT)?;
859 if key_len < 16 {
862 return Err(Error::INVALID_INPUT);
863 }
864 self.enable_session_tickets(true)?;
865 unsafe {
866 s2n_config_add_ticket_crypto_key(
867 self.as_mut_ptr(),
868 key_name.as_ptr(),
869 key_name_len,
870 key.as_ptr() as *mut u8,
872 key_len,
873 intro_time.as_secs(),
874 )
875 .into_result()
876 }?;
877 Ok(self)
878 }
879
880 pub fn set_ticket_key_encrypt_decrypt_lifetime(
885 &mut self,
886 lifetime: Duration,
887 ) -> Result<&mut Self, Error> {
888 unsafe {
889 s2n_config_set_ticket_encrypt_decrypt_key_lifetime(
890 self.as_mut_ptr(),
891 lifetime.as_secs(),
892 )
893 .into_result()
894 }?;
895 Ok(self)
896 }
897
898 pub fn set_ticket_key_decrypt_lifetime(
902 &mut self,
903 lifetime: Duration,
904 ) -> Result<&mut Self, Error> {
905 unsafe {
906 s2n_config_set_ticket_decrypt_key_lifetime(self.as_mut_ptr(), lifetime.as_secs())
907 .into_result()
908 }?;
909 Ok(self)
910 }
911
912 pub fn set_serialization_version(
917 &mut self,
918 version: SerializationVersion,
919 ) -> Result<&mut Self, Error> {
920 unsafe {
921 s2n_config_set_serialization_version(self.as_mut_ptr(), version.into()).into_result()
922 }?;
923 Ok(self)
924 }
925
926 pub fn set_max_blinding_delay(&mut self, seconds: u32) -> Result<&mut Self, Error> {
930 unsafe { s2n_config_set_max_blinding_delay(self.as_mut_ptr(), seconds).into_result() }?;
931 Ok(self)
932 }
933
934 pub fn require_ticket_forward_secrecy(&mut self, enable: bool) -> Result<&mut Self, Error> {
938 unsafe {
939 s2n_config_require_ticket_forward_secrecy(self.as_mut_ptr(), enable).into_result()
940 }?;
941 Ok(self)
942 }
943
944 pub fn build(mut self) -> Result<Config, Error> {
945 if self.load_system_certs {
946 unsafe {
947 s2n_config_load_system_certs(self.as_mut_ptr()).into_result()?;
948 }
949 }
950
951 Ok(self.config)
952 }
953
954 pub(crate) fn as_mut_ptr(&mut self) -> *mut s2n_config {
955 self.config.as_mut_ptr()
956 }
957}
958
959#[cfg(feature = "quic")]
960impl Builder {
961 pub fn enable_quic(&mut self) -> Result<&mut Self, Error> {
963 unsafe { s2n_tls_sys::s2n_config_enable_quic(self.as_mut_ptr()).into_result() }?;
964 Ok(self)
965 }
966}
967
968impl Default for Builder {
976 fn default() -> Self {
977 Self::new()
978 }
979}
980
981pub(crate) struct Context {
982 refcount: AtomicUsize,
983 application_owned_certs: Vec<CertificateChain<'static>>,
993 pub(crate) client_hello_callback: Option<Box<dyn ClientHelloCallback>>,
994 pub(crate) private_key_callback: Option<Box<dyn PrivateKeyCallback>>,
995 pub(crate) verify_host_callback: Option<Box<dyn VerifyHostNameCallback>>,
996 pub(crate) session_ticket_callback: Option<Box<dyn SessionTicketCallback>>,
997 pub(crate) connection_initializer: Option<Box<dyn ConnectionInitializer>>,
998 pub(crate) wall_clock: Option<Box<dyn WallClock>>,
999 pub(crate) monotonic_clock: Option<Box<dyn MonotonicClock>>,
1000 #[cfg(feature = "unstable-renegotiate")]
1001 pub(crate) renegotiate: Option<Box<dyn RenegotiateCallback>>,
1002}
1003
1004impl Default for Context {
1005 fn default() -> Self {
1006 let refcount = AtomicUsize::new(1);
1009
1010 Self {
1011 refcount,
1012 application_owned_certs: Vec::new(),
1013 client_hello_callback: None,
1014 private_key_callback: None,
1015 verify_host_callback: None,
1016 session_ticket_callback: None,
1017 connection_initializer: None,
1018 wall_clock: None,
1019 monotonic_clock: None,
1020 #[cfg(feature = "unstable-renegotiate")]
1021 renegotiate: None,
1022 }
1023 }
1024}
1025
1026pub trait ConnectionInitializer: 'static + Send + Sync {
1034 fn initialize_connection(
1038 &self,
1039 connection: &mut crate::connection::Connection,
1040 ) -> ConnectionFutureResult;
1041}
1042
1043impl<A: ConnectionInitializer, B: ConnectionInitializer> ConnectionInitializer for (A, B) {
1044 fn initialize_connection(
1045 &self,
1046 connection: &mut crate::connection::Connection,
1047 ) -> ConnectionFutureResult {
1048 let a = self.0.initialize_connection(connection)?;
1049 let b = self.1.initialize_connection(connection)?;
1050 match (a, b) {
1051 (None, None) => Ok(None),
1052 (None, Some(fut)) => Ok(Some(fut)),
1053 (Some(fut), None) => Ok(Some(fut)),
1054 (Some(fut_a), Some(fut_b)) => Ok(Some(Box::pin(ConcurrentConnectionFuture::new([
1055 fut_a, fut_b,
1056 ])))),
1057 }
1058 }
1059}
1060
1061struct ConcurrentConnectionFuture<const N: usize> {
1062 futures: [Option<Pin<Box<dyn ConnectionFuture>>>; N],
1063}
1064
1065impl<const N: usize> ConcurrentConnectionFuture<N> {
1066 fn new(futures: [Pin<Box<dyn ConnectionFuture>>; N]) -> Self {
1067 let futures = futures.map(Some);
1068 Self { futures }
1069 }
1070}
1071
1072impl<const N: usize> ConnectionFuture for ConcurrentConnectionFuture<N> {
1073 fn poll(
1074 mut self: std::pin::Pin<&mut Self>,
1075 connection: &mut crate::connection::Connection,
1076 ctx: &mut core::task::Context,
1077 ) -> std::task::Poll<Result<(), Error>> {
1078 let mut is_pending = false;
1079 for container in self.futures.iter_mut() {
1080 if let Some(future) = container.as_mut() {
1081 match future.as_mut().poll(connection, ctx) {
1082 Poll::Ready(result) => {
1083 result?;
1084 *container = None;
1085 }
1086 Poll::Pending => is_pending = true,
1087 }
1088 }
1089 }
1090 if is_pending {
1091 Poll::Pending
1092 } else {
1093 Poll::Ready(Ok(()))
1094 }
1095 }
1096}
1097
1098#[cfg(test)]
1099mod tests {
1100 use super::*;
1101
1102 #[test]
1104 fn context_send_sync_test() {
1105 fn assert_send_sync<T: 'static + Send + Sync>() {}
1106 assert_send_sync::<Context>();
1107 }
1108}