1#![allow(clippy::result_large_err)]
75#[allow(unused_imports)]
76use core_foundation::array::{CFArray, CFArrayRef};
77use core_foundation::base::{Boolean, TCFType};
78#[cfg(feature = "alpn")]
79use core_foundation::string::CFString;
80use core_foundation::{declare_TCFType, impl_TCFType};
81use core_foundation_sys::base::{kCFAllocatorDefault, OSStatus};
82use std::os::raw::c_void;
83
84#[allow(unused_imports)]
85use security_framework_sys::base::{
86 errSecBadReq, errSecIO, errSecNotTrusted, errSecSuccess, errSecTrustSettingDeny,
87 errSecUnimplemented,
88};
89
90use security_framework_sys::secure_transport::*;
91use std::any::Any;
92use std::cmp;
93use std::fmt;
94use std::io;
95use std::io::prelude::*;
96use std::marker::PhantomData;
97use std::panic::{self, AssertUnwindSafe};
98use std::ptr;
99use std::result;
100use std::slice;
101
102use crate::base::{Error, Result};
103use crate::certificate::SecCertificate;
104use crate::cipher_suite::CipherSuite;
105use crate::cvt;
106use crate::identity::SecIdentity;
107use crate::import_export::Pkcs12ImportOptions;
108use crate::policy::SecPolicy;
109use crate::trust::SecTrust;
110use security_framework_sys::base::errSecParam;
111
112#[derive(Debug, Copy, Clone, PartialEq, Eq)]
114pub struct SslProtocolSide(SSLProtocolSide);
115
116impl SslProtocolSide {
117 pub const CLIENT: Self = Self(kSSLClientSide);
119 pub const SERVER: Self = Self(kSSLServerSide);
121}
122
123#[derive(Debug, Copy, Clone)]
125pub struct SslConnectionType(SSLConnectionType);
126
127impl SslConnectionType {
128 pub const DATAGRAM: Self = Self(kSSLDatagramType);
130 pub const STREAM: Self = Self(kSSLStreamType);
132}
133
134#[derive(Debug)]
136pub enum HandshakeError<S> {
137 Failure(Error),
139 Interrupted(MidHandshakeSslStream<S>),
141}
142
143impl<S> From<Error> for HandshakeError<S> {
144 #[inline(always)]
145 fn from(err: Error) -> Self {
146 Self::Failure(err)
147 }
148}
149
150#[derive(Debug)]
152pub enum ClientHandshakeError<S> {
153 Failure(Error),
155 Interrupted(MidHandshakeClientBuilder<S>),
157}
158
159impl<S> From<Error> for ClientHandshakeError<S> {
160 #[inline(always)]
161 fn from(err: Error) -> Self {
162 Self::Failure(err)
163 }
164}
165
166#[derive(Debug)]
168pub struct MidHandshakeSslStream<S> {
169 stream: SslStream<S>,
170 error: Error,
171}
172
173impl<S> MidHandshakeSslStream<S> {
174 #[inline(always)]
176 #[must_use]
177 pub fn get_ref(&self) -> &S {
178 self.stream.get_ref()
179 }
180
181 #[inline(always)]
183 pub fn get_mut(&mut self) -> &mut S {
184 self.stream.get_mut()
185 }
186
187 #[inline(always)]
189 #[must_use]
190 pub fn context(&self) -> &SslContext {
191 self.stream.context()
192 }
193
194 #[inline(always)]
196 pub fn context_mut(&mut self) -> &mut SslContext {
197 self.stream.context_mut()
198 }
199
200 #[inline(always)]
203 #[must_use]
204 pub fn server_auth_completed(&self) -> bool {
205 self.error.code() == errSSLPeerAuthCompleted
206 }
207
208 #[inline(always)]
211 #[must_use]
212 pub fn client_cert_requested(&self) -> bool {
213 self.error.code() == errSSLClientCertRequested
214 }
215
216 #[inline(always)]
219 #[must_use]
220 pub fn would_block(&self) -> bool {
221 self.error.code() == errSSLWouldBlock
222 }
223
224 #[inline(always)]
226 #[must_use]
227 pub const fn error(&self) -> &Error {
228 &self.error
229 }
230
231 #[inline(always)]
233 pub fn handshake(self) -> result::Result<SslStream<S>, HandshakeError<S>> {
234 self.stream.handshake()
235 }
236}
237
238#[derive(Debug)]
240pub struct MidHandshakeClientBuilder<S> {
241 stream: MidHandshakeSslStream<S>,
242 domain: Option<String>,
243 certs: Vec<SecCertificate>,
244 trust_certs_only: bool,
245 danger_accept_invalid_certs: bool,
246}
247
248impl<S> MidHandshakeClientBuilder<S> {
249 #[inline(always)]
251 #[must_use]
252 pub fn get_ref(&self) -> &S {
253 self.stream.get_ref()
254 }
255
256 #[inline(always)]
258 pub fn get_mut(&mut self) -> &mut S {
259 self.stream.get_mut()
260 }
261
262 #[inline(always)]
264 #[must_use]
265 pub fn error(&self) -> &Error {
266 self.stream.error()
267 }
268
269 pub fn handshake(self) -> result::Result<SslStream<S>, ClientHandshakeError<S>> {
271 let Self {
272 stream,
273 domain,
274 certs,
275 trust_certs_only,
276 danger_accept_invalid_certs,
277 } = self;
278
279 let mut result = stream.handshake();
280 loop {
281 let stream = match result {
282 Ok(stream) => return Ok(stream),
283 Err(HandshakeError::Interrupted(stream)) => stream,
284 Err(HandshakeError::Failure(err)) => return Err(ClientHandshakeError::Failure(err)),
285 };
286
287 if stream.would_block() {
288 let ret = Self {
289 stream,
290 domain,
291 certs,
292 trust_certs_only,
293 danger_accept_invalid_certs,
294 };
295 return Err(ClientHandshakeError::Interrupted(ret));
296 }
297
298 if stream.server_auth_completed() {
299 if danger_accept_invalid_certs {
300 result = stream.handshake();
301 continue;
302 }
303 let Some(mut trust) = stream.context().peer_trust2()? else {
304 result = stream.handshake();
305 continue;
306 };
307 trust.set_anchor_certificates(&certs)?;
308 trust.set_trust_anchor_certificates_only(self.trust_certs_only)?;
309 let policy = SecPolicy::create_ssl(SslProtocolSide::SERVER, domain.as_deref());
310 trust.set_policy(&policy)?;
311 trust.evaluate_with_error().map_err(|error| {
312 #[cfg(feature = "log")]
313 log::warn!("SecTrustEvaluateWithError: {error}");
314 Error::from_code(error.code() as _)
315 })?;
316 result = stream.handshake();
317 continue;
318 }
319
320 let err = Error::from_code(stream.error().code());
321 return Err(ClientHandshakeError::Failure(err));
322 }
323 }
324}
325
326#[derive(Debug, PartialEq, Eq)]
328pub struct SessionState(SSLSessionState);
329
330impl SessionState {
331 pub const ABORTED: Self = Self(kSSLAborted);
333 pub const CLOSED: Self = Self(kSSLClosed);
335 pub const CONNECTED: Self = Self(kSSLConnected);
337 pub const HANDSHAKE: Self = Self(kSSLHandshake);
339 pub const IDLE: Self = Self(kSSLIdle);
341}
342
343#[derive(Debug, Copy, Clone, PartialEq, Eq)]
345pub struct SslAuthenticate(SSLAuthenticate);
346
347impl SslAuthenticate {
348 pub const ALWAYS: Self = Self(kAlwaysAuthenticate);
350 pub const NEVER: Self = Self(kNeverAuthenticate);
352 pub const TRY: Self = Self(kTryAuthenticate);
354}
355
356#[derive(Debug, Copy, Clone, PartialEq, Eq)]
358pub struct SslClientCertificateState(SSLClientCertificateState);
359
360impl SslClientCertificateState {
361 pub const NONE: Self = Self(kSSLClientCertNone);
363 pub const REJECTED: Self = Self(kSSLClientCertRejected);
365 pub const REQUESTED: Self = Self(kSSLClientCertRequested);
367 pub const SENT: Self = Self(kSSLClientCertSent);
369}
370
371#[derive(Debug, Copy, Clone, PartialEq, Eq)]
373pub struct SslProtocol(SSLProtocol);
374
375impl SslProtocol {
376 pub const ALL: Self = Self(kSSLProtocolAll);
378 pub const DTLS1: Self = Self(kDTLSProtocol1);
380 pub const SSL2: Self = Self(kSSLProtocol2);
382 pub const SSL3: Self = Self(kSSLProtocol3);
385 pub const SSL3_ONLY: Self = Self(kSSLProtocol3Only);
387 pub const TLS1: Self = Self(kTLSProtocol1);
390 pub const TLS11: Self = Self(kTLSProtocol11);
393 pub const TLS12: Self = Self(kTLSProtocol12);
396 pub const TLS13: Self = Self(kTLSProtocol13);
399 pub const TLS1_ONLY: Self = Self(kTLSProtocol1Only);
401 pub const UNKNOWN: Self = Self(kSSLProtocolUnknown);
403}
404
405declare_TCFType! {
406 SslContext, SSLContextRef
408}
409
410impl_TCFType!(SslContext, SSLContextRef, SSLContextGetTypeID);
411
412impl fmt::Debug for SslContext {
413 #[cold]
414 fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
415 let mut builder = fmt.debug_struct("SslContext");
416 if let Ok(state) = self.state() {
417 builder.field("state", &state);
418 }
419 builder.finish()
420 }
421}
422
423unsafe impl Sync for SslContext {}
424unsafe impl Send for SslContext {}
425
426#[cfg(target_os = "macos")]
427impl SslContext {
428 pub(crate) fn as_inner(&self) -> SSLContextRef {
429 self.0
430 }
431}
432
433macro_rules! impl_options {
434 ($($(#[$a:meta])* const $opt:ident: $get:ident & $set:ident,)*) => {
435 $(
436 #[allow(deprecated)]
437 $(#[$a])*
438 #[inline(always)]
439 pub fn $set(&mut self, value: bool) -> Result<()> {
440 unsafe { cvt(SSLSetSessionOption(self.0, $opt, Boolean::from(value))) }
441 }
442
443 #[allow(deprecated)]
444 $(#[$a])*
445 #[inline]
446 pub fn $get(&self) -> Result<bool> {
447 let mut value = 0;
448 unsafe { cvt(SSLGetSessionOption(self.0, $opt, &mut value))?; }
449 Ok(value != 0)
450 }
451 )*
452 }
453}
454
455impl SslContext {
456 #[inline]
459 pub fn new(side: SslProtocolSide, type_: SslConnectionType) -> Result<Self> {
460 unsafe {
461 let ctx = SSLCreateContext(kCFAllocatorDefault, side.0, type_.0);
462 Ok(Self(ctx))
463 }
464 }
465
466 #[inline]
475 pub fn set_peer_domain_name(&mut self, peer_name: &str) -> Result<()> {
476 unsafe {
477 cvt(SSLSetPeerDomainName(self.0, peer_name.as_ptr().cast(), peer_name.len()))
479 }
480 }
481
482 pub fn peer_domain_name(&self) -> Result<String> {
484 unsafe {
485 let mut len = 0;
486 cvt(SSLGetPeerDomainNameLength(self.0, &mut len))?;
487 let mut buf = vec![0; len];
488 cvt(SSLGetPeerDomainName(self.0, buf.as_mut_ptr().cast(), &mut len))?;
489 String::from_utf8(buf).map_err(|_| Error::from_code(-1))
490 }
491 }
492
493 pub fn set_certificate(
501 &mut self,
502 identity: &SecIdentity,
503 certs: &[SecCertificate],
504 ) -> Result<()> {
505 let mut arr = vec![identity.as_CFType()];
506 arr.extend(certs.iter().map(|c| c.as_CFType()));
507 let certs = CFArray::from_CFTypes(&arr);
508
509 unsafe { cvt(SSLSetCertificate(self.0, certs.as_concrete_TypeRef())) }
510 }
511
512 #[inline]
519 pub fn set_peer_id(&mut self, peer_id: &[u8]) -> Result<()> {
520 unsafe { cvt(SSLSetPeerID(self.0, peer_id.as_ptr().cast(), peer_id.len())) }
521 }
522
523 pub fn peer_id(&self) -> Result<Option<&[u8]>> {
525 unsafe {
526 let mut ptr = ptr::null();
527 let mut len = 0;
528 cvt(SSLGetPeerID(self.0, &mut ptr, &mut len))?;
529 if ptr.is_null() {
530 Ok(None)
531 } else {
532 Ok(Some(slice::from_raw_parts(ptr.cast(), len)))
533 }
534 }
535 }
536
537 pub fn supported_ciphers(&self) -> Result<Vec<CipherSuite>> {
539 unsafe {
540 let mut num_ciphers = 0;
541 cvt(SSLGetNumberSupportedCiphers(self.0, &mut num_ciphers))?;
542 let mut ciphers = vec![0; num_ciphers];
543 cvt(SSLGetSupportedCiphers(
544 self.0,
545 ciphers.as_mut_ptr(),
546 &mut num_ciphers,
547 ))?;
548 Ok(ciphers.iter().map(|c| CipherSuite::from_raw(*c)).collect())
549 }
550 }
551
552 pub fn enabled_ciphers(&self) -> Result<Vec<CipherSuite>> {
555 unsafe {
556 let mut num_ciphers = 0;
557 cvt(SSLGetNumberEnabledCiphers(self.0, &mut num_ciphers))?;
558 let mut ciphers = vec![0; num_ciphers];
559 cvt(SSLGetEnabledCiphers(
560 self.0,
561 ciphers.as_mut_ptr(),
562 &mut num_ciphers,
563 ))?;
564 Ok(ciphers.iter().map(|c| CipherSuite::from_raw(*c)).collect())
565 }
566 }
567
568 pub fn set_enabled_ciphers(&mut self, ciphers: &[CipherSuite]) -> Result<()> {
570 let ciphers = ciphers.iter().map(|c| c.to_raw()).collect::<Vec<_>>();
571 unsafe {
572 cvt(SSLSetEnabledCiphers(
573 self.0,
574 ciphers.as_ptr(),
575 ciphers.len(),
576 ))
577 }
578 }
579
580 #[inline]
582 pub fn negotiated_cipher(&self) -> Result<CipherSuite> {
583 unsafe {
584 let mut cipher = 0;
585 cvt(SSLGetNegotiatedCipher(self.0, &mut cipher))?;
586 Ok(CipherSuite::from_raw(cipher))
587 }
588 }
589
590 #[inline]
594 pub fn set_client_side_authenticate(&mut self, auth: SslAuthenticate) -> Result<()> {
595 unsafe { cvt(SSLSetClientSideAuthenticate(self.0, auth.0)) }
596 }
597
598 #[inline]
600 pub fn client_certificate_state(&self) -> Result<SslClientCertificateState> {
601 let mut state = 0;
602
603 unsafe {
604 cvt(SSLGetClientCertificateState(self.0, &mut state))?;
605 }
606 Ok(SslClientCertificateState(state))
607 }
608
609 pub fn peer_trust2(&self) -> Result<Option<SecTrust>> {
614 if self.state()? == SessionState::IDLE {
617 return Err(Error::from_code(errSecBadReq));
618 }
619
620 unsafe {
621 let mut trust = ptr::null_mut();
622 cvt(SSLCopyPeerTrust(self.0, &mut trust))?;
623 if trust.is_null() {
624 Ok(None)
625 } else {
626 Ok(Some(SecTrust::wrap_under_create_rule(trust)))
627 }
628 }
629 }
630
631 #[inline]
633 pub fn state(&self) -> Result<SessionState> {
634 unsafe {
635 let mut state = 0;
636 cvt(SSLGetSessionState(self.0, &mut state))?;
637 Ok(SessionState(state))
638 }
639 }
640
641 #[inline]
643 pub fn negotiated_protocol_version(&self) -> Result<SslProtocol> {
644 unsafe {
645 let mut version = 0;
646 cvt(SSLGetNegotiatedProtocolVersion(self.0, &mut version))?;
647 Ok(SslProtocol(version))
648 }
649 }
650
651 #[inline]
653 pub fn protocol_version_max(&self) -> Result<SslProtocol> {
654 unsafe {
655 let mut version = 0;
656 cvt(SSLGetProtocolVersionMax(self.0, &mut version))?;
657 Ok(SslProtocol(version))
658 }
659 }
660
661 #[inline]
663 pub fn set_protocol_version_max(&mut self, max_version: SslProtocol) -> Result<()> {
664 unsafe { cvt(SSLSetProtocolVersionMax(self.0, max_version.0)) }
665 }
666
667 #[inline]
669 pub fn protocol_version_min(&self) -> Result<SslProtocol> {
670 unsafe {
671 let mut version = 0;
672 cvt(SSLGetProtocolVersionMin(self.0, &mut version))?;
673 Ok(SslProtocol(version))
674 }
675 }
676
677 #[inline]
679 pub fn set_protocol_version_min(&mut self, min_version: SslProtocol) -> Result<()> {
680 unsafe { cvt(SSLSetProtocolVersionMin(self.0, min_version.0)) }
681 }
682
683 #[cfg(feature = "alpn")]
685 pub fn alpn_protocols(&self) -> Result<Vec<String>> {
686 let mut array: CFArrayRef = ptr::null();
687 unsafe {
688 #[cfg(feature = "OSX_10_13")]
689 {
690 cvt(SSLCopyALPNProtocols(self.0, &mut array))?;
691 }
692
693 #[cfg(not(feature = "OSX_10_13"))]
694 {
695 dlsym! { fn SSLCopyALPNProtocols(SSLContextRef, *mut CFArrayRef) -> OSStatus }
696 if let Some(f) = SSLCopyALPNProtocols.get() {
697 cvt(f(self.0, &mut array))?;
698 } else {
699 return Err(Error::from_code(errSecUnimplemented));
700 }
701 }
702
703 if array.is_null() {
704 return Ok(vec![]);
705 }
706
707 let array = CFArray::<CFString>::wrap_under_create_rule(array);
708 Ok(array.into_iter().map(|p| p.to_string()).collect())
709 }
710 }
711
712 #[cfg(feature = "alpn")]
716 pub fn set_alpn_protocols(&mut self, protocols: &[&str]) -> Result<()> {
717 let protocols = CFArray::from_CFTypes(
721 &protocols
722 .iter()
723 .map(|proto| CFString::new(proto))
724 .collect::<Vec<_>>(),
725 );
726
727 #[cfg(feature = "OSX_10_13")]
728 {
729 unsafe { cvt(SSLSetALPNProtocols(self.0, protocols.as_concrete_TypeRef())) }
730 }
731 #[cfg(not(feature = "OSX_10_13"))]
732 {
733 dlsym! { fn SSLSetALPNProtocols(SSLContextRef, CFArrayRef) -> OSStatus }
734 if let Some(f) = SSLSetALPNProtocols.get() {
735 unsafe { cvt(f(self.0, protocols.as_concrete_TypeRef())) }
736 } else {
737 Err(Error::from_code(errSecUnimplemented))
738 }
739 }
740 }
741
742 #[cfg(feature = "session-tickets")]
750 pub fn set_session_tickets_enabled(&mut self, enabled: bool) -> Result<()> {
751 #[cfg(feature = "OSX_10_13")]
752 {
753 unsafe { cvt(SSLSetSessionTicketsEnabled(self.0, Boolean::from(enabled))) }
754 }
755 #[cfg(not(feature = "OSX_10_13"))]
756 {
757 dlsym! { fn SSLSetSessionTicketsEnabled(SSLContextRef, Boolean) -> OSStatus }
758 if let Some(f) = SSLSetSessionTicketsEnabled.get() {
759 unsafe { cvt(f(self.0, Boolean::from(enabled))) }
760 } else {
761 Err(Error::from_code(errSecUnimplemented))
762 }
763 }
764 }
765
766 #[inline]
769 pub fn buffered_read_size(&self) -> Result<usize> {
770 unsafe {
771 let mut size = 0;
772 cvt(SSLGetBufferedReadSize(self.0, &mut size))?;
773 Ok(size)
774 }
775 }
776
777 impl_options! {
778 const kSSLSessionOptionBreakOnServerAuth: break_on_server_auth & set_break_on_server_auth,
781 const kSSLSessionOptionBreakOnCertRequested: break_on_cert_requested & set_break_on_cert_requested,
784 const kSSLSessionOptionBreakOnClientAuth: break_on_client_auth & set_break_on_client_auth,
787 const kSSLSessionOptionFalseStart: false_start & set_false_start,
791 const kSSLSessionOptionSendOneByteRecord: send_one_byte_record & set_send_one_byte_record,
795 }
796
797 fn into_stream<S>(self, stream: S) -> Result<SslStream<S>>
798 where S: Read + Write {
799 unsafe {
800 let ret = SSLSetIOFuncs(self.0, read_func::<S>, write_func::<S>);
801 if ret != errSecSuccess {
802 return Err(Error::from_code(ret));
803 }
804
805 let stream = Connection { stream, err: None, panic: None };
806 let stream = Box::into_raw(Box::new(stream));
807 let ret = SSLSetConnection(self.0, stream.cast());
808 if ret != errSecSuccess {
809 let _conn = Box::from_raw(stream);
810 return Err(Error::from_code(ret));
811 }
812
813 Ok(SslStream { ctx: self, _m: PhantomData })
814 }
815 }
816
817 pub fn handshake<S>(self, stream: S) -> result::Result<SslStream<S>, HandshakeError<S>>
819 where
820 S: Read + Write,
821 {
822 self.into_stream(stream)
823 .map_err(HandshakeError::Failure)
824 .and_then(SslStream::handshake)
825 }
826}
827
828struct Connection<S> {
829 stream: S,
830 err: Option<io::Error>,
831 panic: Option<Box<dyn Any + Send>>,
832}
833
834#[cold]
836fn translate_err(e: &io::Error) -> OSStatus {
837 match e.kind() {
838 io::ErrorKind::NotFound => errSSLClosedGraceful,
839 io::ErrorKind::ConnectionReset => errSSLClosedAbort,
840 io::ErrorKind::WouldBlock |
841 io::ErrorKind::NotConnected => errSSLWouldBlock,
842 _ => errSecIO,
843 }
844}
845
846unsafe extern "C" fn read_func<S>(
847 connection: SSLConnectionRef,
848 data: *mut c_void,
849 data_length: *mut usize,
850) -> OSStatus
851where S: Read {
852 let conn: &mut Connection<S> = &mut *(connection as *mut _);
853 let mut read = 0;
854
855 let ret = panic::catch_unwind(AssertUnwindSafe(|| {
856 let mut data = slice::from_raw_parts_mut(data.cast::<u8>(), *data_length);
857 while !data.is_empty() {
858 match conn.stream.read(data) {
859 Ok(0) => return errSSLClosedNoNotify,
860 Ok(len) => {
861 let Some(rest) = data.get_mut(len..) else {
862 return errSecIO;
863 };
864 data = rest;
865 read += len;
866 },
867 Err(e) => {
868 let ret = translate_err(&e);
869 conn.err = Some(e);
870 return ret;
871 },
872 }
873 }
874 errSecSuccess
875 }))
876 .unwrap_or_else(|e| {
877 conn.panic = Some(e);
878 errSecIO
879 });
880
881 *data_length = read;
882 ret
883}
884
885unsafe extern "C" fn write_func<S>(
886 connection: SSLConnectionRef,
887 data: *const c_void,
888 data_length: *mut usize,
889) -> OSStatus
890where S: Write {
891 let conn: &mut Connection<S> = &mut *(connection as *mut _);
892 let mut written = 0;
893
894 let ret = panic::catch_unwind(AssertUnwindSafe(|| {
895 let mut data = slice::from_raw_parts(data.cast::<u8>(), *data_length);
896 while !data.is_empty() {
897 match conn.stream.write(data) {
898 Ok(0) => return errSSLClosedNoNotify,
899 Ok(len) => {
900 let Some(rest) = data.get(len..) else {
901 return errSecIO;
902 };
903 data = rest;
904 written += len;
905 },
906 Err(e) => {
907 let ret = translate_err(&e);
908 conn.err = Some(e);
909 return ret;
910 },
911 }
912 }
913 if let Err(e) = conn.stream.flush() {
917 let ret = translate_err(&e);
918 conn.err = Some(e);
919 return ret;
920 }
921 errSecSuccess
922 }))
923 .unwrap_or_else(|e| {
924 conn.panic = Some(e);
925 errSecIO
926 });
927
928 *data_length = written;
929 ret
930}
931
932pub struct SslStream<S> {
934 ctx: SslContext,
935 _m: PhantomData<S>,
936}
937
938impl<S: fmt::Debug> fmt::Debug for SslStream<S> {
939 #[cold]
940 fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
941 fmt.debug_struct("SslStream")
942 .field("context", &self.ctx)
943 .field("stream", self.get_ref())
944 .finish()
945 }
946}
947
948impl<S> Drop for SslStream<S> {
949 fn drop(&mut self) {
950 unsafe {
951 let mut conn = ptr::null();
952 let ret = SSLGetConnection(self.ctx.0, &mut conn);
953 assert!(ret == errSecSuccess);
954 let _ = Box::<Connection<S>>::from_raw(conn as *mut _);
955 }
956 }
957}
958
959impl<S> SslStream<S> {
960 fn handshake(mut self) -> result::Result<Self, HandshakeError<S>> {
961 match unsafe { SSLHandshake(self.ctx.0) } {
962 errSecSuccess => Ok(self),
963 reason @ (errSSLPeerAuthCompleted
964 | errSSLClientCertRequested
965 | errSSLWouldBlock
966 | errSSLClientHelloReceived) => {
967 Err(HandshakeError::Interrupted(MidHandshakeSslStream {
968 stream: self,
969 error: Error::from_code(reason),
970 }))
971 },
972 err => {
973 self.check_panic();
974 Err(HandshakeError::Failure(Error::from_code(err)))
975 },
976 }
977 }
978
979 #[inline(always)]
981 #[must_use]
982 pub fn get_ref(&self) -> &S {
983 &self.connection().stream
984 }
985
986 #[inline(always)]
988 pub fn get_mut(&mut self) -> &mut S {
989 &mut self.connection_mut().stream
990 }
991
992 #[inline(always)]
994 #[must_use]
995 pub fn context(&self) -> &SslContext {
996 &self.ctx
997 }
998
999 #[inline(always)]
1001 pub fn context_mut(&mut self) -> &mut SslContext {
1002 &mut self.ctx
1003 }
1004
1005 pub fn close(&mut self) -> result::Result<(), io::Error> {
1007 unsafe {
1008 let ret = SSLClose(self.ctx.0);
1009 if ret == errSecSuccess {
1010 Ok(())
1011 } else {
1012 Err(self.get_error(ret))
1013 }
1014 }
1015 }
1016
1017 fn connection(&self) -> &Connection<S> {
1018 unsafe {
1019 let mut conn = ptr::null();
1020 let ret = SSLGetConnection(self.ctx.0, &mut conn);
1021 assert!(ret == errSecSuccess);
1022
1023 &mut *(conn as *mut Connection<S>)
1024 }
1025 }
1026
1027 fn connection_mut(&mut self) -> &mut Connection<S> {
1028 unsafe {
1029 let mut conn = ptr::null();
1030 let ret = SSLGetConnection(self.ctx.0, &mut conn);
1031 assert!(ret == errSecSuccess);
1032
1033 &mut *(conn as *mut Connection<S>)
1034 }
1035 }
1036
1037 #[cold]
1038 fn check_panic(&mut self) {
1039 let conn = self.connection_mut();
1040 if let Some(err) = conn.panic.take() {
1041 panic::resume_unwind(err);
1042 }
1043 }
1044
1045 #[cold]
1046 fn get_error(&mut self, ret: OSStatus) -> io::Error {
1047 self.check_panic();
1048
1049 if let Some(err) = self.connection_mut().err.take() {
1050 err
1051 } else {
1052 io::Error::new(io::ErrorKind::Other, Error::from_code(ret))
1053 }
1054 }
1055}
1056
1057impl<S: Read + Write> Read for SslStream<S> {
1058 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
1059 if buf.is_empty() {
1064 return Ok(0);
1065 }
1066
1067 let buffered = self.context().buffered_read_size().unwrap_or(0);
1072 let to_read = if buffered > 0 {
1073 cmp::min(buffered, buf.len())
1074 } else {
1075 buf.len()
1076 };
1077
1078 unsafe {
1079 let mut nread = 0;
1080 let ret = SSLRead(self.ctx.0, buf.as_mut_ptr().cast(), to_read, &mut nread);
1081 if nread > 0 {
1084 return Ok(nread);
1085 }
1086
1087 match ret {
1088 errSSLClosedGraceful | errSSLClosedAbort | errSSLClosedNoNotify => Ok(0),
1089 errSSLPeerAuthCompleted => self.read(buf),
1091 _ => Err(self.get_error(ret)),
1092 }
1093 }
1094 }
1095}
1096
1097impl<S: Read + Write> Write for SslStream<S> {
1098 fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
1099 if buf.is_empty() {
1101 return Ok(0);
1102 }
1103 unsafe {
1104 let mut nwritten = 0;
1105 let ret = SSLWrite(
1106 self.ctx.0,
1107 buf.as_ptr().cast(),
1108 buf.len(),
1109 &mut nwritten,
1110 );
1111 if nwritten > 0 {
1114 Ok(nwritten)
1115 } else {
1116 Err(self.get_error(ret))
1117 }
1118 }
1119 }
1120
1121 fn flush(&mut self) -> io::Result<()> {
1122 self.connection_mut().stream.flush()
1123 }
1124}
1125
1126#[derive(Debug)]
1128pub struct ClientBuilder {
1129 identity: Option<SecIdentity>,
1130 certs: Vec<SecCertificate>,
1131 chain: Vec<SecCertificate>,
1132 protocol_min: Option<SslProtocol>,
1133 protocol_max: Option<SslProtocol>,
1134 trust_certs_only: bool,
1135 use_sni: bool,
1136 danger_accept_invalid_certs: bool,
1137 danger_accept_invalid_hostnames: bool,
1138 whitelisted_ciphers: Vec<CipherSuite>,
1139 blacklisted_ciphers: Vec<CipherSuite>,
1140 #[cfg(feature = "alpn")]
1141 alpn: Option<Vec<String>>,
1142 #[cfg(feature = "session-tickets")]
1143 enable_session_tickets: bool,
1144}
1145
1146impl Default for ClientBuilder {
1147 #[inline(always)]
1148 fn default() -> Self {
1149 Self::new()
1150 }
1151}
1152
1153impl ClientBuilder {
1154 #[inline]
1156 #[must_use]
1157 pub fn new() -> Self {
1158 Self {
1159 identity: None,
1160 certs: Vec::new(),
1161 chain: Vec::new(),
1162 protocol_min: None,
1163 protocol_max: None,
1164 trust_certs_only: false,
1165 use_sni: true,
1166 danger_accept_invalid_certs: false,
1167 danger_accept_invalid_hostnames: false,
1168 whitelisted_ciphers: Vec::new(),
1169 blacklisted_ciphers: Vec::new(),
1170 #[cfg(feature = "alpn")]
1171 alpn: None,
1172 #[cfg(feature = "session-tickets")]
1173 enable_session_tickets: false,
1174 }
1175 }
1176
1177 #[inline]
1180 pub fn anchor_certificates(&mut self, certs: &[SecCertificate]) -> &mut Self {
1181 certs.clone_into(&mut self.certs);
1182 self
1183 }
1184
1185 #[inline]
1188 pub fn add_anchor_certificate(&mut self, certs: &SecCertificate) -> &mut Self {
1189 self.certs.push(certs.to_owned());
1190 self
1191 }
1192
1193 #[inline(always)]
1196 pub fn trust_anchor_certificates_only(&mut self, only: bool) -> &mut Self {
1197 self.trust_certs_only = only;
1198 self
1199 }
1200
1201 #[inline(always)]
1210 pub fn danger_accept_invalid_certs(&mut self, noverify: bool) -> &mut Self {
1211 self.danger_accept_invalid_certs = noverify;
1212 self
1213 }
1214
1215 #[inline(always)]
1217 pub fn use_sni(&mut self, use_sni: bool) -> &mut Self {
1218 self.use_sni = use_sni;
1219 self
1220 }
1221
1222 #[inline(always)]
1230 pub fn danger_accept_invalid_hostnames(
1231 &mut self,
1232 danger_accept_invalid_hostnames: bool,
1233 ) -> &mut Self {
1234 self.danger_accept_invalid_hostnames = danger_accept_invalid_hostnames;
1235 self
1236 }
1237
1238 pub fn whitelist_ciphers(&mut self, whitelisted_ciphers: &[CipherSuite]) -> &mut Self {
1240 whitelisted_ciphers.clone_into(&mut self.whitelisted_ciphers);
1241 self
1242 }
1243
1244 pub fn blacklist_ciphers(&mut self, blacklisted_ciphers: &[CipherSuite]) -> &mut Self {
1246 blacklisted_ciphers.clone_into(&mut self.blacklisted_ciphers);
1247 self
1248 }
1249
1250 pub fn identity(&mut self, identity: &SecIdentity, chain: &[SecCertificate]) -> &mut Self {
1252 self.identity = Some(identity.clone());
1253 chain.clone_into(&mut self.chain);
1254 self
1255 }
1256
1257 #[inline(always)]
1259 pub fn protocol_min(&mut self, min: SslProtocol) -> &mut Self {
1260 self.protocol_min = Some(min);
1261 self
1262 }
1263
1264 #[inline(always)]
1266 pub fn protocol_max(&mut self, max: SslProtocol) -> &mut Self {
1267 self.protocol_max = Some(max);
1268 self
1269 }
1270
1271 #[cfg(feature = "alpn")]
1273 pub fn alpn_protocols(&mut self, protocols: &[&str]) -> &mut Self {
1274 self.alpn = Some(protocols.iter().map(|s| (*s).to_string()).collect());
1275 self
1276 }
1277
1278 #[cfg(feature = "session-tickets")]
1282 #[inline(always)]
1283 pub fn enable_session_tickets(&mut self, enable: bool) -> &mut Self {
1284 self.enable_session_tickets = enable;
1285 self
1286 }
1287
1288 pub fn handshake<S>(
1292 &self,
1293 domain: &str,
1294 stream: S,
1295 ) -> result::Result<SslStream<S>, ClientHandshakeError<S>>
1296 where
1297 S: Read + Write,
1298 {
1299 let stream = MidHandshakeSslStream {
1302 stream: self.ctx_into_stream(domain, stream)?,
1303 error: Error::from(errSecSuccess),
1304 };
1305
1306 let certs = self.certs.clone();
1307 let stream = MidHandshakeClientBuilder {
1308 stream,
1309 domain: if self.danger_accept_invalid_hostnames {
1310 None
1311 } else {
1312 Some(domain.to_string())
1313 },
1314 certs,
1315 trust_certs_only: self.trust_certs_only,
1316 danger_accept_invalid_certs: self.danger_accept_invalid_certs,
1317 };
1318 stream.handshake()
1319 }
1320
1321 fn ctx_into_stream<S>(&self, domain: &str, stream: S) -> Result<SslStream<S>>
1322 where S: Read + Write {
1323 let mut ctx = SslContext::new(SslProtocolSide::CLIENT, SslConnectionType::STREAM)?;
1324
1325 if self.use_sni {
1326 ctx.set_peer_domain_name(domain)?;
1327 }
1328 if let Some(identity) = &self.identity {
1329 ctx.set_certificate(identity, &self.chain)?;
1330 }
1331 #[cfg(feature = "alpn")]
1332 {
1333 if let Some(alpn) = &self.alpn {
1334 ctx.set_alpn_protocols(&alpn.iter().map(|s| &**s).collect::<Vec<_>>())?;
1335 }
1336 }
1337 #[cfg(feature = "session-tickets")]
1338 {
1339 if self.enable_session_tickets {
1340 ctx.set_peer_id(domain.as_bytes())?;
1343 ctx.set_session_tickets_enabled(true)?;
1344 }
1345 }
1346 ctx.set_break_on_server_auth(true)?;
1347 self.configure_protocols(&mut ctx)?;
1348 self.configure_ciphers(&mut ctx)?;
1349
1350 ctx.into_stream(stream)
1351 }
1352
1353 fn configure_protocols(&self, ctx: &mut SslContext) -> Result<()> {
1354 if let Some(min) = self.protocol_min {
1355 ctx.set_protocol_version_min(min)?;
1356 }
1357 if let Some(max) = self.protocol_max {
1358 ctx.set_protocol_version_max(max)?;
1359 }
1360 Ok(())
1361 }
1362
1363 fn configure_ciphers(&self, ctx: &mut SslContext) -> Result<()> {
1364 let mut ciphers = if self.whitelisted_ciphers.is_empty() {
1365 ctx.enabled_ciphers()?
1366 } else {
1367 self.whitelisted_ciphers.clone()
1368 };
1369
1370 if !self.blacklisted_ciphers.is_empty() {
1371 ciphers.retain(|cipher| !self.blacklisted_ciphers.contains(cipher));
1372 }
1373
1374 ctx.set_enabled_ciphers(&ciphers)?;
1375 Ok(())
1376 }
1377}
1378
1379#[derive(Debug)]
1381pub struct ServerBuilder {
1382 identity: SecIdentity,
1383 certs: Vec<SecCertificate>,
1384}
1385
1386impl ServerBuilder {
1387 #[must_use]
1390 pub fn new(identity: &SecIdentity, certs: &[SecCertificate]) -> Self {
1391 Self {
1392 identity: identity.clone(),
1393 certs: certs.to_owned(),
1394 }
1395 }
1396
1397 pub fn from_pkcs12(pkcs12_der: &[u8], passphrase: &str) -> Result<Self> {
1404 let mut identities: Vec<(SecIdentity, Vec<SecCertificate>)> = Pkcs12ImportOptions::new()
1405 .passphrase(passphrase)
1406 .import(pkcs12_der)?
1407 .into_iter()
1408 .filter_map(|idendity| {
1409 Some((idendity.identity?, idendity.cert_chain.unwrap_or_default()))
1410 })
1411 .take(2)
1412 .collect();
1413 if identities.len() == 1 {
1414 let (identity, certs) = identities.pop().unwrap();
1415 Ok(Self { identity, certs })
1416 } else {
1417 Err(Error::from_code(errSecParam))
1419 }
1420 }
1421
1422 pub fn new_ssl_context(&self) -> Result<SslContext> {
1424 let mut ctx = SslContext::new(SslProtocolSide::SERVER, SslConnectionType::STREAM)?;
1425 ctx.set_certificate(&self.identity, &self.certs)?;
1426 Ok(ctx)
1427 }
1428
1429 pub fn handshake<S>(&self, stream: S) -> Result<SslStream<S>>
1431 where S: Read + Write {
1432 match self.new_ssl_context()?.handshake(stream) {
1433 Ok(stream) => Ok(stream),
1434 Err(HandshakeError::Interrupted(stream)) => Err(*stream.error()),
1435 Err(HandshakeError::Failure(err)) => Err(err),
1436 }
1437 }
1438}
1439
1440#[cfg(test)]
1441mod test {
1442 use std::io::prelude::*;
1443 use std::net::TcpStream;
1444
1445 use super::*;
1446
1447 #[test]
1448 fn server_builder_from_pkcs12() {
1449 let pkcs12_der = include_bytes!("../test/server.p12");
1450 ServerBuilder::from_pkcs12(pkcs12_der, "password123").unwrap();
1451 }
1452
1453 #[test]
1454 fn connect() {
1455 let mut ctx = p!(SslContext::new(SslProtocolSide::CLIENT, SslConnectionType::STREAM));
1456 p!(ctx.set_peer_domain_name("google.com"));
1457 let stream = p!(TcpStream::connect("google.com:443"));
1458 p!(ctx.handshake(stream));
1459 }
1460
1461 #[test]
1462 fn connect_bad_domain() {
1463 let mut ctx = p!(SslContext::new(SslProtocolSide::CLIENT, SslConnectionType::STREAM));
1464 p!(ctx.set_peer_domain_name("foobar.com"));
1465 let stream = p!(TcpStream::connect("google.com:443"));
1466 ctx.handshake(stream).expect_err("expected failure");
1467 }
1468
1469 #[test]
1470 fn connect_buffered_stream() {
1471 use std::io::BufWriter;
1472
1473 #[derive(Debug)]
1475 struct BufferedTcpStream {
1476 reader: TcpStream,
1477 writer: BufWriter<TcpStream>,
1478 }
1479
1480 impl BufferedTcpStream {
1481 fn new(tcp: TcpStream) -> std::io::Result<Self> {
1482 Ok(Self {
1483 writer: BufWriter::with_capacity(500, tcp.try_clone()?),
1484 reader: tcp,
1485 })
1486 }
1487 }
1488
1489 impl Read for BufferedTcpStream {
1490 fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
1491 self.reader.read(buf)
1492 }
1493 }
1494
1495 impl Write for BufferedTcpStream {
1496 fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
1497 self.writer.write(buf)
1498 }
1499
1500 fn flush(&mut self) -> std::io::Result<()> {
1501 self.writer.flush()
1502 }
1503 }
1504
1505 let mut ctx = p!(SslContext::new(SslProtocolSide::CLIENT, SslConnectionType::STREAM));
1506 p!(ctx.set_peer_domain_name("google.com"));
1507 let stream = p!(TcpStream::connect("google.com:443"));
1508 let stream = p!(BufferedTcpStream::new(stream));
1509 p!(ctx.handshake(stream));
1510 }
1511
1512 #[test]
1513 fn load_page() {
1514 let mut ctx = p!(SslContext::new(SslProtocolSide::CLIENT, SslConnectionType::STREAM));
1515 p!(ctx.set_peer_domain_name("google.com"));
1516 let stream = p!(TcpStream::connect("google.com:443"));
1517 let mut stream = p!(ctx.handshake(stream));
1518 p!(stream.write_all(b"GET / HTTP/1.0\r\n\r\n"));
1519 p!(stream.flush());
1520 let mut buf = vec![];
1521 p!(stream.read_to_end(&mut buf));
1522 println!("{}", String::from_utf8_lossy(&buf));
1523 }
1524
1525 #[test]
1526 fn client_no_session_ticket_resumption() {
1527 for _ in 0..2 {
1528 let stream = p!(TcpStream::connect("google.com:443"));
1529
1530 let stream = MidHandshakeSslStream {
1532 stream: ClientBuilder::new()
1533 .ctx_into_stream("google.com", stream)
1534 .unwrap(),
1535 error: Error::from(errSecSuccess),
1536 };
1537
1538 let mut result = stream.handshake();
1539
1540 if let Err(HandshakeError::Interrupted(stream)) = result {
1541 assert!(stream.server_auth_completed());
1542 result = stream.handshake();
1543 } else {
1544 panic!("Unexpectedly skipped server auth");
1545 }
1546
1547 assert!(result.is_ok());
1548 }
1549 }
1550
1551 #[test]
1552 #[cfg(feature = "session-tickets")]
1553 fn client_session_ticket_resumption() {
1554 for i in 0..2 {
1557 let stream = p!(TcpStream::connect("google.com:443"));
1558 let mut builder = ClientBuilder::new();
1559 builder.enable_session_tickets(true);
1560
1561 let stream = MidHandshakeSslStream {
1563 stream: builder.ctx_into_stream("google.com", stream).unwrap(),
1564 error: Error::from(errSecSuccess),
1565 };
1566
1567 let mut result = stream.handshake();
1568
1569 if let Err(HandshakeError::Interrupted(stream)) = result {
1570 assert!(stream.server_auth_completed());
1571 assert_eq!(i, 0, "Session ticket resumption did not work, server auth was not skipped");
1572 result = stream.handshake();
1573 } else {
1574 assert_eq!(i, 1, "Unexpectedly skipped server auth");
1575 }
1576
1577 assert!(result.is_ok());
1578 }
1579 }
1580
1581 #[test]
1582 #[cfg(feature = "alpn")]
1583 fn client_alpn_accept() {
1584 let mut ctx = p!(SslContext::new(SslProtocolSide::CLIENT, SslConnectionType::STREAM));
1585 p!(ctx.set_peer_domain_name("google.com"));
1586 p!(ctx.set_alpn_protocols(&["h2"]));
1587 let stream = p!(TcpStream::connect("google.com:443"));
1588 let stream = ctx.handshake(stream).unwrap();
1589 assert_eq!(vec!["h2"], stream.context().alpn_protocols().unwrap());
1590 }
1591
1592 #[test]
1593 #[cfg(feature = "alpn")]
1594 fn client_alpn_reject() {
1595 let mut ctx = p!(SslContext::new(SslProtocolSide::CLIENT, SslConnectionType::STREAM));
1596 p!(ctx.set_peer_domain_name("google.com"));
1597 p!(ctx.set_alpn_protocols(&["h2c"]));
1598 let stream = p!(TcpStream::connect("google.com:443"));
1599 let stream = ctx.handshake(stream).unwrap();
1600 assert!(stream.context().alpn_protocols().is_err());
1601 }
1602
1603 #[test]
1604 fn client_no_anchor_certs() {
1605 let stream = p!(TcpStream::connect("google.com:443"));
1606 assert!(ClientBuilder::new()
1607 .trust_anchor_certificates_only(true)
1608 .handshake("google.com", stream)
1609 .is_err());
1610 }
1611
1612 #[test]
1613 fn client_bad_domain() {
1614 let stream = p!(TcpStream::connect("google.com:443"));
1615 assert!(ClientBuilder::new()
1616 .handshake("foobar.com", stream)
1617 .is_err());
1618 }
1619
1620 #[test]
1621 fn client_bad_domain_ignored() {
1622 let stream = p!(TcpStream::connect("google.com:443"));
1623 ClientBuilder::new()
1624 .danger_accept_invalid_hostnames(true)
1625 .handshake("foobar.com", stream)
1626 .unwrap();
1627 }
1628
1629 #[test]
1630 fn connect_no_verify_ssl() {
1631 let stream = p!(TcpStream::connect("expired.badssl.com:443"));
1632 let mut builder = ClientBuilder::new();
1633 builder.danger_accept_invalid_certs(true);
1634 builder.handshake("expired.badssl.com", stream).unwrap();
1635 }
1636
1637 #[test]
1638 fn load_page_client() {
1639 let stream = p!(TcpStream::connect("google.com:443"));
1640 let mut stream = p!(ClientBuilder::new().handshake("google.com", stream));
1641 p!(stream.write_all(b"GET / HTTP/1.0\r\n\r\n"));
1642 p!(stream.flush());
1643 let mut buf = vec![];
1644 p!(stream.read_to_end(&mut buf));
1645 println!("{}", String::from_utf8_lossy(&buf));
1646 }
1647
1648 #[test]
1649 #[cfg_attr(any(target_os = "ios", target_os = "tvos", target_os = "watchos", target_os = "visionos"), ignore)] fn cipher_configuration() {
1651 let mut ctx = p!(SslContext::new(SslProtocolSide::SERVER, SslConnectionType::STREAM));
1652 let ciphers = p!(ctx.enabled_ciphers());
1653 let ciphers = ciphers
1654 .iter()
1655 .enumerate()
1656 .filter_map(|(i, c)| if i % 2 == 0 { Some(*c) } else { None })
1657 .collect::<Vec<_>>();
1658 p!(ctx.set_enabled_ciphers(&ciphers));
1659 assert_eq!(ciphers, p!(ctx.enabled_ciphers()));
1660 }
1661
1662 #[test]
1663 fn test_builder_whitelist_ciphers() {
1664 let stream = p!(TcpStream::connect("google.com:443"));
1665
1666 let ctx = p!(SslContext::new(SslProtocolSide::CLIENT, SslConnectionType::STREAM));
1667 assert!(p!(ctx.enabled_ciphers()).len() > 1);
1668
1669 let ciphers = p!(ctx.enabled_ciphers());
1670 let cipher = ciphers.first().unwrap();
1671 let stream = p!(ClientBuilder::new()
1672 .whitelist_ciphers(&[*cipher])
1673 .ctx_into_stream("google.com", stream));
1674
1675 assert_eq!(1, p!(stream.context().enabled_ciphers()).len());
1676 }
1677
1678 #[test]
1679 #[cfg_attr(any(target_os = "ios", target_os = "tvos", target_os = "watchos", target_os = "visionos"), ignore)] fn test_builder_blacklist_ciphers() {
1681 let stream = p!(TcpStream::connect("google.com:443"));
1682
1683 let ctx = p!(SslContext::new(SslProtocolSide::CLIENT, SslConnectionType::STREAM));
1684 let num = p!(ctx.enabled_ciphers()).len();
1685 assert!(num > 1);
1686
1687 let ciphers = p!(ctx.enabled_ciphers());
1688 let cipher = ciphers.first().unwrap();
1689 let stream = p!(ClientBuilder::new()
1690 .blacklist_ciphers(&[*cipher])
1691 .ctx_into_stream("google.com", stream));
1692
1693 assert_eq!(num - 1, p!(stream.context().enabled_ciphers()).len());
1694 }
1695
1696 #[test]
1697 fn idle_context_peer_trust() {
1698 let ctx = p!(SslContext::new(SslProtocolSide::SERVER, SslConnectionType::STREAM));
1699 assert!(ctx.peer_trust2().is_err());
1700 }
1701
1702 #[test]
1703 fn peer_id() {
1704 let mut ctx = p!(SslContext::new(SslProtocolSide::SERVER, SslConnectionType::STREAM));
1705 assert!(p!(ctx.peer_id()).is_none());
1706 p!(ctx.set_peer_id(b"foobar"));
1707 assert_eq!(p!(ctx.peer_id()), Some(&b"foobar"[..]));
1708 }
1709
1710 #[test]
1711 fn peer_domain_name() {
1712 let mut ctx = p!(SslContext::new(SslProtocolSide::CLIENT, SslConnectionType::STREAM));
1713 assert_eq!("", p!(ctx.peer_domain_name()));
1714 p!(ctx.set_peer_domain_name("foobar.com"));
1715 assert_eq!("foobar.com", p!(ctx.peer_domain_name()));
1716 }
1717
1718 #[test]
1719 #[should_panic(expected = "blammo")]
1720 fn write_panic() {
1721 struct ExplodingStream(TcpStream);
1722
1723 impl Read for ExplodingStream {
1724 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
1725 self.0.read(buf)
1726 }
1727 }
1728
1729 impl Write for ExplodingStream {
1730 fn write(&mut self, _: &[u8]) -> io::Result<usize> {
1731 panic!("blammo");
1732 }
1733
1734 fn flush(&mut self) -> io::Result<()> {
1735 self.0.flush()
1736 }
1737 }
1738
1739 let mut ctx = p!(SslContext::new(SslProtocolSide::CLIENT, SslConnectionType::STREAM));
1740 p!(ctx.set_peer_domain_name("google.com"));
1741 let stream = p!(TcpStream::connect("google.com:443"));
1742 let _ = ctx.handshake(ExplodingStream(stream));
1743 }
1744
1745 #[test]
1746 #[should_panic(expected = "blammo")]
1747 fn read_panic() {
1748 struct ExplodingStream(TcpStream);
1749
1750 impl Read for ExplodingStream {
1751 fn read(&mut self, _: &mut [u8]) -> io::Result<usize> {
1752 panic!("blammo");
1753 }
1754 }
1755
1756 impl Write for ExplodingStream {
1757 fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
1758 self.0.write(buf)
1759 }
1760
1761 fn flush(&mut self) -> io::Result<()> {
1762 self.0.flush()
1763 }
1764 }
1765
1766 let mut ctx = p!(SslContext::new(SslProtocolSide::CLIENT, SslConnectionType::STREAM));
1767 p!(ctx.set_peer_domain_name("google.com"));
1768 let stream = p!(TcpStream::connect("google.com:443"));
1769 let _ = ctx.handshake(ExplodingStream(stream));
1770 }
1771
1772 #[test]
1773 fn zero_length_buffers() {
1774 let mut ctx = p!(SslContext::new(SslProtocolSide::CLIENT, SslConnectionType::STREAM));
1775 p!(ctx.set_peer_domain_name("google.com"));
1776 let stream = p!(TcpStream::connect("google.com:443"));
1777 let mut stream = ctx.handshake(stream).unwrap();
1778 assert_eq!(stream.write(b"").unwrap(), 0);
1779 assert_eq!(stream.read(&mut []).unwrap(), 0);
1780 }
1781}