1#![allow(clippy::result_large_err)]
75#[allow(unused_imports)]
76use core_foundation::array::{CFArray, CFArrayRef};
77use core_foundation::base::{Boolean, TCFType};
78#[cfg(feature = "alpn")]
79use core_foundation::string::CFString;
80use core_foundation::{declare_TCFType, impl_TCFType};
81use core_foundation_sys::base::{kCFAllocatorDefault, OSStatus};
82use std::os::raw::c_void;
83
84#[allow(unused_imports)]
85use security_framework_sys::base::{
86 errSecBadReq, errSecIO, errSecNotTrusted, errSecSuccess, errSecTrustSettingDeny,
87 errSecUnimplemented,
88};
89
90use security_framework_sys::secure_transport::*;
91use std::any::Any;
92use std::cmp;
93use std::fmt;
94use std::io;
95use std::io::prelude::*;
96use std::marker::PhantomData;
97use std::panic::{self, AssertUnwindSafe};
98use std::ptr;
99use std::result;
100use std::slice;
101
102use crate::base::{Error, Result};
103use crate::certificate::SecCertificate;
104use crate::cipher_suite::CipherSuite;
105use crate::cvt;
106use crate::identity::SecIdentity;
107use crate::import_export::Pkcs12ImportOptions;
108use crate::policy::SecPolicy;
109use crate::trust::SecTrust;
110use security_framework_sys::base::errSecParam;
111
112#[derive(Debug, Copy, Clone, PartialEq, Eq)]
114pub struct SslProtocolSide(SSLProtocolSide);
115
116impl SslProtocolSide {
117 pub const CLIENT: Self = Self(kSSLClientSide);
119 pub const SERVER: Self = Self(kSSLServerSide);
121}
122
123#[derive(Debug, Copy, Clone)]
125pub struct SslConnectionType(SSLConnectionType);
126
127impl SslConnectionType {
128 pub const DATAGRAM: Self = Self(kSSLDatagramType);
130 pub const STREAM: Self = Self(kSSLStreamType);
132}
133
134#[derive(Debug)]
136pub enum HandshakeError<S> {
137 Failure(Error),
139 Interrupted(MidHandshakeSslStream<S>),
141}
142
143impl<S> From<Error> for HandshakeError<S> {
144 #[inline(always)]
145 fn from(err: Error) -> Self {
146 Self::Failure(err)
147 }
148}
149
150#[derive(Debug)]
152pub enum ClientHandshakeError<S> {
153 Failure(Error),
155 Interrupted(MidHandshakeClientBuilder<S>),
157}
158
159impl<S> From<Error> for ClientHandshakeError<S> {
160 #[inline(always)]
161 fn from(err: Error) -> Self {
162 Self::Failure(err)
163 }
164}
165
166#[derive(Debug)]
168pub struct MidHandshakeSslStream<S> {
169 stream: SslStream<S>,
170 error: Error,
171}
172
173impl<S> MidHandshakeSslStream<S> {
174 #[inline(always)]
176 #[must_use]
177 pub fn get_ref(&self) -> &S {
178 self.stream.get_ref()
179 }
180
181 #[inline(always)]
183 pub fn get_mut(&mut self) -> &mut S {
184 self.stream.get_mut()
185 }
186
187 #[inline(always)]
189 #[must_use]
190 pub fn context(&self) -> &SslContext {
191 self.stream.context()
192 }
193
194 #[inline(always)]
196 pub fn context_mut(&mut self) -> &mut SslContext {
197 self.stream.context_mut()
198 }
199
200 #[inline(always)]
203 #[must_use]
204 pub fn server_auth_completed(&self) -> bool {
205 self.error.code() == errSSLPeerAuthCompleted
206 }
207
208 #[inline(always)]
211 #[must_use]
212 pub fn client_cert_requested(&self) -> bool {
213 self.error.code() == errSSLClientCertRequested
214 }
215
216 #[inline(always)]
219 #[must_use]
220 pub fn would_block(&self) -> bool {
221 self.error.code() == errSSLWouldBlock
222 }
223
224 #[inline(always)]
226 #[must_use]
227 pub const fn error(&self) -> &Error {
228 &self.error
229 }
230
231 #[inline(always)]
233 pub fn handshake(self) -> result::Result<SslStream<S>, HandshakeError<S>> {
234 self.stream.handshake()
235 }
236}
237
238#[derive(Debug)]
240pub struct MidHandshakeClientBuilder<S> {
241 stream: MidHandshakeSslStream<S>,
242 domain: Option<String>,
243 certs: Vec<SecCertificate>,
244 trust_certs_only: bool,
245 danger_accept_invalid_certs: bool,
246}
247
248impl<S> MidHandshakeClientBuilder<S> {
249 #[inline(always)]
251 #[must_use]
252 pub fn get_ref(&self) -> &S {
253 self.stream.get_ref()
254 }
255
256 #[inline(always)]
258 pub fn get_mut(&mut self) -> &mut S {
259 self.stream.get_mut()
260 }
261
262 #[inline(always)]
264 #[must_use]
265 pub fn error(&self) -> &Error {
266 self.stream.error()
267 }
268
269 pub fn handshake(self) -> result::Result<SslStream<S>, ClientHandshakeError<S>> {
271 let Self {
272 stream,
273 domain,
274 certs,
275 trust_certs_only,
276 danger_accept_invalid_certs,
277 } = self;
278
279 let mut result = stream.handshake();
280 loop {
281 let stream = match result {
282 Ok(stream) => return Ok(stream),
283 Err(HandshakeError::Interrupted(stream)) => stream,
284 Err(HandshakeError::Failure(err)) => return Err(ClientHandshakeError::Failure(err)),
285 };
286
287 if stream.would_block() {
288 let ret = Self {
289 stream,
290 domain,
291 certs,
292 trust_certs_only,
293 danger_accept_invalid_certs,
294 };
295 return Err(ClientHandshakeError::Interrupted(ret));
296 }
297
298 if stream.server_auth_completed() {
299 if danger_accept_invalid_certs {
300 result = stream.handshake();
301 continue;
302 }
303 let Some(mut trust) = stream.context().peer_trust2()? else {
304 result = stream.handshake();
305 continue;
306 };
307 trust.set_anchor_certificates(&certs)?;
308 trust.set_trust_anchor_certificates_only(self.trust_certs_only)?;
309 let policy = SecPolicy::create_ssl(SslProtocolSide::SERVER, domain.as_deref());
310 trust.set_policy(&policy)?;
311 trust.evaluate_with_error().map_err(|error| {
312 #[cfg(feature = "log")]
313 log::warn!("SecTrustEvaluateWithError: {error}");
314 Error::from_code(error.code() as _)
315 })?;
316 result = stream.handshake();
317 continue;
318 }
319
320 let err = Error::from_code(stream.error().code());
321 return Err(ClientHandshakeError::Failure(err));
322 }
323 }
324}
325
326#[derive(Debug, PartialEq, Eq)]
328pub struct SessionState(SSLSessionState);
329
330impl SessionState {
331 pub const ABORTED: Self = Self(kSSLAborted);
333 pub const CLOSED: Self = Self(kSSLClosed);
335 pub const CONNECTED: Self = Self(kSSLConnected);
337 pub const HANDSHAKE: Self = Self(kSSLHandshake);
339 pub const IDLE: Self = Self(kSSLIdle);
341}
342
343#[derive(Debug, Copy, Clone, PartialEq, Eq)]
345pub struct SslAuthenticate(SSLAuthenticate);
346
347impl SslAuthenticate {
348 pub const ALWAYS: Self = Self(kAlwaysAuthenticate);
350 pub const NEVER: Self = Self(kNeverAuthenticate);
352 pub const TRY: Self = Self(kTryAuthenticate);
354}
355
356#[derive(Debug, Copy, Clone, PartialEq, Eq)]
358pub struct SslClientCertificateState(SSLClientCertificateState);
359
360impl SslClientCertificateState {
361 pub const NONE: Self = Self(kSSLClientCertNone);
363 pub const REJECTED: Self = Self(kSSLClientCertRejected);
365 pub const REQUESTED: Self = Self(kSSLClientCertRequested);
367 pub const SENT: Self = Self(kSSLClientCertSent);
369}
370
371#[derive(Debug, Copy, Clone, PartialEq, Eq)]
373pub struct SslProtocol(SSLProtocol);
374
375impl SslProtocol {
376 pub const ALL: Self = Self(kSSLProtocolAll);
378 pub const DTLS1: Self = Self(kDTLSProtocol1);
380 pub const SSL2: Self = Self(kSSLProtocol2);
382 pub const SSL3: Self = Self(kSSLProtocol3);
385 pub const SSL3_ONLY: Self = Self(kSSLProtocol3Only);
387 pub const TLS1: Self = Self(kTLSProtocol1);
390 pub const TLS11: Self = Self(kTLSProtocol11);
393 pub const TLS12: Self = Self(kTLSProtocol12);
396 pub const TLS13: Self = Self(kTLSProtocol13);
399 pub const TLS1_ONLY: Self = Self(kTLSProtocol1Only);
401 pub const UNKNOWN: Self = Self(kSSLProtocolUnknown);
403}
404
405declare_TCFType! {
406 SslContext, SSLContextRef
408}
409
410impl_TCFType!(SslContext, SSLContextRef, SSLContextGetTypeID);
411
412impl fmt::Debug for SslContext {
413 #[cold]
414 fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
415 let mut builder = fmt.debug_struct("SslContext");
416 if let Ok(state) = self.state() {
417 builder.field("state", &state);
418 }
419 builder.finish()
420 }
421}
422
423unsafe impl Sync for SslContext {}
424unsafe impl Send for SslContext {}
425
426impl SslContext {
427 pub(crate) fn as_inner(&self) -> SSLContextRef {
428 self.0
429 }
430}
431
432macro_rules! impl_options {
433 ($($(#[$a:meta])* const $opt:ident: $get:ident & $set:ident,)*) => {
434 $(
435 #[allow(deprecated)]
436 $(#[$a])*
437 #[inline(always)]
438 pub fn $set(&mut self, value: bool) -> Result<()> {
439 unsafe { cvt(SSLSetSessionOption(self.0, $opt, Boolean::from(value))) }
440 }
441
442 #[allow(deprecated)]
443 $(#[$a])*
444 #[inline]
445 pub fn $get(&self) -> Result<bool> {
446 let mut value = 0;
447 unsafe { cvt(SSLGetSessionOption(self.0, $opt, &mut value))?; }
448 Ok(value != 0)
449 }
450 )*
451 }
452}
453
454impl SslContext {
455 #[inline]
458 pub fn new(side: SslProtocolSide, type_: SslConnectionType) -> Result<Self> {
459 unsafe {
460 let ctx = SSLCreateContext(kCFAllocatorDefault, side.0, type_.0);
461 Ok(Self(ctx))
462 }
463 }
464
465 #[inline]
474 pub fn set_peer_domain_name(&mut self, peer_name: &str) -> Result<()> {
475 unsafe {
476 cvt(SSLSetPeerDomainName(self.0, peer_name.as_ptr().cast(), peer_name.len()))
478 }
479 }
480
481 pub fn peer_domain_name(&self) -> Result<String> {
483 unsafe {
484 let mut len = 0;
485 cvt(SSLGetPeerDomainNameLength(self.0, &mut len))?;
486 let mut buf = vec![0; len];
487 cvt(SSLGetPeerDomainName(self.0, buf.as_mut_ptr().cast(), &mut len))?;
488 Ok(String::from_utf8(buf).unwrap())
489 }
490 }
491
492 pub fn set_certificate(
500 &mut self,
501 identity: &SecIdentity,
502 certs: &[SecCertificate],
503 ) -> Result<()> {
504 let mut arr = vec![identity.as_CFType()];
505 arr.extend(certs.iter().map(|c| c.as_CFType()));
506 let certs = CFArray::from_CFTypes(&arr);
507
508 unsafe { cvt(SSLSetCertificate(self.0, certs.as_concrete_TypeRef())) }
509 }
510
511 #[inline]
518 pub fn set_peer_id(&mut self, peer_id: &[u8]) -> Result<()> {
519 unsafe { cvt(SSLSetPeerID(self.0, peer_id.as_ptr().cast(), peer_id.len())) }
520 }
521
522 pub fn peer_id(&self) -> Result<Option<&[u8]>> {
524 unsafe {
525 let mut ptr = ptr::null();
526 let mut len = 0;
527 cvt(SSLGetPeerID(self.0, &mut ptr, &mut len))?;
528 if ptr.is_null() {
529 Ok(None)
530 } else {
531 Ok(Some(slice::from_raw_parts(ptr.cast(), len)))
532 }
533 }
534 }
535
536 pub fn supported_ciphers(&self) -> Result<Vec<CipherSuite>> {
538 unsafe {
539 let mut num_ciphers = 0;
540 cvt(SSLGetNumberSupportedCiphers(self.0, &mut num_ciphers))?;
541 let mut ciphers = vec![0; num_ciphers];
542 cvt(SSLGetSupportedCiphers(
543 self.0,
544 ciphers.as_mut_ptr(),
545 &mut num_ciphers,
546 ))?;
547 Ok(ciphers.iter().map(|c| CipherSuite::from_raw(*c)).collect())
548 }
549 }
550
551 pub fn enabled_ciphers(&self) -> Result<Vec<CipherSuite>> {
554 unsafe {
555 let mut num_ciphers = 0;
556 cvt(SSLGetNumberEnabledCiphers(self.0, &mut num_ciphers))?;
557 let mut ciphers = vec![0; num_ciphers];
558 cvt(SSLGetEnabledCiphers(
559 self.0,
560 ciphers.as_mut_ptr(),
561 &mut num_ciphers,
562 ))?;
563 Ok(ciphers.iter().map(|c| CipherSuite::from_raw(*c)).collect())
564 }
565 }
566
567 pub fn set_enabled_ciphers(&mut self, ciphers: &[CipherSuite]) -> Result<()> {
569 let ciphers = ciphers.iter().map(|c| c.to_raw()).collect::<Vec<_>>();
570 unsafe {
571 cvt(SSLSetEnabledCiphers(
572 self.0,
573 ciphers.as_ptr(),
574 ciphers.len(),
575 ))
576 }
577 }
578
579 #[inline]
581 pub fn negotiated_cipher(&self) -> Result<CipherSuite> {
582 unsafe {
583 let mut cipher = 0;
584 cvt(SSLGetNegotiatedCipher(self.0, &mut cipher))?;
585 Ok(CipherSuite::from_raw(cipher))
586 }
587 }
588
589 #[inline]
593 pub fn set_client_side_authenticate(&mut self, auth: SslAuthenticate) -> Result<()> {
594 unsafe { cvt(SSLSetClientSideAuthenticate(self.0, auth.0)) }
595 }
596
597 #[inline]
599 pub fn client_certificate_state(&self) -> Result<SslClientCertificateState> {
600 let mut state = 0;
601
602 unsafe {
603 cvt(SSLGetClientCertificateState(self.0, &mut state))?;
604 }
605 Ok(SslClientCertificateState(state))
606 }
607
608 pub fn peer_trust2(&self) -> Result<Option<SecTrust>> {
613 if self.state()? == SessionState::IDLE {
616 return Err(Error::from_code(errSecBadReq));
617 }
618
619 unsafe {
620 let mut trust = ptr::null_mut();
621 cvt(SSLCopyPeerTrust(self.0, &mut trust))?;
622 if trust.is_null() {
623 Ok(None)
624 } else {
625 Ok(Some(SecTrust::wrap_under_create_rule(trust)))
626 }
627 }
628 }
629
630 #[inline]
632 pub fn state(&self) -> Result<SessionState> {
633 unsafe {
634 let mut state = 0;
635 cvt(SSLGetSessionState(self.0, &mut state))?;
636 Ok(SessionState(state))
637 }
638 }
639
640 #[inline]
642 pub fn negotiated_protocol_version(&self) -> Result<SslProtocol> {
643 unsafe {
644 let mut version = 0;
645 cvt(SSLGetNegotiatedProtocolVersion(self.0, &mut version))?;
646 Ok(SslProtocol(version))
647 }
648 }
649
650 #[inline]
652 pub fn protocol_version_max(&self) -> Result<SslProtocol> {
653 unsafe {
654 let mut version = 0;
655 cvt(SSLGetProtocolVersionMax(self.0, &mut version))?;
656 Ok(SslProtocol(version))
657 }
658 }
659
660 #[inline]
662 pub fn set_protocol_version_max(&mut self, max_version: SslProtocol) -> Result<()> {
663 unsafe { cvt(SSLSetProtocolVersionMax(self.0, max_version.0)) }
664 }
665
666 #[inline]
668 pub fn protocol_version_min(&self) -> Result<SslProtocol> {
669 unsafe {
670 let mut version = 0;
671 cvt(SSLGetProtocolVersionMin(self.0, &mut version))?;
672 Ok(SslProtocol(version))
673 }
674 }
675
676 #[inline]
678 pub fn set_protocol_version_min(&mut self, min_version: SslProtocol) -> Result<()> {
679 unsafe { cvt(SSLSetProtocolVersionMin(self.0, min_version.0)) }
680 }
681
682 #[cfg(feature = "alpn")]
684 pub fn alpn_protocols(&self) -> Result<Vec<String>> {
685 let mut array: CFArrayRef = ptr::null();
686 unsafe {
687 #[cfg(feature = "OSX_10_13")]
688 {
689 cvt(SSLCopyALPNProtocols(self.0, &mut array))?;
690 }
691
692 #[cfg(not(feature = "OSX_10_13"))]
693 {
694 dlsym! { fn SSLCopyALPNProtocols(SSLContextRef, *mut CFArrayRef) -> OSStatus }
695 if let Some(f) = SSLCopyALPNProtocols.get() {
696 cvt(f(self.0, &mut array))?;
697 } else {
698 return Err(Error::from_code(errSecUnimplemented));
699 }
700 }
701
702 if array.is_null() {
703 return Ok(vec![]);
704 }
705
706 let array = CFArray::<CFString>::wrap_under_create_rule(array);
707 Ok(array.into_iter().map(|p| p.to_string()).collect())
708 }
709 }
710
711 #[cfg(feature = "alpn")]
715 pub fn set_alpn_protocols(&mut self, protocols: &[&str]) -> Result<()> {
716 let protocols = CFArray::from_CFTypes(
720 &protocols
721 .iter()
722 .map(|proto| CFString::new(proto))
723 .collect::<Vec<_>>(),
724 );
725
726 #[cfg(feature = "OSX_10_13")]
727 {
728 unsafe { cvt(SSLSetALPNProtocols(self.0, protocols.as_concrete_TypeRef())) }
729 }
730 #[cfg(not(feature = "OSX_10_13"))]
731 {
732 dlsym! { fn SSLSetALPNProtocols(SSLContextRef, CFArrayRef) -> OSStatus }
733 if let Some(f) = SSLSetALPNProtocols.get() {
734 unsafe { cvt(f(self.0, protocols.as_concrete_TypeRef())) }
735 } else {
736 Err(Error::from_code(errSecUnimplemented))
737 }
738 }
739 }
740
741 #[cfg(feature = "session-tickets")]
749 pub fn set_session_tickets_enabled(&mut self, enabled: bool) -> Result<()> {
750 #[cfg(feature = "OSX_10_13")]
751 {
752 unsafe { cvt(SSLSetSessionTicketsEnabled(self.0, Boolean::from(enabled))) }
753 }
754 #[cfg(not(feature = "OSX_10_13"))]
755 {
756 dlsym! { fn SSLSetSessionTicketsEnabled(SSLContextRef, Boolean) -> OSStatus }
757 if let Some(f) = SSLSetSessionTicketsEnabled.get() {
758 unsafe { cvt(f(self.0, Boolean::from(enabled))) }
759 } else {
760 Err(Error::from_code(errSecUnimplemented))
761 }
762 }
763 }
764
765 #[inline]
768 pub fn buffered_read_size(&self) -> Result<usize> {
769 unsafe {
770 let mut size = 0;
771 cvt(SSLGetBufferedReadSize(self.0, &mut size))?;
772 Ok(size)
773 }
774 }
775
776 impl_options! {
777 const kSSLSessionOptionBreakOnServerAuth: break_on_server_auth & set_break_on_server_auth,
780 const kSSLSessionOptionBreakOnCertRequested: break_on_cert_requested & set_break_on_cert_requested,
783 const kSSLSessionOptionBreakOnClientAuth: break_on_client_auth & set_break_on_client_auth,
786 const kSSLSessionOptionFalseStart: false_start & set_false_start,
790 const kSSLSessionOptionSendOneByteRecord: send_one_byte_record & set_send_one_byte_record,
794 }
795
796 fn into_stream<S>(self, stream: S) -> Result<SslStream<S>>
797 where S: Read + Write {
798 unsafe {
799 let ret = SSLSetIOFuncs(self.0, read_func::<S>, write_func::<S>);
800 if ret != errSecSuccess {
801 return Err(Error::from_code(ret));
802 }
803
804 let stream = Connection { stream, err: None, panic: None };
805 let stream = Box::into_raw(Box::new(stream));
806 let ret = SSLSetConnection(self.0, stream.cast());
807 if ret != errSecSuccess {
808 let _conn = Box::from_raw(stream);
809 return Err(Error::from_code(ret));
810 }
811
812 Ok(SslStream { ctx: self, _m: PhantomData })
813 }
814 }
815
816 pub fn handshake<S>(self, stream: S) -> result::Result<SslStream<S>, HandshakeError<S>>
818 where
819 S: Read + Write,
820 {
821 self.into_stream(stream)
822 .map_err(HandshakeError::Failure)
823 .and_then(SslStream::handshake)
824 }
825}
826
827struct Connection<S> {
828 stream: S,
829 err: Option<io::Error>,
830 panic: Option<Box<dyn Any + Send>>,
831}
832
833#[cold]
835fn translate_err(e: &io::Error) -> OSStatus {
836 match e.kind() {
837 io::ErrorKind::NotFound => errSSLClosedGraceful,
838 io::ErrorKind::ConnectionReset => errSSLClosedAbort,
839 io::ErrorKind::WouldBlock |
840 io::ErrorKind::NotConnected => errSSLWouldBlock,
841 _ => errSecIO,
842 }
843}
844
845unsafe extern "C" fn read_func<S>(
846 connection: SSLConnectionRef,
847 data: *mut c_void,
848 data_length: *mut usize,
849) -> OSStatus
850where S: Read {
851 let conn: &mut Connection<S> = &mut *(connection as *mut _);
852 let data = slice::from_raw_parts_mut(data.cast::<u8>(), *data_length);
853 let mut start = 0;
854 let mut ret = errSecSuccess;
855
856 while start < data.len() {
857 match panic::catch_unwind(AssertUnwindSafe(|| conn.stream.read(&mut data[start..]))) {
858 Ok(Ok(0)) => {
859 ret = errSSLClosedNoNotify;
860 break;
861 },
862 Ok(Ok(len)) => start += len,
863 Ok(Err(e)) => {
864 ret = translate_err(&e);
865 conn.err = Some(e);
866 break;
867 },
868 Err(e) => {
869 ret = errSecIO;
870 conn.panic = Some(e);
871 break;
872 },
873 }
874 }
875
876 *data_length = start;
877 ret
878}
879
880unsafe extern "C" fn write_func<S>(
881 connection: SSLConnectionRef,
882 data: *const c_void,
883 data_length: *mut usize,
884) -> OSStatus
885where S: Write {
886 let conn: &mut Connection<S> = &mut *(connection as *mut _);
887 let data = slice::from_raw_parts(data as *mut u8, *data_length);
888 let mut start = 0;
889 let mut ret = errSecSuccess;
890
891 while start < data.len() {
892 match panic::catch_unwind(AssertUnwindSafe(|| conn.stream.write(&data[start..]))) {
893 Ok(Ok(0)) => {
894 ret = errSSLClosedNoNotify;
895 break;
896 },
897 Ok(Ok(len)) => start += len,
898 Ok(Err(e)) => {
899 ret = translate_err(&e);
900 conn.err = Some(e);
901 break;
902 },
903 Err(e) => {
904 ret = errSecIO;
905 conn.panic = Some(e);
906 break;
907 },
908 }
909 }
910
911 *data_length = start;
912 ret
913}
914
915pub struct SslStream<S> {
917 ctx: SslContext,
918 _m: PhantomData<S>,
919}
920
921impl<S: fmt::Debug> fmt::Debug for SslStream<S> {
922 #[cold]
923 fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
924 fmt.debug_struct("SslStream")
925 .field("context", &self.ctx)
926 .field("stream", self.get_ref())
927 .finish()
928 }
929}
930
931impl<S> Drop for SslStream<S> {
932 fn drop(&mut self) {
933 unsafe {
934 let mut conn = ptr::null();
935 let ret = SSLGetConnection(self.ctx.0, &mut conn);
936 assert!(ret == errSecSuccess);
937 let _ = Box::<Connection<S>>::from_raw(conn as *mut _);
938 }
939 }
940}
941
942impl<S> SslStream<S> {
943 fn handshake(mut self) -> result::Result<Self, HandshakeError<S>> {
944 match unsafe { SSLHandshake(self.ctx.0) } {
945 errSecSuccess => Ok(self),
946 reason @ errSSLPeerAuthCompleted
947 | reason @ errSSLClientCertRequested
948 | reason @ errSSLWouldBlock
949 | reason @ errSSLClientHelloReceived => {
950 Err(HandshakeError::Interrupted(MidHandshakeSslStream {
951 stream: self,
952 error: Error::from_code(reason),
953 }))
954 },
955 err => {
956 self.check_panic();
957 Err(HandshakeError::Failure(Error::from_code(err)))
958 },
959 }
960 }
961
962 #[inline(always)]
964 #[must_use]
965 pub fn get_ref(&self) -> &S {
966 &self.connection().stream
967 }
968
969 #[inline(always)]
971 pub fn get_mut(&mut self) -> &mut S {
972 &mut self.connection_mut().stream
973 }
974
975 #[inline(always)]
977 #[must_use]
978 pub fn context(&self) -> &SslContext {
979 &self.ctx
980 }
981
982 #[inline(always)]
984 pub fn context_mut(&mut self) -> &mut SslContext {
985 &mut self.ctx
986 }
987
988 pub fn close(&mut self) -> result::Result<(), io::Error> {
990 unsafe {
991 let ret = SSLClose(self.ctx.0);
992 if ret == errSecSuccess {
993 Ok(())
994 } else {
995 Err(self.get_error(ret))
996 }
997 }
998 }
999
1000 fn connection(&self) -> &Connection<S> {
1001 unsafe {
1002 let mut conn = ptr::null();
1003 let ret = SSLGetConnection(self.ctx.0, &mut conn);
1004 assert!(ret == errSecSuccess);
1005
1006 &mut *(conn as *mut Connection<S>)
1007 }
1008 }
1009
1010 fn connection_mut(&mut self) -> &mut Connection<S> {
1011 unsafe {
1012 let mut conn = ptr::null();
1013 let ret = SSLGetConnection(self.ctx.0, &mut conn);
1014 assert!(ret == errSecSuccess);
1015
1016 &mut *(conn as *mut Connection<S>)
1017 }
1018 }
1019
1020 #[cold]
1021 fn check_panic(&mut self) {
1022 let conn = self.connection_mut();
1023 if let Some(err) = conn.panic.take() {
1024 panic::resume_unwind(err);
1025 }
1026 }
1027
1028 #[cold]
1029 fn get_error(&mut self, ret: OSStatus) -> io::Error {
1030 self.check_panic();
1031
1032 if let Some(err) = self.connection_mut().err.take() {
1033 err
1034 } else {
1035 io::Error::new(io::ErrorKind::Other, Error::from_code(ret))
1036 }
1037 }
1038}
1039
1040impl<S: Read + Write> Read for SslStream<S> {
1041 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
1042 if buf.is_empty() {
1047 return Ok(0);
1048 }
1049
1050 let buffered = self.context().buffered_read_size().unwrap_or(0);
1055 let to_read = if buffered > 0 {
1056 cmp::min(buffered, buf.len())
1057 } else {
1058 buf.len()
1059 };
1060
1061 unsafe {
1062 let mut nread = 0;
1063 let ret = SSLRead(self.ctx.0, buf.as_mut_ptr().cast(), to_read, &mut nread);
1064 if nread > 0 {
1067 return Ok(nread);
1068 }
1069
1070 match ret {
1071 errSSLClosedGraceful | errSSLClosedAbort | errSSLClosedNoNotify => Ok(0),
1072 errSSLPeerAuthCompleted => self.read(buf),
1074 _ => Err(self.get_error(ret)),
1075 }
1076 }
1077 }
1078}
1079
1080impl<S: Read + Write> Write for SslStream<S> {
1081 fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
1082 if buf.is_empty() {
1084 return Ok(0);
1085 }
1086 unsafe {
1087 let mut nwritten = 0;
1088 let ret = SSLWrite(
1089 self.ctx.0,
1090 buf.as_ptr().cast(),
1091 buf.len(),
1092 &mut nwritten,
1093 );
1094 if nwritten > 0 {
1097 Ok(nwritten)
1098 } else {
1099 Err(self.get_error(ret))
1100 }
1101 }
1102 }
1103
1104 fn flush(&mut self) -> io::Result<()> {
1105 self.connection_mut().stream.flush()
1106 }
1107}
1108
1109#[derive(Debug)]
1111pub struct ClientBuilder {
1112 identity: Option<SecIdentity>,
1113 certs: Vec<SecCertificate>,
1114 chain: Vec<SecCertificate>,
1115 protocol_min: Option<SslProtocol>,
1116 protocol_max: Option<SslProtocol>,
1117 trust_certs_only: bool,
1118 use_sni: bool,
1119 danger_accept_invalid_certs: bool,
1120 danger_accept_invalid_hostnames: bool,
1121 whitelisted_ciphers: Vec<CipherSuite>,
1122 blacklisted_ciphers: Vec<CipherSuite>,
1123 #[cfg(feature = "alpn")]
1124 alpn: Option<Vec<String>>,
1125 #[cfg(feature = "session-tickets")]
1126 enable_session_tickets: bool,
1127}
1128
1129impl Default for ClientBuilder {
1130 #[inline(always)]
1131 fn default() -> Self {
1132 Self::new()
1133 }
1134}
1135
1136impl ClientBuilder {
1137 #[inline]
1139 #[must_use]
1140 pub fn new() -> Self {
1141 Self {
1142 identity: None,
1143 certs: Vec::new(),
1144 chain: Vec::new(),
1145 protocol_min: None,
1146 protocol_max: None,
1147 trust_certs_only: false,
1148 use_sni: true,
1149 danger_accept_invalid_certs: false,
1150 danger_accept_invalid_hostnames: false,
1151 whitelisted_ciphers: Vec::new(),
1152 blacklisted_ciphers: Vec::new(),
1153 #[cfg(feature = "alpn")]
1154 alpn: None,
1155 #[cfg(feature = "session-tickets")]
1156 enable_session_tickets: false,
1157 }
1158 }
1159
1160 #[inline]
1163 pub fn anchor_certificates(&mut self, certs: &[SecCertificate]) -> &mut Self {
1164 certs.clone_into(&mut self.certs);
1165 self
1166 }
1167
1168 #[inline]
1171 pub fn add_anchor_certificate(&mut self, certs: &SecCertificate) -> &mut Self {
1172 self.certs.push(certs.to_owned());
1173 self
1174 }
1175
1176 #[inline(always)]
1179 pub fn trust_anchor_certificates_only(&mut self, only: bool) -> &mut Self {
1180 self.trust_certs_only = only;
1181 self
1182 }
1183
1184 #[inline(always)]
1193 pub fn danger_accept_invalid_certs(&mut self, noverify: bool) -> &mut Self {
1194 self.danger_accept_invalid_certs = noverify;
1195 self
1196 }
1197
1198 #[inline(always)]
1200 pub fn use_sni(&mut self, use_sni: bool) -> &mut Self {
1201 self.use_sni = use_sni;
1202 self
1203 }
1204
1205 #[inline(always)]
1213 pub fn danger_accept_invalid_hostnames(
1214 &mut self,
1215 danger_accept_invalid_hostnames: bool,
1216 ) -> &mut Self {
1217 self.danger_accept_invalid_hostnames = danger_accept_invalid_hostnames;
1218 self
1219 }
1220
1221 pub fn whitelist_ciphers(&mut self, whitelisted_ciphers: &[CipherSuite]) -> &mut Self {
1223 whitelisted_ciphers.clone_into(&mut self.whitelisted_ciphers);
1224 self
1225 }
1226
1227 pub fn blacklist_ciphers(&mut self, blacklisted_ciphers: &[CipherSuite]) -> &mut Self {
1229 blacklisted_ciphers.clone_into(&mut self.blacklisted_ciphers);
1230 self
1231 }
1232
1233 pub fn identity(&mut self, identity: &SecIdentity, chain: &[SecCertificate]) -> &mut Self {
1235 self.identity = Some(identity.clone());
1236 chain.clone_into(&mut self.chain);
1237 self
1238 }
1239
1240 #[inline(always)]
1242 pub fn protocol_min(&mut self, min: SslProtocol) -> &mut Self {
1243 self.protocol_min = Some(min);
1244 self
1245 }
1246
1247 #[inline(always)]
1249 pub fn protocol_max(&mut self, max: SslProtocol) -> &mut Self {
1250 self.protocol_max = Some(max);
1251 self
1252 }
1253
1254 #[cfg(feature = "alpn")]
1256 pub fn alpn_protocols(&mut self, protocols: &[&str]) -> &mut Self {
1257 self.alpn = Some(protocols.iter().map(|s| (*s).to_string()).collect());
1258 self
1259 }
1260
1261 #[cfg(feature = "session-tickets")]
1265 #[inline(always)]
1266 pub fn enable_session_tickets(&mut self, enable: bool) -> &mut Self {
1267 self.enable_session_tickets = enable;
1268 self
1269 }
1270
1271 pub fn handshake<S>(
1275 &self,
1276 domain: &str,
1277 stream: S,
1278 ) -> result::Result<SslStream<S>, ClientHandshakeError<S>>
1279 where
1280 S: Read + Write,
1281 {
1282 let stream = MidHandshakeSslStream {
1285 stream: self.ctx_into_stream(domain, stream)?,
1286 error: Error::from(errSecSuccess),
1287 };
1288
1289 let certs = self.certs.clone();
1290 let stream = MidHandshakeClientBuilder {
1291 stream,
1292 domain: if self.danger_accept_invalid_hostnames {
1293 None
1294 } else {
1295 Some(domain.to_string())
1296 },
1297 certs,
1298 trust_certs_only: self.trust_certs_only,
1299 danger_accept_invalid_certs: self.danger_accept_invalid_certs,
1300 };
1301 stream.handshake()
1302 }
1303
1304 fn ctx_into_stream<S>(&self, domain: &str, stream: S) -> Result<SslStream<S>>
1305 where S: Read + Write {
1306 let mut ctx = SslContext::new(SslProtocolSide::CLIENT, SslConnectionType::STREAM)?;
1307
1308 if self.use_sni {
1309 ctx.set_peer_domain_name(domain)?;
1310 }
1311 if let Some(ref identity) = self.identity {
1312 ctx.set_certificate(identity, &self.chain)?;
1313 }
1314 #[cfg(feature = "alpn")]
1315 {
1316 if let Some(ref alpn) = self.alpn {
1317 ctx.set_alpn_protocols(&alpn.iter().map(|s| &**s).collect::<Vec<_>>())?;
1318 }
1319 }
1320 #[cfg(feature = "session-tickets")]
1321 {
1322 if self.enable_session_tickets {
1323 ctx.set_peer_id(domain.as_bytes())?;
1326 ctx.set_session_tickets_enabled(true)?;
1327 }
1328 }
1329 ctx.set_break_on_server_auth(true)?;
1330 self.configure_protocols(&mut ctx)?;
1331 self.configure_ciphers(&mut ctx)?;
1332
1333 ctx.into_stream(stream)
1334 }
1335
1336 fn configure_protocols(&self, ctx: &mut SslContext) -> Result<()> {
1337 if let Some(min) = self.protocol_min {
1338 ctx.set_protocol_version_min(min)?;
1339 }
1340 if let Some(max) = self.protocol_max {
1341 ctx.set_protocol_version_max(max)?;
1342 }
1343 Ok(())
1344 }
1345
1346 fn configure_ciphers(&self, ctx: &mut SslContext) -> Result<()> {
1347 let mut ciphers = if self.whitelisted_ciphers.is_empty() {
1348 ctx.enabled_ciphers()?
1349 } else {
1350 self.whitelisted_ciphers.clone()
1351 };
1352
1353 if !self.blacklisted_ciphers.is_empty() {
1354 ciphers.retain(|cipher| !self.blacklisted_ciphers.contains(cipher));
1355 }
1356
1357 ctx.set_enabled_ciphers(&ciphers)?;
1358 Ok(())
1359 }
1360}
1361
1362#[derive(Debug)]
1364pub struct ServerBuilder {
1365 identity: SecIdentity,
1366 certs: Vec<SecCertificate>,
1367}
1368
1369impl ServerBuilder {
1370 #[must_use]
1373 pub fn new(identity: &SecIdentity, certs: &[SecCertificate]) -> Self {
1374 Self {
1375 identity: identity.clone(),
1376 certs: certs.to_owned(),
1377 }
1378 }
1379
1380 pub fn from_pkcs12(pkcs12_der: &[u8], passphrase: &str) -> Result<Self> {
1387 let mut identities: Vec<(SecIdentity, Vec<SecCertificate>)> = Pkcs12ImportOptions::new()
1388 .passphrase(passphrase)
1389 .import(pkcs12_der)?
1390 .into_iter()
1391 .filter_map(|idendity| {
1392 let certs = idendity.cert_chain.unwrap_or_default();
1393 idendity.identity.map(|identity| (identity, certs))
1394 })
1395 .collect();
1396 if identities.len() == 1 {
1397 let (identity, certs) = identities.pop().unwrap();
1398 Ok(Self::new(&identity, &certs))
1399 } else {
1400 Err(Error::from_code(errSecParam))
1402 }
1403 }
1404
1405 pub fn new_ssl_context(&self) -> Result<SslContext> {
1407 let mut ctx = SslContext::new(SslProtocolSide::SERVER, SslConnectionType::STREAM)?;
1408 ctx.set_certificate(&self.identity, &self.certs)?;
1409 Ok(ctx)
1410 }
1411
1412 pub fn handshake<S>(&self, stream: S) -> Result<SslStream<S>>
1414 where S: Read + Write {
1415 match self.new_ssl_context()?.handshake(stream) {
1416 Ok(stream) => Ok(stream),
1417 Err(HandshakeError::Interrupted(stream)) => Err(*stream.error()),
1418 Err(HandshakeError::Failure(err)) => Err(err),
1419 }
1420 }
1421}
1422
1423#[cfg(test)]
1424mod test {
1425 use std::io::prelude::*;
1426 use std::net::TcpStream;
1427
1428 use super::*;
1429
1430 #[test]
1431 fn server_builder_from_pkcs12() {
1432 let pkcs12_der = include_bytes!("../test/server.p12");
1433 ServerBuilder::from_pkcs12(pkcs12_der, "password123").unwrap();
1434 }
1435
1436 #[test]
1437 fn connect() {
1438 let mut ctx = p!(SslContext::new(SslProtocolSide::CLIENT, SslConnectionType::STREAM));
1439 p!(ctx.set_peer_domain_name("google.com"));
1440 let stream = p!(TcpStream::connect("google.com:443"));
1441 p!(ctx.handshake(stream));
1442 }
1443
1444 #[test]
1445 fn connect_bad_domain() {
1446 let mut ctx = p!(SslContext::new(SslProtocolSide::CLIENT, SslConnectionType::STREAM));
1447 p!(ctx.set_peer_domain_name("foobar.com"));
1448 let stream = p!(TcpStream::connect("google.com:443"));
1449 ctx.handshake(stream).expect_err("expected failure");
1450 }
1451
1452 #[test]
1453 fn load_page() {
1454 let mut ctx = p!(SslContext::new(SslProtocolSide::CLIENT, SslConnectionType::STREAM));
1455 p!(ctx.set_peer_domain_name("google.com"));
1456 let stream = p!(TcpStream::connect("google.com:443"));
1457 let mut stream = p!(ctx.handshake(stream));
1458 p!(stream.write_all(b"GET / HTTP/1.0\r\n\r\n"));
1459 p!(stream.flush());
1460 let mut buf = vec![];
1461 p!(stream.read_to_end(&mut buf));
1462 println!("{}", String::from_utf8_lossy(&buf));
1463 }
1464
1465 #[test]
1466 fn client_no_session_ticket_resumption() {
1467 for _ in 0..2 {
1468 let stream = p!(TcpStream::connect("google.com:443"));
1469
1470 let stream = MidHandshakeSslStream {
1472 stream: ClientBuilder::new()
1473 .ctx_into_stream("google.com", stream)
1474 .unwrap(),
1475 error: Error::from(errSecSuccess),
1476 };
1477
1478 let mut result = stream.handshake();
1479
1480 if let Err(HandshakeError::Interrupted(stream)) = result {
1481 assert!(stream.server_auth_completed());
1482 result = stream.handshake();
1483 } else {
1484 panic!("Unexpectedly skipped server auth");
1485 }
1486
1487 assert!(result.is_ok());
1488 }
1489 }
1490
1491 #[test]
1492 #[cfg(feature = "session-tickets")]
1493 fn client_session_ticket_resumption() {
1494 for i in 0..2 {
1497 let stream = p!(TcpStream::connect("google.com:443"));
1498 let mut builder = ClientBuilder::new();
1499 builder.enable_session_tickets(true);
1500
1501 let stream = MidHandshakeSslStream {
1503 stream: builder.ctx_into_stream("google.com", stream).unwrap(),
1504 error: Error::from(errSecSuccess),
1505 };
1506
1507 let mut result = stream.handshake();
1508
1509 if let Err(HandshakeError::Interrupted(stream)) = result {
1510 assert!(stream.server_auth_completed());
1511 assert_eq!(i, 0, "Session ticket resumption did not work, server auth was not skipped");
1512 result = stream.handshake();
1513 } else {
1514 assert_eq!(i, 1, "Unexpectedly skipped server auth");
1515 }
1516
1517 assert!(result.is_ok());
1518 }
1519 }
1520
1521 #[test]
1522 #[cfg(feature = "alpn")]
1523 fn client_alpn_accept() {
1524 let mut ctx = p!(SslContext::new(SslProtocolSide::CLIENT, SslConnectionType::STREAM));
1525 p!(ctx.set_peer_domain_name("google.com"));
1526 p!(ctx.set_alpn_protocols(&vec!["h2"]));
1527 let stream = p!(TcpStream::connect("google.com:443"));
1528 let stream = ctx.handshake(stream).unwrap();
1529 assert_eq!(vec!["h2"], stream.context().alpn_protocols().unwrap());
1530 }
1531
1532 #[test]
1533 #[cfg(feature = "alpn")]
1534 fn client_alpn_reject() {
1535 let mut ctx = p!(SslContext::new(SslProtocolSide::CLIENT, SslConnectionType::STREAM));
1536 p!(ctx.set_peer_domain_name("google.com"));
1537 p!(ctx.set_alpn_protocols(&vec!["h2c"]));
1538 let stream = p!(TcpStream::connect("google.com:443"));
1539 let stream = ctx.handshake(stream).unwrap();
1540 assert!(stream.context().alpn_protocols().is_err());
1541 }
1542
1543 #[test]
1544 fn client_no_anchor_certs() {
1545 let stream = p!(TcpStream::connect("google.com:443"));
1546 assert!(ClientBuilder::new()
1547 .trust_anchor_certificates_only(true)
1548 .handshake("google.com", stream)
1549 .is_err());
1550 }
1551
1552 #[test]
1553 fn client_bad_domain() {
1554 let stream = p!(TcpStream::connect("google.com:443"));
1555 assert!(ClientBuilder::new()
1556 .handshake("foobar.com", stream)
1557 .is_err());
1558 }
1559
1560 #[test]
1561 fn client_bad_domain_ignored() {
1562 let stream = p!(TcpStream::connect("google.com:443"));
1563 ClientBuilder::new()
1564 .danger_accept_invalid_hostnames(true)
1565 .handshake("foobar.com", stream)
1566 .unwrap();
1567 }
1568
1569 #[test]
1570 fn connect_no_verify_ssl() {
1571 let stream = p!(TcpStream::connect("expired.badssl.com:443"));
1572 let mut builder = ClientBuilder::new();
1573 builder.danger_accept_invalid_certs(true);
1574 builder.handshake("expired.badssl.com", stream).unwrap();
1575 }
1576
1577 #[test]
1578 fn load_page_client() {
1579 let stream = p!(TcpStream::connect("google.com:443"));
1580 let mut stream = p!(ClientBuilder::new().handshake("google.com", stream));
1581 p!(stream.write_all(b"GET / HTTP/1.0\r\n\r\n"));
1582 p!(stream.flush());
1583 let mut buf = vec![];
1584 p!(stream.read_to_end(&mut buf));
1585 println!("{}", String::from_utf8_lossy(&buf));
1586 }
1587
1588 #[test]
1589 #[cfg_attr(any(target_os = "ios", target_os = "tvos", target_os = "watchos", target_os = "visionos"), ignore)] fn cipher_configuration() {
1591 let mut ctx = p!(SslContext::new(SslProtocolSide::SERVER, SslConnectionType::STREAM));
1592 let ciphers = p!(ctx.enabled_ciphers());
1593 let ciphers = ciphers
1594 .iter()
1595 .enumerate()
1596 .filter_map(|(i, c)| if i % 2 == 0 { Some(*c) } else { None })
1597 .collect::<Vec<_>>();
1598 p!(ctx.set_enabled_ciphers(&ciphers));
1599 assert_eq!(ciphers, p!(ctx.enabled_ciphers()));
1600 }
1601
1602 #[test]
1603 fn test_builder_whitelist_ciphers() {
1604 let stream = p!(TcpStream::connect("google.com:443"));
1605
1606 let ctx = p!(SslContext::new(SslProtocolSide::CLIENT, SslConnectionType::STREAM));
1607 assert!(p!(ctx.enabled_ciphers()).len() > 1);
1608
1609 let ciphers = p!(ctx.enabled_ciphers());
1610 let cipher = ciphers.first().unwrap();
1611 let stream = p!(ClientBuilder::new()
1612 .whitelist_ciphers(&[*cipher])
1613 .ctx_into_stream("google.com", stream));
1614
1615 assert_eq!(1, p!(stream.context().enabled_ciphers()).len());
1616 }
1617
1618 #[test]
1619 #[cfg_attr(any(target_os = "ios", target_os = "tvos", target_os = "watchos", target_os = "visionos"), ignore)] fn test_builder_blacklist_ciphers() {
1621 let stream = p!(TcpStream::connect("google.com:443"));
1622
1623 let ctx = p!(SslContext::new(SslProtocolSide::CLIENT, SslConnectionType::STREAM));
1624 let num = p!(ctx.enabled_ciphers()).len();
1625 assert!(num > 1);
1626
1627 let ciphers = p!(ctx.enabled_ciphers());
1628 let cipher = ciphers.first().unwrap();
1629 let stream = p!(ClientBuilder::new()
1630 .blacklist_ciphers(&[*cipher])
1631 .ctx_into_stream("google.com", stream));
1632
1633 assert_eq!(num - 1, p!(stream.context().enabled_ciphers()).len());
1634 }
1635
1636 #[test]
1637 fn idle_context_peer_trust() {
1638 let ctx = p!(SslContext::new(SslProtocolSide::SERVER, SslConnectionType::STREAM));
1639 assert!(ctx.peer_trust2().is_err());
1640 }
1641
1642 #[test]
1643 fn peer_id() {
1644 let mut ctx = p!(SslContext::new(SslProtocolSide::SERVER, SslConnectionType::STREAM));
1645 assert!(p!(ctx.peer_id()).is_none());
1646 p!(ctx.set_peer_id(b"foobar"));
1647 assert_eq!(p!(ctx.peer_id()), Some(&b"foobar"[..]));
1648 }
1649
1650 #[test]
1651 fn peer_domain_name() {
1652 let mut ctx = p!(SslContext::new(SslProtocolSide::CLIENT, SslConnectionType::STREAM));
1653 assert_eq!("", p!(ctx.peer_domain_name()));
1654 p!(ctx.set_peer_domain_name("foobar.com"));
1655 assert_eq!("foobar.com", p!(ctx.peer_domain_name()));
1656 }
1657
1658 #[test]
1659 #[should_panic(expected = "blammo")]
1660 fn write_panic() {
1661 struct ExplodingStream(TcpStream);
1662
1663 impl Read for ExplodingStream {
1664 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
1665 self.0.read(buf)
1666 }
1667 }
1668
1669 impl Write for ExplodingStream {
1670 fn write(&mut self, _: &[u8]) -> io::Result<usize> {
1671 panic!("blammo");
1672 }
1673
1674 fn flush(&mut self) -> io::Result<()> {
1675 self.0.flush()
1676 }
1677 }
1678
1679 let mut ctx = p!(SslContext::new(SslProtocolSide::CLIENT, SslConnectionType::STREAM));
1680 p!(ctx.set_peer_domain_name("google.com"));
1681 let stream = p!(TcpStream::connect("google.com:443"));
1682 let _ = ctx.handshake(ExplodingStream(stream));
1683 }
1684
1685 #[test]
1686 #[should_panic(expected = "blammo")]
1687 fn read_panic() {
1688 struct ExplodingStream(TcpStream);
1689
1690 impl Read for ExplodingStream {
1691 fn read(&mut self, _: &mut [u8]) -> io::Result<usize> {
1692 panic!("blammo");
1693 }
1694 }
1695
1696 impl Write for ExplodingStream {
1697 fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
1698 self.0.write(buf)
1699 }
1700
1701 fn flush(&mut self) -> io::Result<()> {
1702 self.0.flush()
1703 }
1704 }
1705
1706 let mut ctx = p!(SslContext::new(SslProtocolSide::CLIENT, SslConnectionType::STREAM));
1707 p!(ctx.set_peer_domain_name("google.com"));
1708 let stream = p!(TcpStream::connect("google.com:443"));
1709 let _ = ctx.handshake(ExplodingStream(stream));
1710 }
1711
1712 #[test]
1713 fn zero_length_buffers() {
1714 let mut ctx = p!(SslContext::new(SslProtocolSide::CLIENT, SslConnectionType::STREAM));
1715 p!(ctx.set_peer_domain_name("google.com"));
1716 let stream = p!(TcpStream::connect("google.com:443"));
1717 let mut stream = ctx.handshake(stream).unwrap();
1718 assert_eq!(stream.write(b"").unwrap(), 0);
1719 assert_eq!(stream.read(&mut []).unwrap(), 0);
1720 }
1721}