1#![allow(clippy::result_large_err)]
75#[allow(unused_imports)]
76use core_foundation::array::{CFArray, CFArrayRef};
77use core_foundation::base::{Boolean, TCFType};
78use core_foundation::string::CFString;
79use core_foundation::{declare_TCFType, impl_TCFType};
80use core_foundation_sys::base::{kCFAllocatorDefault, OSStatus};
81use std::os::raw::c_void;
82
83#[allow(unused_imports)]
84use security_framework_sys::base::{
85 errSecBadReq, errSecIO, errSecNotTrusted, errSecSuccess, errSecTrustSettingDeny,
86 errSecUnimplemented,
87};
88
89use security_framework_sys::secure_transport::*;
90use std::any::Any;
91use std::cmp;
92use std::fmt;
93use std::io;
94use std::io::prelude::*;
95use std::marker::PhantomData;
96use std::panic::{self, AssertUnwindSafe};
97use std::ptr;
98use std::result;
99use std::slice;
100
101use crate::base::{Error, Result};
102use crate::certificate::SecCertificate;
103use crate::cipher_suite::CipherSuite;
104use crate::cvt;
105use crate::identity::SecIdentity;
106use crate::import_export::Pkcs12ImportOptions;
107use crate::policy::SecPolicy;
108use crate::trust::SecTrust;
109use security_framework_sys::base::errSecParam;
110
111#[derive(Debug, Copy, Clone, PartialEq, Eq)]
113pub struct SslProtocolSide(SSLProtocolSide);
114
115impl SslProtocolSide {
116 pub const CLIENT: Self = Self(kSSLClientSide);
118 pub const SERVER: Self = Self(kSSLServerSide);
120}
121
122#[derive(Debug, Copy, Clone)]
124pub struct SslConnectionType(SSLConnectionType);
125
126impl SslConnectionType {
127 pub const DATAGRAM: Self = Self(kSSLDatagramType);
129 pub const STREAM: Self = Self(kSSLStreamType);
131}
132
133#[derive(Debug)]
135pub enum HandshakeError<S> {
136 Failure(Error),
138 Interrupted(MidHandshakeSslStream<S>),
140}
141
142impl<S> From<Error> for HandshakeError<S> {
143 #[inline(always)]
144 fn from(err: Error) -> Self {
145 Self::Failure(err)
146 }
147}
148
149#[derive(Debug)]
151pub enum ClientHandshakeError<S> {
152 Failure(Error),
154 Interrupted(MidHandshakeClientBuilder<S>),
156}
157
158impl<S> From<Error> for ClientHandshakeError<S> {
159 #[inline(always)]
160 fn from(err: Error) -> Self {
161 Self::Failure(err)
162 }
163}
164
165#[derive(Debug)]
167pub struct MidHandshakeSslStream<S> {
168 stream: SslStream<S>,
169 error: Error,
170}
171
172impl<S> MidHandshakeSslStream<S> {
173 #[inline(always)]
175 #[must_use]
176 pub fn get_ref(&self) -> &S {
177 self.stream.get_ref()
178 }
179
180 #[inline(always)]
182 pub fn get_mut(&mut self) -> &mut S {
183 self.stream.get_mut()
184 }
185
186 #[inline(always)]
188 #[must_use]
189 pub fn context(&self) -> &SslContext {
190 self.stream.context()
191 }
192
193 #[inline(always)]
195 pub fn context_mut(&mut self) -> &mut SslContext {
196 self.stream.context_mut()
197 }
198
199 #[inline(always)]
202 #[must_use]
203 pub fn server_auth_completed(&self) -> bool {
204 self.error.code() == errSSLPeerAuthCompleted
205 }
206
207 #[inline(always)]
210 #[must_use]
211 pub fn client_cert_requested(&self) -> bool {
212 self.error.code() == errSSLClientCertRequested
213 }
214
215 #[inline(always)]
218 #[must_use]
219 pub fn would_block(&self) -> bool {
220 self.error.code() == errSSLWouldBlock
221 }
222
223 #[inline(always)]
225 #[must_use]
226 pub const fn error(&self) -> &Error {
227 &self.error
228 }
229
230 #[inline(always)]
232 pub fn handshake(self) -> result::Result<SslStream<S>, HandshakeError<S>> {
233 self.stream.handshake()
234 }
235}
236
237#[derive(Debug)]
239pub struct MidHandshakeClientBuilder<S> {
240 stream: MidHandshakeSslStream<S>,
241 domain: Option<String>,
242 certs: Vec<SecCertificate>,
243 trust_certs_only: bool,
244 danger_accept_invalid_certs: bool,
245}
246
247impl<S> MidHandshakeClientBuilder<S> {
248 #[inline(always)]
250 #[must_use]
251 pub fn get_ref(&self) -> &S {
252 self.stream.get_ref()
253 }
254
255 #[inline(always)]
257 pub fn get_mut(&mut self) -> &mut S {
258 self.stream.get_mut()
259 }
260
261 #[inline(always)]
263 #[must_use]
264 pub fn error(&self) -> &Error {
265 self.stream.error()
266 }
267
268 pub fn handshake(self) -> result::Result<SslStream<S>, ClientHandshakeError<S>> {
270 let Self {
271 stream,
272 domain,
273 certs,
274 trust_certs_only,
275 danger_accept_invalid_certs,
276 } = self;
277
278 let mut result = stream.handshake();
279 loop {
280 let stream = match result {
281 Ok(stream) => return Ok(stream),
282 Err(HandshakeError::Interrupted(stream)) => stream,
283 Err(HandshakeError::Failure(err)) => return Err(ClientHandshakeError::Failure(err)),
284 };
285
286 if stream.would_block() {
287 let ret = Self {
288 stream,
289 domain,
290 certs,
291 trust_certs_only,
292 danger_accept_invalid_certs,
293 };
294 return Err(ClientHandshakeError::Interrupted(ret));
295 }
296
297 if stream.server_auth_completed() {
298 if danger_accept_invalid_certs {
299 result = stream.handshake();
300 continue;
301 }
302 let Some(mut trust) = stream.context().peer_trust2()? else {
303 result = stream.handshake();
304 continue;
305 };
306 trust.set_anchor_certificates(&certs)?;
307 trust.set_trust_anchor_certificates_only(self.trust_certs_only)?;
308 let policy = SecPolicy::create_ssl(SslProtocolSide::SERVER, domain.as_deref());
309 trust.set_policy(&policy)?;
310 trust.evaluate_with_error().map_err(|error| {
311 #[cfg(feature = "log")]
312 log::warn!("SecTrustEvaluateWithError: {error}");
313 Error::from_code(error.code() as _)
314 })?;
315 result = stream.handshake();
316 continue;
317 }
318
319 let err = Error::from_code(stream.error().code());
320 return Err(ClientHandshakeError::Failure(err));
321 }
322 }
323}
324
325#[derive(Debug, PartialEq, Eq)]
327pub struct SessionState(SSLSessionState);
328
329impl SessionState {
330 pub const ABORTED: Self = Self(kSSLAborted);
332 pub const CLOSED: Self = Self(kSSLClosed);
334 pub const CONNECTED: Self = Self(kSSLConnected);
336 pub const HANDSHAKE: Self = Self(kSSLHandshake);
338 pub const IDLE: Self = Self(kSSLIdle);
340}
341
342#[derive(Debug, Copy, Clone, PartialEq, Eq)]
344pub struct SslAuthenticate(SSLAuthenticate);
345
346impl SslAuthenticate {
347 pub const ALWAYS: Self = Self(kAlwaysAuthenticate);
349 pub const NEVER: Self = Self(kNeverAuthenticate);
351 pub const TRY: Self = Self(kTryAuthenticate);
353}
354
355#[derive(Debug, Copy, Clone, PartialEq, Eq)]
357pub struct SslClientCertificateState(SSLClientCertificateState);
358
359impl SslClientCertificateState {
360 pub const NONE: Self = Self(kSSLClientCertNone);
362 pub const REJECTED: Self = Self(kSSLClientCertRejected);
364 pub const REQUESTED: Self = Self(kSSLClientCertRequested);
366 pub const SENT: Self = Self(kSSLClientCertSent);
368}
369
370#[derive(Debug, Copy, Clone, PartialEq, Eq)]
372pub struct SslProtocol(SSLProtocol);
373
374impl SslProtocol {
375 pub const ALL: Self = Self(kSSLProtocolAll);
377 pub const DTLS1: Self = Self(kDTLSProtocol1);
379 pub const SSL2: Self = Self(kSSLProtocol2);
381 pub const SSL3: Self = Self(kSSLProtocol3);
384 pub const SSL3_ONLY: Self = Self(kSSLProtocol3Only);
386 pub const TLS1: Self = Self(kTLSProtocol1);
389 pub const TLS11: Self = Self(kTLSProtocol11);
392 pub const TLS12: Self = Self(kTLSProtocol12);
395 pub const TLS13: Self = Self(kTLSProtocol13);
398 pub const TLS1_ONLY: Self = Self(kTLSProtocol1Only);
400 pub const UNKNOWN: Self = Self(kSSLProtocolUnknown);
402}
403
404declare_TCFType! {
405 SslContext, SSLContextRef
407}
408
409impl_TCFType!(SslContext, SSLContextRef, SSLContextGetTypeID);
410
411impl fmt::Debug for SslContext {
412 #[cold]
413 fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
414 let mut builder = fmt.debug_struct("SslContext");
415 if let Ok(state) = self.state() {
416 builder.field("state", &state);
417 }
418 builder.finish()
419 }
420}
421
422unsafe impl Sync for SslContext {}
423unsafe impl Send for SslContext {}
424
425#[cfg(target_os = "macos")]
426impl SslContext {
427 pub(crate) fn as_inner(&self) -> SSLContextRef {
428 self.0
429 }
430}
431
432macro_rules! impl_options {
433 ($($(#[$a:meta])* const $opt:ident: $get:ident & $set:ident,)*) => {
434 $(
435 #[allow(deprecated)]
436 $(#[$a])*
437 #[inline(always)]
438 pub fn $set(&mut self, value: bool) -> Result<()> {
439 unsafe { cvt(SSLSetSessionOption(self.0, $opt, Boolean::from(value))) }
440 }
441
442 #[allow(deprecated)]
443 $(#[$a])*
444 #[inline]
445 pub fn $get(&self) -> Result<bool> {
446 let mut value = 0;
447 unsafe { cvt(SSLGetSessionOption(self.0, $opt, &mut value))?; }
448 Ok(value != 0)
449 }
450 )*
451 }
452}
453
454impl SslContext {
455 #[inline]
458 pub fn new(side: SslProtocolSide, type_: SslConnectionType) -> Result<Self> {
459 unsafe {
460 let ctx = SSLCreateContext(kCFAllocatorDefault, side.0, type_.0);
461 Ok(Self(ctx))
462 }
463 }
464
465 #[inline]
474 pub fn set_peer_domain_name(&mut self, peer_name: &str) -> Result<()> {
475 unsafe {
476 cvt(SSLSetPeerDomainName(self.0, peer_name.as_ptr().cast(), peer_name.len()))
478 }
479 }
480
481 pub fn peer_domain_name(&self) -> Result<String> {
483 unsafe {
484 let mut len = 0;
485 cvt(SSLGetPeerDomainNameLength(self.0, &mut len))?;
486 let mut buf = vec![0; len];
487 cvt(SSLGetPeerDomainName(self.0, buf.as_mut_ptr().cast(), &mut len))?;
488 String::from_utf8(buf).map_err(|_| Error::from_code(-1))
489 }
490 }
491
492 pub fn set_certificate(
500 &mut self,
501 identity: &SecIdentity,
502 certs: &[SecCertificate],
503 ) -> Result<()> {
504 let mut arr = vec![identity.as_CFType()];
505 arr.extend(certs.iter().map(|c| c.as_CFType()));
506 let certs = CFArray::from_CFTypes(&arr);
507
508 unsafe { cvt(SSLSetCertificate(self.0, certs.as_concrete_TypeRef())) }
509 }
510
511 #[inline]
518 pub fn set_peer_id(&mut self, peer_id: &[u8]) -> Result<()> {
519 unsafe { cvt(SSLSetPeerID(self.0, peer_id.as_ptr().cast(), peer_id.len())) }
520 }
521
522 pub fn peer_id(&self) -> Result<Option<&[u8]>> {
524 unsafe {
525 let mut ptr = ptr::null();
526 let mut len = 0;
527 cvt(SSLGetPeerID(self.0, &mut ptr, &mut len))?;
528 if ptr.is_null() {
529 Ok(None)
530 } else {
531 Ok(Some(slice::from_raw_parts(ptr.cast(), len)))
532 }
533 }
534 }
535
536 pub fn supported_ciphers(&self) -> Result<Vec<CipherSuite>> {
538 unsafe {
539 let mut num_ciphers = 0;
540 cvt(SSLGetNumberSupportedCiphers(self.0, &mut num_ciphers))?;
541 let mut ciphers = vec![0; num_ciphers];
542 cvt(SSLGetSupportedCiphers(
543 self.0,
544 ciphers.as_mut_ptr(),
545 &mut num_ciphers,
546 ))?;
547 Ok(ciphers.iter().map(|c| CipherSuite::from_raw(*c)).collect())
548 }
549 }
550
551 pub fn enabled_ciphers(&self) -> Result<Vec<CipherSuite>> {
554 unsafe {
555 let mut num_ciphers = 0;
556 cvt(SSLGetNumberEnabledCiphers(self.0, &mut num_ciphers))?;
557 let mut ciphers = vec![0; num_ciphers];
558 cvt(SSLGetEnabledCiphers(
559 self.0,
560 ciphers.as_mut_ptr(),
561 &mut num_ciphers,
562 ))?;
563 Ok(ciphers.iter().map(|c| CipherSuite::from_raw(*c)).collect())
564 }
565 }
566
567 pub fn set_enabled_ciphers(&mut self, ciphers: &[CipherSuite]) -> Result<()> {
569 let ciphers = ciphers.iter().map(|c| c.to_raw()).collect::<Vec<_>>();
570 unsafe {
571 cvt(SSLSetEnabledCiphers(
572 self.0,
573 ciphers.as_ptr(),
574 ciphers.len(),
575 ))
576 }
577 }
578
579 #[inline]
581 pub fn negotiated_cipher(&self) -> Result<CipherSuite> {
582 unsafe {
583 let mut cipher = 0;
584 cvt(SSLGetNegotiatedCipher(self.0, &mut cipher))?;
585 Ok(CipherSuite::from_raw(cipher))
586 }
587 }
588
589 #[inline]
593 pub fn set_client_side_authenticate(&mut self, auth: SslAuthenticate) -> Result<()> {
594 unsafe { cvt(SSLSetClientSideAuthenticate(self.0, auth.0)) }
595 }
596
597 #[inline]
599 pub fn client_certificate_state(&self) -> Result<SslClientCertificateState> {
600 let mut state = 0;
601
602 unsafe {
603 cvt(SSLGetClientCertificateState(self.0, &mut state))?;
604 }
605 Ok(SslClientCertificateState(state))
606 }
607
608 pub fn peer_trust2(&self) -> Result<Option<SecTrust>> {
613 if self.state()? == SessionState::IDLE {
616 return Err(Error::from_code(errSecBadReq));
617 }
618
619 unsafe {
620 let mut trust = ptr::null_mut();
621 cvt(SSLCopyPeerTrust(self.0, &mut trust))?;
622 if trust.is_null() {
623 Ok(None)
624 } else {
625 Ok(Some(SecTrust::wrap_under_create_rule(trust)))
626 }
627 }
628 }
629
630 #[inline]
632 pub fn state(&self) -> Result<SessionState> {
633 unsafe {
634 let mut state = 0;
635 cvt(SSLGetSessionState(self.0, &mut state))?;
636 Ok(SessionState(state))
637 }
638 }
639
640 #[inline]
642 pub fn negotiated_protocol_version(&self) -> Result<SslProtocol> {
643 unsafe {
644 let mut version = 0;
645 cvt(SSLGetNegotiatedProtocolVersion(self.0, &mut version))?;
646 Ok(SslProtocol(version))
647 }
648 }
649
650 #[inline]
652 pub fn protocol_version_max(&self) -> Result<SslProtocol> {
653 unsafe {
654 let mut version = 0;
655 cvt(SSLGetProtocolVersionMax(self.0, &mut version))?;
656 Ok(SslProtocol(version))
657 }
658 }
659
660 #[inline]
662 pub fn set_protocol_version_max(&mut self, max_version: SslProtocol) -> Result<()> {
663 unsafe { cvt(SSLSetProtocolVersionMax(self.0, max_version.0)) }
664 }
665
666 #[inline]
668 pub fn protocol_version_min(&self) -> Result<SslProtocol> {
669 unsafe {
670 let mut version = 0;
671 cvt(SSLGetProtocolVersionMin(self.0, &mut version))?;
672 Ok(SslProtocol(version))
673 }
674 }
675
676 #[inline]
678 pub fn set_protocol_version_min(&mut self, min_version: SslProtocol) -> Result<()> {
679 unsafe { cvt(SSLSetProtocolVersionMin(self.0, min_version.0)) }
680 }
681
682 pub fn alpn_protocols(&self) -> Result<Vec<String>> {
684 let mut array: CFArrayRef = ptr::null();
685 unsafe {
686 cvt(SSLCopyALPNProtocols(self.0, &mut array))?;
687
688 if array.is_null() {
689 return Ok(vec![]);
690 }
691
692 let array = CFArray::<CFString>::wrap_under_create_rule(array);
693 Ok(array.into_iter().map(|p| p.to_string()).collect())
694 }
695 }
696
697 pub fn set_alpn_protocols(&mut self, protocols: &[impl AsRef<str>]) -> Result<()> {
701 let protocols = CFArray::from_CFTypes(
705 &protocols
706 .iter()
707 .map(|proto| CFString::new(proto.as_ref()))
708 .collect::<Vec<_>>(),
709 );
710
711 unsafe { cvt(SSLSetALPNProtocols(self.0, protocols.as_concrete_TypeRef())) }
712 }
713
714 pub fn set_session_tickets_enabled(&mut self, enabled: bool) -> Result<()> {
722 unsafe { cvt(SSLSetSessionTicketsEnabled(self.0, Boolean::from(enabled))) }
723 }
724
725 #[inline]
728 pub fn buffered_read_size(&self) -> Result<usize> {
729 unsafe {
730 let mut size = 0;
731 cvt(SSLGetBufferedReadSize(self.0, &mut size))?;
732 Ok(size)
733 }
734 }
735
736 impl_options! {
737 const kSSLSessionOptionBreakOnServerAuth: break_on_server_auth & set_break_on_server_auth,
740 const kSSLSessionOptionBreakOnCertRequested: break_on_cert_requested & set_break_on_cert_requested,
743 const kSSLSessionOptionBreakOnClientAuth: break_on_client_auth & set_break_on_client_auth,
746 const kSSLSessionOptionFalseStart: false_start & set_false_start,
750 const kSSLSessionOptionSendOneByteRecord: send_one_byte_record & set_send_one_byte_record,
754 }
755
756 fn into_stream<S>(self, stream: S) -> Result<SslStream<S>>
757 where S: Read + Write {
758 unsafe {
759 let ret = SSLSetIOFuncs(self.0, read_func::<S>, write_func::<S>);
760 if ret != errSecSuccess {
761 return Err(Error::from_code(ret));
762 }
763
764 let stream = Connection { stream, err: None, panic: None };
765 let stream = Box::into_raw(Box::new(stream));
766 let ret = SSLSetConnection(self.0, stream.cast());
767 if ret != errSecSuccess {
768 let _conn = Box::from_raw(stream);
769 return Err(Error::from_code(ret));
770 }
771
772 Ok(SslStream { ctx: self, _m: PhantomData })
773 }
774 }
775
776 pub fn handshake<S>(self, stream: S) -> result::Result<SslStream<S>, HandshakeError<S>>
778 where
779 S: Read + Write,
780 {
781 self.into_stream(stream)
782 .map_err(HandshakeError::Failure)
783 .and_then(SslStream::handshake)
784 }
785}
786
787struct Connection<S> {
788 stream: S,
789 err: Option<io::Error>,
790 panic: Option<Box<dyn Any + Send>>,
791}
792
793#[cold]
795fn translate_err(e: &io::Error) -> OSStatus {
796 match e.kind() {
797 io::ErrorKind::NotFound => errSSLClosedGraceful,
798 io::ErrorKind::ConnectionReset => errSSLClosedAbort,
799 io::ErrorKind::WouldBlock |
800 io::ErrorKind::NotConnected => errSSLWouldBlock,
801 _ => errSecIO,
802 }
803}
804
805unsafe extern "C" fn read_func<S>(
806 connection: SSLConnectionRef,
807 data: *mut c_void,
808 data_length: *mut usize,
809) -> OSStatus
810where S: Read {
811 if data.is_null() || data_length.is_null() || connection.is_null() {
812 return errSecParam;
813 }
814
815 let conn: &mut Connection<S> = unsafe { &mut *(connection as *mut _) };
816 let data = unsafe { slice::from_raw_parts_mut(data.cast::<u8>(), *data_length) };
817 let mut read = 0;
818
819 let ret = panic::catch_unwind(AssertUnwindSafe(|| {
820 let mut data = data;
821 while !data.is_empty() {
822 match conn.stream.read(data) {
823 Ok(0) => return errSSLClosedNoNotify,
824 Ok(len) => {
825 let Some(rest) = data.get_mut(len..) else {
826 return errSecIO;
827 };
828 data = rest;
829 read += len;
830 },
831 Err(e) => {
832 let ret = translate_err(&e);
833 conn.err = Some(e);
834 return ret;
835 },
836 }
837 }
838 errSecSuccess
839 }))
840 .unwrap_or_else(|e| {
841 conn.panic = Some(e);
842 errSecIO
843 });
844
845 unsafe {
846 *data_length = read;
847 }
848 ret
849}
850
851unsafe extern "C" fn write_func<S>(
852 connection: SSLConnectionRef,
853 data: *const c_void,
854 data_length: *mut usize,
855) -> OSStatus
856where S: Write {
857 if data.is_null() || data_length.is_null() || connection.is_null() {
858 return errSecParam;
859 }
860
861 let conn: &mut Connection<S> = unsafe { &mut *(connection as *mut _) };
862 let mut written = 0;
863 let mut data = unsafe {
864 slice::from_raw_parts(data.cast::<u8>(), *data_length)
865 };
866
867 let ret = panic::catch_unwind(AssertUnwindSafe(|| {
868 while !data.is_empty() {
869 match conn.stream.write(data) {
870 Ok(0) => return errSSLClosedNoNotify,
871 Ok(len) => {
872 let Some(rest) = data.get(len..) else {
873 return errSecIO;
874 };
875 data = rest;
876 written += len;
877 },
878 Err(e) => {
879 let ret = translate_err(&e);
880 conn.err = Some(e);
881 return ret;
882 },
883 }
884 }
885 if let Err(e) = conn.stream.flush() {
889 let ret = translate_err(&e);
890 conn.err = Some(e);
891 return ret;
892 }
893 errSecSuccess
894 }))
895 .unwrap_or_else(|e| {
896 conn.panic = Some(e);
897 errSecIO
898 });
899
900 unsafe { *data_length = written };
901 ret
902}
903
904pub struct SslStream<S> {
906 ctx: SslContext,
907 _m: PhantomData<S>,
908}
909
910impl<S: fmt::Debug> fmt::Debug for SslStream<S> {
911 #[cold]
912 fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
913 fmt.debug_struct("SslStream")
914 .field("context", &self.ctx)
915 .field("stream", self.get_ref())
916 .finish()
917 }
918}
919
920impl<S> Drop for SslStream<S> {
921 fn drop(&mut self) {
922 unsafe {
923 let mut conn = ptr::null();
924 let ret = SSLGetConnection(self.ctx.0, &mut conn);
925 assert!(ret == errSecSuccess);
926 let _ = Box::<Connection<S>>::from_raw(conn as *mut _);
927 }
928 }
929}
930
931impl<S> SslStream<S> {
932 fn handshake(mut self) -> result::Result<Self, HandshakeError<S>> {
933 match unsafe { SSLHandshake(self.ctx.0) } {
934 errSecSuccess => Ok(self),
935 reason @ (errSSLPeerAuthCompleted
936 | errSSLClientCertRequested
937 | errSSLWouldBlock
938 | errSSLClientHelloReceived) => {
939 Err(HandshakeError::Interrupted(MidHandshakeSslStream {
940 stream: self,
941 error: Error::from_code(reason),
942 }))
943 },
944 err => {
945 self.check_panic();
946 Err(HandshakeError::Failure(Error::from_code(err)))
947 },
948 }
949 }
950
951 #[inline(always)]
953 #[must_use]
954 pub fn get_ref(&self) -> &S {
955 &self.connection().stream
956 }
957
958 #[inline(always)]
960 pub fn get_mut(&mut self) -> &mut S {
961 &mut self.connection_mut().stream
962 }
963
964 #[inline(always)]
966 #[must_use]
967 pub fn context(&self) -> &SslContext {
968 &self.ctx
969 }
970
971 #[inline(always)]
973 pub fn context_mut(&mut self) -> &mut SslContext {
974 &mut self.ctx
975 }
976
977 pub fn close(&mut self) -> result::Result<(), io::Error> {
979 unsafe {
980 let ret = SSLClose(self.ctx.0);
981 if ret == errSecSuccess {
982 Ok(())
983 } else {
984 Err(self.get_error(ret))
985 }
986 }
987 }
988
989 fn connection(&self) -> &Connection<S> {
990 unsafe {
991 let mut conn = ptr::null();
992 let ret = SSLGetConnection(self.ctx.0, &mut conn);
993 assert!(ret == errSecSuccess);
994
995 &mut *(conn as *mut Connection<S>)
996 }
997 }
998
999 fn connection_mut(&mut self) -> &mut Connection<S> {
1000 unsafe {
1001 let mut conn = ptr::null();
1002 let ret = SSLGetConnection(self.ctx.0, &mut conn);
1003 assert!(ret == errSecSuccess);
1004
1005 &mut *(conn as *mut Connection<S>)
1006 }
1007 }
1008
1009 #[cold]
1010 fn check_panic(&mut self) {
1011 let conn = self.connection_mut();
1012 if let Some(err) = conn.panic.take() {
1013 panic::resume_unwind(err);
1014 }
1015 }
1016
1017 #[cold]
1018 fn get_error(&mut self, ret: OSStatus) -> io::Error {
1019 self.check_panic();
1020
1021 if let Some(err) = self.connection_mut().err.take() {
1022 err
1023 } else {
1024 io::Error::new(io::ErrorKind::Other, Error::from_code(ret))
1025 }
1026 }
1027}
1028
1029impl<S: Read + Write> Read for SslStream<S> {
1030 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
1031 if buf.is_empty() {
1036 return Ok(0);
1037 }
1038
1039 let buffered = self.context().buffered_read_size().unwrap_or(0);
1044 let to_read = if buffered > 0 {
1045 cmp::min(buffered, buf.len())
1046 } else {
1047 buf.len()
1048 };
1049
1050 unsafe {
1051 let mut nread = 0;
1052 let ret = SSLRead(self.ctx.0, buf.as_mut_ptr().cast(), to_read, &mut nread);
1053 if nread > 0 {
1056 return Ok(nread);
1057 }
1058
1059 match ret {
1060 errSSLClosedGraceful | errSSLClosedAbort | errSSLClosedNoNotify => Ok(0),
1061 errSSLPeerAuthCompleted => self.read(buf),
1063 _ => Err(self.get_error(ret)),
1064 }
1065 }
1066 }
1067}
1068
1069impl<S: Read + Write> Write for SslStream<S> {
1070 fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
1071 if buf.is_empty() {
1073 return Ok(0);
1074 }
1075 unsafe {
1076 let mut nwritten = 0;
1077 let ret = SSLWrite(
1078 self.ctx.0,
1079 buf.as_ptr().cast(),
1080 buf.len(),
1081 &mut nwritten,
1082 );
1083 if nwritten > 0 {
1086 Ok(nwritten)
1087 } else {
1088 Err(self.get_error(ret))
1089 }
1090 }
1091 }
1092
1093 fn flush(&mut self) -> io::Result<()> {
1094 self.connection_mut().stream.flush()
1095 }
1096}
1097
1098#[derive(Debug)]
1100pub struct ClientBuilder {
1101 identity: Option<SecIdentity>,
1102 certs: Vec<SecCertificate>,
1103 chain: Vec<SecCertificate>,
1104 protocol_min: Option<SslProtocol>,
1105 protocol_max: Option<SslProtocol>,
1106 trust_certs_only: bool,
1107 use_sni: bool,
1108 danger_accept_invalid_certs: bool,
1109 danger_accept_invalid_hostnames: bool,
1110 whitelisted_ciphers: Vec<CipherSuite>,
1111 blacklisted_ciphers: Vec<CipherSuite>,
1112 alpn: Option<Vec<Box<str>>>,
1113 enable_session_tickets: bool,
1114}
1115
1116impl Default for ClientBuilder {
1117 #[inline(always)]
1118 fn default() -> Self {
1119 Self::new()
1120 }
1121}
1122
1123impl ClientBuilder {
1124 #[inline]
1126 #[must_use]
1127 pub fn new() -> Self {
1128 Self {
1129 identity: None,
1130 certs: Vec::new(),
1131 chain: Vec::new(),
1132 protocol_min: None,
1133 protocol_max: None,
1134 trust_certs_only: false,
1135 use_sni: true,
1136 danger_accept_invalid_certs: false,
1137 danger_accept_invalid_hostnames: false,
1138 whitelisted_ciphers: Vec::new(),
1139 blacklisted_ciphers: Vec::new(),
1140 alpn: None,
1141 enable_session_tickets: false,
1142 }
1143 }
1144
1145 #[inline]
1148 pub fn anchor_certificates(&mut self, certs: &[SecCertificate]) -> &mut Self {
1149 certs.clone_into(&mut self.certs);
1150 self
1151 }
1152
1153 #[inline]
1156 pub fn add_anchor_certificate(&mut self, certs: &SecCertificate) -> &mut Self {
1157 self.certs.push(certs.to_owned());
1158 self
1159 }
1160
1161 #[inline(always)]
1164 pub fn trust_anchor_certificates_only(&mut self, only: bool) -> &mut Self {
1165 self.trust_certs_only = only;
1166 self
1167 }
1168
1169 #[inline(always)]
1178 pub fn danger_accept_invalid_certs(&mut self, noverify: bool) -> &mut Self {
1179 self.danger_accept_invalid_certs = noverify;
1180 self
1181 }
1182
1183 #[inline(always)]
1185 pub fn use_sni(&mut self, use_sni: bool) -> &mut Self {
1186 self.use_sni = use_sni;
1187 self
1188 }
1189
1190 #[inline(always)]
1198 pub fn danger_accept_invalid_hostnames(
1199 &mut self,
1200 danger_accept_invalid_hostnames: bool,
1201 ) -> &mut Self {
1202 self.danger_accept_invalid_hostnames = danger_accept_invalid_hostnames;
1203 self
1204 }
1205
1206 pub fn whitelist_ciphers(&mut self, whitelisted_ciphers: &[CipherSuite]) -> &mut Self {
1208 whitelisted_ciphers.clone_into(&mut self.whitelisted_ciphers);
1209 self
1210 }
1211
1212 pub fn blacklist_ciphers(&mut self, blacklisted_ciphers: &[CipherSuite]) -> &mut Self {
1214 blacklisted_ciphers.clone_into(&mut self.blacklisted_ciphers);
1215 self
1216 }
1217
1218 pub fn identity(&mut self, identity: &SecIdentity, chain: &[SecCertificate]) -> &mut Self {
1220 self.identity = Some(identity.clone());
1221 chain.clone_into(&mut self.chain);
1222 self
1223 }
1224
1225 #[inline(always)]
1227 pub fn protocol_min(&mut self, min: SslProtocol) -> &mut Self {
1228 self.protocol_min = Some(min);
1229 self
1230 }
1231
1232 #[inline(always)]
1234 pub fn protocol_max(&mut self, max: SslProtocol) -> &mut Self {
1235 self.protocol_max = Some(max);
1236 self
1237 }
1238
1239 pub fn alpn_protocols(&mut self, protocols: &[&str]) -> &mut Self {
1241 self.alpn = Some(protocols.iter().copied().map(Box::from).collect());
1242 self
1243 }
1244
1245 #[inline(always)]
1249 pub fn enable_session_tickets(&mut self, enable: bool) -> &mut Self {
1250 self.enable_session_tickets = enable;
1251 self
1252 }
1253
1254 pub fn handshake<S>(
1258 &self,
1259 domain: &str,
1260 stream: S,
1261 ) -> result::Result<SslStream<S>, ClientHandshakeError<S>>
1262 where
1263 S: Read + Write,
1264 {
1265 let stream = MidHandshakeSslStream {
1268 stream: self.ctx_into_stream(domain, stream)?,
1269 error: Error::from(errSecSuccess),
1270 };
1271
1272 let certs = self.certs.clone();
1273 let stream = MidHandshakeClientBuilder {
1274 stream,
1275 domain: if self.danger_accept_invalid_hostnames {
1276 None
1277 } else {
1278 Some(domain.to_string())
1279 },
1280 certs,
1281 trust_certs_only: self.trust_certs_only,
1282 danger_accept_invalid_certs: self.danger_accept_invalid_certs,
1283 };
1284 stream.handshake()
1285 }
1286
1287 fn ctx_into_stream<S>(&self, domain: &str, stream: S) -> Result<SslStream<S>>
1288 where S: Read + Write {
1289 let mut ctx = SslContext::new(SslProtocolSide::CLIENT, SslConnectionType::STREAM)?;
1290
1291 if self.use_sni {
1292 ctx.set_peer_domain_name(domain)?;
1293 }
1294 if let Some(identity) = &self.identity {
1295 ctx.set_certificate(identity, &self.chain)?;
1296 }
1297 if let Some(alpn) = &self.alpn {
1298 ctx.set_alpn_protocols(alpn)?;
1299 }
1300 if self.enable_session_tickets {
1301 ctx.set_peer_id(domain.as_bytes())?;
1304 ctx.set_session_tickets_enabled(true)?;
1305 }
1306 ctx.set_break_on_server_auth(true)?;
1307 self.configure_protocols(&mut ctx)?;
1308 self.configure_ciphers(&mut ctx)?;
1309
1310 ctx.into_stream(stream)
1311 }
1312
1313 fn configure_protocols(&self, ctx: &mut SslContext) -> Result<()> {
1314 if let Some(min) = self.protocol_min {
1315 ctx.set_protocol_version_min(min)?;
1316 }
1317 if let Some(max) = self.protocol_max {
1318 ctx.set_protocol_version_max(max)?;
1319 }
1320 Ok(())
1321 }
1322
1323 fn configure_ciphers(&self, ctx: &mut SslContext) -> Result<()> {
1324 let mut ciphers = if self.whitelisted_ciphers.is_empty() {
1325 ctx.enabled_ciphers()?
1326 } else {
1327 self.whitelisted_ciphers.clone()
1328 };
1329
1330 if !self.blacklisted_ciphers.is_empty() {
1331 ciphers.retain(|cipher| !self.blacklisted_ciphers.contains(cipher));
1332 }
1333
1334 ctx.set_enabled_ciphers(&ciphers)?;
1335 Ok(())
1336 }
1337}
1338
1339#[derive(Debug)]
1341pub struct ServerBuilder {
1342 identity: SecIdentity,
1343 certs: Vec<SecCertificate>,
1344}
1345
1346impl ServerBuilder {
1347 #[must_use]
1350 pub fn new(identity: &SecIdentity, certs: &[SecCertificate]) -> Self {
1351 Self {
1352 identity: identity.clone(),
1353 certs: certs.to_owned(),
1354 }
1355 }
1356
1357 pub fn from_pkcs12(pkcs12_der: &[u8], passphrase: &str) -> Result<Self> {
1364 let mut identities: Vec<(SecIdentity, Vec<SecCertificate>)> = Pkcs12ImportOptions::new()
1365 .passphrase(passphrase)
1366 .import(pkcs12_der)?
1367 .into_iter()
1368 .filter_map(|idendity| {
1369 Some((idendity.identity?, idendity.cert_chain.unwrap_or_default()))
1370 })
1371 .take(2)
1372 .collect();
1373 if identities.len() == 1 {
1374 let (identity, certs) = identities.pop().unwrap();
1375 Ok(Self { identity, certs })
1376 } else {
1377 Err(Error::from_code(errSecParam))
1379 }
1380 }
1381
1382 pub fn new_ssl_context(&self) -> Result<SslContext> {
1384 let mut ctx = SslContext::new(SslProtocolSide::SERVER, SslConnectionType::STREAM)?;
1385 ctx.set_certificate(&self.identity, &self.certs)?;
1386 Ok(ctx)
1387 }
1388
1389 pub fn handshake<S>(&self, stream: S) -> Result<SslStream<S>>
1391 where S: Read + Write {
1392 match self.new_ssl_context()?.handshake(stream) {
1393 Ok(stream) => Ok(stream),
1394 Err(HandshakeError::Interrupted(stream)) => Err(*stream.error()),
1395 Err(HandshakeError::Failure(err)) => Err(err),
1396 }
1397 }
1398}
1399
1400#[cfg(test)]
1401mod test {
1402 use std::io::prelude::*;
1403 use std::net::TcpStream;
1404
1405 use super::*;
1406
1407 #[test]
1408 fn server_builder_from_pkcs12() {
1409 let pkcs12_der = include_bytes!("../test/server.p12");
1410 ServerBuilder::from_pkcs12(pkcs12_der, "password123").unwrap();
1411 }
1412
1413 #[test]
1414 fn connect() {
1415 let mut ctx = p!(SslContext::new(SslProtocolSide::CLIENT, SslConnectionType::STREAM));
1416 p!(ctx.set_peer_domain_name("google.com"));
1417 let stream = p!(TcpStream::connect("google.com:443"));
1418 p!(ctx.handshake(stream));
1419 }
1420
1421 #[test]
1422 fn connect_bad_domain() {
1423 let mut ctx = p!(SslContext::new(SslProtocolSide::CLIENT, SslConnectionType::STREAM));
1424 p!(ctx.set_peer_domain_name("foobar.com"));
1425 let stream = p!(TcpStream::connect("google.com:443"));
1426 ctx.handshake(stream).expect_err("expected failure");
1427 }
1428
1429 #[test]
1430 fn connect_buffered_stream() {
1431 use std::io::BufWriter;
1432
1433 #[derive(Debug)]
1435 struct BufferedTcpStream {
1436 reader: TcpStream,
1437 writer: BufWriter<TcpStream>,
1438 }
1439
1440 impl BufferedTcpStream {
1441 fn new(tcp: TcpStream) -> std::io::Result<Self> {
1442 Ok(Self {
1443 writer: BufWriter::with_capacity(500, tcp.try_clone()?),
1444 reader: tcp,
1445 })
1446 }
1447 }
1448
1449 impl Read for BufferedTcpStream {
1450 fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
1451 self.reader.read(buf)
1452 }
1453 }
1454
1455 impl Write for BufferedTcpStream {
1456 fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
1457 self.writer.write(buf)
1458 }
1459
1460 fn flush(&mut self) -> std::io::Result<()> {
1461 self.writer.flush()
1462 }
1463 }
1464
1465 let mut ctx = p!(SslContext::new(SslProtocolSide::CLIENT, SslConnectionType::STREAM));
1466 p!(ctx.set_peer_domain_name("google.com"));
1467 let stream = p!(TcpStream::connect("google.com:443"));
1468 let stream = p!(BufferedTcpStream::new(stream));
1469 p!(ctx.handshake(stream));
1470 }
1471
1472 #[test]
1473 fn load_page() {
1474 let mut ctx = p!(SslContext::new(SslProtocolSide::CLIENT, SslConnectionType::STREAM));
1475 p!(ctx.set_peer_domain_name("google.com"));
1476 let stream = p!(TcpStream::connect("google.com:443"));
1477 let mut stream = p!(ctx.handshake(stream));
1478 p!(stream.write_all(b"GET / HTTP/1.0\r\n\r\n"));
1479 p!(stream.flush());
1480 let mut buf = vec![];
1481 p!(stream.read_to_end(&mut buf));
1482 println!("{}", String::from_utf8_lossy(&buf));
1483 }
1484
1485 #[test]
1486 fn client_no_session_ticket_resumption() {
1487 for _ in 0..2 {
1488 let stream = p!(TcpStream::connect("google.com:443"));
1489
1490 let stream = MidHandshakeSslStream {
1492 stream: ClientBuilder::new()
1493 .ctx_into_stream("google.com", stream)
1494 .unwrap(),
1495 error: Error::from(errSecSuccess),
1496 };
1497
1498 let mut result = stream.handshake();
1499
1500 if let Err(HandshakeError::Interrupted(stream)) = result {
1501 assert!(stream.server_auth_completed());
1502 result = stream.handshake();
1503 } else {
1504 panic!("Unexpectedly skipped server auth");
1505 }
1506
1507 assert!(result.is_ok());
1508 }
1509 }
1510
1511 #[test]
1512 fn client_session_ticket_resumption() {
1513 for i in 0..2 {
1516 let stream = p!(TcpStream::connect("google.com:443"));
1517 let mut builder = ClientBuilder::new();
1518 builder.enable_session_tickets(true);
1519
1520 let stream = MidHandshakeSslStream {
1522 stream: builder.ctx_into_stream("google.com", stream).unwrap(),
1523 error: Error::from(errSecSuccess),
1524 };
1525
1526 let mut result = stream.handshake();
1527
1528 if let Err(HandshakeError::Interrupted(stream)) = result {
1529 assert!(stream.server_auth_completed());
1530 assert_eq!(i, 0, "Session ticket resumption did not work, server auth was not skipped");
1531 result = stream.handshake();
1532 } else {
1533 assert_eq!(i, 1, "Unexpectedly skipped server auth");
1534 }
1535
1536 assert!(result.is_ok());
1537 }
1538 }
1539
1540 #[test]
1541 fn client_alpn_accept() {
1542 let mut ctx = p!(SslContext::new(SslProtocolSide::CLIENT, SslConnectionType::STREAM));
1543 p!(ctx.set_peer_domain_name("google.com"));
1544 p!(ctx.set_alpn_protocols(&["h2"]));
1545 let stream = p!(TcpStream::connect("google.com:443"));
1546 let stream = ctx.handshake(stream).unwrap();
1547 assert_eq!(vec!["h2"], stream.context().alpn_protocols().unwrap());
1548 }
1549
1550 #[test]
1551 fn client_alpn_reject() {
1552 let mut ctx = p!(SslContext::new(SslProtocolSide::CLIENT, SslConnectionType::STREAM));
1553 p!(ctx.set_peer_domain_name("google.com"));
1554 p!(ctx.set_alpn_protocols(&["h2c"]));
1555 let stream = p!(TcpStream::connect("google.com:443"));
1556 let stream = ctx.handshake(stream).unwrap();
1557 assert!(stream.context().alpn_protocols().is_err());
1558 }
1559
1560 #[test]
1561 fn client_no_anchor_certs() {
1562 let stream = p!(TcpStream::connect("google.com:443"));
1563 assert!(ClientBuilder::new()
1564 .trust_anchor_certificates_only(true)
1565 .handshake("google.com", stream)
1566 .is_err());
1567 }
1568
1569 #[test]
1570 fn client_bad_domain() {
1571 let stream = p!(TcpStream::connect("google.com:443"));
1572 assert!(ClientBuilder::new()
1573 .handshake("foobar.com", stream)
1574 .is_err());
1575 }
1576
1577 #[test]
1578 fn client_bad_domain_ignored() {
1579 let stream = p!(TcpStream::connect("google.com:443"));
1580 ClientBuilder::new()
1581 .danger_accept_invalid_hostnames(true)
1582 .handshake("foobar.com", stream)
1583 .unwrap();
1584 }
1585
1586 #[test]
1587 fn connect_no_verify_ssl() {
1588 let stream = p!(TcpStream::connect("expired.badssl.com:443"));
1589 let mut builder = ClientBuilder::new();
1590 builder.danger_accept_invalid_certs(true);
1591 builder.handshake("expired.badssl.com", stream).unwrap();
1592 }
1593
1594 #[test]
1595 fn load_page_client() {
1596 let stream = p!(TcpStream::connect("google.com:443"));
1597 let mut stream = p!(ClientBuilder::new().handshake("google.com", stream));
1598 p!(stream.write_all(b"GET / HTTP/1.0\r\n\r\n"));
1599 p!(stream.flush());
1600 let mut buf = vec![];
1601 p!(stream.read_to_end(&mut buf));
1602 println!("{}", String::from_utf8_lossy(&buf));
1603 }
1604
1605 #[test]
1606 #[cfg_attr(any(target_os = "ios", target_os = "tvos", target_os = "watchos", target_os = "visionos"), ignore)] fn cipher_configuration() {
1608 let mut ctx = p!(SslContext::new(SslProtocolSide::SERVER, SslConnectionType::STREAM));
1609 let ciphers = p!(ctx.enabled_ciphers());
1610 let ciphers = ciphers
1611 .iter()
1612 .enumerate()
1613 .filter_map(|(i, c)| if i % 2 == 0 { Some(*c) } else { None })
1614 .collect::<Vec<_>>();
1615 p!(ctx.set_enabled_ciphers(&ciphers));
1616 assert_eq!(ciphers, p!(ctx.enabled_ciphers()));
1617 }
1618
1619 #[test]
1620 fn test_builder_whitelist_ciphers() {
1621 let stream = p!(TcpStream::connect("google.com:443"));
1622
1623 let ctx = p!(SslContext::new(SslProtocolSide::CLIENT, SslConnectionType::STREAM));
1624 assert!(p!(ctx.enabled_ciphers()).len() > 1);
1625
1626 let ciphers = p!(ctx.enabled_ciphers());
1627 let cipher = ciphers.first().unwrap();
1628 let stream = p!(ClientBuilder::new()
1629 .whitelist_ciphers(&[*cipher])
1630 .ctx_into_stream("google.com", stream));
1631
1632 assert_eq!(1, p!(stream.context().enabled_ciphers()).len());
1633 }
1634
1635 #[test]
1636 #[cfg_attr(any(target_os = "ios", target_os = "tvos", target_os = "watchos", target_os = "visionos"), ignore)] fn test_builder_blacklist_ciphers() {
1638 let stream = p!(TcpStream::connect("google.com:443"));
1639
1640 let ctx = p!(SslContext::new(SslProtocolSide::CLIENT, SslConnectionType::STREAM));
1641 let num = p!(ctx.enabled_ciphers()).len();
1642 assert!(num > 1);
1643
1644 let ciphers = p!(ctx.enabled_ciphers());
1645 let cipher = ciphers.first().unwrap();
1646 let stream = p!(ClientBuilder::new()
1647 .blacklist_ciphers(&[*cipher])
1648 .ctx_into_stream("google.com", stream));
1649
1650 assert_eq!(num - 1, p!(stream.context().enabled_ciphers()).len());
1651 }
1652
1653 #[test]
1654 fn idle_context_peer_trust() {
1655 let ctx = p!(SslContext::new(SslProtocolSide::SERVER, SslConnectionType::STREAM));
1656 assert!(ctx.peer_trust2().is_err());
1657 }
1658
1659 #[test]
1660 fn peer_id() {
1661 let mut ctx = p!(SslContext::new(SslProtocolSide::SERVER, SslConnectionType::STREAM));
1662 assert!(p!(ctx.peer_id()).is_none());
1663 p!(ctx.set_peer_id(b"foobar"));
1664 assert_eq!(p!(ctx.peer_id()), Some(&b"foobar"[..]));
1665 }
1666
1667 #[test]
1668 fn peer_domain_name() {
1669 let mut ctx = p!(SslContext::new(SslProtocolSide::CLIENT, SslConnectionType::STREAM));
1670 assert_eq!("", p!(ctx.peer_domain_name()));
1671 p!(ctx.set_peer_domain_name("foobar.com"));
1672 assert_eq!("foobar.com", p!(ctx.peer_domain_name()));
1673 }
1674
1675 #[test]
1676 #[should_panic(expected = "blammo")]
1677 fn write_panic() {
1678 struct ExplodingStream(TcpStream);
1679
1680 impl Read for ExplodingStream {
1681 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
1682 self.0.read(buf)
1683 }
1684 }
1685
1686 impl Write for ExplodingStream {
1687 fn write(&mut self, _: &[u8]) -> io::Result<usize> {
1688 panic!("blammo");
1689 }
1690
1691 fn flush(&mut self) -> io::Result<()> {
1692 self.0.flush()
1693 }
1694 }
1695
1696 let mut ctx = p!(SslContext::new(SslProtocolSide::CLIENT, SslConnectionType::STREAM));
1697 p!(ctx.set_peer_domain_name("google.com"));
1698 let stream = p!(TcpStream::connect("google.com:443"));
1699 let _ = ctx.handshake(ExplodingStream(stream));
1700 }
1701
1702 #[test]
1703 #[should_panic(expected = "blammo")]
1704 fn read_panic() {
1705 struct ExplodingStream(TcpStream);
1706
1707 impl Read for ExplodingStream {
1708 fn read(&mut self, _: &mut [u8]) -> io::Result<usize> {
1709 panic!("blammo");
1710 }
1711 }
1712
1713 impl Write for ExplodingStream {
1714 fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
1715 self.0.write(buf)
1716 }
1717
1718 fn flush(&mut self) -> io::Result<()> {
1719 self.0.flush()
1720 }
1721 }
1722
1723 let mut ctx = p!(SslContext::new(SslProtocolSide::CLIENT, SslConnectionType::STREAM));
1724 p!(ctx.set_peer_domain_name("google.com"));
1725 let stream = p!(TcpStream::connect("google.com:443"));
1726 let _ = ctx.handshake(ExplodingStream(stream));
1727 }
1728
1729 #[test]
1730 fn zero_length_buffers() {
1731 let mut ctx = p!(SslContext::new(SslProtocolSide::CLIENT, SslConnectionType::STREAM));
1732 p!(ctx.set_peer_domain_name("google.com"));
1733 let stream = p!(TcpStream::connect("google.com:443"));
1734 let mut stream = ctx.handshake(stream).unwrap();
1735 assert_eq!(stream.write(b"").unwrap(), 0);
1736 assert_eq!(stream.read(&mut []).unwrap(), 0);
1737 }
1738}