security_framework/
secure_transport.rs

1//! SSL/TLS encryption support using Secure Transport.
2//!
3//! # Examples
4//!
5//! To connect as a client to a server with a certificate trusted by the system:
6//!
7//! ```rust
8//! use security_framework::secure_transport::ClientBuilder;
9//! use std::io::prelude::*;
10//! use std::net::TcpStream;
11//!
12//! let stream = TcpStream::connect("google.com:443").unwrap();
13//! let mut stream = ClientBuilder::new().handshake("google.com", stream).unwrap();
14//!
15//! stream.write_all(b"GET / HTTP/1.0\r\n\r\n").unwrap();
16//! let mut page = vec![];
17//! stream.read_to_end(&mut page).unwrap();
18//! println!("{}", String::from_utf8_lossy(&page));
19//! ```
20//!
21//! To connect to a server with a certificate that's *not* trusted by the
22//! system, specify the root certificates for the server's chain to the
23//! `ClientBuilder`:
24//!
25//! ```rust,no_run
26//! use security_framework::secure_transport::ClientBuilder;
27//! use std::io::prelude::*;
28//! use std::net::TcpStream;
29//!
30//! # let root_cert = unsafe { std::mem::zeroed() };
31//! let stream = TcpStream::connect("my_server.com:443").unwrap();
32//! let mut stream = ClientBuilder::new()
33//!     .anchor_certificates(&[root_cert])
34//!     .handshake("my_server.com", stream)
35//!     .unwrap();
36//!
37//! stream.write_all(b"GET / HTTP/1.0\r\n\r\n").unwrap();
38//! let mut page = vec![];
39//! stream.read_to_end(&mut page).unwrap();
40//! println!("{}", String::from_utf8_lossy(&page));
41//! ```
42//!
43//! For more advanced configuration, the `SslContext` type can be used directly.
44//!
45//! To run a server:
46//!
47//! ```rust,no_run
48//! use security_framework::secure_transport::{SslConnectionType, SslContext, SslProtocolSide};
49//! use std::net::TcpListener;
50//! use std::thread;
51//!
52//! // Create a TCP listener and start accepting on it.
53//! let mut listener = TcpListener::bind("0.0.0.0:443").unwrap();
54//!
55//! for stream in listener.incoming() {
56//!     let stream = stream.unwrap();
57//!     thread::spawn(move || {
58//!         // Create a new context configured to operate on the server side of
59//!         // a traditional SSL/TLS session.
60//!         let mut ctx = SslContext::new(SslProtocolSide::SERVER, SslConnectionType::STREAM)
61//!                           .unwrap();
62//!
63//!         // Install the certificate chain that we will be using.
64//!         # let identity = unsafe { std::mem::zeroed() };
65//!         # let intermediate_cert = unsafe { std::mem::zeroed() };
66//!         # let root_cert = unsafe { std::mem::zeroed() };
67//!         ctx.set_certificate(identity, &[intermediate_cert, root_cert]).unwrap();
68//!
69//!         // Perform the SSL/TLS handshake and get our stream.
70//!         let mut stream = ctx.handshake(stream).unwrap();
71//!     });
72//! }
73//! ```
74#![allow(clippy::result_large_err)]
75#[allow(unused_imports)]
76use core_foundation::array::{CFArray, CFArrayRef};
77use core_foundation::base::{Boolean, TCFType};
78#[cfg(feature = "alpn")]
79use core_foundation::string::CFString;
80use core_foundation::{declare_TCFType, impl_TCFType};
81use core_foundation_sys::base::{kCFAllocatorDefault, OSStatus};
82use std::os::raw::c_void;
83
84#[allow(unused_imports)]
85use security_framework_sys::base::{
86    errSecBadReq, errSecIO, errSecNotTrusted, errSecSuccess, errSecTrustSettingDeny,
87    errSecUnimplemented,
88};
89
90use security_framework_sys::secure_transport::*;
91use std::any::Any;
92use std::cmp;
93use std::fmt;
94use std::io;
95use std::io::prelude::*;
96use std::marker::PhantomData;
97use std::panic::{self, AssertUnwindSafe};
98use std::ptr;
99use std::result;
100use std::slice;
101
102use crate::base::{Error, Result};
103use crate::certificate::SecCertificate;
104use crate::cipher_suite::CipherSuite;
105use crate::cvt;
106use crate::identity::SecIdentity;
107use crate::import_export::Pkcs12ImportOptions;
108use crate::policy::SecPolicy;
109use crate::trust::SecTrust;
110use security_framework_sys::base::errSecParam;
111
112/// Specifies a side of a TLS session.
113#[derive(Debug, Copy, Clone, PartialEq, Eq)]
114pub struct SslProtocolSide(SSLProtocolSide);
115
116impl SslProtocolSide {
117    /// The client side of the session.
118    pub const CLIENT: Self = Self(kSSLClientSide);
119    /// The server side of the session.
120    pub const SERVER: Self = Self(kSSLServerSide);
121}
122
123/// Specifies the type of TLS session.
124#[derive(Debug, Copy, Clone)]
125pub struct SslConnectionType(SSLConnectionType);
126
127impl SslConnectionType {
128    /// A DTLS session.
129    pub const DATAGRAM: Self = Self(kSSLDatagramType);
130    /// A traditional TLS stream.
131    pub const STREAM: Self = Self(kSSLStreamType);
132}
133
134/// An error or intermediate state after a TLS handshake attempt.
135#[derive(Debug)]
136pub enum HandshakeError<S> {
137    /// The handshake failed.
138    Failure(Error),
139    /// The handshake was interrupted midway through.
140    Interrupted(MidHandshakeSslStream<S>),
141}
142
143impl<S> From<Error> for HandshakeError<S> {
144    #[inline(always)]
145    fn from(err: Error) -> Self {
146        Self::Failure(err)
147    }
148}
149
150/// An error or intermediate state after a TLS handshake attempt.
151#[derive(Debug)]
152pub enum ClientHandshakeError<S> {
153    /// The handshake failed.
154    Failure(Error),
155    /// The handshake was interrupted midway through.
156    Interrupted(MidHandshakeClientBuilder<S>),
157}
158
159impl<S> From<Error> for ClientHandshakeError<S> {
160    #[inline(always)]
161    fn from(err: Error) -> Self {
162        Self::Failure(err)
163    }
164}
165
166/// An SSL stream midway through the handshake process.
167#[derive(Debug)]
168pub struct MidHandshakeSslStream<S> {
169    stream: SslStream<S>,
170    error: Error,
171}
172
173impl<S> MidHandshakeSslStream<S> {
174    /// Returns a shared reference to the inner stream.
175    #[inline(always)]
176    #[must_use]
177    pub fn get_ref(&self) -> &S {
178        self.stream.get_ref()
179    }
180
181    /// Returns a mutable reference to the inner stream.
182    #[inline(always)]
183    pub fn get_mut(&mut self) -> &mut S {
184        self.stream.get_mut()
185    }
186
187    /// Returns a shared reference to the `SslContext` of the stream.
188    #[inline(always)]
189    #[must_use]
190    pub fn context(&self) -> &SslContext {
191        self.stream.context()
192    }
193
194    /// Returns a mutable reference to the `SslContext` of the stream.
195    #[inline(always)]
196    pub fn context_mut(&mut self) -> &mut SslContext {
197        self.stream.context_mut()
198    }
199
200    /// Returns `true` iff `break_on_server_auth` was set and the handshake has
201    /// progressed to that point.
202    #[inline(always)]
203    #[must_use]
204    pub fn server_auth_completed(&self) -> bool {
205        self.error.code() == errSSLPeerAuthCompleted
206    }
207
208    /// Returns `true` iff `break_on_cert_requested` was set and the handshake
209    /// has progressed to that point.
210    #[inline(always)]
211    #[must_use]
212    pub fn client_cert_requested(&self) -> bool {
213        self.error.code() == errSSLClientCertRequested
214    }
215
216    /// Returns `true` iff the underlying stream returned an error with the
217    /// `WouldBlock` kind.
218    #[inline(always)]
219    #[must_use]
220    pub fn would_block(&self) -> bool {
221        self.error.code() == errSSLWouldBlock
222    }
223
224    /// Returns the error which caused the handshake interruption.
225    #[inline(always)]
226    #[must_use]
227    pub const fn error(&self) -> &Error {
228        &self.error
229    }
230
231    /// Restarts the handshake process.
232    #[inline(always)]
233    pub fn handshake(self) -> result::Result<SslStream<S>, HandshakeError<S>> {
234        self.stream.handshake()
235    }
236}
237
238/// An SSL stream midway through the handshake process.
239#[derive(Debug)]
240pub struct MidHandshakeClientBuilder<S> {
241    stream: MidHandshakeSslStream<S>,
242    domain: Option<String>,
243    certs: Vec<SecCertificate>,
244    trust_certs_only: bool,
245    danger_accept_invalid_certs: bool,
246}
247
248impl<S> MidHandshakeClientBuilder<S> {
249    /// Returns a shared reference to the inner stream.
250    #[inline(always)]
251    #[must_use]
252    pub fn get_ref(&self) -> &S {
253        self.stream.get_ref()
254    }
255
256    /// Returns a mutable reference to the inner stream.
257    #[inline(always)]
258    pub fn get_mut(&mut self) -> &mut S {
259        self.stream.get_mut()
260    }
261
262    /// Returns the error which caused the handshake interruption.
263    #[inline(always)]
264    #[must_use]
265    pub fn error(&self) -> &Error {
266        self.stream.error()
267    }
268
269    /// Restarts the handshake process.
270    pub fn handshake(self) -> result::Result<SslStream<S>, ClientHandshakeError<S>> {
271        let Self {
272            stream,
273            domain,
274            certs,
275            trust_certs_only,
276            danger_accept_invalid_certs,
277        } = self;
278
279        let mut result = stream.handshake();
280        loop {
281            let stream = match result {
282                Ok(stream) => return Ok(stream),
283                Err(HandshakeError::Interrupted(stream)) => stream,
284                Err(HandshakeError::Failure(err)) => return Err(ClientHandshakeError::Failure(err)),
285            };
286
287            if stream.would_block() {
288                let ret = Self {
289                    stream,
290                    domain,
291                    certs,
292                    trust_certs_only,
293                    danger_accept_invalid_certs,
294                };
295                return Err(ClientHandshakeError::Interrupted(ret));
296            }
297
298            if stream.server_auth_completed() {
299                if danger_accept_invalid_certs {
300                    result = stream.handshake();
301                    continue;
302                }
303                let Some(mut trust) = stream.context().peer_trust2()? else {
304                    result = stream.handshake();
305                    continue;
306                };
307                trust.set_anchor_certificates(&certs)?;
308                trust.set_trust_anchor_certificates_only(self.trust_certs_only)?;
309                let policy = SecPolicy::create_ssl(SslProtocolSide::SERVER, domain.as_deref());
310                trust.set_policy(&policy)?;
311                trust.evaluate_with_error().map_err(|error| {
312                    #[cfg(feature = "log")]
313                    log::warn!("SecTrustEvaluateWithError: {error}");
314                    Error::from_code(error.code() as _)
315                })?;
316                result = stream.handshake();
317                continue;
318            }
319
320            let err = Error::from_code(stream.error().code());
321            return Err(ClientHandshakeError::Failure(err));
322        }
323    }
324}
325
326/// Specifies the state of a TLS session.
327#[derive(Debug, PartialEq, Eq)]
328pub struct SessionState(SSLSessionState);
329
330impl SessionState {
331    /// The session has been aborted due to an error.
332    pub const ABORTED: Self = Self(kSSLAborted);
333    /// The session has been terminated.
334    pub const CLOSED: Self = Self(kSSLClosed);
335    /// The session is connected.
336    pub const CONNECTED: Self = Self(kSSLConnected);
337    /// The session is in the handshake process.
338    pub const HANDSHAKE: Self = Self(kSSLHandshake);
339    /// The session has not yet started.
340    pub const IDLE: Self = Self(kSSLIdle);
341}
342
343/// Specifies a server's requirement for client certificates.
344#[derive(Debug, Copy, Clone, PartialEq, Eq)]
345pub struct SslAuthenticate(SSLAuthenticate);
346
347impl SslAuthenticate {
348    /// Require a client certificate.
349    pub const ALWAYS: Self = Self(kAlwaysAuthenticate);
350    /// Do not request a client certificate.
351    pub const NEVER: Self = Self(kNeverAuthenticate);
352    /// Request but do not require a client certificate.
353    pub const TRY: Self = Self(kTryAuthenticate);
354}
355
356/// Specifies the state of client certificate processing.
357#[derive(Debug, Copy, Clone, PartialEq, Eq)]
358pub struct SslClientCertificateState(SSLClientCertificateState);
359
360impl SslClientCertificateState {
361    /// A client certificate has not been requested or sent.
362    pub const NONE: Self = Self(kSSLClientCertNone);
363    /// A client certificate has been received but has failed to validate.
364    pub const REJECTED: Self = Self(kSSLClientCertRejected);
365    /// A client certificate has been requested but not recieved.
366    pub const REQUESTED: Self = Self(kSSLClientCertRequested);
367    /// A client certificate has been received and successfully validated.
368    pub const SENT: Self = Self(kSSLClientCertSent);
369}
370
371/// Specifies protocol versions.
372#[derive(Debug, Copy, Clone, PartialEq, Eq)]
373pub struct SslProtocol(SSLProtocol);
374
375impl SslProtocol {
376    /// All supported TLS/SSL versions are accepted.
377    pub const ALL: Self = Self(kSSLProtocolAll);
378    /// The `DTLSv1` protocol is preferred.
379    pub const DTLS1: Self = Self(kDTLSProtocol1);
380    /// Only the SSL 2.0 protocol is accepted.
381    pub const SSL2: Self = Self(kSSLProtocol2);
382    /// The SSL 3.0 protocol is preferred, though SSL 2.0 may be used if the peer does not support
383    /// SSL 3.0.
384    pub const SSL3: Self = Self(kSSLProtocol3);
385    /// Only the SSL 3.0 protocol is accepted.
386    pub const SSL3_ONLY: Self = Self(kSSLProtocol3Only);
387    /// The TLS 1.0 protocol is preferred, though lower versions may be used
388    /// if the peer does not support TLS 1.0.
389    pub const TLS1: Self = Self(kTLSProtocol1);
390    /// The TLS 1.1 protocol is preferred, though lower versions may be used
391    /// if the peer does not support TLS 1.1.
392    pub const TLS11: Self = Self(kTLSProtocol11);
393    /// The TLS 1.2 protocol is preferred, though lower versions may be used
394    /// if the peer does not support TLS 1.2.
395    pub const TLS12: Self = Self(kTLSProtocol12);
396    /// The TLS 1.3 protocol is preferred, though lower versions may be used
397    /// if the peer does not support TLS 1.3.
398    pub const TLS13: Self = Self(kTLSProtocol13);
399    /// Only the TLS 1.0 protocol is accepted.
400    pub const TLS1_ONLY: Self = Self(kTLSProtocol1Only);
401    /// No protocol has been or should be negotiated or specified; use the default.
402    pub const UNKNOWN: Self = Self(kSSLProtocolUnknown);
403}
404
405declare_TCFType! {
406    /// A Secure Transport SSL/TLS context object.
407    SslContext, SSLContextRef
408}
409
410impl_TCFType!(SslContext, SSLContextRef, SSLContextGetTypeID);
411
412impl fmt::Debug for SslContext {
413    #[cold]
414    fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
415        let mut builder = fmt.debug_struct("SslContext");
416        if let Ok(state) = self.state() {
417            builder.field("state", &state);
418        }
419        builder.finish()
420    }
421}
422
423unsafe impl Sync for SslContext {}
424unsafe impl Send for SslContext {}
425
426impl SslContext {
427    pub(crate) fn as_inner(&self) -> SSLContextRef {
428        self.0
429    }
430}
431
432macro_rules! impl_options {
433    ($($(#[$a:meta])* const $opt:ident: $get:ident & $set:ident,)*) => {
434        $(
435            #[allow(deprecated)]
436            $(#[$a])*
437            #[inline(always)]
438            pub fn $set(&mut self, value: bool) -> Result<()> {
439                unsafe { cvt(SSLSetSessionOption(self.0, $opt, Boolean::from(value))) }
440            }
441
442            #[allow(deprecated)]
443            $(#[$a])*
444            #[inline]
445            pub fn $get(&self) -> Result<bool> {
446                let mut value = 0;
447                unsafe { cvt(SSLGetSessionOption(self.0, $opt, &mut value))?; }
448                Ok(value != 0)
449            }
450        )*
451    }
452}
453
454impl SslContext {
455    /// Creates a new `SslContext` for the specified side and type of SSL
456    /// connection.
457    #[inline]
458    pub fn new(side: SslProtocolSide, type_: SslConnectionType) -> Result<Self> {
459        unsafe {
460            let ctx = SSLCreateContext(kCFAllocatorDefault, side.0, type_.0);
461            Ok(Self(ctx))
462        }
463    }
464
465    /// Sets the fully qualified domain name of the peer.
466    ///
467    /// This will be used on the client side of a session to validate the
468    /// common name field of the server's certificate. It has no effect if
469    /// called on a server-side `SslContext`.
470    ///
471    /// It is *highly* recommended to call this method before starting the
472    /// handshake process.
473    #[inline]
474    pub fn set_peer_domain_name(&mut self, peer_name: &str) -> Result<()> {
475        unsafe {
476            // SSLSetPeerDomainName doesn't need a null terminated string
477            cvt(SSLSetPeerDomainName(self.0, peer_name.as_ptr().cast(), peer_name.len()))
478        }
479    }
480
481    /// Returns the peer domain name set by `set_peer_domain_name`.
482    pub fn peer_domain_name(&self) -> Result<String> {
483        unsafe {
484            let mut len = 0;
485            cvt(SSLGetPeerDomainNameLength(self.0, &mut len))?;
486            let mut buf = vec![0; len];
487            cvt(SSLGetPeerDomainName(self.0, buf.as_mut_ptr().cast(), &mut len))?;
488            String::from_utf8(buf).map_err(|_| Error::from_code(-1))
489        }
490    }
491
492    /// Sets the certificate to be used by this side of the SSL session.
493    ///
494    /// This must be called before the handshake for server-side connections,
495    /// and can be used on the client-side to specify a client certificate.
496    ///
497    /// The `identity` corresponds to the leaf certificate and private
498    /// key, and the `certs` correspond to extra certificates in the chain.
499    pub fn set_certificate(
500        &mut self,
501        identity: &SecIdentity,
502        certs: &[SecCertificate],
503    ) -> Result<()> {
504        let mut arr = vec![identity.as_CFType()];
505        arr.extend(certs.iter().map(|c| c.as_CFType()));
506        let certs = CFArray::from_CFTypes(&arr);
507
508        unsafe { cvt(SSLSetCertificate(self.0, certs.as_concrete_TypeRef())) }
509    }
510
511    /// Sets the peer ID of this session.
512    ///
513    /// A peer ID is an opaque sequence of bytes that will be used by Secure
514    /// Transport to identify the peer of an SSL session. If the peer ID of
515    /// this session matches that of a previously terminated session, the
516    /// previous session can be resumed without requiring a full handshake.
517    #[inline]
518    pub fn set_peer_id(&mut self, peer_id: &[u8]) -> Result<()> {
519        unsafe { cvt(SSLSetPeerID(self.0, peer_id.as_ptr().cast(), peer_id.len())) }
520    }
521
522    /// Returns the peer ID of this session.
523    pub fn peer_id(&self) -> Result<Option<&[u8]>> {
524        unsafe {
525            let mut ptr = ptr::null();
526            let mut len = 0;
527            cvt(SSLGetPeerID(self.0, &mut ptr, &mut len))?;
528            if ptr.is_null() {
529                Ok(None)
530            } else {
531                Ok(Some(slice::from_raw_parts(ptr.cast(), len)))
532            }
533        }
534    }
535
536    /// Returns the list of ciphers that are supported by Secure Transport.
537    pub fn supported_ciphers(&self) -> Result<Vec<CipherSuite>> {
538        unsafe {
539            let mut num_ciphers = 0;
540            cvt(SSLGetNumberSupportedCiphers(self.0, &mut num_ciphers))?;
541            let mut ciphers = vec![0; num_ciphers];
542            cvt(SSLGetSupportedCiphers(
543                self.0,
544                ciphers.as_mut_ptr(),
545                &mut num_ciphers,
546            ))?;
547            Ok(ciphers.iter().map(|c| CipherSuite::from_raw(*c)).collect())
548        }
549    }
550
551    /// Returns the list of ciphers that are eligible to be used for
552    /// negotiation.
553    pub fn enabled_ciphers(&self) -> Result<Vec<CipherSuite>> {
554        unsafe {
555            let mut num_ciphers = 0;
556            cvt(SSLGetNumberEnabledCiphers(self.0, &mut num_ciphers))?;
557            let mut ciphers = vec![0; num_ciphers];
558            cvt(SSLGetEnabledCiphers(
559                self.0,
560                ciphers.as_mut_ptr(),
561                &mut num_ciphers,
562            ))?;
563            Ok(ciphers.iter().map(|c| CipherSuite::from_raw(*c)).collect())
564        }
565    }
566
567    /// Sets the list of ciphers that are eligible to be used for negotiation.
568    pub fn set_enabled_ciphers(&mut self, ciphers: &[CipherSuite]) -> Result<()> {
569        let ciphers = ciphers.iter().map(|c| c.to_raw()).collect::<Vec<_>>();
570        unsafe {
571            cvt(SSLSetEnabledCiphers(
572                self.0,
573                ciphers.as_ptr(),
574                ciphers.len(),
575            ))
576        }
577    }
578
579    /// Returns the cipher being used by the session.
580    #[inline]
581    pub fn negotiated_cipher(&self) -> Result<CipherSuite> {
582        unsafe {
583            let mut cipher = 0;
584            cvt(SSLGetNegotiatedCipher(self.0, &mut cipher))?;
585            Ok(CipherSuite::from_raw(cipher))
586        }
587    }
588
589    /// Sets the requirements for client certificates.
590    ///
591    /// Should only be called on server-side sessions.
592    #[inline]
593    pub fn set_client_side_authenticate(&mut self, auth: SslAuthenticate) -> Result<()> {
594        unsafe { cvt(SSLSetClientSideAuthenticate(self.0, auth.0)) }
595    }
596
597    /// Returns the state of client certificate processing.
598    #[inline]
599    pub fn client_certificate_state(&self) -> Result<SslClientCertificateState> {
600        let mut state = 0;
601
602        unsafe {
603            cvt(SSLGetClientCertificateState(self.0, &mut state))?;
604        }
605        Ok(SslClientCertificateState(state))
606    }
607
608    /// Returns the `SecTrust` object corresponding to the peer.
609    ///
610    /// This can be used in conjunction with `set_break_on_server_auth` to
611    /// validate certificates which do not have roots in the default set.
612    pub fn peer_trust2(&self) -> Result<Option<SecTrust>> {
613        // Calling SSLCopyPeerTrust on an idle connection does not seem to be well defined,
614        // so explicitly check for that
615        if self.state()? == SessionState::IDLE {
616            return Err(Error::from_code(errSecBadReq));
617        }
618
619        unsafe {
620            let mut trust = ptr::null_mut();
621            cvt(SSLCopyPeerTrust(self.0, &mut trust))?;
622            if trust.is_null() {
623                Ok(None)
624            } else {
625                Ok(Some(SecTrust::wrap_under_create_rule(trust)))
626            }
627        }
628    }
629
630    /// Returns the state of the session.
631    #[inline]
632    pub fn state(&self) -> Result<SessionState> {
633        unsafe {
634            let mut state = 0;
635            cvt(SSLGetSessionState(self.0, &mut state))?;
636            Ok(SessionState(state))
637        }
638    }
639
640    /// Returns the protocol version being used by the session.
641    #[inline]
642    pub fn negotiated_protocol_version(&self) -> Result<SslProtocol> {
643        unsafe {
644            let mut version = 0;
645            cvt(SSLGetNegotiatedProtocolVersion(self.0, &mut version))?;
646            Ok(SslProtocol(version))
647        }
648    }
649
650    /// Returns the maximum protocol version allowed by the session.
651    #[inline]
652    pub fn protocol_version_max(&self) -> Result<SslProtocol> {
653        unsafe {
654            let mut version = 0;
655            cvt(SSLGetProtocolVersionMax(self.0, &mut version))?;
656            Ok(SslProtocol(version))
657        }
658    }
659
660    /// Sets the maximum protocol version allowed by the session.
661    #[inline]
662    pub fn set_protocol_version_max(&mut self, max_version: SslProtocol) -> Result<()> {
663        unsafe { cvt(SSLSetProtocolVersionMax(self.0, max_version.0)) }
664    }
665
666    /// Returns the minimum protocol version allowed by the session.
667    #[inline]
668    pub fn protocol_version_min(&self) -> Result<SslProtocol> {
669        unsafe {
670            let mut version = 0;
671            cvt(SSLGetProtocolVersionMin(self.0, &mut version))?;
672            Ok(SslProtocol(version))
673        }
674    }
675
676    /// Sets the minimum protocol version allowed by the session.
677    #[inline]
678    pub fn set_protocol_version_min(&mut self, min_version: SslProtocol) -> Result<()> {
679        unsafe { cvt(SSLSetProtocolVersionMin(self.0, min_version.0)) }
680    }
681
682    /// Returns the set of protocols selected via ALPN if it succeeded.
683    #[cfg(feature = "alpn")]
684    pub fn alpn_protocols(&self) -> Result<Vec<String>> {
685        let mut array: CFArrayRef = ptr::null();
686        unsafe {
687            #[cfg(feature = "OSX_10_13")]
688            {
689                cvt(SSLCopyALPNProtocols(self.0, &mut array))?;
690            }
691
692            #[cfg(not(feature = "OSX_10_13"))]
693            {
694                dlsym! { fn SSLCopyALPNProtocols(SSLContextRef, *mut CFArrayRef) -> OSStatus }
695                if let Some(f) = SSLCopyALPNProtocols.get() {
696                    cvt(f(self.0, &mut array))?;
697                } else {
698                    return Err(Error::from_code(errSecUnimplemented));
699                }
700            }
701
702            if array.is_null() {
703                return Ok(vec![]);
704            }
705
706            let array = CFArray::<CFString>::wrap_under_create_rule(array);
707            Ok(array.into_iter().map(|p| p.to_string()).collect())
708        }
709    }
710
711    /// Configures the set of protocols use for ALPN.
712    ///
713    /// This is only used for client-side connections.
714    #[cfg(feature = "alpn")]
715    pub fn set_alpn_protocols(&mut self, protocols: &[&str]) -> Result<()> {
716        // When CFMutableArray is added to core-foundation and IntoIterator trait
717        // is implemented for CFMutableArray, the code below should directly collect
718        // into a CFMutableArray.
719        let protocols = CFArray::from_CFTypes(
720            &protocols
721                .iter()
722                .map(|proto| CFString::new(proto))
723                .collect::<Vec<_>>(),
724        );
725
726        #[cfg(feature = "OSX_10_13")]
727        {
728            unsafe { cvt(SSLSetALPNProtocols(self.0, protocols.as_concrete_TypeRef())) }
729        }
730        #[cfg(not(feature = "OSX_10_13"))]
731        {
732            dlsym! { fn SSLSetALPNProtocols(SSLContextRef, CFArrayRef) -> OSStatus }
733            if let Some(f) = SSLSetALPNProtocols.get() {
734                unsafe { cvt(f(self.0, protocols.as_concrete_TypeRef())) }
735            } else {
736                Err(Error::from_code(errSecUnimplemented))
737            }
738        }
739    }
740
741    /// Sets whether the client sends the `SessionTicket` extension in its `ClientHello`.
742    ///
743    /// On its own, this will just cause the client to send an empty `SessionTicket` extension on
744    /// every connection. [`SslContext::set_peer_id`] must also be used to key the session
745    /// ticket returned by the server.
746    ///
747    /// [`SslContext::set_peer_id`]: #method.set_peer_id
748    #[cfg(feature = "session-tickets")]
749    pub fn set_session_tickets_enabled(&mut self, enabled: bool) -> Result<()> {
750        #[cfg(feature = "OSX_10_13")]
751        {
752            unsafe { cvt(SSLSetSessionTicketsEnabled(self.0, Boolean::from(enabled))) }
753        }
754        #[cfg(not(feature = "OSX_10_13"))]
755        {
756            dlsym! { fn SSLSetSessionTicketsEnabled(SSLContextRef, Boolean) -> OSStatus }
757            if let Some(f) = SSLSetSessionTicketsEnabled.get() {
758                unsafe { cvt(f(self.0, Boolean::from(enabled))) }
759            } else {
760                Err(Error::from_code(errSecUnimplemented))
761            }
762        }
763    }
764
765    /// Returns the number of bytes which can be read without triggering a
766    /// `read` call in the underlying stream.
767    #[inline]
768    pub fn buffered_read_size(&self) -> Result<usize> {
769        unsafe {
770            let mut size = 0;
771            cvt(SSLGetBufferedReadSize(self.0, &mut size))?;
772            Ok(size)
773        }
774    }
775
776    impl_options! {
777        /// If enabled, the handshake process will pause and return instead of
778        /// automatically validating a server's certificate.
779        const kSSLSessionOptionBreakOnServerAuth: break_on_server_auth & set_break_on_server_auth,
780        /// If enabled, the handshake process will pause and return after
781        /// the server requests a certificate from the client.
782        const kSSLSessionOptionBreakOnCertRequested: break_on_cert_requested & set_break_on_cert_requested,
783        /// If enabled, the handshake process will pause and return instead of
784        /// automatically validating a client's certificate.
785        const kSSLSessionOptionBreakOnClientAuth: break_on_client_auth & set_break_on_client_auth,
786        /// If enabled, TLS false start will be performed if an appropriate
787        /// cipher suite is negotiated.
788        ///
789        const kSSLSessionOptionFalseStart: false_start & set_false_start,
790        /// If enabled, 1/n-1 record splitting will be enabled for TLS 1.0
791        /// connections using block ciphers to mitigate the BEAST attack.
792        ///
793        const kSSLSessionOptionSendOneByteRecord: send_one_byte_record & set_send_one_byte_record,
794    }
795
796    fn into_stream<S>(self, stream: S) -> Result<SslStream<S>>
797    where S: Read + Write {
798        unsafe {
799            let ret = SSLSetIOFuncs(self.0, read_func::<S>, write_func::<S>);
800            if ret != errSecSuccess {
801                return Err(Error::from_code(ret));
802            }
803
804            let stream = Connection { stream, err: None, panic: None };
805            let stream = Box::into_raw(Box::new(stream));
806            let ret = SSLSetConnection(self.0, stream.cast());
807            if ret != errSecSuccess {
808                let _conn = Box::from_raw(stream);
809                return Err(Error::from_code(ret));
810            }
811
812            Ok(SslStream { ctx: self, _m: PhantomData })
813        }
814    }
815
816    /// Performs the SSL/TLS handshake.
817    pub fn handshake<S>(self, stream: S) -> result::Result<SslStream<S>, HandshakeError<S>>
818    where
819        S: Read + Write,
820    {
821        self.into_stream(stream)
822            .map_err(HandshakeError::Failure)
823            .and_then(SslStream::handshake)
824    }
825}
826
827struct Connection<S> {
828    stream: S,
829    err: Option<io::Error>,
830    panic: Option<Box<dyn Any + Send>>,
831}
832
833// the logic here is based off of libcurl's
834#[cold]
835fn translate_err(e: &io::Error) -> OSStatus {
836    match e.kind() {
837        io::ErrorKind::NotFound => errSSLClosedGraceful,
838        io::ErrorKind::ConnectionReset => errSSLClosedAbort,
839        io::ErrorKind::WouldBlock |
840        io::ErrorKind::NotConnected => errSSLWouldBlock,
841        _ => errSecIO,
842    }
843}
844
845unsafe extern "C" fn read_func<S>(
846    connection: SSLConnectionRef,
847    data: *mut c_void,
848    data_length: *mut usize,
849) -> OSStatus
850where S: Read {
851    let conn: &mut Connection<S> = &mut *(connection as *mut _);
852    let mut read = 0;
853
854    let ret = panic::catch_unwind(AssertUnwindSafe(|| {
855        let mut data = slice::from_raw_parts_mut(data.cast::<u8>(), *data_length);
856        while !data.is_empty() {
857            match conn.stream.read(data) {
858                Ok(0) => return errSSLClosedNoNotify,
859                Ok(len) => {
860                    let Some(rest) = data.get_mut(len..) else {
861                        return errSecIO;
862                    };
863                    data = rest;
864                    read += len;
865                },
866                Err(e) => {
867                    let ret = translate_err(&e);
868                    conn.err = Some(e);
869                    return ret;
870                },
871            }
872        }
873        errSecSuccess
874    }))
875    .unwrap_or_else(|e| {
876        conn.panic = Some(e);
877        errSecIO
878    });
879
880    *data_length = read;
881    ret
882}
883
884unsafe extern "C" fn write_func<S>(
885    connection: SSLConnectionRef,
886    data: *const c_void,
887    data_length: *mut usize,
888) -> OSStatus
889where S: Write {
890    let conn: &mut Connection<S> = &mut *(connection as *mut _);
891    let mut written = 0;
892
893    let ret = panic::catch_unwind(AssertUnwindSafe(|| {
894        let mut data = slice::from_raw_parts(data.cast::<u8>(), *data_length);
895        while !data.is_empty() {
896            match conn.stream.write(data) {
897                Ok(0) => return errSSLClosedNoNotify,
898                Ok(len) => {
899                    let Some(rest) = data.get(len..) else {
900                        return errSecIO;
901                    };
902                    data = rest;
903                    written += len;
904                },
905                Err(e) => {
906                    let ret = translate_err(&e);
907                    conn.err = Some(e);
908                    return ret;
909                },
910            }
911        }
912        // Need to flush during the handshake so that the handshake doesn't stall on buffered
913        // write streams. It would be better if we only flushed automatically during the
914        // handshake, and not for the remainder of the stream.
915        if let Err(e) = conn.stream.flush() {
916            let ret = translate_err(&e);
917            conn.err = Some(e);
918            return ret;
919        }
920        errSecSuccess
921    }))
922    .unwrap_or_else(|e| {
923        conn.panic = Some(e);
924        errSecIO
925    });
926
927    *data_length = written;
928    ret
929}
930
931/// A type implementing SSL/TLS encryption over an underlying stream.
932pub struct SslStream<S> {
933    ctx: SslContext,
934    _m: PhantomData<S>,
935}
936
937impl<S: fmt::Debug> fmt::Debug for SslStream<S> {
938    #[cold]
939    fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
940        fmt.debug_struct("SslStream")
941            .field("context", &self.ctx)
942            .field("stream", self.get_ref())
943            .finish()
944    }
945}
946
947impl<S> Drop for SslStream<S> {
948    fn drop(&mut self) {
949        unsafe {
950            let mut conn = ptr::null();
951            let ret = SSLGetConnection(self.ctx.0, &mut conn);
952            assert!(ret == errSecSuccess);
953            let _ = Box::<Connection<S>>::from_raw(conn as *mut _);
954        }
955    }
956}
957
958impl<S> SslStream<S> {
959    fn handshake(mut self) -> result::Result<Self, HandshakeError<S>> {
960        match unsafe { SSLHandshake(self.ctx.0) } {
961            errSecSuccess => Ok(self),
962            reason @ errSSLPeerAuthCompleted
963            | reason @ errSSLClientCertRequested
964            | reason @ errSSLWouldBlock
965            | reason @ errSSLClientHelloReceived => {
966                Err(HandshakeError::Interrupted(MidHandshakeSslStream {
967                    stream: self,
968                    error: Error::from_code(reason),
969                }))
970            },
971            err => {
972                self.check_panic();
973                Err(HandshakeError::Failure(Error::from_code(err)))
974            },
975        }
976    }
977
978    /// Returns a shared reference to the inner stream.
979    #[inline(always)]
980    #[must_use]
981    pub fn get_ref(&self) -> &S {
982        &self.connection().stream
983    }
984
985    /// Returns a mutable reference to the underlying stream.
986    #[inline(always)]
987    pub fn get_mut(&mut self) -> &mut S {
988        &mut self.connection_mut().stream
989    }
990
991    /// Returns a shared reference to the `SslContext` of the stream.
992    #[inline(always)]
993    #[must_use]
994    pub fn context(&self) -> &SslContext {
995        &self.ctx
996    }
997
998    /// Returns a mutable reference to the `SslContext` of the stream.
999    #[inline(always)]
1000    pub fn context_mut(&mut self) -> &mut SslContext {
1001        &mut self.ctx
1002    }
1003
1004    /// Shuts down the connection.
1005    pub fn close(&mut self) -> result::Result<(), io::Error> {
1006        unsafe {
1007            let ret = SSLClose(self.ctx.0);
1008            if ret == errSecSuccess {
1009                Ok(())
1010            } else {
1011                Err(self.get_error(ret))
1012            }
1013        }
1014    }
1015
1016    fn connection(&self) -> &Connection<S> {
1017        unsafe {
1018            let mut conn = ptr::null();
1019            let ret = SSLGetConnection(self.ctx.0, &mut conn);
1020            assert!(ret == errSecSuccess);
1021
1022            &mut *(conn as *mut Connection<S>)
1023        }
1024    }
1025
1026    fn connection_mut(&mut self) -> &mut Connection<S> {
1027        unsafe {
1028            let mut conn = ptr::null();
1029            let ret = SSLGetConnection(self.ctx.0, &mut conn);
1030            assert!(ret == errSecSuccess);
1031
1032            &mut *(conn as *mut Connection<S>)
1033        }
1034    }
1035
1036    #[cold]
1037    fn check_panic(&mut self) {
1038        let conn = self.connection_mut();
1039        if let Some(err) = conn.panic.take() {
1040            panic::resume_unwind(err);
1041        }
1042    }
1043
1044    #[cold]
1045    fn get_error(&mut self, ret: OSStatus) -> io::Error {
1046        self.check_panic();
1047
1048        if let Some(err) = self.connection_mut().err.take() {
1049            err
1050        } else {
1051            io::Error::new(io::ErrorKind::Other, Error::from_code(ret))
1052        }
1053    }
1054}
1055
1056impl<S: Read + Write> Read for SslStream<S> {
1057    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
1058        // Below we base our return value off the amount of data read, so a
1059        // zero-length buffer might cause us to erroneously interpret this
1060        // request as an error. Instead short-circuit that logic and return
1061        // `Ok(0)` instead.
1062        if buf.is_empty() {
1063            return Ok(0);
1064        }
1065
1066        // If some data was buffered but not enough to fill `buf`, SSLRead
1067        // will try to read a new packet. This is bad because there may be
1068        // no more data but the socket is remaining open (e.g HTTPS with
1069        // Connection: keep-alive).
1070        let buffered = self.context().buffered_read_size().unwrap_or(0);
1071        let to_read = if buffered > 0 {
1072            cmp::min(buffered, buf.len())
1073        } else {
1074            buf.len()
1075        };
1076
1077        unsafe {
1078            let mut nread = 0;
1079            let ret = SSLRead(self.ctx.0, buf.as_mut_ptr().cast(), to_read, &mut nread);
1080            // SSLRead can return an error at the same time it returns the last
1081            // chunk of data (!)
1082            if nread > 0 {
1083                return Ok(nread);
1084            }
1085
1086            match ret {
1087                errSSLClosedGraceful | errSSLClosedAbort | errSSLClosedNoNotify => Ok(0),
1088                // this error isn't fatal
1089                errSSLPeerAuthCompleted => self.read(buf),
1090                _ => Err(self.get_error(ret)),
1091            }
1092        }
1093    }
1094}
1095
1096impl<S: Read + Write> Write for SslStream<S> {
1097    fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
1098        // Like above in read, short circuit a 0-length write
1099        if buf.is_empty() {
1100            return Ok(0);
1101        }
1102        unsafe {
1103            let mut nwritten = 0;
1104            let ret = SSLWrite(
1105                self.ctx.0,
1106                buf.as_ptr().cast(),
1107                buf.len(),
1108                &mut nwritten,
1109            );
1110            // just to be safe, base success off of nwritten rather than ret
1111            // for the same reason as in read
1112            if nwritten > 0 {
1113                Ok(nwritten)
1114            } else {
1115                Err(self.get_error(ret))
1116            }
1117        }
1118    }
1119
1120    fn flush(&mut self) -> io::Result<()> {
1121        self.connection_mut().stream.flush()
1122    }
1123}
1124
1125/// A builder type to simplify the creation of client side `SslStream`s.
1126#[derive(Debug)]
1127pub struct ClientBuilder {
1128    identity: Option<SecIdentity>,
1129    certs: Vec<SecCertificate>,
1130    chain: Vec<SecCertificate>,
1131    protocol_min: Option<SslProtocol>,
1132    protocol_max: Option<SslProtocol>,
1133    trust_certs_only: bool,
1134    use_sni: bool,
1135    danger_accept_invalid_certs: bool,
1136    danger_accept_invalid_hostnames: bool,
1137    whitelisted_ciphers: Vec<CipherSuite>,
1138    blacklisted_ciphers: Vec<CipherSuite>,
1139    #[cfg(feature = "alpn")]
1140    alpn: Option<Vec<String>>,
1141    #[cfg(feature = "session-tickets")]
1142    enable_session_tickets: bool,
1143}
1144
1145impl Default for ClientBuilder {
1146    #[inline(always)]
1147    fn default() -> Self {
1148        Self::new()
1149    }
1150}
1151
1152impl ClientBuilder {
1153    /// Creates a new builder with default options.
1154    #[inline]
1155    #[must_use]
1156    pub fn new() -> Self {
1157        Self {
1158            identity: None,
1159            certs: Vec::new(),
1160            chain: Vec::new(),
1161            protocol_min: None,
1162            protocol_max: None,
1163            trust_certs_only: false,
1164            use_sni: true,
1165            danger_accept_invalid_certs: false,
1166            danger_accept_invalid_hostnames: false,
1167            whitelisted_ciphers: Vec::new(),
1168            blacklisted_ciphers: Vec::new(),
1169            #[cfg(feature = "alpn")]
1170            alpn: None,
1171            #[cfg(feature = "session-tickets")]
1172            enable_session_tickets: false,
1173        }
1174    }
1175
1176    /// Specifies the set of root certificates to trust when
1177    /// verifying the server's certificate.
1178    #[inline]
1179    pub fn anchor_certificates(&mut self, certs: &[SecCertificate]) -> &mut Self {
1180        certs.clone_into(&mut self.certs);
1181        self
1182    }
1183
1184    /// Add the certificate the set of root certificates to trust
1185    /// when verifying the server's certificate.
1186    #[inline]
1187    pub fn add_anchor_certificate(&mut self, certs: &SecCertificate) -> &mut Self {
1188        self.certs.push(certs.to_owned());
1189        self
1190    }
1191
1192    /// Specifies whether to trust the built-in certificates in addition
1193    /// to specified anchor certificates.
1194    #[inline(always)]
1195    pub fn trust_anchor_certificates_only(&mut self, only: bool) -> &mut Self {
1196        self.trust_certs_only = only;
1197        self
1198    }
1199
1200    /// Specifies whether to trust invalid certificates.
1201    ///
1202    /// # Warning
1203    ///
1204    /// You should think very carefully before using this method. If invalid
1205    /// certificates are trusted, *any* certificate for *any* site will be
1206    /// trusted for use. This includes expired certificates. This introduces
1207    /// significant vulnerabilities, and should only be used as a last resort.
1208    #[inline(always)]
1209    pub fn danger_accept_invalid_certs(&mut self, noverify: bool) -> &mut Self {
1210        self.danger_accept_invalid_certs = noverify;
1211        self
1212    }
1213
1214    /// Specifies whether to use Server Name Indication (SNI).
1215    #[inline(always)]
1216    pub fn use_sni(&mut self, use_sni: bool) -> &mut Self {
1217        self.use_sni = use_sni;
1218        self
1219    }
1220
1221    /// Specifies whether to verify that the server's hostname matches its certificate.
1222    ///
1223    /// # Warning
1224    ///
1225    /// You should think very carefully before using this method. If hostnames are not verified,
1226    /// *any* valid certificate for *any* site will be trusted for use. This introduces significant
1227    /// vulnerabilities, and should only be used as a last resort.
1228    #[inline(always)]
1229    pub fn danger_accept_invalid_hostnames(
1230        &mut self,
1231        danger_accept_invalid_hostnames: bool,
1232    ) -> &mut Self {
1233        self.danger_accept_invalid_hostnames = danger_accept_invalid_hostnames;
1234        self
1235    }
1236
1237    /// Set a whitelist of enabled ciphers. Any ciphers not whitelisted will be disabled.
1238    pub fn whitelist_ciphers(&mut self, whitelisted_ciphers: &[CipherSuite]) -> &mut Self {
1239        whitelisted_ciphers.clone_into(&mut self.whitelisted_ciphers);
1240        self
1241    }
1242
1243    /// Set a blacklist of disabled ciphers. Blacklisted ciphers will be disabled.
1244    pub fn blacklist_ciphers(&mut self, blacklisted_ciphers: &[CipherSuite]) -> &mut Self {
1245        blacklisted_ciphers.clone_into(&mut self.blacklisted_ciphers);
1246        self
1247    }
1248
1249    /// Use the specified identity as a SSL/TLS client certificate.
1250    pub fn identity(&mut self, identity: &SecIdentity, chain: &[SecCertificate]) -> &mut Self {
1251        self.identity = Some(identity.clone());
1252        chain.clone_into(&mut self.chain);
1253        self
1254    }
1255
1256    /// Configure the minimum protocol that this client will support.
1257    #[inline(always)]
1258    pub fn protocol_min(&mut self, min: SslProtocol) -> &mut Self {
1259        self.protocol_min = Some(min);
1260        self
1261    }
1262
1263    /// Configure the minimum protocol that this client will support.
1264    #[inline(always)]
1265    pub fn protocol_max(&mut self, max: SslProtocol) -> &mut Self {
1266        self.protocol_max = Some(max);
1267        self
1268    }
1269
1270    /// Configures the set of protocols used for ALPN.
1271    #[cfg(feature = "alpn")]
1272    pub fn alpn_protocols(&mut self, protocols: &[&str]) -> &mut Self {
1273        self.alpn = Some(protocols.iter().map(|s| (*s).to_string()).collect());
1274        self
1275    }
1276
1277    /// Configures the use of the RFC 5077 `SessionTicket` extension.
1278    ///
1279    /// Defaults to `false`.
1280    #[cfg(feature = "session-tickets")]
1281    #[inline(always)]
1282    pub fn enable_session_tickets(&mut self, enable: bool) -> &mut Self {
1283        self.enable_session_tickets = enable;
1284        self
1285    }
1286
1287    /// Initiates a new SSL/TLS session over a stream connected to the specified domain.
1288    ///
1289    /// If both SNI and hostname verification are disabled, the value of `domain` will be ignored.
1290    pub fn handshake<S>(
1291        &self,
1292        domain: &str,
1293        stream: S,
1294    ) -> result::Result<SslStream<S>, ClientHandshakeError<S>>
1295    where
1296        S: Read + Write,
1297    {
1298        // the logic for trust validation is in MidHandshakeClientBuilder::connect, so run all
1299        // of the handshake logic through that.
1300        let stream = MidHandshakeSslStream {
1301            stream: self.ctx_into_stream(domain, stream)?,
1302            error: Error::from(errSecSuccess),
1303        };
1304
1305        let certs = self.certs.clone();
1306        let stream = MidHandshakeClientBuilder {
1307            stream,
1308            domain: if self.danger_accept_invalid_hostnames {
1309                None
1310            } else {
1311                Some(domain.to_string())
1312            },
1313            certs,
1314            trust_certs_only: self.trust_certs_only,
1315            danger_accept_invalid_certs: self.danger_accept_invalid_certs,
1316        };
1317        stream.handshake()
1318    }
1319
1320    fn ctx_into_stream<S>(&self, domain: &str, stream: S) -> Result<SslStream<S>>
1321    where S: Read + Write {
1322        let mut ctx = SslContext::new(SslProtocolSide::CLIENT, SslConnectionType::STREAM)?;
1323
1324        if self.use_sni {
1325            ctx.set_peer_domain_name(domain)?;
1326        }
1327        if let Some(ref identity) = self.identity {
1328            ctx.set_certificate(identity, &self.chain)?;
1329        }
1330        #[cfg(feature = "alpn")]
1331        {
1332            if let Some(ref alpn) = self.alpn {
1333                ctx.set_alpn_protocols(&alpn.iter().map(|s| &**s).collect::<Vec<_>>())?;
1334            }
1335        }
1336        #[cfg(feature = "session-tickets")]
1337        {
1338            if self.enable_session_tickets {
1339                // We must use the domain here to ensure that we go through certificate validation
1340                // again rather than resuming the session if the domain changes.
1341                ctx.set_peer_id(domain.as_bytes())?;
1342                ctx.set_session_tickets_enabled(true)?;
1343            }
1344        }
1345        ctx.set_break_on_server_auth(true)?;
1346        self.configure_protocols(&mut ctx)?;
1347        self.configure_ciphers(&mut ctx)?;
1348
1349        ctx.into_stream(stream)
1350    }
1351
1352    fn configure_protocols(&self, ctx: &mut SslContext) -> Result<()> {
1353        if let Some(min) = self.protocol_min {
1354            ctx.set_protocol_version_min(min)?;
1355        }
1356        if let Some(max) = self.protocol_max {
1357            ctx.set_protocol_version_max(max)?;
1358        }
1359        Ok(())
1360    }
1361
1362    fn configure_ciphers(&self, ctx: &mut SslContext) -> Result<()> {
1363        let mut ciphers = if self.whitelisted_ciphers.is_empty() {
1364            ctx.enabled_ciphers()?
1365        } else {
1366            self.whitelisted_ciphers.clone()
1367        };
1368
1369        if !self.blacklisted_ciphers.is_empty() {
1370            ciphers.retain(|cipher| !self.blacklisted_ciphers.contains(cipher));
1371        }
1372
1373        ctx.set_enabled_ciphers(&ciphers)?;
1374        Ok(())
1375    }
1376}
1377
1378/// A builder type to simplify the creation of server-side `SslStream`s.
1379#[derive(Debug)]
1380pub struct ServerBuilder {
1381    identity: SecIdentity,
1382    certs: Vec<SecCertificate>,
1383}
1384
1385impl ServerBuilder {
1386    /// Creates a new `ServerBuilder` which will use the specified identity
1387    /// and certificate chain for handshakes.
1388    #[must_use]
1389    pub fn new(identity: &SecIdentity, certs: &[SecCertificate]) -> Self {
1390        Self {
1391            identity: identity.clone(),
1392            certs: certs.to_owned(),
1393        }
1394    }
1395
1396    /// Creates a new `ServerBuilder` which will use the identity
1397    /// from the given PKCS #12 data.
1398    ///
1399    /// This operation fails if PKCS #12 file contains zero or more than one identity.
1400    ///
1401    /// This is a shortcut for the most common operation.
1402    pub fn from_pkcs12(pkcs12_der: &[u8], passphrase: &str) -> Result<Self> {
1403        let mut identities: Vec<(SecIdentity, Vec<SecCertificate>)> = Pkcs12ImportOptions::new()
1404            .passphrase(passphrase)
1405            .import(pkcs12_der)?
1406            .into_iter()
1407            .filter_map(|idendity| {
1408                Some((idendity.identity?, idendity.cert_chain.unwrap_or_default()))
1409            })
1410            .take(2)
1411            .collect();
1412        if identities.len() == 1 {
1413            let (identity, certs) = identities.pop().unwrap();
1414            Ok(Self { identity, certs })
1415        } else {
1416            // This error code is not really helpful
1417            Err(Error::from_code(errSecParam))
1418        }
1419    }
1420
1421    /// Create a SSL context for lower-level stream initialization.
1422    pub fn new_ssl_context(&self) -> Result<SslContext> {
1423        let mut ctx = SslContext::new(SslProtocolSide::SERVER, SslConnectionType::STREAM)?;
1424        ctx.set_certificate(&self.identity, &self.certs)?;
1425        Ok(ctx)
1426    }
1427
1428    /// Initiates a new SSL/TLS session over a stream.
1429    pub fn handshake<S>(&self, stream: S) -> Result<SslStream<S>>
1430    where S: Read + Write {
1431        match self.new_ssl_context()?.handshake(stream) {
1432            Ok(stream) => Ok(stream),
1433            Err(HandshakeError::Interrupted(stream)) => Err(*stream.error()),
1434            Err(HandshakeError::Failure(err)) => Err(err),
1435        }
1436    }
1437}
1438
1439#[cfg(test)]
1440mod test {
1441    use std::io::prelude::*;
1442    use std::net::TcpStream;
1443
1444    use super::*;
1445
1446    #[test]
1447    fn server_builder_from_pkcs12() {
1448        let pkcs12_der = include_bytes!("../test/server.p12");
1449        ServerBuilder::from_pkcs12(pkcs12_der, "password123").unwrap();
1450    }
1451
1452    #[test]
1453    fn connect() {
1454        let mut ctx = p!(SslContext::new(SslProtocolSide::CLIENT, SslConnectionType::STREAM));
1455        p!(ctx.set_peer_domain_name("google.com"));
1456        let stream = p!(TcpStream::connect("google.com:443"));
1457        p!(ctx.handshake(stream));
1458    }
1459
1460    #[test]
1461    fn connect_bad_domain() {
1462        let mut ctx = p!(SslContext::new(SslProtocolSide::CLIENT, SslConnectionType::STREAM));
1463        p!(ctx.set_peer_domain_name("foobar.com"));
1464        let stream = p!(TcpStream::connect("google.com:443"));
1465        ctx.handshake(stream).expect_err("expected failure");
1466    }
1467
1468    #[test]
1469    fn connect_buffered_stream() {
1470        use std::io::BufWriter;
1471
1472        /// Small wrapper around a `TcpStream` to provide buffered writes.
1473        #[derive(Debug)]
1474        struct BufferedTcpStream {
1475            reader: TcpStream,
1476            writer: BufWriter<TcpStream>,
1477        }
1478
1479        impl BufferedTcpStream {
1480            fn new(tcp: TcpStream) -> std::io::Result<Self> {
1481                Ok(Self {
1482                    writer: BufWriter::with_capacity(500, tcp.try_clone()?),
1483                    reader: tcp,
1484                })
1485            }
1486        }
1487
1488        impl Read for BufferedTcpStream {
1489            fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
1490                self.reader.read(buf)
1491            }
1492        }
1493
1494        impl Write for BufferedTcpStream {
1495            fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
1496                self.writer.write(buf)
1497            }
1498
1499            fn flush(&mut self) -> std::io::Result<()> {
1500                self.writer.flush()
1501            }
1502        }
1503
1504        let mut ctx = p!(SslContext::new(SslProtocolSide::CLIENT, SslConnectionType::STREAM));
1505        p!(ctx.set_peer_domain_name("google.com"));
1506        let stream = p!(TcpStream::connect("google.com:443"));
1507        let stream = p!(BufferedTcpStream::new(stream));
1508        p!(ctx.handshake(stream));
1509    }
1510
1511    #[test]
1512    fn load_page() {
1513        let mut ctx = p!(SslContext::new(SslProtocolSide::CLIENT, SslConnectionType::STREAM));
1514        p!(ctx.set_peer_domain_name("google.com"));
1515        let stream = p!(TcpStream::connect("google.com:443"));
1516        let mut stream = p!(ctx.handshake(stream));
1517        p!(stream.write_all(b"GET / HTTP/1.0\r\n\r\n"));
1518        p!(stream.flush());
1519        let mut buf = vec![];
1520        p!(stream.read_to_end(&mut buf));
1521        println!("{}", String::from_utf8_lossy(&buf));
1522    }
1523
1524    #[test]
1525    fn client_no_session_ticket_resumption() {
1526        for _ in 0..2 {
1527            let stream = p!(TcpStream::connect("google.com:443"));
1528
1529            // Manually handshake here.
1530            let stream = MidHandshakeSslStream {
1531                stream: ClientBuilder::new()
1532                    .ctx_into_stream("google.com", stream)
1533                    .unwrap(),
1534                error: Error::from(errSecSuccess),
1535            };
1536
1537            let mut result = stream.handshake();
1538
1539            if let Err(HandshakeError::Interrupted(stream)) = result {
1540                assert!(stream.server_auth_completed());
1541                result = stream.handshake();
1542            } else {
1543                panic!("Unexpectedly skipped server auth");
1544            }
1545
1546            assert!(result.is_ok());
1547        }
1548    }
1549
1550    #[test]
1551    #[cfg(feature = "session-tickets")]
1552    fn client_session_ticket_resumption() {
1553        // The first time through this loop, we should do a full handshake. The second time, we
1554        // should immediately finish the handshake without breaking on server auth.
1555        for i in 0..2 {
1556            let stream = p!(TcpStream::connect("google.com:443"));
1557            let mut builder = ClientBuilder::new();
1558            builder.enable_session_tickets(true);
1559
1560            // Manually handshake here.
1561            let stream = MidHandshakeSslStream {
1562                stream: builder.ctx_into_stream("google.com", stream).unwrap(),
1563                error: Error::from(errSecSuccess),
1564            };
1565
1566            let mut result = stream.handshake();
1567
1568            if let Err(HandshakeError::Interrupted(stream)) = result {
1569                assert!(stream.server_auth_completed());
1570                assert_eq!(i, 0, "Session ticket resumption did not work, server auth was not skipped");
1571                result = stream.handshake();
1572            } else {
1573                assert_eq!(i, 1, "Unexpectedly skipped server auth");
1574            }
1575
1576            assert!(result.is_ok());
1577        }
1578    }
1579
1580    #[test]
1581    #[cfg(feature = "alpn")]
1582    fn client_alpn_accept() {
1583        let mut ctx = p!(SslContext::new(SslProtocolSide::CLIENT, SslConnectionType::STREAM));
1584        p!(ctx.set_peer_domain_name("google.com"));
1585        p!(ctx.set_alpn_protocols(&vec!["h2"]));
1586        let stream = p!(TcpStream::connect("google.com:443"));
1587        let stream = ctx.handshake(stream).unwrap();
1588        assert_eq!(vec!["h2"], stream.context().alpn_protocols().unwrap());
1589    }
1590
1591    #[test]
1592    #[cfg(feature = "alpn")]
1593    fn client_alpn_reject() {
1594        let mut ctx = p!(SslContext::new(SslProtocolSide::CLIENT, SslConnectionType::STREAM));
1595        p!(ctx.set_peer_domain_name("google.com"));
1596        p!(ctx.set_alpn_protocols(&vec!["h2c"]));
1597        let stream = p!(TcpStream::connect("google.com:443"));
1598        let stream = ctx.handshake(stream).unwrap();
1599        assert!(stream.context().alpn_protocols().is_err());
1600    }
1601
1602    #[test]
1603    fn client_no_anchor_certs() {
1604        let stream = p!(TcpStream::connect("google.com:443"));
1605        assert!(ClientBuilder::new()
1606            .trust_anchor_certificates_only(true)
1607            .handshake("google.com", stream)
1608            .is_err());
1609    }
1610
1611    #[test]
1612    fn client_bad_domain() {
1613        let stream = p!(TcpStream::connect("google.com:443"));
1614        assert!(ClientBuilder::new()
1615            .handshake("foobar.com", stream)
1616            .is_err());
1617    }
1618
1619    #[test]
1620    fn client_bad_domain_ignored() {
1621        let stream = p!(TcpStream::connect("google.com:443"));
1622        ClientBuilder::new()
1623            .danger_accept_invalid_hostnames(true)
1624            .handshake("foobar.com", stream)
1625            .unwrap();
1626    }
1627
1628    #[test]
1629    fn connect_no_verify_ssl() {
1630        let stream = p!(TcpStream::connect("expired.badssl.com:443"));
1631        let mut builder = ClientBuilder::new();
1632        builder.danger_accept_invalid_certs(true);
1633        builder.handshake("expired.badssl.com", stream).unwrap();
1634    }
1635
1636    #[test]
1637    fn load_page_client() {
1638        let stream = p!(TcpStream::connect("google.com:443"));
1639        let mut stream = p!(ClientBuilder::new().handshake("google.com", stream));
1640        p!(stream.write_all(b"GET / HTTP/1.0\r\n\r\n"));
1641        p!(stream.flush());
1642        let mut buf = vec![];
1643        p!(stream.read_to_end(&mut buf));
1644        println!("{}", String::from_utf8_lossy(&buf));
1645    }
1646
1647    #[test]
1648    #[cfg_attr(any(target_os = "ios", target_os = "tvos", target_os = "watchos", target_os = "visionos"), ignore)] // FIXME what's going on with ios?
1649    fn cipher_configuration() {
1650        let mut ctx = p!(SslContext::new(SslProtocolSide::SERVER, SslConnectionType::STREAM));
1651        let ciphers = p!(ctx.enabled_ciphers());
1652        let ciphers = ciphers
1653            .iter()
1654            .enumerate()
1655            .filter_map(|(i, c)| if i % 2 == 0 { Some(*c) } else { None })
1656            .collect::<Vec<_>>();
1657        p!(ctx.set_enabled_ciphers(&ciphers));
1658        assert_eq!(ciphers, p!(ctx.enabled_ciphers()));
1659    }
1660
1661    #[test]
1662    fn test_builder_whitelist_ciphers() {
1663        let stream = p!(TcpStream::connect("google.com:443"));
1664
1665        let ctx = p!(SslContext::new(SslProtocolSide::CLIENT, SslConnectionType::STREAM));
1666        assert!(p!(ctx.enabled_ciphers()).len() > 1);
1667
1668        let ciphers = p!(ctx.enabled_ciphers());
1669        let cipher = ciphers.first().unwrap();
1670        let stream = p!(ClientBuilder::new()
1671            .whitelist_ciphers(&[*cipher])
1672            .ctx_into_stream("google.com", stream));
1673
1674        assert_eq!(1, p!(stream.context().enabled_ciphers()).len());
1675    }
1676
1677    #[test]
1678    #[cfg_attr(any(target_os = "ios", target_os = "tvos", target_os = "watchos", target_os = "visionos"), ignore)] // FIXME same issue as cipher_configuration
1679    fn test_builder_blacklist_ciphers() {
1680        let stream = p!(TcpStream::connect("google.com:443"));
1681
1682        let ctx = p!(SslContext::new(SslProtocolSide::CLIENT, SslConnectionType::STREAM));
1683        let num = p!(ctx.enabled_ciphers()).len();
1684        assert!(num > 1);
1685
1686        let ciphers = p!(ctx.enabled_ciphers());
1687        let cipher = ciphers.first().unwrap();
1688        let stream = p!(ClientBuilder::new()
1689            .blacklist_ciphers(&[*cipher])
1690            .ctx_into_stream("google.com", stream));
1691
1692        assert_eq!(num - 1, p!(stream.context().enabled_ciphers()).len());
1693    }
1694
1695    #[test]
1696    fn idle_context_peer_trust() {
1697        let ctx = p!(SslContext::new(SslProtocolSide::SERVER, SslConnectionType::STREAM));
1698        assert!(ctx.peer_trust2().is_err());
1699    }
1700
1701    #[test]
1702    fn peer_id() {
1703        let mut ctx = p!(SslContext::new(SslProtocolSide::SERVER, SslConnectionType::STREAM));
1704        assert!(p!(ctx.peer_id()).is_none());
1705        p!(ctx.set_peer_id(b"foobar"));
1706        assert_eq!(p!(ctx.peer_id()), Some(&b"foobar"[..]));
1707    }
1708
1709    #[test]
1710    fn peer_domain_name() {
1711        let mut ctx = p!(SslContext::new(SslProtocolSide::CLIENT, SslConnectionType::STREAM));
1712        assert_eq!("", p!(ctx.peer_domain_name()));
1713        p!(ctx.set_peer_domain_name("foobar.com"));
1714        assert_eq!("foobar.com", p!(ctx.peer_domain_name()));
1715    }
1716
1717    #[test]
1718    #[should_panic(expected = "blammo")]
1719    fn write_panic() {
1720        struct ExplodingStream(TcpStream);
1721
1722        impl Read for ExplodingStream {
1723            fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
1724                self.0.read(buf)
1725            }
1726        }
1727
1728        impl Write for ExplodingStream {
1729            fn write(&mut self, _: &[u8]) -> io::Result<usize> {
1730                panic!("blammo");
1731            }
1732
1733            fn flush(&mut self) -> io::Result<()> {
1734                self.0.flush()
1735            }
1736        }
1737
1738        let mut ctx = p!(SslContext::new(SslProtocolSide::CLIENT, SslConnectionType::STREAM));
1739        p!(ctx.set_peer_domain_name("google.com"));
1740        let stream = p!(TcpStream::connect("google.com:443"));
1741        let _ = ctx.handshake(ExplodingStream(stream));
1742    }
1743
1744    #[test]
1745    #[should_panic(expected = "blammo")]
1746    fn read_panic() {
1747        struct ExplodingStream(TcpStream);
1748
1749        impl Read for ExplodingStream {
1750            fn read(&mut self, _: &mut [u8]) -> io::Result<usize> {
1751                panic!("blammo");
1752            }
1753        }
1754
1755        impl Write for ExplodingStream {
1756            fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
1757                self.0.write(buf)
1758            }
1759
1760            fn flush(&mut self) -> io::Result<()> {
1761                self.0.flush()
1762            }
1763        }
1764
1765        let mut ctx = p!(SslContext::new(SslProtocolSide::CLIENT, SslConnectionType::STREAM));
1766        p!(ctx.set_peer_domain_name("google.com"));
1767        let stream = p!(TcpStream::connect("google.com:443"));
1768        let _ = ctx.handshake(ExplodingStream(stream));
1769    }
1770
1771    #[test]
1772    fn zero_length_buffers() {
1773        let mut ctx = p!(SslContext::new(SslProtocolSide::CLIENT, SslConnectionType::STREAM));
1774        p!(ctx.set_peer_domain_name("google.com"));
1775        let stream = p!(TcpStream::connect("google.com:443"));
1776        let mut stream = ctx.handshake(stream).unwrap();
1777        assert_eq!(stream.write(b"").unwrap(), 0);
1778        assert_eq!(stream.read(&mut []).unwrap(), 0);
1779    }
1780}