1#![allow(clippy::result_large_err)]
75#[allow(unused_imports)]
76use core_foundation::array::{CFArray, CFArrayRef};
77use core_foundation::base::{Boolean, TCFType};
78#[cfg(feature = "alpn")]
79use core_foundation::string::CFString;
80use core_foundation::{declare_TCFType, impl_TCFType};
81use core_foundation_sys::base::{kCFAllocatorDefault, OSStatus};
82use std::os::raw::c_void;
83
84#[allow(unused_imports)]
85use security_framework_sys::base::{
86 errSecBadReq, errSecIO, errSecNotTrusted, errSecSuccess, errSecTrustSettingDeny,
87 errSecUnimplemented,
88};
89
90use security_framework_sys::secure_transport::*;
91use std::any::Any;
92use std::cmp;
93use std::fmt;
94use std::io;
95use std::io::prelude::*;
96use std::marker::PhantomData;
97use std::panic::{self, AssertUnwindSafe};
98use std::ptr;
99use std::result;
100use std::slice;
101
102use crate::base::{Error, Result};
103use crate::certificate::SecCertificate;
104use crate::cipher_suite::CipherSuite;
105use crate::cvt;
106use crate::identity::SecIdentity;
107use crate::import_export::Pkcs12ImportOptions;
108use crate::policy::SecPolicy;
109use crate::trust::SecTrust;
110use security_framework_sys::base::errSecParam;
111
112#[derive(Debug, Copy, Clone, PartialEq, Eq)]
114pub struct SslProtocolSide(SSLProtocolSide);
115
116impl SslProtocolSide {
117 pub const CLIENT: Self = Self(kSSLClientSide);
119 pub const SERVER: Self = Self(kSSLServerSide);
121}
122
123#[derive(Debug, Copy, Clone)]
125pub struct SslConnectionType(SSLConnectionType);
126
127impl SslConnectionType {
128 pub const DATAGRAM: Self = Self(kSSLDatagramType);
130 pub const STREAM: Self = Self(kSSLStreamType);
132}
133
134#[derive(Debug)]
136pub enum HandshakeError<S> {
137 Failure(Error),
139 Interrupted(MidHandshakeSslStream<S>),
141}
142
143impl<S> From<Error> for HandshakeError<S> {
144 #[inline(always)]
145 fn from(err: Error) -> Self {
146 Self::Failure(err)
147 }
148}
149
150#[derive(Debug)]
152pub enum ClientHandshakeError<S> {
153 Failure(Error),
155 Interrupted(MidHandshakeClientBuilder<S>),
157}
158
159impl<S> From<Error> for ClientHandshakeError<S> {
160 #[inline(always)]
161 fn from(err: Error) -> Self {
162 Self::Failure(err)
163 }
164}
165
166#[derive(Debug)]
168pub struct MidHandshakeSslStream<S> {
169 stream: SslStream<S>,
170 error: Error,
171}
172
173impl<S> MidHandshakeSslStream<S> {
174 #[inline(always)]
176 #[must_use]
177 pub fn get_ref(&self) -> &S {
178 self.stream.get_ref()
179 }
180
181 #[inline(always)]
183 pub fn get_mut(&mut self) -> &mut S {
184 self.stream.get_mut()
185 }
186
187 #[inline(always)]
189 #[must_use]
190 pub fn context(&self) -> &SslContext {
191 self.stream.context()
192 }
193
194 #[inline(always)]
196 pub fn context_mut(&mut self) -> &mut SslContext {
197 self.stream.context_mut()
198 }
199
200 #[inline(always)]
203 #[must_use]
204 pub fn server_auth_completed(&self) -> bool {
205 self.error.code() == errSSLPeerAuthCompleted
206 }
207
208 #[inline(always)]
211 #[must_use]
212 pub fn client_cert_requested(&self) -> bool {
213 self.error.code() == errSSLClientCertRequested
214 }
215
216 #[inline(always)]
219 #[must_use]
220 pub fn would_block(&self) -> bool {
221 self.error.code() == errSSLWouldBlock
222 }
223
224 #[inline(always)]
226 #[must_use]
227 pub const fn error(&self) -> &Error {
228 &self.error
229 }
230
231 #[inline(always)]
233 pub fn handshake(self) -> result::Result<SslStream<S>, HandshakeError<S>> {
234 self.stream.handshake()
235 }
236}
237
238#[derive(Debug)]
240pub struct MidHandshakeClientBuilder<S> {
241 stream: MidHandshakeSslStream<S>,
242 domain: Option<String>,
243 certs: Vec<SecCertificate>,
244 trust_certs_only: bool,
245 danger_accept_invalid_certs: bool,
246}
247
248impl<S> MidHandshakeClientBuilder<S> {
249 #[inline(always)]
251 #[must_use]
252 pub fn get_ref(&self) -> &S {
253 self.stream.get_ref()
254 }
255
256 #[inline(always)]
258 pub fn get_mut(&mut self) -> &mut S {
259 self.stream.get_mut()
260 }
261
262 #[inline(always)]
264 #[must_use]
265 pub fn error(&self) -> &Error {
266 self.stream.error()
267 }
268
269 pub fn handshake(self) -> result::Result<SslStream<S>, ClientHandshakeError<S>> {
271 let Self {
272 stream,
273 domain,
274 certs,
275 trust_certs_only,
276 danger_accept_invalid_certs,
277 } = self;
278
279 let mut result = stream.handshake();
280 loop {
281 let stream = match result {
282 Ok(stream) => return Ok(stream),
283 Err(HandshakeError::Interrupted(stream)) => stream,
284 Err(HandshakeError::Failure(err)) => return Err(ClientHandshakeError::Failure(err)),
285 };
286
287 if stream.would_block() {
288 let ret = Self {
289 stream,
290 domain,
291 certs,
292 trust_certs_only,
293 danger_accept_invalid_certs,
294 };
295 return Err(ClientHandshakeError::Interrupted(ret));
296 }
297
298 if stream.server_auth_completed() {
299 if danger_accept_invalid_certs {
300 result = stream.handshake();
301 continue;
302 }
303 let Some(mut trust) = stream.context().peer_trust2()? else {
304 result = stream.handshake();
305 continue;
306 };
307 trust.set_anchor_certificates(&certs)?;
308 trust.set_trust_anchor_certificates_only(self.trust_certs_only)?;
309 let policy = SecPolicy::create_ssl(SslProtocolSide::SERVER, domain.as_deref());
310 trust.set_policy(&policy)?;
311 trust.evaluate_with_error().map_err(|error| {
312 #[cfg(feature = "log")]
313 log::warn!("SecTrustEvaluateWithError: {error}");
314 Error::from_code(error.code() as _)
315 })?;
316 result = stream.handshake();
317 continue;
318 }
319
320 let err = Error::from_code(stream.error().code());
321 return Err(ClientHandshakeError::Failure(err));
322 }
323 }
324}
325
326#[derive(Debug, PartialEq, Eq)]
328pub struct SessionState(SSLSessionState);
329
330impl SessionState {
331 pub const ABORTED: Self = Self(kSSLAborted);
333 pub const CLOSED: Self = Self(kSSLClosed);
335 pub const CONNECTED: Self = Self(kSSLConnected);
337 pub const HANDSHAKE: Self = Self(kSSLHandshake);
339 pub const IDLE: Self = Self(kSSLIdle);
341}
342
343#[derive(Debug, Copy, Clone, PartialEq, Eq)]
345pub struct SslAuthenticate(SSLAuthenticate);
346
347impl SslAuthenticate {
348 pub const ALWAYS: Self = Self(kAlwaysAuthenticate);
350 pub const NEVER: Self = Self(kNeverAuthenticate);
352 pub const TRY: Self = Self(kTryAuthenticate);
354}
355
356#[derive(Debug, Copy, Clone, PartialEq, Eq)]
358pub struct SslClientCertificateState(SSLClientCertificateState);
359
360impl SslClientCertificateState {
361 pub const NONE: Self = Self(kSSLClientCertNone);
363 pub const REJECTED: Self = Self(kSSLClientCertRejected);
365 pub const REQUESTED: Self = Self(kSSLClientCertRequested);
367 pub const SENT: Self = Self(kSSLClientCertSent);
369}
370
371#[derive(Debug, Copy, Clone, PartialEq, Eq)]
373pub struct SslProtocol(SSLProtocol);
374
375impl SslProtocol {
376 pub const ALL: Self = Self(kSSLProtocolAll);
378 pub const DTLS1: Self = Self(kDTLSProtocol1);
380 pub const SSL2: Self = Self(kSSLProtocol2);
382 pub const SSL3: Self = Self(kSSLProtocol3);
385 pub const SSL3_ONLY: Self = Self(kSSLProtocol3Only);
387 pub const TLS1: Self = Self(kTLSProtocol1);
390 pub const TLS11: Self = Self(kTLSProtocol11);
393 pub const TLS12: Self = Self(kTLSProtocol12);
396 pub const TLS13: Self = Self(kTLSProtocol13);
399 pub const TLS1_ONLY: Self = Self(kTLSProtocol1Only);
401 pub const UNKNOWN: Self = Self(kSSLProtocolUnknown);
403}
404
405declare_TCFType! {
406 SslContext, SSLContextRef
408}
409
410impl_TCFType!(SslContext, SSLContextRef, SSLContextGetTypeID);
411
412impl fmt::Debug for SslContext {
413 #[cold]
414 fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
415 let mut builder = fmt.debug_struct("SslContext");
416 if let Ok(state) = self.state() {
417 builder.field("state", &state);
418 }
419 builder.finish()
420 }
421}
422
423unsafe impl Sync for SslContext {}
424unsafe impl Send for SslContext {}
425
426impl SslContext {
427 pub(crate) fn as_inner(&self) -> SSLContextRef {
428 self.0
429 }
430}
431
432macro_rules! impl_options {
433 ($($(#[$a:meta])* const $opt:ident: $get:ident & $set:ident,)*) => {
434 $(
435 #[allow(deprecated)]
436 $(#[$a])*
437 #[inline(always)]
438 pub fn $set(&mut self, value: bool) -> Result<()> {
439 unsafe { cvt(SSLSetSessionOption(self.0, $opt, Boolean::from(value))) }
440 }
441
442 #[allow(deprecated)]
443 $(#[$a])*
444 #[inline]
445 pub fn $get(&self) -> Result<bool> {
446 let mut value = 0;
447 unsafe { cvt(SSLGetSessionOption(self.0, $opt, &mut value))?; }
448 Ok(value != 0)
449 }
450 )*
451 }
452}
453
454impl SslContext {
455 #[inline]
458 pub fn new(side: SslProtocolSide, type_: SslConnectionType) -> Result<Self> {
459 unsafe {
460 let ctx = SSLCreateContext(kCFAllocatorDefault, side.0, type_.0);
461 Ok(Self(ctx))
462 }
463 }
464
465 #[inline]
474 pub fn set_peer_domain_name(&mut self, peer_name: &str) -> Result<()> {
475 unsafe {
476 cvt(SSLSetPeerDomainName(self.0, peer_name.as_ptr().cast(), peer_name.len()))
478 }
479 }
480
481 pub fn peer_domain_name(&self) -> Result<String> {
483 unsafe {
484 let mut len = 0;
485 cvt(SSLGetPeerDomainNameLength(self.0, &mut len))?;
486 let mut buf = vec![0; len];
487 cvt(SSLGetPeerDomainName(self.0, buf.as_mut_ptr().cast(), &mut len))?;
488 String::from_utf8(buf).map_err(|_| Error::from_code(-1))
489 }
490 }
491
492 pub fn set_certificate(
500 &mut self,
501 identity: &SecIdentity,
502 certs: &[SecCertificate],
503 ) -> Result<()> {
504 let mut arr = vec![identity.as_CFType()];
505 arr.extend(certs.iter().map(|c| c.as_CFType()));
506 let certs = CFArray::from_CFTypes(&arr);
507
508 unsafe { cvt(SSLSetCertificate(self.0, certs.as_concrete_TypeRef())) }
509 }
510
511 #[inline]
518 pub fn set_peer_id(&mut self, peer_id: &[u8]) -> Result<()> {
519 unsafe { cvt(SSLSetPeerID(self.0, peer_id.as_ptr().cast(), peer_id.len())) }
520 }
521
522 pub fn peer_id(&self) -> Result<Option<&[u8]>> {
524 unsafe {
525 let mut ptr = ptr::null();
526 let mut len = 0;
527 cvt(SSLGetPeerID(self.0, &mut ptr, &mut len))?;
528 if ptr.is_null() {
529 Ok(None)
530 } else {
531 Ok(Some(slice::from_raw_parts(ptr.cast(), len)))
532 }
533 }
534 }
535
536 pub fn supported_ciphers(&self) -> Result<Vec<CipherSuite>> {
538 unsafe {
539 let mut num_ciphers = 0;
540 cvt(SSLGetNumberSupportedCiphers(self.0, &mut num_ciphers))?;
541 let mut ciphers = vec![0; num_ciphers];
542 cvt(SSLGetSupportedCiphers(
543 self.0,
544 ciphers.as_mut_ptr(),
545 &mut num_ciphers,
546 ))?;
547 Ok(ciphers.iter().map(|c| CipherSuite::from_raw(*c)).collect())
548 }
549 }
550
551 pub fn enabled_ciphers(&self) -> Result<Vec<CipherSuite>> {
554 unsafe {
555 let mut num_ciphers = 0;
556 cvt(SSLGetNumberEnabledCiphers(self.0, &mut num_ciphers))?;
557 let mut ciphers = vec![0; num_ciphers];
558 cvt(SSLGetEnabledCiphers(
559 self.0,
560 ciphers.as_mut_ptr(),
561 &mut num_ciphers,
562 ))?;
563 Ok(ciphers.iter().map(|c| CipherSuite::from_raw(*c)).collect())
564 }
565 }
566
567 pub fn set_enabled_ciphers(&mut self, ciphers: &[CipherSuite]) -> Result<()> {
569 let ciphers = ciphers.iter().map(|c| c.to_raw()).collect::<Vec<_>>();
570 unsafe {
571 cvt(SSLSetEnabledCiphers(
572 self.0,
573 ciphers.as_ptr(),
574 ciphers.len(),
575 ))
576 }
577 }
578
579 #[inline]
581 pub fn negotiated_cipher(&self) -> Result<CipherSuite> {
582 unsafe {
583 let mut cipher = 0;
584 cvt(SSLGetNegotiatedCipher(self.0, &mut cipher))?;
585 Ok(CipherSuite::from_raw(cipher))
586 }
587 }
588
589 #[inline]
593 pub fn set_client_side_authenticate(&mut self, auth: SslAuthenticate) -> Result<()> {
594 unsafe { cvt(SSLSetClientSideAuthenticate(self.0, auth.0)) }
595 }
596
597 #[inline]
599 pub fn client_certificate_state(&self) -> Result<SslClientCertificateState> {
600 let mut state = 0;
601
602 unsafe {
603 cvt(SSLGetClientCertificateState(self.0, &mut state))?;
604 }
605 Ok(SslClientCertificateState(state))
606 }
607
608 pub fn peer_trust2(&self) -> Result<Option<SecTrust>> {
613 if self.state()? == SessionState::IDLE {
616 return Err(Error::from_code(errSecBadReq));
617 }
618
619 unsafe {
620 let mut trust = ptr::null_mut();
621 cvt(SSLCopyPeerTrust(self.0, &mut trust))?;
622 if trust.is_null() {
623 Ok(None)
624 } else {
625 Ok(Some(SecTrust::wrap_under_create_rule(trust)))
626 }
627 }
628 }
629
630 #[inline]
632 pub fn state(&self) -> Result<SessionState> {
633 unsafe {
634 let mut state = 0;
635 cvt(SSLGetSessionState(self.0, &mut state))?;
636 Ok(SessionState(state))
637 }
638 }
639
640 #[inline]
642 pub fn negotiated_protocol_version(&self) -> Result<SslProtocol> {
643 unsafe {
644 let mut version = 0;
645 cvt(SSLGetNegotiatedProtocolVersion(self.0, &mut version))?;
646 Ok(SslProtocol(version))
647 }
648 }
649
650 #[inline]
652 pub fn protocol_version_max(&self) -> Result<SslProtocol> {
653 unsafe {
654 let mut version = 0;
655 cvt(SSLGetProtocolVersionMax(self.0, &mut version))?;
656 Ok(SslProtocol(version))
657 }
658 }
659
660 #[inline]
662 pub fn set_protocol_version_max(&mut self, max_version: SslProtocol) -> Result<()> {
663 unsafe { cvt(SSLSetProtocolVersionMax(self.0, max_version.0)) }
664 }
665
666 #[inline]
668 pub fn protocol_version_min(&self) -> Result<SslProtocol> {
669 unsafe {
670 let mut version = 0;
671 cvt(SSLGetProtocolVersionMin(self.0, &mut version))?;
672 Ok(SslProtocol(version))
673 }
674 }
675
676 #[inline]
678 pub fn set_protocol_version_min(&mut self, min_version: SslProtocol) -> Result<()> {
679 unsafe { cvt(SSLSetProtocolVersionMin(self.0, min_version.0)) }
680 }
681
682 #[cfg(feature = "alpn")]
684 pub fn alpn_protocols(&self) -> Result<Vec<String>> {
685 let mut array: CFArrayRef = ptr::null();
686 unsafe {
687 #[cfg(feature = "OSX_10_13")]
688 {
689 cvt(SSLCopyALPNProtocols(self.0, &mut array))?;
690 }
691
692 #[cfg(not(feature = "OSX_10_13"))]
693 {
694 dlsym! { fn SSLCopyALPNProtocols(SSLContextRef, *mut CFArrayRef) -> OSStatus }
695 if let Some(f) = SSLCopyALPNProtocols.get() {
696 cvt(f(self.0, &mut array))?;
697 } else {
698 return Err(Error::from_code(errSecUnimplemented));
699 }
700 }
701
702 if array.is_null() {
703 return Ok(vec![]);
704 }
705
706 let array = CFArray::<CFString>::wrap_under_create_rule(array);
707 Ok(array.into_iter().map(|p| p.to_string()).collect())
708 }
709 }
710
711 #[cfg(feature = "alpn")]
715 pub fn set_alpn_protocols(&mut self, protocols: &[&str]) -> Result<()> {
716 let protocols = CFArray::from_CFTypes(
720 &protocols
721 .iter()
722 .map(|proto| CFString::new(proto))
723 .collect::<Vec<_>>(),
724 );
725
726 #[cfg(feature = "OSX_10_13")]
727 {
728 unsafe { cvt(SSLSetALPNProtocols(self.0, protocols.as_concrete_TypeRef())) }
729 }
730 #[cfg(not(feature = "OSX_10_13"))]
731 {
732 dlsym! { fn SSLSetALPNProtocols(SSLContextRef, CFArrayRef) -> OSStatus }
733 if let Some(f) = SSLSetALPNProtocols.get() {
734 unsafe { cvt(f(self.0, protocols.as_concrete_TypeRef())) }
735 } else {
736 Err(Error::from_code(errSecUnimplemented))
737 }
738 }
739 }
740
741 #[cfg(feature = "session-tickets")]
749 pub fn set_session_tickets_enabled(&mut self, enabled: bool) -> Result<()> {
750 #[cfg(feature = "OSX_10_13")]
751 {
752 unsafe { cvt(SSLSetSessionTicketsEnabled(self.0, Boolean::from(enabled))) }
753 }
754 #[cfg(not(feature = "OSX_10_13"))]
755 {
756 dlsym! { fn SSLSetSessionTicketsEnabled(SSLContextRef, Boolean) -> OSStatus }
757 if let Some(f) = SSLSetSessionTicketsEnabled.get() {
758 unsafe { cvt(f(self.0, Boolean::from(enabled))) }
759 } else {
760 Err(Error::from_code(errSecUnimplemented))
761 }
762 }
763 }
764
765 #[inline]
768 pub fn buffered_read_size(&self) -> Result<usize> {
769 unsafe {
770 let mut size = 0;
771 cvt(SSLGetBufferedReadSize(self.0, &mut size))?;
772 Ok(size)
773 }
774 }
775
776 impl_options! {
777 const kSSLSessionOptionBreakOnServerAuth: break_on_server_auth & set_break_on_server_auth,
780 const kSSLSessionOptionBreakOnCertRequested: break_on_cert_requested & set_break_on_cert_requested,
783 const kSSLSessionOptionBreakOnClientAuth: break_on_client_auth & set_break_on_client_auth,
786 const kSSLSessionOptionFalseStart: false_start & set_false_start,
790 const kSSLSessionOptionSendOneByteRecord: send_one_byte_record & set_send_one_byte_record,
794 }
795
796 fn into_stream<S>(self, stream: S) -> Result<SslStream<S>>
797 where S: Read + Write {
798 unsafe {
799 let ret = SSLSetIOFuncs(self.0, read_func::<S>, write_func::<S>);
800 if ret != errSecSuccess {
801 return Err(Error::from_code(ret));
802 }
803
804 let stream = Connection { stream, err: None, panic: None };
805 let stream = Box::into_raw(Box::new(stream));
806 let ret = SSLSetConnection(self.0, stream.cast());
807 if ret != errSecSuccess {
808 let _conn = Box::from_raw(stream);
809 return Err(Error::from_code(ret));
810 }
811
812 Ok(SslStream { ctx: self, _m: PhantomData })
813 }
814 }
815
816 pub fn handshake<S>(self, stream: S) -> result::Result<SslStream<S>, HandshakeError<S>>
818 where
819 S: Read + Write,
820 {
821 self.into_stream(stream)
822 .map_err(HandshakeError::Failure)
823 .and_then(SslStream::handshake)
824 }
825}
826
827struct Connection<S> {
828 stream: S,
829 err: Option<io::Error>,
830 panic: Option<Box<dyn Any + Send>>,
831}
832
833#[cold]
835fn translate_err(e: &io::Error) -> OSStatus {
836 match e.kind() {
837 io::ErrorKind::NotFound => errSSLClosedGraceful,
838 io::ErrorKind::ConnectionReset => errSSLClosedAbort,
839 io::ErrorKind::WouldBlock |
840 io::ErrorKind::NotConnected => errSSLWouldBlock,
841 _ => errSecIO,
842 }
843}
844
845unsafe extern "C" fn read_func<S>(
846 connection: SSLConnectionRef,
847 data: *mut c_void,
848 data_length: *mut usize,
849) -> OSStatus
850where S: Read {
851 let conn: &mut Connection<S> = &mut *(connection as *mut _);
852 let mut read = 0;
853
854 let ret = panic::catch_unwind(AssertUnwindSafe(|| {
855 let mut data = slice::from_raw_parts_mut(data.cast::<u8>(), *data_length);
856 while !data.is_empty() {
857 match conn.stream.read(data) {
858 Ok(0) => return errSSLClosedNoNotify,
859 Ok(len) => {
860 let Some(rest) = data.get_mut(len..) else {
861 return errSecIO;
862 };
863 data = rest;
864 read += len;
865 },
866 Err(e) => {
867 let ret = translate_err(&e);
868 conn.err = Some(e);
869 return ret;
870 },
871 }
872 }
873 errSecSuccess
874 }))
875 .unwrap_or_else(|e| {
876 conn.panic = Some(e);
877 errSecIO
878 });
879
880 *data_length = read;
881 ret
882}
883
884unsafe extern "C" fn write_func<S>(
885 connection: SSLConnectionRef,
886 data: *const c_void,
887 data_length: *mut usize,
888) -> OSStatus
889where S: Write {
890 let conn: &mut Connection<S> = &mut *(connection as *mut _);
891 let mut written = 0;
892
893 let ret = panic::catch_unwind(AssertUnwindSafe(|| {
894 let mut data = slice::from_raw_parts(data.cast::<u8>(), *data_length);
895 while !data.is_empty() {
896 match conn.stream.write(data) {
897 Ok(0) => return errSSLClosedNoNotify,
898 Ok(len) => {
899 let Some(rest) = data.get(len..) else {
900 return errSecIO;
901 };
902 data = rest;
903 written += len;
904 },
905 Err(e) => {
906 let ret = translate_err(&e);
907 conn.err = Some(e);
908 return ret;
909 },
910 }
911 }
912 if let Err(e) = conn.stream.flush() {
916 let ret = translate_err(&e);
917 conn.err = Some(e);
918 return ret;
919 }
920 errSecSuccess
921 }))
922 .unwrap_or_else(|e| {
923 conn.panic = Some(e);
924 errSecIO
925 });
926
927 *data_length = written;
928 ret
929}
930
931pub struct SslStream<S> {
933 ctx: SslContext,
934 _m: PhantomData<S>,
935}
936
937impl<S: fmt::Debug> fmt::Debug for SslStream<S> {
938 #[cold]
939 fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
940 fmt.debug_struct("SslStream")
941 .field("context", &self.ctx)
942 .field("stream", self.get_ref())
943 .finish()
944 }
945}
946
947impl<S> Drop for SslStream<S> {
948 fn drop(&mut self) {
949 unsafe {
950 let mut conn = ptr::null();
951 let ret = SSLGetConnection(self.ctx.0, &mut conn);
952 assert!(ret == errSecSuccess);
953 let _ = Box::<Connection<S>>::from_raw(conn as *mut _);
954 }
955 }
956}
957
958impl<S> SslStream<S> {
959 fn handshake(mut self) -> result::Result<Self, HandshakeError<S>> {
960 match unsafe { SSLHandshake(self.ctx.0) } {
961 errSecSuccess => Ok(self),
962 reason @ errSSLPeerAuthCompleted
963 | reason @ errSSLClientCertRequested
964 | reason @ errSSLWouldBlock
965 | reason @ errSSLClientHelloReceived => {
966 Err(HandshakeError::Interrupted(MidHandshakeSslStream {
967 stream: self,
968 error: Error::from_code(reason),
969 }))
970 },
971 err => {
972 self.check_panic();
973 Err(HandshakeError::Failure(Error::from_code(err)))
974 },
975 }
976 }
977
978 #[inline(always)]
980 #[must_use]
981 pub fn get_ref(&self) -> &S {
982 &self.connection().stream
983 }
984
985 #[inline(always)]
987 pub fn get_mut(&mut self) -> &mut S {
988 &mut self.connection_mut().stream
989 }
990
991 #[inline(always)]
993 #[must_use]
994 pub fn context(&self) -> &SslContext {
995 &self.ctx
996 }
997
998 #[inline(always)]
1000 pub fn context_mut(&mut self) -> &mut SslContext {
1001 &mut self.ctx
1002 }
1003
1004 pub fn close(&mut self) -> result::Result<(), io::Error> {
1006 unsafe {
1007 let ret = SSLClose(self.ctx.0);
1008 if ret == errSecSuccess {
1009 Ok(())
1010 } else {
1011 Err(self.get_error(ret))
1012 }
1013 }
1014 }
1015
1016 fn connection(&self) -> &Connection<S> {
1017 unsafe {
1018 let mut conn = ptr::null();
1019 let ret = SSLGetConnection(self.ctx.0, &mut conn);
1020 assert!(ret == errSecSuccess);
1021
1022 &mut *(conn as *mut Connection<S>)
1023 }
1024 }
1025
1026 fn connection_mut(&mut self) -> &mut Connection<S> {
1027 unsafe {
1028 let mut conn = ptr::null();
1029 let ret = SSLGetConnection(self.ctx.0, &mut conn);
1030 assert!(ret == errSecSuccess);
1031
1032 &mut *(conn as *mut Connection<S>)
1033 }
1034 }
1035
1036 #[cold]
1037 fn check_panic(&mut self) {
1038 let conn = self.connection_mut();
1039 if let Some(err) = conn.panic.take() {
1040 panic::resume_unwind(err);
1041 }
1042 }
1043
1044 #[cold]
1045 fn get_error(&mut self, ret: OSStatus) -> io::Error {
1046 self.check_panic();
1047
1048 if let Some(err) = self.connection_mut().err.take() {
1049 err
1050 } else {
1051 io::Error::new(io::ErrorKind::Other, Error::from_code(ret))
1052 }
1053 }
1054}
1055
1056impl<S: Read + Write> Read for SslStream<S> {
1057 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
1058 if buf.is_empty() {
1063 return Ok(0);
1064 }
1065
1066 let buffered = self.context().buffered_read_size().unwrap_or(0);
1071 let to_read = if buffered > 0 {
1072 cmp::min(buffered, buf.len())
1073 } else {
1074 buf.len()
1075 };
1076
1077 unsafe {
1078 let mut nread = 0;
1079 let ret = SSLRead(self.ctx.0, buf.as_mut_ptr().cast(), to_read, &mut nread);
1080 if nread > 0 {
1083 return Ok(nread);
1084 }
1085
1086 match ret {
1087 errSSLClosedGraceful | errSSLClosedAbort | errSSLClosedNoNotify => Ok(0),
1088 errSSLPeerAuthCompleted => self.read(buf),
1090 _ => Err(self.get_error(ret)),
1091 }
1092 }
1093 }
1094}
1095
1096impl<S: Read + Write> Write for SslStream<S> {
1097 fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
1098 if buf.is_empty() {
1100 return Ok(0);
1101 }
1102 unsafe {
1103 let mut nwritten = 0;
1104 let ret = SSLWrite(
1105 self.ctx.0,
1106 buf.as_ptr().cast(),
1107 buf.len(),
1108 &mut nwritten,
1109 );
1110 if nwritten > 0 {
1113 Ok(nwritten)
1114 } else {
1115 Err(self.get_error(ret))
1116 }
1117 }
1118 }
1119
1120 fn flush(&mut self) -> io::Result<()> {
1121 self.connection_mut().stream.flush()
1122 }
1123}
1124
1125#[derive(Debug)]
1127pub struct ClientBuilder {
1128 identity: Option<SecIdentity>,
1129 certs: Vec<SecCertificate>,
1130 chain: Vec<SecCertificate>,
1131 protocol_min: Option<SslProtocol>,
1132 protocol_max: Option<SslProtocol>,
1133 trust_certs_only: bool,
1134 use_sni: bool,
1135 danger_accept_invalid_certs: bool,
1136 danger_accept_invalid_hostnames: bool,
1137 whitelisted_ciphers: Vec<CipherSuite>,
1138 blacklisted_ciphers: Vec<CipherSuite>,
1139 #[cfg(feature = "alpn")]
1140 alpn: Option<Vec<String>>,
1141 #[cfg(feature = "session-tickets")]
1142 enable_session_tickets: bool,
1143}
1144
1145impl Default for ClientBuilder {
1146 #[inline(always)]
1147 fn default() -> Self {
1148 Self::new()
1149 }
1150}
1151
1152impl ClientBuilder {
1153 #[inline]
1155 #[must_use]
1156 pub fn new() -> Self {
1157 Self {
1158 identity: None,
1159 certs: Vec::new(),
1160 chain: Vec::new(),
1161 protocol_min: None,
1162 protocol_max: None,
1163 trust_certs_only: false,
1164 use_sni: true,
1165 danger_accept_invalid_certs: false,
1166 danger_accept_invalid_hostnames: false,
1167 whitelisted_ciphers: Vec::new(),
1168 blacklisted_ciphers: Vec::new(),
1169 #[cfg(feature = "alpn")]
1170 alpn: None,
1171 #[cfg(feature = "session-tickets")]
1172 enable_session_tickets: false,
1173 }
1174 }
1175
1176 #[inline]
1179 pub fn anchor_certificates(&mut self, certs: &[SecCertificate]) -> &mut Self {
1180 certs.clone_into(&mut self.certs);
1181 self
1182 }
1183
1184 #[inline]
1187 pub fn add_anchor_certificate(&mut self, certs: &SecCertificate) -> &mut Self {
1188 self.certs.push(certs.to_owned());
1189 self
1190 }
1191
1192 #[inline(always)]
1195 pub fn trust_anchor_certificates_only(&mut self, only: bool) -> &mut Self {
1196 self.trust_certs_only = only;
1197 self
1198 }
1199
1200 #[inline(always)]
1209 pub fn danger_accept_invalid_certs(&mut self, noverify: bool) -> &mut Self {
1210 self.danger_accept_invalid_certs = noverify;
1211 self
1212 }
1213
1214 #[inline(always)]
1216 pub fn use_sni(&mut self, use_sni: bool) -> &mut Self {
1217 self.use_sni = use_sni;
1218 self
1219 }
1220
1221 #[inline(always)]
1229 pub fn danger_accept_invalid_hostnames(
1230 &mut self,
1231 danger_accept_invalid_hostnames: bool,
1232 ) -> &mut Self {
1233 self.danger_accept_invalid_hostnames = danger_accept_invalid_hostnames;
1234 self
1235 }
1236
1237 pub fn whitelist_ciphers(&mut self, whitelisted_ciphers: &[CipherSuite]) -> &mut Self {
1239 whitelisted_ciphers.clone_into(&mut self.whitelisted_ciphers);
1240 self
1241 }
1242
1243 pub fn blacklist_ciphers(&mut self, blacklisted_ciphers: &[CipherSuite]) -> &mut Self {
1245 blacklisted_ciphers.clone_into(&mut self.blacklisted_ciphers);
1246 self
1247 }
1248
1249 pub fn identity(&mut self, identity: &SecIdentity, chain: &[SecCertificate]) -> &mut Self {
1251 self.identity = Some(identity.clone());
1252 chain.clone_into(&mut self.chain);
1253 self
1254 }
1255
1256 #[inline(always)]
1258 pub fn protocol_min(&mut self, min: SslProtocol) -> &mut Self {
1259 self.protocol_min = Some(min);
1260 self
1261 }
1262
1263 #[inline(always)]
1265 pub fn protocol_max(&mut self, max: SslProtocol) -> &mut Self {
1266 self.protocol_max = Some(max);
1267 self
1268 }
1269
1270 #[cfg(feature = "alpn")]
1272 pub fn alpn_protocols(&mut self, protocols: &[&str]) -> &mut Self {
1273 self.alpn = Some(protocols.iter().map(|s| (*s).to_string()).collect());
1274 self
1275 }
1276
1277 #[cfg(feature = "session-tickets")]
1281 #[inline(always)]
1282 pub fn enable_session_tickets(&mut self, enable: bool) -> &mut Self {
1283 self.enable_session_tickets = enable;
1284 self
1285 }
1286
1287 pub fn handshake<S>(
1291 &self,
1292 domain: &str,
1293 stream: S,
1294 ) -> result::Result<SslStream<S>, ClientHandshakeError<S>>
1295 where
1296 S: Read + Write,
1297 {
1298 let stream = MidHandshakeSslStream {
1301 stream: self.ctx_into_stream(domain, stream)?,
1302 error: Error::from(errSecSuccess),
1303 };
1304
1305 let certs = self.certs.clone();
1306 let stream = MidHandshakeClientBuilder {
1307 stream,
1308 domain: if self.danger_accept_invalid_hostnames {
1309 None
1310 } else {
1311 Some(domain.to_string())
1312 },
1313 certs,
1314 trust_certs_only: self.trust_certs_only,
1315 danger_accept_invalid_certs: self.danger_accept_invalid_certs,
1316 };
1317 stream.handshake()
1318 }
1319
1320 fn ctx_into_stream<S>(&self, domain: &str, stream: S) -> Result<SslStream<S>>
1321 where S: Read + Write {
1322 let mut ctx = SslContext::new(SslProtocolSide::CLIENT, SslConnectionType::STREAM)?;
1323
1324 if self.use_sni {
1325 ctx.set_peer_domain_name(domain)?;
1326 }
1327 if let Some(ref identity) = self.identity {
1328 ctx.set_certificate(identity, &self.chain)?;
1329 }
1330 #[cfg(feature = "alpn")]
1331 {
1332 if let Some(ref alpn) = self.alpn {
1333 ctx.set_alpn_protocols(&alpn.iter().map(|s| &**s).collect::<Vec<_>>())?;
1334 }
1335 }
1336 #[cfg(feature = "session-tickets")]
1337 {
1338 if self.enable_session_tickets {
1339 ctx.set_peer_id(domain.as_bytes())?;
1342 ctx.set_session_tickets_enabled(true)?;
1343 }
1344 }
1345 ctx.set_break_on_server_auth(true)?;
1346 self.configure_protocols(&mut ctx)?;
1347 self.configure_ciphers(&mut ctx)?;
1348
1349 ctx.into_stream(stream)
1350 }
1351
1352 fn configure_protocols(&self, ctx: &mut SslContext) -> Result<()> {
1353 if let Some(min) = self.protocol_min {
1354 ctx.set_protocol_version_min(min)?;
1355 }
1356 if let Some(max) = self.protocol_max {
1357 ctx.set_protocol_version_max(max)?;
1358 }
1359 Ok(())
1360 }
1361
1362 fn configure_ciphers(&self, ctx: &mut SslContext) -> Result<()> {
1363 let mut ciphers = if self.whitelisted_ciphers.is_empty() {
1364 ctx.enabled_ciphers()?
1365 } else {
1366 self.whitelisted_ciphers.clone()
1367 };
1368
1369 if !self.blacklisted_ciphers.is_empty() {
1370 ciphers.retain(|cipher| !self.blacklisted_ciphers.contains(cipher));
1371 }
1372
1373 ctx.set_enabled_ciphers(&ciphers)?;
1374 Ok(())
1375 }
1376}
1377
1378#[derive(Debug)]
1380pub struct ServerBuilder {
1381 identity: SecIdentity,
1382 certs: Vec<SecCertificate>,
1383}
1384
1385impl ServerBuilder {
1386 #[must_use]
1389 pub fn new(identity: &SecIdentity, certs: &[SecCertificate]) -> Self {
1390 Self {
1391 identity: identity.clone(),
1392 certs: certs.to_owned(),
1393 }
1394 }
1395
1396 pub fn from_pkcs12(pkcs12_der: &[u8], passphrase: &str) -> Result<Self> {
1403 let mut identities: Vec<(SecIdentity, Vec<SecCertificate>)> = Pkcs12ImportOptions::new()
1404 .passphrase(passphrase)
1405 .import(pkcs12_der)?
1406 .into_iter()
1407 .filter_map(|idendity| {
1408 Some((idendity.identity?, idendity.cert_chain.unwrap_or_default()))
1409 })
1410 .take(2)
1411 .collect();
1412 if identities.len() == 1 {
1413 let (identity, certs) = identities.pop().unwrap();
1414 Ok(Self { identity, certs })
1415 } else {
1416 Err(Error::from_code(errSecParam))
1418 }
1419 }
1420
1421 pub fn new_ssl_context(&self) -> Result<SslContext> {
1423 let mut ctx = SslContext::new(SslProtocolSide::SERVER, SslConnectionType::STREAM)?;
1424 ctx.set_certificate(&self.identity, &self.certs)?;
1425 Ok(ctx)
1426 }
1427
1428 pub fn handshake<S>(&self, stream: S) -> Result<SslStream<S>>
1430 where S: Read + Write {
1431 match self.new_ssl_context()?.handshake(stream) {
1432 Ok(stream) => Ok(stream),
1433 Err(HandshakeError::Interrupted(stream)) => Err(*stream.error()),
1434 Err(HandshakeError::Failure(err)) => Err(err),
1435 }
1436 }
1437}
1438
1439#[cfg(test)]
1440mod test {
1441 use std::io::prelude::*;
1442 use std::net::TcpStream;
1443
1444 use super::*;
1445
1446 #[test]
1447 fn server_builder_from_pkcs12() {
1448 let pkcs12_der = include_bytes!("../test/server.p12");
1449 ServerBuilder::from_pkcs12(pkcs12_der, "password123").unwrap();
1450 }
1451
1452 #[test]
1453 fn connect() {
1454 let mut ctx = p!(SslContext::new(SslProtocolSide::CLIENT, SslConnectionType::STREAM));
1455 p!(ctx.set_peer_domain_name("google.com"));
1456 let stream = p!(TcpStream::connect("google.com:443"));
1457 p!(ctx.handshake(stream));
1458 }
1459
1460 #[test]
1461 fn connect_bad_domain() {
1462 let mut ctx = p!(SslContext::new(SslProtocolSide::CLIENT, SslConnectionType::STREAM));
1463 p!(ctx.set_peer_domain_name("foobar.com"));
1464 let stream = p!(TcpStream::connect("google.com:443"));
1465 ctx.handshake(stream).expect_err("expected failure");
1466 }
1467
1468 #[test]
1469 fn connect_buffered_stream() {
1470 use std::io::BufWriter;
1471
1472 #[derive(Debug)]
1474 struct BufferedTcpStream {
1475 reader: TcpStream,
1476 writer: BufWriter<TcpStream>,
1477 }
1478
1479 impl BufferedTcpStream {
1480 fn new(tcp: TcpStream) -> std::io::Result<Self> {
1481 Ok(Self {
1482 writer: BufWriter::with_capacity(500, tcp.try_clone()?),
1483 reader: tcp,
1484 })
1485 }
1486 }
1487
1488 impl Read for BufferedTcpStream {
1489 fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
1490 self.reader.read(buf)
1491 }
1492 }
1493
1494 impl Write for BufferedTcpStream {
1495 fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
1496 self.writer.write(buf)
1497 }
1498
1499 fn flush(&mut self) -> std::io::Result<()> {
1500 self.writer.flush()
1501 }
1502 }
1503
1504 let mut ctx = p!(SslContext::new(SslProtocolSide::CLIENT, SslConnectionType::STREAM));
1505 p!(ctx.set_peer_domain_name("google.com"));
1506 let stream = p!(TcpStream::connect("google.com:443"));
1507 let stream = p!(BufferedTcpStream::new(stream));
1508 p!(ctx.handshake(stream));
1509 }
1510
1511 #[test]
1512 fn load_page() {
1513 let mut ctx = p!(SslContext::new(SslProtocolSide::CLIENT, SslConnectionType::STREAM));
1514 p!(ctx.set_peer_domain_name("google.com"));
1515 let stream = p!(TcpStream::connect("google.com:443"));
1516 let mut stream = p!(ctx.handshake(stream));
1517 p!(stream.write_all(b"GET / HTTP/1.0\r\n\r\n"));
1518 p!(stream.flush());
1519 let mut buf = vec![];
1520 p!(stream.read_to_end(&mut buf));
1521 println!("{}", String::from_utf8_lossy(&buf));
1522 }
1523
1524 #[test]
1525 fn client_no_session_ticket_resumption() {
1526 for _ in 0..2 {
1527 let stream = p!(TcpStream::connect("google.com:443"));
1528
1529 let stream = MidHandshakeSslStream {
1531 stream: ClientBuilder::new()
1532 .ctx_into_stream("google.com", stream)
1533 .unwrap(),
1534 error: Error::from(errSecSuccess),
1535 };
1536
1537 let mut result = stream.handshake();
1538
1539 if let Err(HandshakeError::Interrupted(stream)) = result {
1540 assert!(stream.server_auth_completed());
1541 result = stream.handshake();
1542 } else {
1543 panic!("Unexpectedly skipped server auth");
1544 }
1545
1546 assert!(result.is_ok());
1547 }
1548 }
1549
1550 #[test]
1551 #[cfg(feature = "session-tickets")]
1552 fn client_session_ticket_resumption() {
1553 for i in 0..2 {
1556 let stream = p!(TcpStream::connect("google.com:443"));
1557 let mut builder = ClientBuilder::new();
1558 builder.enable_session_tickets(true);
1559
1560 let stream = MidHandshakeSslStream {
1562 stream: builder.ctx_into_stream("google.com", stream).unwrap(),
1563 error: Error::from(errSecSuccess),
1564 };
1565
1566 let mut result = stream.handshake();
1567
1568 if let Err(HandshakeError::Interrupted(stream)) = result {
1569 assert!(stream.server_auth_completed());
1570 assert_eq!(i, 0, "Session ticket resumption did not work, server auth was not skipped");
1571 result = stream.handshake();
1572 } else {
1573 assert_eq!(i, 1, "Unexpectedly skipped server auth");
1574 }
1575
1576 assert!(result.is_ok());
1577 }
1578 }
1579
1580 #[test]
1581 #[cfg(feature = "alpn")]
1582 fn client_alpn_accept() {
1583 let mut ctx = p!(SslContext::new(SslProtocolSide::CLIENT, SslConnectionType::STREAM));
1584 p!(ctx.set_peer_domain_name("google.com"));
1585 p!(ctx.set_alpn_protocols(&vec!["h2"]));
1586 let stream = p!(TcpStream::connect("google.com:443"));
1587 let stream = ctx.handshake(stream).unwrap();
1588 assert_eq!(vec!["h2"], stream.context().alpn_protocols().unwrap());
1589 }
1590
1591 #[test]
1592 #[cfg(feature = "alpn")]
1593 fn client_alpn_reject() {
1594 let mut ctx = p!(SslContext::new(SslProtocolSide::CLIENT, SslConnectionType::STREAM));
1595 p!(ctx.set_peer_domain_name("google.com"));
1596 p!(ctx.set_alpn_protocols(&vec!["h2c"]));
1597 let stream = p!(TcpStream::connect("google.com:443"));
1598 let stream = ctx.handshake(stream).unwrap();
1599 assert!(stream.context().alpn_protocols().is_err());
1600 }
1601
1602 #[test]
1603 fn client_no_anchor_certs() {
1604 let stream = p!(TcpStream::connect("google.com:443"));
1605 assert!(ClientBuilder::new()
1606 .trust_anchor_certificates_only(true)
1607 .handshake("google.com", stream)
1608 .is_err());
1609 }
1610
1611 #[test]
1612 fn client_bad_domain() {
1613 let stream = p!(TcpStream::connect("google.com:443"));
1614 assert!(ClientBuilder::new()
1615 .handshake("foobar.com", stream)
1616 .is_err());
1617 }
1618
1619 #[test]
1620 fn client_bad_domain_ignored() {
1621 let stream = p!(TcpStream::connect("google.com:443"));
1622 ClientBuilder::new()
1623 .danger_accept_invalid_hostnames(true)
1624 .handshake("foobar.com", stream)
1625 .unwrap();
1626 }
1627
1628 #[test]
1629 fn connect_no_verify_ssl() {
1630 let stream = p!(TcpStream::connect("expired.badssl.com:443"));
1631 let mut builder = ClientBuilder::new();
1632 builder.danger_accept_invalid_certs(true);
1633 builder.handshake("expired.badssl.com", stream).unwrap();
1634 }
1635
1636 #[test]
1637 fn load_page_client() {
1638 let stream = p!(TcpStream::connect("google.com:443"));
1639 let mut stream = p!(ClientBuilder::new().handshake("google.com", stream));
1640 p!(stream.write_all(b"GET / HTTP/1.0\r\n\r\n"));
1641 p!(stream.flush());
1642 let mut buf = vec![];
1643 p!(stream.read_to_end(&mut buf));
1644 println!("{}", String::from_utf8_lossy(&buf));
1645 }
1646
1647 #[test]
1648 #[cfg_attr(any(target_os = "ios", target_os = "tvos", target_os = "watchos", target_os = "visionos"), ignore)] fn cipher_configuration() {
1650 let mut ctx = p!(SslContext::new(SslProtocolSide::SERVER, SslConnectionType::STREAM));
1651 let ciphers = p!(ctx.enabled_ciphers());
1652 let ciphers = ciphers
1653 .iter()
1654 .enumerate()
1655 .filter_map(|(i, c)| if i % 2 == 0 { Some(*c) } else { None })
1656 .collect::<Vec<_>>();
1657 p!(ctx.set_enabled_ciphers(&ciphers));
1658 assert_eq!(ciphers, p!(ctx.enabled_ciphers()));
1659 }
1660
1661 #[test]
1662 fn test_builder_whitelist_ciphers() {
1663 let stream = p!(TcpStream::connect("google.com:443"));
1664
1665 let ctx = p!(SslContext::new(SslProtocolSide::CLIENT, SslConnectionType::STREAM));
1666 assert!(p!(ctx.enabled_ciphers()).len() > 1);
1667
1668 let ciphers = p!(ctx.enabled_ciphers());
1669 let cipher = ciphers.first().unwrap();
1670 let stream = p!(ClientBuilder::new()
1671 .whitelist_ciphers(&[*cipher])
1672 .ctx_into_stream("google.com", stream));
1673
1674 assert_eq!(1, p!(stream.context().enabled_ciphers()).len());
1675 }
1676
1677 #[test]
1678 #[cfg_attr(any(target_os = "ios", target_os = "tvos", target_os = "watchos", target_os = "visionos"), ignore)] fn test_builder_blacklist_ciphers() {
1680 let stream = p!(TcpStream::connect("google.com:443"));
1681
1682 let ctx = p!(SslContext::new(SslProtocolSide::CLIENT, SslConnectionType::STREAM));
1683 let num = p!(ctx.enabled_ciphers()).len();
1684 assert!(num > 1);
1685
1686 let ciphers = p!(ctx.enabled_ciphers());
1687 let cipher = ciphers.first().unwrap();
1688 let stream = p!(ClientBuilder::new()
1689 .blacklist_ciphers(&[*cipher])
1690 .ctx_into_stream("google.com", stream));
1691
1692 assert_eq!(num - 1, p!(stream.context().enabled_ciphers()).len());
1693 }
1694
1695 #[test]
1696 fn idle_context_peer_trust() {
1697 let ctx = p!(SslContext::new(SslProtocolSide::SERVER, SslConnectionType::STREAM));
1698 assert!(ctx.peer_trust2().is_err());
1699 }
1700
1701 #[test]
1702 fn peer_id() {
1703 let mut ctx = p!(SslContext::new(SslProtocolSide::SERVER, SslConnectionType::STREAM));
1704 assert!(p!(ctx.peer_id()).is_none());
1705 p!(ctx.set_peer_id(b"foobar"));
1706 assert_eq!(p!(ctx.peer_id()), Some(&b"foobar"[..]));
1707 }
1708
1709 #[test]
1710 fn peer_domain_name() {
1711 let mut ctx = p!(SslContext::new(SslProtocolSide::CLIENT, SslConnectionType::STREAM));
1712 assert_eq!("", p!(ctx.peer_domain_name()));
1713 p!(ctx.set_peer_domain_name("foobar.com"));
1714 assert_eq!("foobar.com", p!(ctx.peer_domain_name()));
1715 }
1716
1717 #[test]
1718 #[should_panic(expected = "blammo")]
1719 fn write_panic() {
1720 struct ExplodingStream(TcpStream);
1721
1722 impl Read for ExplodingStream {
1723 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
1724 self.0.read(buf)
1725 }
1726 }
1727
1728 impl Write for ExplodingStream {
1729 fn write(&mut self, _: &[u8]) -> io::Result<usize> {
1730 panic!("blammo");
1731 }
1732
1733 fn flush(&mut self) -> io::Result<()> {
1734 self.0.flush()
1735 }
1736 }
1737
1738 let mut ctx = p!(SslContext::new(SslProtocolSide::CLIENT, SslConnectionType::STREAM));
1739 p!(ctx.set_peer_domain_name("google.com"));
1740 let stream = p!(TcpStream::connect("google.com:443"));
1741 let _ = ctx.handshake(ExplodingStream(stream));
1742 }
1743
1744 #[test]
1745 #[should_panic(expected = "blammo")]
1746 fn read_panic() {
1747 struct ExplodingStream(TcpStream);
1748
1749 impl Read for ExplodingStream {
1750 fn read(&mut self, _: &mut [u8]) -> io::Result<usize> {
1751 panic!("blammo");
1752 }
1753 }
1754
1755 impl Write for ExplodingStream {
1756 fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
1757 self.0.write(buf)
1758 }
1759
1760 fn flush(&mut self) -> io::Result<()> {
1761 self.0.flush()
1762 }
1763 }
1764
1765 let mut ctx = p!(SslContext::new(SslProtocolSide::CLIENT, SslConnectionType::STREAM));
1766 p!(ctx.set_peer_domain_name("google.com"));
1767 let stream = p!(TcpStream::connect("google.com:443"));
1768 let _ = ctx.handshake(ExplodingStream(stream));
1769 }
1770
1771 #[test]
1772 fn zero_length_buffers() {
1773 let mut ctx = p!(SslContext::new(SslProtocolSide::CLIENT, SslConnectionType::STREAM));
1774 p!(ctx.set_peer_domain_name("google.com"));
1775 let stream = p!(TcpStream::connect("google.com:443"));
1776 let mut stream = ctx.handshake(stream).unwrap();
1777 assert_eq!(stream.write(b"").unwrap(), 0);
1778 assert_eq!(stream.read(&mut []).unwrap(), 0);
1779 }
1780}