Skip to main content

security_framework/
secure_transport.rs

1//! SSL/TLS encryption support using Secure Transport.
2//!
3//! # Examples
4//!
5//! To connect as a client to a server with a certificate trusted by the system:
6//!
7//! ```rust
8//! use security_framework::secure_transport::ClientBuilder;
9//! use std::io::prelude::*;
10//! use std::net::TcpStream;
11//!
12//! let stream = TcpStream::connect("google.com:443").unwrap();
13//! let mut stream = ClientBuilder::new().handshake("google.com", stream).unwrap();
14//!
15//! stream.write_all(b"GET / HTTP/1.0\r\n\r\n").unwrap();
16//! let mut page = vec![];
17//! stream.read_to_end(&mut page).unwrap();
18//! println!("{}", String::from_utf8_lossy(&page));
19//! ```
20//!
21//! To connect to a server with a certificate that's *not* trusted by the
22//! system, specify the root certificates for the server's chain to the
23//! `ClientBuilder`:
24//!
25//! ```rust,no_run
26//! use security_framework::secure_transport::ClientBuilder;
27//! use std::io::prelude::*;
28//! use std::net::TcpStream;
29//!
30//! # let root_cert = unsafe { std::mem::zeroed() };
31//! let stream = TcpStream::connect("my_server.com:443").unwrap();
32//! let mut stream = ClientBuilder::new()
33//!     .anchor_certificates(&[root_cert])
34//!     .handshake("my_server.com", stream)
35//!     .unwrap();
36//!
37//! stream.write_all(b"GET / HTTP/1.0\r\n\r\n").unwrap();
38//! let mut page = vec![];
39//! stream.read_to_end(&mut page).unwrap();
40//! println!("{}", String::from_utf8_lossy(&page));
41//! ```
42//!
43//! For more advanced configuration, the `SslContext` type can be used directly.
44//!
45//! To run a server:
46//!
47//! ```rust,no_run
48//! use security_framework::secure_transport::{SslConnectionType, SslContext, SslProtocolSide};
49//! use std::net::TcpListener;
50//! use std::thread;
51//!
52//! // Create a TCP listener and start accepting on it.
53//! let mut listener = TcpListener::bind("0.0.0.0:443").unwrap();
54//!
55//! for stream in listener.incoming() {
56//!     let stream = stream.unwrap();
57//!     thread::spawn(move || {
58//!         // Create a new context configured to operate on the server side of
59//!         // a traditional SSL/TLS session.
60//!         let mut ctx = SslContext::new(SslProtocolSide::SERVER, SslConnectionType::STREAM)
61//!                           .unwrap();
62//!
63//!         // Install the certificate chain that we will be using.
64//!         # let identity = unsafe { std::mem::zeroed() };
65//!         # let intermediate_cert = unsafe { std::mem::zeroed() };
66//!         # let root_cert = unsafe { std::mem::zeroed() };
67//!         ctx.set_certificate(identity, &[intermediate_cert, root_cert]).unwrap();
68//!
69//!         // Perform the SSL/TLS handshake and get our stream.
70//!         let mut stream = ctx.handshake(stream).unwrap();
71//!     });
72//! }
73//! ```
74#![allow(clippy::result_large_err)]
75#[allow(unused_imports)]
76use core_foundation::array::{CFArray, CFArrayRef};
77use core_foundation::base::{Boolean, TCFType};
78#[cfg(feature = "alpn")]
79use core_foundation::string::CFString;
80use core_foundation::{declare_TCFType, impl_TCFType};
81use core_foundation_sys::base::{kCFAllocatorDefault, OSStatus};
82use std::os::raw::c_void;
83
84#[allow(unused_imports)]
85use security_framework_sys::base::{
86    errSecBadReq, errSecIO, errSecNotTrusted, errSecSuccess, errSecTrustSettingDeny,
87    errSecUnimplemented,
88};
89
90use security_framework_sys::secure_transport::*;
91use std::any::Any;
92use std::cmp;
93use std::fmt;
94use std::io;
95use std::io::prelude::*;
96use std::marker::PhantomData;
97use std::panic::{self, AssertUnwindSafe};
98use std::ptr;
99use std::result;
100use std::slice;
101
102use crate::base::{Error, Result};
103use crate::certificate::SecCertificate;
104use crate::cipher_suite::CipherSuite;
105use crate::cvt;
106use crate::identity::SecIdentity;
107use crate::import_export::Pkcs12ImportOptions;
108use crate::policy::SecPolicy;
109use crate::trust::SecTrust;
110use security_framework_sys::base::errSecParam;
111
112/// Specifies a side of a TLS session.
113#[derive(Debug, Copy, Clone, PartialEq, Eq)]
114pub struct SslProtocolSide(SSLProtocolSide);
115
116impl SslProtocolSide {
117    /// The client side of the session.
118    pub const CLIENT: Self = Self(kSSLClientSide);
119    /// The server side of the session.
120    pub const SERVER: Self = Self(kSSLServerSide);
121}
122
123/// Specifies the type of TLS session.
124#[derive(Debug, Copy, Clone)]
125pub struct SslConnectionType(SSLConnectionType);
126
127impl SslConnectionType {
128    /// A DTLS session.
129    pub const DATAGRAM: Self = Self(kSSLDatagramType);
130    /// A traditional TLS stream.
131    pub const STREAM: Self = Self(kSSLStreamType);
132}
133
134/// An error or intermediate state after a TLS handshake attempt.
135#[derive(Debug)]
136pub enum HandshakeError<S> {
137    /// The handshake failed.
138    Failure(Error),
139    /// The handshake was interrupted midway through.
140    Interrupted(MidHandshakeSslStream<S>),
141}
142
143impl<S> From<Error> for HandshakeError<S> {
144    #[inline(always)]
145    fn from(err: Error) -> Self {
146        Self::Failure(err)
147    }
148}
149
150/// An error or intermediate state after a TLS handshake attempt.
151#[derive(Debug)]
152pub enum ClientHandshakeError<S> {
153    /// The handshake failed.
154    Failure(Error),
155    /// The handshake was interrupted midway through.
156    Interrupted(MidHandshakeClientBuilder<S>),
157}
158
159impl<S> From<Error> for ClientHandshakeError<S> {
160    #[inline(always)]
161    fn from(err: Error) -> Self {
162        Self::Failure(err)
163    }
164}
165
166/// An SSL stream midway through the handshake process.
167#[derive(Debug)]
168pub struct MidHandshakeSslStream<S> {
169    stream: SslStream<S>,
170    error: Error,
171}
172
173impl<S> MidHandshakeSslStream<S> {
174    /// Returns a shared reference to the inner stream.
175    #[inline(always)]
176    #[must_use]
177    pub fn get_ref(&self) -> &S {
178        self.stream.get_ref()
179    }
180
181    /// Returns a mutable reference to the inner stream.
182    #[inline(always)]
183    pub fn get_mut(&mut self) -> &mut S {
184        self.stream.get_mut()
185    }
186
187    /// Returns a shared reference to the `SslContext` of the stream.
188    #[inline(always)]
189    #[must_use]
190    pub fn context(&self) -> &SslContext {
191        self.stream.context()
192    }
193
194    /// Returns a mutable reference to the `SslContext` of the stream.
195    #[inline(always)]
196    pub fn context_mut(&mut self) -> &mut SslContext {
197        self.stream.context_mut()
198    }
199
200    /// Returns `true` iff `break_on_server_auth` was set and the handshake has
201    /// progressed to that point.
202    #[inline(always)]
203    #[must_use]
204    pub fn server_auth_completed(&self) -> bool {
205        self.error.code() == errSSLPeerAuthCompleted
206    }
207
208    /// Returns `true` iff `break_on_cert_requested` was set and the handshake
209    /// has progressed to that point.
210    #[inline(always)]
211    #[must_use]
212    pub fn client_cert_requested(&self) -> bool {
213        self.error.code() == errSSLClientCertRequested
214    }
215
216    /// Returns `true` iff the underlying stream returned an error with the
217    /// `WouldBlock` kind.
218    #[inline(always)]
219    #[must_use]
220    pub fn would_block(&self) -> bool {
221        self.error.code() == errSSLWouldBlock
222    }
223
224    /// Returns the error which caused the handshake interruption.
225    #[inline(always)]
226    #[must_use]
227    pub const fn error(&self) -> &Error {
228        &self.error
229    }
230
231    /// Restarts the handshake process.
232    #[inline(always)]
233    pub fn handshake(self) -> result::Result<SslStream<S>, HandshakeError<S>> {
234        self.stream.handshake()
235    }
236}
237
238/// An SSL stream midway through the handshake process.
239#[derive(Debug)]
240pub struct MidHandshakeClientBuilder<S> {
241    stream: MidHandshakeSslStream<S>,
242    domain: Option<String>,
243    certs: Vec<SecCertificate>,
244    trust_certs_only: bool,
245    danger_accept_invalid_certs: bool,
246}
247
248impl<S> MidHandshakeClientBuilder<S> {
249    /// Returns a shared reference to the inner stream.
250    #[inline(always)]
251    #[must_use]
252    pub fn get_ref(&self) -> &S {
253        self.stream.get_ref()
254    }
255
256    /// Returns a mutable reference to the inner stream.
257    #[inline(always)]
258    pub fn get_mut(&mut self) -> &mut S {
259        self.stream.get_mut()
260    }
261
262    /// Returns the error which caused the handshake interruption.
263    #[inline(always)]
264    #[must_use]
265    pub fn error(&self) -> &Error {
266        self.stream.error()
267    }
268
269    /// Restarts the handshake process.
270    pub fn handshake(self) -> result::Result<SslStream<S>, ClientHandshakeError<S>> {
271        let Self {
272            stream,
273            domain,
274            certs,
275            trust_certs_only,
276            danger_accept_invalid_certs,
277        } = self;
278
279        let mut result = stream.handshake();
280        loop {
281            let stream = match result {
282                Ok(stream) => return Ok(stream),
283                Err(HandshakeError::Interrupted(stream)) => stream,
284                Err(HandshakeError::Failure(err)) => return Err(ClientHandshakeError::Failure(err)),
285            };
286
287            if stream.would_block() {
288                let ret = Self {
289                    stream,
290                    domain,
291                    certs,
292                    trust_certs_only,
293                    danger_accept_invalid_certs,
294                };
295                return Err(ClientHandshakeError::Interrupted(ret));
296            }
297
298            if stream.server_auth_completed() {
299                if danger_accept_invalid_certs {
300                    result = stream.handshake();
301                    continue;
302                }
303                let Some(mut trust) = stream.context().peer_trust2()? else {
304                    result = stream.handshake();
305                    continue;
306                };
307                trust.set_anchor_certificates(&certs)?;
308                trust.set_trust_anchor_certificates_only(self.trust_certs_only)?;
309                let policy = SecPolicy::create_ssl(SslProtocolSide::SERVER, domain.as_deref());
310                trust.set_policy(&policy)?;
311                trust.evaluate_with_error().map_err(|error| {
312                    #[cfg(feature = "log")]
313                    log::warn!("SecTrustEvaluateWithError: {error}");
314                    Error::from_code(error.code() as _)
315                })?;
316                result = stream.handshake();
317                continue;
318            }
319
320            let err = Error::from_code(stream.error().code());
321            return Err(ClientHandshakeError::Failure(err));
322        }
323    }
324}
325
326/// Specifies the state of a TLS session.
327#[derive(Debug, PartialEq, Eq)]
328pub struct SessionState(SSLSessionState);
329
330impl SessionState {
331    /// The session has been aborted due to an error.
332    pub const ABORTED: Self = Self(kSSLAborted);
333    /// The session has been terminated.
334    pub const CLOSED: Self = Self(kSSLClosed);
335    /// The session is connected.
336    pub const CONNECTED: Self = Self(kSSLConnected);
337    /// The session is in the handshake process.
338    pub const HANDSHAKE: Self = Self(kSSLHandshake);
339    /// The session has not yet started.
340    pub const IDLE: Self = Self(kSSLIdle);
341}
342
343/// Specifies a server's requirement for client certificates.
344#[derive(Debug, Copy, Clone, PartialEq, Eq)]
345pub struct SslAuthenticate(SSLAuthenticate);
346
347impl SslAuthenticate {
348    /// Require a client certificate.
349    pub const ALWAYS: Self = Self(kAlwaysAuthenticate);
350    /// Do not request a client certificate.
351    pub const NEVER: Self = Self(kNeverAuthenticate);
352    /// Request but do not require a client certificate.
353    pub const TRY: Self = Self(kTryAuthenticate);
354}
355
356/// Specifies the state of client certificate processing.
357#[derive(Debug, Copy, Clone, PartialEq, Eq)]
358pub struct SslClientCertificateState(SSLClientCertificateState);
359
360impl SslClientCertificateState {
361    /// A client certificate has not been requested or sent.
362    pub const NONE: Self = Self(kSSLClientCertNone);
363    /// A client certificate has been received but has failed to validate.
364    pub const REJECTED: Self = Self(kSSLClientCertRejected);
365    /// A client certificate has been requested but not recieved.
366    pub const REQUESTED: Self = Self(kSSLClientCertRequested);
367    /// A client certificate has been received and successfully validated.
368    pub const SENT: Self = Self(kSSLClientCertSent);
369}
370
371/// Specifies protocol versions.
372#[derive(Debug, Copy, Clone, PartialEq, Eq)]
373pub struct SslProtocol(SSLProtocol);
374
375impl SslProtocol {
376    /// All supported TLS/SSL versions are accepted.
377    pub const ALL: Self = Self(kSSLProtocolAll);
378    /// The `DTLSv1` protocol is preferred.
379    pub const DTLS1: Self = Self(kDTLSProtocol1);
380    /// Only the SSL 2.0 protocol is accepted.
381    pub const SSL2: Self = Self(kSSLProtocol2);
382    /// The SSL 3.0 protocol is preferred, though SSL 2.0 may be used if the peer does not support
383    /// SSL 3.0.
384    pub const SSL3: Self = Self(kSSLProtocol3);
385    /// Only the SSL 3.0 protocol is accepted.
386    pub const SSL3_ONLY: Self = Self(kSSLProtocol3Only);
387    /// The TLS 1.0 protocol is preferred, though lower versions may be used
388    /// if the peer does not support TLS 1.0.
389    pub const TLS1: Self = Self(kTLSProtocol1);
390    /// The TLS 1.1 protocol is preferred, though lower versions may be used
391    /// if the peer does not support TLS 1.1.
392    pub const TLS11: Self = Self(kTLSProtocol11);
393    /// The TLS 1.2 protocol is preferred, though lower versions may be used
394    /// if the peer does not support TLS 1.2.
395    pub const TLS12: Self = Self(kTLSProtocol12);
396    /// The TLS 1.3 protocol is preferred, though lower versions may be used
397    /// if the peer does not support TLS 1.3.
398    pub const TLS13: Self = Self(kTLSProtocol13);
399    /// Only the TLS 1.0 protocol is accepted.
400    pub const TLS1_ONLY: Self = Self(kTLSProtocol1Only);
401    /// No protocol has been or should be negotiated or specified; use the default.
402    pub const UNKNOWN: Self = Self(kSSLProtocolUnknown);
403}
404
405declare_TCFType! {
406    /// A Secure Transport SSL/TLS context object.
407    SslContext, SSLContextRef
408}
409
410impl_TCFType!(SslContext, SSLContextRef, SSLContextGetTypeID);
411
412impl fmt::Debug for SslContext {
413    #[cold]
414    fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
415        let mut builder = fmt.debug_struct("SslContext");
416        if let Ok(state) = self.state() {
417            builder.field("state", &state);
418        }
419        builder.finish()
420    }
421}
422
423unsafe impl Sync for SslContext {}
424unsafe impl Send for SslContext {}
425
426#[cfg(target_os = "macos")]
427impl SslContext {
428    pub(crate) fn as_inner(&self) -> SSLContextRef {
429        self.0
430    }
431}
432
433macro_rules! impl_options {
434    ($($(#[$a:meta])* const $opt:ident: $get:ident & $set:ident,)*) => {
435        $(
436            #[allow(deprecated)]
437            $(#[$a])*
438            #[inline(always)]
439            pub fn $set(&mut self, value: bool) -> Result<()> {
440                unsafe { cvt(SSLSetSessionOption(self.0, $opt, Boolean::from(value))) }
441            }
442
443            #[allow(deprecated)]
444            $(#[$a])*
445            #[inline]
446            pub fn $get(&self) -> Result<bool> {
447                let mut value = 0;
448                unsafe { cvt(SSLGetSessionOption(self.0, $opt, &mut value))?; }
449                Ok(value != 0)
450            }
451        )*
452    }
453}
454
455impl SslContext {
456    /// Creates a new `SslContext` for the specified side and type of SSL
457    /// connection.
458    #[inline]
459    pub fn new(side: SslProtocolSide, type_: SslConnectionType) -> Result<Self> {
460        unsafe {
461            let ctx = SSLCreateContext(kCFAllocatorDefault, side.0, type_.0);
462            Ok(Self(ctx))
463        }
464    }
465
466    /// Sets the fully qualified domain name of the peer.
467    ///
468    /// This will be used on the client side of a session to validate the
469    /// common name field of the server's certificate. It has no effect if
470    /// called on a server-side `SslContext`.
471    ///
472    /// It is *highly* recommended to call this method before starting the
473    /// handshake process.
474    #[inline]
475    pub fn set_peer_domain_name(&mut self, peer_name: &str) -> Result<()> {
476        unsafe {
477            // SSLSetPeerDomainName doesn't need a null terminated string
478            cvt(SSLSetPeerDomainName(self.0, peer_name.as_ptr().cast(), peer_name.len()))
479        }
480    }
481
482    /// Returns the peer domain name set by `set_peer_domain_name`.
483    pub fn peer_domain_name(&self) -> Result<String> {
484        unsafe {
485            let mut len = 0;
486            cvt(SSLGetPeerDomainNameLength(self.0, &mut len))?;
487            let mut buf = vec![0; len];
488            cvt(SSLGetPeerDomainName(self.0, buf.as_mut_ptr().cast(), &mut len))?;
489            String::from_utf8(buf).map_err(|_| Error::from_code(-1))
490        }
491    }
492
493    /// Sets the certificate to be used by this side of the SSL session.
494    ///
495    /// This must be called before the handshake for server-side connections,
496    /// and can be used on the client-side to specify a client certificate.
497    ///
498    /// The `identity` corresponds to the leaf certificate and private
499    /// key, and the `certs` correspond to extra certificates in the chain.
500    pub fn set_certificate(
501        &mut self,
502        identity: &SecIdentity,
503        certs: &[SecCertificate],
504    ) -> Result<()> {
505        let mut arr = vec![identity.as_CFType()];
506        arr.extend(certs.iter().map(|c| c.as_CFType()));
507        let certs = CFArray::from_CFTypes(&arr);
508
509        unsafe { cvt(SSLSetCertificate(self.0, certs.as_concrete_TypeRef())) }
510    }
511
512    /// Sets the peer ID of this session.
513    ///
514    /// A peer ID is an opaque sequence of bytes that will be used by Secure
515    /// Transport to identify the peer of an SSL session. If the peer ID of
516    /// this session matches that of a previously terminated session, the
517    /// previous session can be resumed without requiring a full handshake.
518    #[inline]
519    pub fn set_peer_id(&mut self, peer_id: &[u8]) -> Result<()> {
520        unsafe { cvt(SSLSetPeerID(self.0, peer_id.as_ptr().cast(), peer_id.len())) }
521    }
522
523    /// Returns the peer ID of this session.
524    pub fn peer_id(&self) -> Result<Option<&[u8]>> {
525        unsafe {
526            let mut ptr = ptr::null();
527            let mut len = 0;
528            cvt(SSLGetPeerID(self.0, &mut ptr, &mut len))?;
529            if ptr.is_null() {
530                Ok(None)
531            } else {
532                Ok(Some(slice::from_raw_parts(ptr.cast(), len)))
533            }
534        }
535    }
536
537    /// Returns the list of ciphers that are supported by Secure Transport.
538    pub fn supported_ciphers(&self) -> Result<Vec<CipherSuite>> {
539        unsafe {
540            let mut num_ciphers = 0;
541            cvt(SSLGetNumberSupportedCiphers(self.0, &mut num_ciphers))?;
542            let mut ciphers = vec![0; num_ciphers];
543            cvt(SSLGetSupportedCiphers(
544                self.0,
545                ciphers.as_mut_ptr(),
546                &mut num_ciphers,
547            ))?;
548            Ok(ciphers.iter().map(|c| CipherSuite::from_raw(*c)).collect())
549        }
550    }
551
552    /// Returns the list of ciphers that are eligible to be used for
553    /// negotiation.
554    pub fn enabled_ciphers(&self) -> Result<Vec<CipherSuite>> {
555        unsafe {
556            let mut num_ciphers = 0;
557            cvt(SSLGetNumberEnabledCiphers(self.0, &mut num_ciphers))?;
558            let mut ciphers = vec![0; num_ciphers];
559            cvt(SSLGetEnabledCiphers(
560                self.0,
561                ciphers.as_mut_ptr(),
562                &mut num_ciphers,
563            ))?;
564            Ok(ciphers.iter().map(|c| CipherSuite::from_raw(*c)).collect())
565        }
566    }
567
568    /// Sets the list of ciphers that are eligible to be used for negotiation.
569    pub fn set_enabled_ciphers(&mut self, ciphers: &[CipherSuite]) -> Result<()> {
570        let ciphers = ciphers.iter().map(|c| c.to_raw()).collect::<Vec<_>>();
571        unsafe {
572            cvt(SSLSetEnabledCiphers(
573                self.0,
574                ciphers.as_ptr(),
575                ciphers.len(),
576            ))
577        }
578    }
579
580    /// Returns the cipher being used by the session.
581    #[inline]
582    pub fn negotiated_cipher(&self) -> Result<CipherSuite> {
583        unsafe {
584            let mut cipher = 0;
585            cvt(SSLGetNegotiatedCipher(self.0, &mut cipher))?;
586            Ok(CipherSuite::from_raw(cipher))
587        }
588    }
589
590    /// Sets the requirements for client certificates.
591    ///
592    /// Should only be called on server-side sessions.
593    #[inline]
594    pub fn set_client_side_authenticate(&mut self, auth: SslAuthenticate) -> Result<()> {
595        unsafe { cvt(SSLSetClientSideAuthenticate(self.0, auth.0)) }
596    }
597
598    /// Returns the state of client certificate processing.
599    #[inline]
600    pub fn client_certificate_state(&self) -> Result<SslClientCertificateState> {
601        let mut state = 0;
602
603        unsafe {
604            cvt(SSLGetClientCertificateState(self.0, &mut state))?;
605        }
606        Ok(SslClientCertificateState(state))
607    }
608
609    /// Returns the `SecTrust` object corresponding to the peer.
610    ///
611    /// This can be used in conjunction with `set_break_on_server_auth` to
612    /// validate certificates which do not have roots in the default set.
613    pub fn peer_trust2(&self) -> Result<Option<SecTrust>> {
614        // Calling SSLCopyPeerTrust on an idle connection does not seem to be well defined,
615        // so explicitly check for that
616        if self.state()? == SessionState::IDLE {
617            return Err(Error::from_code(errSecBadReq));
618        }
619
620        unsafe {
621            let mut trust = ptr::null_mut();
622            cvt(SSLCopyPeerTrust(self.0, &mut trust))?;
623            if trust.is_null() {
624                Ok(None)
625            } else {
626                Ok(Some(SecTrust::wrap_under_create_rule(trust)))
627            }
628        }
629    }
630
631    /// Returns the state of the session.
632    #[inline]
633    pub fn state(&self) -> Result<SessionState> {
634        unsafe {
635            let mut state = 0;
636            cvt(SSLGetSessionState(self.0, &mut state))?;
637            Ok(SessionState(state))
638        }
639    }
640
641    /// Returns the protocol version being used by the session.
642    #[inline]
643    pub fn negotiated_protocol_version(&self) -> Result<SslProtocol> {
644        unsafe {
645            let mut version = 0;
646            cvt(SSLGetNegotiatedProtocolVersion(self.0, &mut version))?;
647            Ok(SslProtocol(version))
648        }
649    }
650
651    /// Returns the maximum protocol version allowed by the session.
652    #[inline]
653    pub fn protocol_version_max(&self) -> Result<SslProtocol> {
654        unsafe {
655            let mut version = 0;
656            cvt(SSLGetProtocolVersionMax(self.0, &mut version))?;
657            Ok(SslProtocol(version))
658        }
659    }
660
661    /// Sets the maximum protocol version allowed by the session.
662    #[inline]
663    pub fn set_protocol_version_max(&mut self, max_version: SslProtocol) -> Result<()> {
664        unsafe { cvt(SSLSetProtocolVersionMax(self.0, max_version.0)) }
665    }
666
667    /// Returns the minimum protocol version allowed by the session.
668    #[inline]
669    pub fn protocol_version_min(&self) -> Result<SslProtocol> {
670        unsafe {
671            let mut version = 0;
672            cvt(SSLGetProtocolVersionMin(self.0, &mut version))?;
673            Ok(SslProtocol(version))
674        }
675    }
676
677    /// Sets the minimum protocol version allowed by the session.
678    #[inline]
679    pub fn set_protocol_version_min(&mut self, min_version: SslProtocol) -> Result<()> {
680        unsafe { cvt(SSLSetProtocolVersionMin(self.0, min_version.0)) }
681    }
682
683    /// Returns the set of protocols selected via ALPN if it succeeded.
684    #[cfg(feature = "alpn")]
685    pub fn alpn_protocols(&self) -> Result<Vec<String>> {
686        let mut array: CFArrayRef = ptr::null();
687        unsafe {
688            #[cfg(feature = "OSX_10_13")]
689            {
690                cvt(SSLCopyALPNProtocols(self.0, &mut array))?;
691            }
692
693            #[cfg(not(feature = "OSX_10_13"))]
694            {
695                dlsym! { fn SSLCopyALPNProtocols(SSLContextRef, *mut CFArrayRef) -> OSStatus }
696                if let Some(f) = SSLCopyALPNProtocols.get() {
697                    cvt(f(self.0, &mut array))?;
698                } else {
699                    return Err(Error::from_code(errSecUnimplemented));
700                }
701            }
702
703            if array.is_null() {
704                return Ok(vec![]);
705            }
706
707            let array = CFArray::<CFString>::wrap_under_create_rule(array);
708            Ok(array.into_iter().map(|p| p.to_string()).collect())
709        }
710    }
711
712    /// Configures the set of protocols use for ALPN.
713    ///
714    /// This is only used for client-side connections.
715    #[cfg(feature = "alpn")]
716    pub fn set_alpn_protocols(&mut self, protocols: &[&str]) -> Result<()> {
717        // When CFMutableArray is added to core-foundation and IntoIterator trait
718        // is implemented for CFMutableArray, the code below should directly collect
719        // into a CFMutableArray.
720        let protocols = CFArray::from_CFTypes(
721            &protocols
722                .iter()
723                .map(|proto| CFString::new(proto))
724                .collect::<Vec<_>>(),
725        );
726
727        #[cfg(feature = "OSX_10_13")]
728        {
729            unsafe { cvt(SSLSetALPNProtocols(self.0, protocols.as_concrete_TypeRef())) }
730        }
731        #[cfg(not(feature = "OSX_10_13"))]
732        {
733            dlsym! { fn SSLSetALPNProtocols(SSLContextRef, CFArrayRef) -> OSStatus }
734            if let Some(f) = SSLSetALPNProtocols.get() {
735                unsafe { cvt(f(self.0, protocols.as_concrete_TypeRef())) }
736            } else {
737                Err(Error::from_code(errSecUnimplemented))
738            }
739        }
740    }
741
742    /// Sets whether the client sends the `SessionTicket` extension in its `ClientHello`.
743    ///
744    /// On its own, this will just cause the client to send an empty `SessionTicket` extension on
745    /// every connection. [`SslContext::set_peer_id`] must also be used to key the session
746    /// ticket returned by the server.
747    ///
748    /// [`SslContext::set_peer_id`]: #method.set_peer_id
749    #[cfg(feature = "session-tickets")]
750    pub fn set_session_tickets_enabled(&mut self, enabled: bool) -> Result<()> {
751        #[cfg(feature = "OSX_10_13")]
752        {
753            unsafe { cvt(SSLSetSessionTicketsEnabled(self.0, Boolean::from(enabled))) }
754        }
755        #[cfg(not(feature = "OSX_10_13"))]
756        {
757            dlsym! { fn SSLSetSessionTicketsEnabled(SSLContextRef, Boolean) -> OSStatus }
758            if let Some(f) = SSLSetSessionTicketsEnabled.get() {
759                unsafe { cvt(f(self.0, Boolean::from(enabled))) }
760            } else {
761                Err(Error::from_code(errSecUnimplemented))
762            }
763        }
764    }
765
766    /// Returns the number of bytes which can be read without triggering a
767    /// `read` call in the underlying stream.
768    #[inline]
769    pub fn buffered_read_size(&self) -> Result<usize> {
770        unsafe {
771            let mut size = 0;
772            cvt(SSLGetBufferedReadSize(self.0, &mut size))?;
773            Ok(size)
774        }
775    }
776
777    impl_options! {
778        /// If enabled, the handshake process will pause and return instead of
779        /// automatically validating a server's certificate.
780        const kSSLSessionOptionBreakOnServerAuth: break_on_server_auth & set_break_on_server_auth,
781        /// If enabled, the handshake process will pause and return after
782        /// the server requests a certificate from the client.
783        const kSSLSessionOptionBreakOnCertRequested: break_on_cert_requested & set_break_on_cert_requested,
784        /// If enabled, the handshake process will pause and return instead of
785        /// automatically validating a client's certificate.
786        const kSSLSessionOptionBreakOnClientAuth: break_on_client_auth & set_break_on_client_auth,
787        /// If enabled, TLS false start will be performed if an appropriate
788        /// cipher suite is negotiated.
789        ///
790        const kSSLSessionOptionFalseStart: false_start & set_false_start,
791        /// If enabled, 1/n-1 record splitting will be enabled for TLS 1.0
792        /// connections using block ciphers to mitigate the BEAST attack.
793        ///
794        const kSSLSessionOptionSendOneByteRecord: send_one_byte_record & set_send_one_byte_record,
795    }
796
797    fn into_stream<S>(self, stream: S) -> Result<SslStream<S>>
798    where S: Read + Write {
799        unsafe {
800            let ret = SSLSetIOFuncs(self.0, read_func::<S>, write_func::<S>);
801            if ret != errSecSuccess {
802                return Err(Error::from_code(ret));
803            }
804
805            let stream = Connection { stream, err: None, panic: None };
806            let stream = Box::into_raw(Box::new(stream));
807            let ret = SSLSetConnection(self.0, stream.cast());
808            if ret != errSecSuccess {
809                let _conn = Box::from_raw(stream);
810                return Err(Error::from_code(ret));
811            }
812
813            Ok(SslStream { ctx: self, _m: PhantomData })
814        }
815    }
816
817    /// Performs the SSL/TLS handshake.
818    pub fn handshake<S>(self, stream: S) -> result::Result<SslStream<S>, HandshakeError<S>>
819    where
820        S: Read + Write,
821    {
822        self.into_stream(stream)
823            .map_err(HandshakeError::Failure)
824            .and_then(SslStream::handshake)
825    }
826}
827
828struct Connection<S> {
829    stream: S,
830    err: Option<io::Error>,
831    panic: Option<Box<dyn Any + Send>>,
832}
833
834// the logic here is based off of libcurl's
835#[cold]
836fn translate_err(e: &io::Error) -> OSStatus {
837    match e.kind() {
838        io::ErrorKind::NotFound => errSSLClosedGraceful,
839        io::ErrorKind::ConnectionReset => errSSLClosedAbort,
840        io::ErrorKind::WouldBlock |
841        io::ErrorKind::NotConnected => errSSLWouldBlock,
842        _ => errSecIO,
843    }
844}
845
846unsafe extern "C" fn read_func<S>(
847    connection: SSLConnectionRef,
848    data: *mut c_void,
849    data_length: *mut usize,
850) -> OSStatus
851where S: Read {
852    let conn: &mut Connection<S> = &mut *(connection as *mut _);
853    let mut read = 0;
854
855    let ret = panic::catch_unwind(AssertUnwindSafe(|| {
856        let mut data = slice::from_raw_parts_mut(data.cast::<u8>(), *data_length);
857        while !data.is_empty() {
858            match conn.stream.read(data) {
859                Ok(0) => return errSSLClosedNoNotify,
860                Ok(len) => {
861                    let Some(rest) = data.get_mut(len..) else {
862                        return errSecIO;
863                    };
864                    data = rest;
865                    read += len;
866                },
867                Err(e) => {
868                    let ret = translate_err(&e);
869                    conn.err = Some(e);
870                    return ret;
871                },
872            }
873        }
874        errSecSuccess
875    }))
876    .unwrap_or_else(|e| {
877        conn.panic = Some(e);
878        errSecIO
879    });
880
881    *data_length = read;
882    ret
883}
884
885unsafe extern "C" fn write_func<S>(
886    connection: SSLConnectionRef,
887    data: *const c_void,
888    data_length: *mut usize,
889) -> OSStatus
890where S: Write {
891    let conn: &mut Connection<S> = &mut *(connection as *mut _);
892    let mut written = 0;
893
894    let ret = panic::catch_unwind(AssertUnwindSafe(|| {
895        let mut data = slice::from_raw_parts(data.cast::<u8>(), *data_length);
896        while !data.is_empty() {
897            match conn.stream.write(data) {
898                Ok(0) => return errSSLClosedNoNotify,
899                Ok(len) => {
900                    let Some(rest) = data.get(len..) else {
901                        return errSecIO;
902                    };
903                    data = rest;
904                    written += len;
905                },
906                Err(e) => {
907                    let ret = translate_err(&e);
908                    conn.err = Some(e);
909                    return ret;
910                },
911            }
912        }
913        // Need to flush during the handshake so that the handshake doesn't stall on buffered
914        // write streams. It would be better if we only flushed automatically during the
915        // handshake, and not for the remainder of the stream.
916        if let Err(e) = conn.stream.flush() {
917            let ret = translate_err(&e);
918            conn.err = Some(e);
919            return ret;
920        }
921        errSecSuccess
922    }))
923    .unwrap_or_else(|e| {
924        conn.panic = Some(e);
925        errSecIO
926    });
927
928    *data_length = written;
929    ret
930}
931
932/// A type implementing SSL/TLS encryption over an underlying stream.
933pub struct SslStream<S> {
934    ctx: SslContext,
935    _m: PhantomData<S>,
936}
937
938impl<S: fmt::Debug> fmt::Debug for SslStream<S> {
939    #[cold]
940    fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
941        fmt.debug_struct("SslStream")
942            .field("context", &self.ctx)
943            .field("stream", self.get_ref())
944            .finish()
945    }
946}
947
948impl<S> Drop for SslStream<S> {
949    fn drop(&mut self) {
950        unsafe {
951            let mut conn = ptr::null();
952            let ret = SSLGetConnection(self.ctx.0, &mut conn);
953            assert!(ret == errSecSuccess);
954            let _ = Box::<Connection<S>>::from_raw(conn as *mut _);
955        }
956    }
957}
958
959impl<S> SslStream<S> {
960    fn handshake(mut self) -> result::Result<Self, HandshakeError<S>> {
961        match unsafe { SSLHandshake(self.ctx.0) } {
962            errSecSuccess => Ok(self),
963            reason @ (errSSLPeerAuthCompleted
964            | errSSLClientCertRequested
965            | errSSLWouldBlock
966            | errSSLClientHelloReceived) => {
967                Err(HandshakeError::Interrupted(MidHandshakeSslStream {
968                    stream: self,
969                    error: Error::from_code(reason),
970                }))
971            },
972            err => {
973                self.check_panic();
974                Err(HandshakeError::Failure(Error::from_code(err)))
975            },
976        }
977    }
978
979    /// Returns a shared reference to the inner stream.
980    #[inline(always)]
981    #[must_use]
982    pub fn get_ref(&self) -> &S {
983        &self.connection().stream
984    }
985
986    /// Returns a mutable reference to the underlying stream.
987    #[inline(always)]
988    pub fn get_mut(&mut self) -> &mut S {
989        &mut self.connection_mut().stream
990    }
991
992    /// Returns a shared reference to the `SslContext` of the stream.
993    #[inline(always)]
994    #[must_use]
995    pub fn context(&self) -> &SslContext {
996        &self.ctx
997    }
998
999    /// Returns a mutable reference to the `SslContext` of the stream.
1000    #[inline(always)]
1001    pub fn context_mut(&mut self) -> &mut SslContext {
1002        &mut self.ctx
1003    }
1004
1005    /// Shuts down the connection.
1006    pub fn close(&mut self) -> result::Result<(), io::Error> {
1007        unsafe {
1008            let ret = SSLClose(self.ctx.0);
1009            if ret == errSecSuccess {
1010                Ok(())
1011            } else {
1012                Err(self.get_error(ret))
1013            }
1014        }
1015    }
1016
1017    fn connection(&self) -> &Connection<S> {
1018        unsafe {
1019            let mut conn = ptr::null();
1020            let ret = SSLGetConnection(self.ctx.0, &mut conn);
1021            assert!(ret == errSecSuccess);
1022
1023            &mut *(conn as *mut Connection<S>)
1024        }
1025    }
1026
1027    fn connection_mut(&mut self) -> &mut Connection<S> {
1028        unsafe {
1029            let mut conn = ptr::null();
1030            let ret = SSLGetConnection(self.ctx.0, &mut conn);
1031            assert!(ret == errSecSuccess);
1032
1033            &mut *(conn as *mut Connection<S>)
1034        }
1035    }
1036
1037    #[cold]
1038    fn check_panic(&mut self) {
1039        let conn = self.connection_mut();
1040        if let Some(err) = conn.panic.take() {
1041            panic::resume_unwind(err);
1042        }
1043    }
1044
1045    #[cold]
1046    fn get_error(&mut self, ret: OSStatus) -> io::Error {
1047        self.check_panic();
1048
1049        if let Some(err) = self.connection_mut().err.take() {
1050            err
1051        } else {
1052            io::Error::new(io::ErrorKind::Other, Error::from_code(ret))
1053        }
1054    }
1055}
1056
1057impl<S: Read + Write> Read for SslStream<S> {
1058    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
1059        // Below we base our return value off the amount of data read, so a
1060        // zero-length buffer might cause us to erroneously interpret this
1061        // request as an error. Instead short-circuit that logic and return
1062        // `Ok(0)` instead.
1063        if buf.is_empty() {
1064            return Ok(0);
1065        }
1066
1067        // If some data was buffered but not enough to fill `buf`, SSLRead
1068        // will try to read a new packet. This is bad because there may be
1069        // no more data but the socket is remaining open (e.g HTTPS with
1070        // Connection: keep-alive).
1071        let buffered = self.context().buffered_read_size().unwrap_or(0);
1072        let to_read = if buffered > 0 {
1073            cmp::min(buffered, buf.len())
1074        } else {
1075            buf.len()
1076        };
1077
1078        unsafe {
1079            let mut nread = 0;
1080            let ret = SSLRead(self.ctx.0, buf.as_mut_ptr().cast(), to_read, &mut nread);
1081            // SSLRead can return an error at the same time it returns the last
1082            // chunk of data (!)
1083            if nread > 0 {
1084                return Ok(nread);
1085            }
1086
1087            match ret {
1088                errSSLClosedGraceful | errSSLClosedAbort | errSSLClosedNoNotify => Ok(0),
1089                // this error isn't fatal
1090                errSSLPeerAuthCompleted => self.read(buf),
1091                _ => Err(self.get_error(ret)),
1092            }
1093        }
1094    }
1095}
1096
1097impl<S: Read + Write> Write for SslStream<S> {
1098    fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
1099        // Like above in read, short circuit a 0-length write
1100        if buf.is_empty() {
1101            return Ok(0);
1102        }
1103        unsafe {
1104            let mut nwritten = 0;
1105            let ret = SSLWrite(
1106                self.ctx.0,
1107                buf.as_ptr().cast(),
1108                buf.len(),
1109                &mut nwritten,
1110            );
1111            // just to be safe, base success off of nwritten rather than ret
1112            // for the same reason as in read
1113            if nwritten > 0 {
1114                Ok(nwritten)
1115            } else {
1116                Err(self.get_error(ret))
1117            }
1118        }
1119    }
1120
1121    fn flush(&mut self) -> io::Result<()> {
1122        self.connection_mut().stream.flush()
1123    }
1124}
1125
1126/// A builder type to simplify the creation of client side `SslStream`s.
1127#[derive(Debug)]
1128pub struct ClientBuilder {
1129    identity: Option<SecIdentity>,
1130    certs: Vec<SecCertificate>,
1131    chain: Vec<SecCertificate>,
1132    protocol_min: Option<SslProtocol>,
1133    protocol_max: Option<SslProtocol>,
1134    trust_certs_only: bool,
1135    use_sni: bool,
1136    danger_accept_invalid_certs: bool,
1137    danger_accept_invalid_hostnames: bool,
1138    whitelisted_ciphers: Vec<CipherSuite>,
1139    blacklisted_ciphers: Vec<CipherSuite>,
1140    #[cfg(feature = "alpn")]
1141    alpn: Option<Vec<String>>,
1142    #[cfg(feature = "session-tickets")]
1143    enable_session_tickets: bool,
1144}
1145
1146impl Default for ClientBuilder {
1147    #[inline(always)]
1148    fn default() -> Self {
1149        Self::new()
1150    }
1151}
1152
1153impl ClientBuilder {
1154    /// Creates a new builder with default options.
1155    #[inline]
1156    #[must_use]
1157    pub fn new() -> Self {
1158        Self {
1159            identity: None,
1160            certs: Vec::new(),
1161            chain: Vec::new(),
1162            protocol_min: None,
1163            protocol_max: None,
1164            trust_certs_only: false,
1165            use_sni: true,
1166            danger_accept_invalid_certs: false,
1167            danger_accept_invalid_hostnames: false,
1168            whitelisted_ciphers: Vec::new(),
1169            blacklisted_ciphers: Vec::new(),
1170            #[cfg(feature = "alpn")]
1171            alpn: None,
1172            #[cfg(feature = "session-tickets")]
1173            enable_session_tickets: false,
1174        }
1175    }
1176
1177    /// Specifies the set of root certificates to trust when
1178    /// verifying the server's certificate.
1179    #[inline]
1180    pub fn anchor_certificates(&mut self, certs: &[SecCertificate]) -> &mut Self {
1181        certs.clone_into(&mut self.certs);
1182        self
1183    }
1184
1185    /// Add the certificate the set of root certificates to trust
1186    /// when verifying the server's certificate.
1187    #[inline]
1188    pub fn add_anchor_certificate(&mut self, certs: &SecCertificate) -> &mut Self {
1189        self.certs.push(certs.to_owned());
1190        self
1191    }
1192
1193    /// Specifies whether to trust the built-in certificates in addition
1194    /// to specified anchor certificates.
1195    #[inline(always)]
1196    pub fn trust_anchor_certificates_only(&mut self, only: bool) -> &mut Self {
1197        self.trust_certs_only = only;
1198        self
1199    }
1200
1201    /// Specifies whether to trust invalid certificates.
1202    ///
1203    /// # Warning
1204    ///
1205    /// You should think very carefully before using this method. If invalid
1206    /// certificates are trusted, *any* certificate for *any* site will be
1207    /// trusted for use. This includes expired certificates. This introduces
1208    /// significant vulnerabilities, and should only be used as a last resort.
1209    #[inline(always)]
1210    pub fn danger_accept_invalid_certs(&mut self, noverify: bool) -> &mut Self {
1211        self.danger_accept_invalid_certs = noverify;
1212        self
1213    }
1214
1215    /// Specifies whether to use Server Name Indication (SNI).
1216    #[inline(always)]
1217    pub fn use_sni(&mut self, use_sni: bool) -> &mut Self {
1218        self.use_sni = use_sni;
1219        self
1220    }
1221
1222    /// Specifies whether to verify that the server's hostname matches its certificate.
1223    ///
1224    /// # Warning
1225    ///
1226    /// You should think very carefully before using this method. If hostnames are not verified,
1227    /// *any* valid certificate for *any* site will be trusted for use. This introduces significant
1228    /// vulnerabilities, and should only be used as a last resort.
1229    #[inline(always)]
1230    pub fn danger_accept_invalid_hostnames(
1231        &mut self,
1232        danger_accept_invalid_hostnames: bool,
1233    ) -> &mut Self {
1234        self.danger_accept_invalid_hostnames = danger_accept_invalid_hostnames;
1235        self
1236    }
1237
1238    /// Set a whitelist of enabled ciphers. Any ciphers not whitelisted will be disabled.
1239    pub fn whitelist_ciphers(&mut self, whitelisted_ciphers: &[CipherSuite]) -> &mut Self {
1240        whitelisted_ciphers.clone_into(&mut self.whitelisted_ciphers);
1241        self
1242    }
1243
1244    /// Set a blacklist of disabled ciphers. Blacklisted ciphers will be disabled.
1245    pub fn blacklist_ciphers(&mut self, blacklisted_ciphers: &[CipherSuite]) -> &mut Self {
1246        blacklisted_ciphers.clone_into(&mut self.blacklisted_ciphers);
1247        self
1248    }
1249
1250    /// Use the specified identity as a SSL/TLS client certificate.
1251    pub fn identity(&mut self, identity: &SecIdentity, chain: &[SecCertificate]) -> &mut Self {
1252        self.identity = Some(identity.clone());
1253        chain.clone_into(&mut self.chain);
1254        self
1255    }
1256
1257    /// Configure the minimum protocol that this client will support.
1258    #[inline(always)]
1259    pub fn protocol_min(&mut self, min: SslProtocol) -> &mut Self {
1260        self.protocol_min = Some(min);
1261        self
1262    }
1263
1264    /// Configure the minimum protocol that this client will support.
1265    #[inline(always)]
1266    pub fn protocol_max(&mut self, max: SslProtocol) -> &mut Self {
1267        self.protocol_max = Some(max);
1268        self
1269    }
1270
1271    /// Configures the set of protocols used for ALPN.
1272    #[cfg(feature = "alpn")]
1273    pub fn alpn_protocols(&mut self, protocols: &[&str]) -> &mut Self {
1274        self.alpn = Some(protocols.iter().map(|s| (*s).to_string()).collect());
1275        self
1276    }
1277
1278    /// Configures the use of the RFC 5077 `SessionTicket` extension.
1279    ///
1280    /// Defaults to `false`.
1281    #[cfg(feature = "session-tickets")]
1282    #[inline(always)]
1283    pub fn enable_session_tickets(&mut self, enable: bool) -> &mut Self {
1284        self.enable_session_tickets = enable;
1285        self
1286    }
1287
1288    /// Initiates a new SSL/TLS session over a stream connected to the specified domain.
1289    ///
1290    /// If both SNI and hostname verification are disabled, the value of `domain` will be ignored.
1291    pub fn handshake<S>(
1292        &self,
1293        domain: &str,
1294        stream: S,
1295    ) -> result::Result<SslStream<S>, ClientHandshakeError<S>>
1296    where
1297        S: Read + Write,
1298    {
1299        // the logic for trust validation is in MidHandshakeClientBuilder::connect, so run all
1300        // of the handshake logic through that.
1301        let stream = MidHandshakeSslStream {
1302            stream: self.ctx_into_stream(domain, stream)?,
1303            error: Error::from(errSecSuccess),
1304        };
1305
1306        let certs = self.certs.clone();
1307        let stream = MidHandshakeClientBuilder {
1308            stream,
1309            domain: if self.danger_accept_invalid_hostnames {
1310                None
1311            } else {
1312                Some(domain.to_string())
1313            },
1314            certs,
1315            trust_certs_only: self.trust_certs_only,
1316            danger_accept_invalid_certs: self.danger_accept_invalid_certs,
1317        };
1318        stream.handshake()
1319    }
1320
1321    fn ctx_into_stream<S>(&self, domain: &str, stream: S) -> Result<SslStream<S>>
1322    where S: Read + Write {
1323        let mut ctx = SslContext::new(SslProtocolSide::CLIENT, SslConnectionType::STREAM)?;
1324
1325        if self.use_sni {
1326            ctx.set_peer_domain_name(domain)?;
1327        }
1328        if let Some(identity) = &self.identity {
1329            ctx.set_certificate(identity, &self.chain)?;
1330        }
1331        #[cfg(feature = "alpn")]
1332        {
1333            if let Some(alpn) = &self.alpn {
1334                ctx.set_alpn_protocols(&alpn.iter().map(|s| &**s).collect::<Vec<_>>())?;
1335            }
1336        }
1337        #[cfg(feature = "session-tickets")]
1338        {
1339            if self.enable_session_tickets {
1340                // We must use the domain here to ensure that we go through certificate validation
1341                // again rather than resuming the session if the domain changes.
1342                ctx.set_peer_id(domain.as_bytes())?;
1343                ctx.set_session_tickets_enabled(true)?;
1344            }
1345        }
1346        ctx.set_break_on_server_auth(true)?;
1347        self.configure_protocols(&mut ctx)?;
1348        self.configure_ciphers(&mut ctx)?;
1349
1350        ctx.into_stream(stream)
1351    }
1352
1353    fn configure_protocols(&self, ctx: &mut SslContext) -> Result<()> {
1354        if let Some(min) = self.protocol_min {
1355            ctx.set_protocol_version_min(min)?;
1356        }
1357        if let Some(max) = self.protocol_max {
1358            ctx.set_protocol_version_max(max)?;
1359        }
1360        Ok(())
1361    }
1362
1363    fn configure_ciphers(&self, ctx: &mut SslContext) -> Result<()> {
1364        let mut ciphers = if self.whitelisted_ciphers.is_empty() {
1365            ctx.enabled_ciphers()?
1366        } else {
1367            self.whitelisted_ciphers.clone()
1368        };
1369
1370        if !self.blacklisted_ciphers.is_empty() {
1371            ciphers.retain(|cipher| !self.blacklisted_ciphers.contains(cipher));
1372        }
1373
1374        ctx.set_enabled_ciphers(&ciphers)?;
1375        Ok(())
1376    }
1377}
1378
1379/// A builder type to simplify the creation of server-side `SslStream`s.
1380#[derive(Debug)]
1381pub struct ServerBuilder {
1382    identity: SecIdentity,
1383    certs: Vec<SecCertificate>,
1384}
1385
1386impl ServerBuilder {
1387    /// Creates a new `ServerBuilder` which will use the specified identity
1388    /// and certificate chain for handshakes.
1389    #[must_use]
1390    pub fn new(identity: &SecIdentity, certs: &[SecCertificate]) -> Self {
1391        Self {
1392            identity: identity.clone(),
1393            certs: certs.to_owned(),
1394        }
1395    }
1396
1397    /// Creates a new `ServerBuilder` which will use the identity
1398    /// from the given PKCS #12 data.
1399    ///
1400    /// This operation fails if PKCS #12 file contains zero or more than one identity.
1401    ///
1402    /// This is a shortcut for the most common operation.
1403    pub fn from_pkcs12(pkcs12_der: &[u8], passphrase: &str) -> Result<Self> {
1404        let mut identities: Vec<(SecIdentity, Vec<SecCertificate>)> = Pkcs12ImportOptions::new()
1405            .passphrase(passphrase)
1406            .import(pkcs12_der)?
1407            .into_iter()
1408            .filter_map(|idendity| {
1409                Some((idendity.identity?, idendity.cert_chain.unwrap_or_default()))
1410            })
1411            .take(2)
1412            .collect();
1413        if identities.len() == 1 {
1414            let (identity, certs) = identities.pop().unwrap();
1415            Ok(Self { identity, certs })
1416        } else {
1417            // This error code is not really helpful
1418            Err(Error::from_code(errSecParam))
1419        }
1420    }
1421
1422    /// Create a SSL context for lower-level stream initialization.
1423    pub fn new_ssl_context(&self) -> Result<SslContext> {
1424        let mut ctx = SslContext::new(SslProtocolSide::SERVER, SslConnectionType::STREAM)?;
1425        ctx.set_certificate(&self.identity, &self.certs)?;
1426        Ok(ctx)
1427    }
1428
1429    /// Initiates a new SSL/TLS session over a stream.
1430    pub fn handshake<S>(&self, stream: S) -> Result<SslStream<S>>
1431    where S: Read + Write {
1432        match self.new_ssl_context()?.handshake(stream) {
1433            Ok(stream) => Ok(stream),
1434            Err(HandshakeError::Interrupted(stream)) => Err(*stream.error()),
1435            Err(HandshakeError::Failure(err)) => Err(err),
1436        }
1437    }
1438}
1439
1440#[cfg(test)]
1441mod test {
1442    use std::io::prelude::*;
1443    use std::net::TcpStream;
1444
1445    use super::*;
1446
1447    #[test]
1448    fn server_builder_from_pkcs12() {
1449        let pkcs12_der = include_bytes!("../test/server.p12");
1450        ServerBuilder::from_pkcs12(pkcs12_der, "password123").unwrap();
1451    }
1452
1453    #[test]
1454    fn connect() {
1455        let mut ctx = p!(SslContext::new(SslProtocolSide::CLIENT, SslConnectionType::STREAM));
1456        p!(ctx.set_peer_domain_name("google.com"));
1457        let stream = p!(TcpStream::connect("google.com:443"));
1458        p!(ctx.handshake(stream));
1459    }
1460
1461    #[test]
1462    fn connect_bad_domain() {
1463        let mut ctx = p!(SslContext::new(SslProtocolSide::CLIENT, SslConnectionType::STREAM));
1464        p!(ctx.set_peer_domain_name("foobar.com"));
1465        let stream = p!(TcpStream::connect("google.com:443"));
1466        ctx.handshake(stream).expect_err("expected failure");
1467    }
1468
1469    #[test]
1470    fn connect_buffered_stream() {
1471        use std::io::BufWriter;
1472
1473        /// Small wrapper around a `TcpStream` to provide buffered writes.
1474        #[derive(Debug)]
1475        struct BufferedTcpStream {
1476            reader: TcpStream,
1477            writer: BufWriter<TcpStream>,
1478        }
1479
1480        impl BufferedTcpStream {
1481            fn new(tcp: TcpStream) -> std::io::Result<Self> {
1482                Ok(Self {
1483                    writer: BufWriter::with_capacity(500, tcp.try_clone()?),
1484                    reader: tcp,
1485                })
1486            }
1487        }
1488
1489        impl Read for BufferedTcpStream {
1490            fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
1491                self.reader.read(buf)
1492            }
1493        }
1494
1495        impl Write for BufferedTcpStream {
1496            fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
1497                self.writer.write(buf)
1498            }
1499
1500            fn flush(&mut self) -> std::io::Result<()> {
1501                self.writer.flush()
1502            }
1503        }
1504
1505        let mut ctx = p!(SslContext::new(SslProtocolSide::CLIENT, SslConnectionType::STREAM));
1506        p!(ctx.set_peer_domain_name("google.com"));
1507        let stream = p!(TcpStream::connect("google.com:443"));
1508        let stream = p!(BufferedTcpStream::new(stream));
1509        p!(ctx.handshake(stream));
1510    }
1511
1512    #[test]
1513    fn load_page() {
1514        let mut ctx = p!(SslContext::new(SslProtocolSide::CLIENT, SslConnectionType::STREAM));
1515        p!(ctx.set_peer_domain_name("google.com"));
1516        let stream = p!(TcpStream::connect("google.com:443"));
1517        let mut stream = p!(ctx.handshake(stream));
1518        p!(stream.write_all(b"GET / HTTP/1.0\r\n\r\n"));
1519        p!(stream.flush());
1520        let mut buf = vec![];
1521        p!(stream.read_to_end(&mut buf));
1522        println!("{}", String::from_utf8_lossy(&buf));
1523    }
1524
1525    #[test]
1526    fn client_no_session_ticket_resumption() {
1527        for _ in 0..2 {
1528            let stream = p!(TcpStream::connect("google.com:443"));
1529
1530            // Manually handshake here.
1531            let stream = MidHandshakeSslStream {
1532                stream: ClientBuilder::new()
1533                    .ctx_into_stream("google.com", stream)
1534                    .unwrap(),
1535                error: Error::from(errSecSuccess),
1536            };
1537
1538            let mut result = stream.handshake();
1539
1540            if let Err(HandshakeError::Interrupted(stream)) = result {
1541                assert!(stream.server_auth_completed());
1542                result = stream.handshake();
1543            } else {
1544                panic!("Unexpectedly skipped server auth");
1545            }
1546
1547            assert!(result.is_ok());
1548        }
1549    }
1550
1551    #[test]
1552    #[cfg(feature = "session-tickets")]
1553    fn client_session_ticket_resumption() {
1554        // The first time through this loop, we should do a full handshake. The second time, we
1555        // should immediately finish the handshake without breaking on server auth.
1556        for i in 0..2 {
1557            let stream = p!(TcpStream::connect("google.com:443"));
1558            let mut builder = ClientBuilder::new();
1559            builder.enable_session_tickets(true);
1560
1561            // Manually handshake here.
1562            let stream = MidHandshakeSslStream {
1563                stream: builder.ctx_into_stream("google.com", stream).unwrap(),
1564                error: Error::from(errSecSuccess),
1565            };
1566
1567            let mut result = stream.handshake();
1568
1569            if let Err(HandshakeError::Interrupted(stream)) = result {
1570                assert!(stream.server_auth_completed());
1571                assert_eq!(i, 0, "Session ticket resumption did not work, server auth was not skipped");
1572                result = stream.handshake();
1573            } else {
1574                assert_eq!(i, 1, "Unexpectedly skipped server auth");
1575            }
1576
1577            assert!(result.is_ok());
1578        }
1579    }
1580
1581    #[test]
1582    #[cfg(feature = "alpn")]
1583    fn client_alpn_accept() {
1584        let mut ctx = p!(SslContext::new(SslProtocolSide::CLIENT, SslConnectionType::STREAM));
1585        p!(ctx.set_peer_domain_name("google.com"));
1586        p!(ctx.set_alpn_protocols(&["h2"]));
1587        let stream = p!(TcpStream::connect("google.com:443"));
1588        let stream = ctx.handshake(stream).unwrap();
1589        assert_eq!(vec!["h2"], stream.context().alpn_protocols().unwrap());
1590    }
1591
1592    #[test]
1593    #[cfg(feature = "alpn")]
1594    fn client_alpn_reject() {
1595        let mut ctx = p!(SslContext::new(SslProtocolSide::CLIENT, SslConnectionType::STREAM));
1596        p!(ctx.set_peer_domain_name("google.com"));
1597        p!(ctx.set_alpn_protocols(&["h2c"]));
1598        let stream = p!(TcpStream::connect("google.com:443"));
1599        let stream = ctx.handshake(stream).unwrap();
1600        assert!(stream.context().alpn_protocols().is_err());
1601    }
1602
1603    #[test]
1604    fn client_no_anchor_certs() {
1605        let stream = p!(TcpStream::connect("google.com:443"));
1606        assert!(ClientBuilder::new()
1607            .trust_anchor_certificates_only(true)
1608            .handshake("google.com", stream)
1609            .is_err());
1610    }
1611
1612    #[test]
1613    fn client_bad_domain() {
1614        let stream = p!(TcpStream::connect("google.com:443"));
1615        assert!(ClientBuilder::new()
1616            .handshake("foobar.com", stream)
1617            .is_err());
1618    }
1619
1620    #[test]
1621    fn client_bad_domain_ignored() {
1622        let stream = p!(TcpStream::connect("google.com:443"));
1623        ClientBuilder::new()
1624            .danger_accept_invalid_hostnames(true)
1625            .handshake("foobar.com", stream)
1626            .unwrap();
1627    }
1628
1629    #[test]
1630    fn connect_no_verify_ssl() {
1631        let stream = p!(TcpStream::connect("expired.badssl.com:443"));
1632        let mut builder = ClientBuilder::new();
1633        builder.danger_accept_invalid_certs(true);
1634        builder.handshake("expired.badssl.com", stream).unwrap();
1635    }
1636
1637    #[test]
1638    fn load_page_client() {
1639        let stream = p!(TcpStream::connect("google.com:443"));
1640        let mut stream = p!(ClientBuilder::new().handshake("google.com", stream));
1641        p!(stream.write_all(b"GET / HTTP/1.0\r\n\r\n"));
1642        p!(stream.flush());
1643        let mut buf = vec![];
1644        p!(stream.read_to_end(&mut buf));
1645        println!("{}", String::from_utf8_lossy(&buf));
1646    }
1647
1648    #[test]
1649    #[cfg_attr(any(target_os = "ios", target_os = "tvos", target_os = "watchos", target_os = "visionos"), ignore)] // FIXME what's going on with ios?
1650    fn cipher_configuration() {
1651        let mut ctx = p!(SslContext::new(SslProtocolSide::SERVER, SslConnectionType::STREAM));
1652        let ciphers = p!(ctx.enabled_ciphers());
1653        let ciphers = ciphers
1654            .iter()
1655            .enumerate()
1656            .filter_map(|(i, c)| if i % 2 == 0 { Some(*c) } else { None })
1657            .collect::<Vec<_>>();
1658        p!(ctx.set_enabled_ciphers(&ciphers));
1659        assert_eq!(ciphers, p!(ctx.enabled_ciphers()));
1660    }
1661
1662    #[test]
1663    fn test_builder_whitelist_ciphers() {
1664        let stream = p!(TcpStream::connect("google.com:443"));
1665
1666        let ctx = p!(SslContext::new(SslProtocolSide::CLIENT, SslConnectionType::STREAM));
1667        assert!(p!(ctx.enabled_ciphers()).len() > 1);
1668
1669        let ciphers = p!(ctx.enabled_ciphers());
1670        let cipher = ciphers.first().unwrap();
1671        let stream = p!(ClientBuilder::new()
1672            .whitelist_ciphers(&[*cipher])
1673            .ctx_into_stream("google.com", stream));
1674
1675        assert_eq!(1, p!(stream.context().enabled_ciphers()).len());
1676    }
1677
1678    #[test]
1679    #[cfg_attr(any(target_os = "ios", target_os = "tvos", target_os = "watchos", target_os = "visionos"), ignore)] // FIXME same issue as cipher_configuration
1680    fn test_builder_blacklist_ciphers() {
1681        let stream = p!(TcpStream::connect("google.com:443"));
1682
1683        let ctx = p!(SslContext::new(SslProtocolSide::CLIENT, SslConnectionType::STREAM));
1684        let num = p!(ctx.enabled_ciphers()).len();
1685        assert!(num > 1);
1686
1687        let ciphers = p!(ctx.enabled_ciphers());
1688        let cipher = ciphers.first().unwrap();
1689        let stream = p!(ClientBuilder::new()
1690            .blacklist_ciphers(&[*cipher])
1691            .ctx_into_stream("google.com", stream));
1692
1693        assert_eq!(num - 1, p!(stream.context().enabled_ciphers()).len());
1694    }
1695
1696    #[test]
1697    fn idle_context_peer_trust() {
1698        let ctx = p!(SslContext::new(SslProtocolSide::SERVER, SslConnectionType::STREAM));
1699        assert!(ctx.peer_trust2().is_err());
1700    }
1701
1702    #[test]
1703    fn peer_id() {
1704        let mut ctx = p!(SslContext::new(SslProtocolSide::SERVER, SslConnectionType::STREAM));
1705        assert!(p!(ctx.peer_id()).is_none());
1706        p!(ctx.set_peer_id(b"foobar"));
1707        assert_eq!(p!(ctx.peer_id()), Some(&b"foobar"[..]));
1708    }
1709
1710    #[test]
1711    fn peer_domain_name() {
1712        let mut ctx = p!(SslContext::new(SslProtocolSide::CLIENT, SslConnectionType::STREAM));
1713        assert_eq!("", p!(ctx.peer_domain_name()));
1714        p!(ctx.set_peer_domain_name("foobar.com"));
1715        assert_eq!("foobar.com", p!(ctx.peer_domain_name()));
1716    }
1717
1718    #[test]
1719    #[should_panic(expected = "blammo")]
1720    fn write_panic() {
1721        struct ExplodingStream(TcpStream);
1722
1723        impl Read for ExplodingStream {
1724            fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
1725                self.0.read(buf)
1726            }
1727        }
1728
1729        impl Write for ExplodingStream {
1730            fn write(&mut self, _: &[u8]) -> io::Result<usize> {
1731                panic!("blammo");
1732            }
1733
1734            fn flush(&mut self) -> io::Result<()> {
1735                self.0.flush()
1736            }
1737        }
1738
1739        let mut ctx = p!(SslContext::new(SslProtocolSide::CLIENT, SslConnectionType::STREAM));
1740        p!(ctx.set_peer_domain_name("google.com"));
1741        let stream = p!(TcpStream::connect("google.com:443"));
1742        let _ = ctx.handshake(ExplodingStream(stream));
1743    }
1744
1745    #[test]
1746    #[should_panic(expected = "blammo")]
1747    fn read_panic() {
1748        struct ExplodingStream(TcpStream);
1749
1750        impl Read for ExplodingStream {
1751            fn read(&mut self, _: &mut [u8]) -> io::Result<usize> {
1752                panic!("blammo");
1753            }
1754        }
1755
1756        impl Write for ExplodingStream {
1757            fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
1758                self.0.write(buf)
1759            }
1760
1761            fn flush(&mut self) -> io::Result<()> {
1762                self.0.flush()
1763            }
1764        }
1765
1766        let mut ctx = p!(SslContext::new(SslProtocolSide::CLIENT, SslConnectionType::STREAM));
1767        p!(ctx.set_peer_domain_name("google.com"));
1768        let stream = p!(TcpStream::connect("google.com:443"));
1769        let _ = ctx.handshake(ExplodingStream(stream));
1770    }
1771
1772    #[test]
1773    fn zero_length_buffers() {
1774        let mut ctx = p!(SslContext::new(SslProtocolSide::CLIENT, SslConnectionType::STREAM));
1775        p!(ctx.set_peer_domain_name("google.com"));
1776        let stream = p!(TcpStream::connect("google.com:443"));
1777        let mut stream = ctx.handshake(stream).unwrap();
1778        assert_eq!(stream.write(b"").unwrap(), 0);
1779        assert_eq!(stream.read(&mut []).unwrap(), 0);
1780    }
1781}