Skip to main content

security_framework/
secure_transport.rs

1//! SSL/TLS encryption support using Secure Transport.
2//!
3//! # Examples
4//!
5//! To connect as a client to a server with a certificate trusted by the system:
6//!
7//! ```rust
8//! use security_framework::secure_transport::ClientBuilder;
9//! use std::io::prelude::*;
10//! use std::net::TcpStream;
11//!
12//! let stream = TcpStream::connect("google.com:443").unwrap();
13//! let mut stream = ClientBuilder::new().handshake("google.com", stream).unwrap();
14//!
15//! stream.write_all(b"GET / HTTP/1.0\r\n\r\n").unwrap();
16//! let mut page = vec![];
17//! stream.read_to_end(&mut page).unwrap();
18//! println!("{}", String::from_utf8_lossy(&page));
19//! ```
20//!
21//! To connect to a server with a certificate that's *not* trusted by the
22//! system, specify the root certificates for the server's chain to the
23//! `ClientBuilder`:
24//!
25//! ```rust,no_run
26//! use security_framework::secure_transport::ClientBuilder;
27//! use std::io::prelude::*;
28//! use std::net::TcpStream;
29//!
30//! # let root_cert = unsafe { std::mem::zeroed() };
31//! let stream = TcpStream::connect("my_server.com:443").unwrap();
32//! let mut stream = ClientBuilder::new()
33//!     .anchor_certificates(&[root_cert])
34//!     .handshake("my_server.com", stream)
35//!     .unwrap();
36//!
37//! stream.write_all(b"GET / HTTP/1.0\r\n\r\n").unwrap();
38//! let mut page = vec![];
39//! stream.read_to_end(&mut page).unwrap();
40//! println!("{}", String::from_utf8_lossy(&page));
41//! ```
42//!
43//! For more advanced configuration, the `SslContext` type can be used directly.
44//!
45//! To run a server:
46//!
47//! ```rust,no_run
48//! use security_framework::secure_transport::{SslConnectionType, SslContext, SslProtocolSide};
49//! use std::net::TcpListener;
50//! use std::thread;
51//!
52//! // Create a TCP listener and start accepting on it.
53//! let mut listener = TcpListener::bind("0.0.0.0:443").unwrap();
54//!
55//! for stream in listener.incoming() {
56//!     let stream = stream.unwrap();
57//!     thread::spawn(move || {
58//!         // Create a new context configured to operate on the server side of
59//!         // a traditional SSL/TLS session.
60//!         let mut ctx = SslContext::new(SslProtocolSide::SERVER, SslConnectionType::STREAM)
61//!                           .unwrap();
62//!
63//!         // Install the certificate chain that we will be using.
64//!         # let identity = unsafe { std::mem::zeroed() };
65//!         # let intermediate_cert = unsafe { std::mem::zeroed() };
66//!         # let root_cert = unsafe { std::mem::zeroed() };
67//!         ctx.set_certificate(identity, &[intermediate_cert, root_cert]).unwrap();
68//!
69//!         // Perform the SSL/TLS handshake and get our stream.
70//!         let mut stream = ctx.handshake(stream).unwrap();
71//!     });
72//! }
73//! ```
74#![allow(clippy::result_large_err)]
75#[allow(unused_imports)]
76use core_foundation::array::{CFArray, CFArrayRef};
77use core_foundation::base::{Boolean, TCFType};
78use core_foundation::string::CFString;
79use core_foundation::{declare_TCFType, impl_TCFType};
80use core_foundation_sys::base::{kCFAllocatorDefault, OSStatus};
81use std::os::raw::c_void;
82
83#[allow(unused_imports)]
84use security_framework_sys::base::{
85    errSecBadReq, errSecIO, errSecNotTrusted, errSecSuccess, errSecTrustSettingDeny,
86    errSecUnimplemented,
87};
88
89use security_framework_sys::secure_transport::*;
90use std::any::Any;
91use std::cmp;
92use std::fmt;
93use std::io;
94use std::io::prelude::*;
95use std::marker::PhantomData;
96use std::panic::{self, AssertUnwindSafe};
97use std::ptr;
98use std::result;
99use std::slice;
100
101use crate::base::{Error, Result};
102use crate::certificate::SecCertificate;
103use crate::cipher_suite::CipherSuite;
104use crate::cvt;
105use crate::identity::SecIdentity;
106use crate::import_export::Pkcs12ImportOptions;
107use crate::policy::SecPolicy;
108use crate::trust::SecTrust;
109use security_framework_sys::base::errSecParam;
110
111/// Specifies a side of a TLS session.
112#[derive(Debug, Copy, Clone, PartialEq, Eq)]
113pub struct SslProtocolSide(SSLProtocolSide);
114
115impl SslProtocolSide {
116    /// The client side of the session.
117    pub const CLIENT: Self = Self(kSSLClientSide);
118    /// The server side of the session.
119    pub const SERVER: Self = Self(kSSLServerSide);
120}
121
122/// Specifies the type of TLS session.
123#[derive(Debug, Copy, Clone)]
124pub struct SslConnectionType(SSLConnectionType);
125
126impl SslConnectionType {
127    /// A DTLS session.
128    pub const DATAGRAM: Self = Self(kSSLDatagramType);
129    /// A traditional TLS stream.
130    pub const STREAM: Self = Self(kSSLStreamType);
131}
132
133/// An error or intermediate state after a TLS handshake attempt.
134#[derive(Debug)]
135pub enum HandshakeError<S> {
136    /// The handshake failed.
137    Failure(Error),
138    /// The handshake was interrupted midway through.
139    Interrupted(MidHandshakeSslStream<S>),
140}
141
142impl<S> From<Error> for HandshakeError<S> {
143    #[inline(always)]
144    fn from(err: Error) -> Self {
145        Self::Failure(err)
146    }
147}
148
149/// An error or intermediate state after a TLS handshake attempt.
150#[derive(Debug)]
151pub enum ClientHandshakeError<S> {
152    /// The handshake failed.
153    Failure(Error),
154    /// The handshake was interrupted midway through.
155    Interrupted(MidHandshakeClientBuilder<S>),
156}
157
158impl<S> From<Error> for ClientHandshakeError<S> {
159    #[inline(always)]
160    fn from(err: Error) -> Self {
161        Self::Failure(err)
162    }
163}
164
165/// An SSL stream midway through the handshake process.
166#[derive(Debug)]
167pub struct MidHandshakeSslStream<S> {
168    stream: SslStream<S>,
169    error: Error,
170}
171
172impl<S> MidHandshakeSslStream<S> {
173    /// Returns a shared reference to the inner stream.
174    #[inline(always)]
175    #[must_use]
176    pub fn get_ref(&self) -> &S {
177        self.stream.get_ref()
178    }
179
180    /// Returns a mutable reference to the inner stream.
181    #[inline(always)]
182    pub fn get_mut(&mut self) -> &mut S {
183        self.stream.get_mut()
184    }
185
186    /// Returns a shared reference to the `SslContext` of the stream.
187    #[inline(always)]
188    #[must_use]
189    pub fn context(&self) -> &SslContext {
190        self.stream.context()
191    }
192
193    /// Returns a mutable reference to the `SslContext` of the stream.
194    #[inline(always)]
195    pub fn context_mut(&mut self) -> &mut SslContext {
196        self.stream.context_mut()
197    }
198
199    /// Returns `true` iff `break_on_server_auth` was set and the handshake has
200    /// progressed to that point.
201    #[inline(always)]
202    #[must_use]
203    pub fn server_auth_completed(&self) -> bool {
204        self.error.code() == errSSLPeerAuthCompleted
205    }
206
207    /// Returns `true` iff `break_on_cert_requested` was set and the handshake
208    /// has progressed to that point.
209    #[inline(always)]
210    #[must_use]
211    pub fn client_cert_requested(&self) -> bool {
212        self.error.code() == errSSLClientCertRequested
213    }
214
215    /// Returns `true` iff the underlying stream returned an error with the
216    /// `WouldBlock` kind.
217    #[inline(always)]
218    #[must_use]
219    pub fn would_block(&self) -> bool {
220        self.error.code() == errSSLWouldBlock
221    }
222
223    /// Returns the error which caused the handshake interruption.
224    #[inline(always)]
225    #[must_use]
226    pub const fn error(&self) -> &Error {
227        &self.error
228    }
229
230    /// Restarts the handshake process.
231    #[inline(always)]
232    pub fn handshake(self) -> result::Result<SslStream<S>, HandshakeError<S>> {
233        self.stream.handshake()
234    }
235}
236
237/// An SSL stream midway through the handshake process.
238#[derive(Debug)]
239pub struct MidHandshakeClientBuilder<S> {
240    stream: MidHandshakeSslStream<S>,
241    domain: Option<String>,
242    certs: Vec<SecCertificate>,
243    trust_certs_only: bool,
244    danger_accept_invalid_certs: bool,
245}
246
247impl<S> MidHandshakeClientBuilder<S> {
248    /// Returns a shared reference to the inner stream.
249    #[inline(always)]
250    #[must_use]
251    pub fn get_ref(&self) -> &S {
252        self.stream.get_ref()
253    }
254
255    /// Returns a mutable reference to the inner stream.
256    #[inline(always)]
257    pub fn get_mut(&mut self) -> &mut S {
258        self.stream.get_mut()
259    }
260
261    /// Returns the error which caused the handshake interruption.
262    #[inline(always)]
263    #[must_use]
264    pub fn error(&self) -> &Error {
265        self.stream.error()
266    }
267
268    /// Restarts the handshake process.
269    pub fn handshake(self) -> result::Result<SslStream<S>, ClientHandshakeError<S>> {
270        let Self {
271            stream,
272            domain,
273            certs,
274            trust_certs_only,
275            danger_accept_invalid_certs,
276        } = self;
277
278        let mut result = stream.handshake();
279        loop {
280            let stream = match result {
281                Ok(stream) => return Ok(stream),
282                Err(HandshakeError::Interrupted(stream)) => stream,
283                Err(HandshakeError::Failure(err)) => return Err(ClientHandshakeError::Failure(err)),
284            };
285
286            if stream.would_block() {
287                let ret = Self {
288                    stream,
289                    domain,
290                    certs,
291                    trust_certs_only,
292                    danger_accept_invalid_certs,
293                };
294                return Err(ClientHandshakeError::Interrupted(ret));
295            }
296
297            if stream.server_auth_completed() {
298                if danger_accept_invalid_certs {
299                    result = stream.handshake();
300                    continue;
301                }
302                let Some(mut trust) = stream.context().peer_trust2()? else {
303                    result = stream.handshake();
304                    continue;
305                };
306                trust.set_anchor_certificates(&certs)?;
307                trust.set_trust_anchor_certificates_only(self.trust_certs_only)?;
308                let policy = SecPolicy::create_ssl(SslProtocolSide::SERVER, domain.as_deref());
309                trust.set_policy(&policy)?;
310                trust.evaluate_with_error().map_err(|error| {
311                    #[cfg(feature = "log")]
312                    log::warn!("SecTrustEvaluateWithError: {error}");
313                    Error::from_code(error.code() as _)
314                })?;
315                result = stream.handshake();
316                continue;
317            }
318
319            let err = Error::from_code(stream.error().code());
320            return Err(ClientHandshakeError::Failure(err));
321        }
322    }
323}
324
325/// Specifies the state of a TLS session.
326#[derive(Debug, PartialEq, Eq)]
327pub struct SessionState(SSLSessionState);
328
329impl SessionState {
330    /// The session has been aborted due to an error.
331    pub const ABORTED: Self = Self(kSSLAborted);
332    /// The session has been terminated.
333    pub const CLOSED: Self = Self(kSSLClosed);
334    /// The session is connected.
335    pub const CONNECTED: Self = Self(kSSLConnected);
336    /// The session is in the handshake process.
337    pub const HANDSHAKE: Self = Self(kSSLHandshake);
338    /// The session has not yet started.
339    pub const IDLE: Self = Self(kSSLIdle);
340}
341
342/// Specifies a server's requirement for client certificates.
343#[derive(Debug, Copy, Clone, PartialEq, Eq)]
344pub struct SslAuthenticate(SSLAuthenticate);
345
346impl SslAuthenticate {
347    /// Require a client certificate.
348    pub const ALWAYS: Self = Self(kAlwaysAuthenticate);
349    /// Do not request a client certificate.
350    pub const NEVER: Self = Self(kNeverAuthenticate);
351    /// Request but do not require a client certificate.
352    pub const TRY: Self = Self(kTryAuthenticate);
353}
354
355/// Specifies the state of client certificate processing.
356#[derive(Debug, Copy, Clone, PartialEq, Eq)]
357pub struct SslClientCertificateState(SSLClientCertificateState);
358
359impl SslClientCertificateState {
360    /// A client certificate has not been requested or sent.
361    pub const NONE: Self = Self(kSSLClientCertNone);
362    /// A client certificate has been received but has failed to validate.
363    pub const REJECTED: Self = Self(kSSLClientCertRejected);
364    /// A client certificate has been requested but not recieved.
365    pub const REQUESTED: Self = Self(kSSLClientCertRequested);
366    /// A client certificate has been received and successfully validated.
367    pub const SENT: Self = Self(kSSLClientCertSent);
368}
369
370/// Specifies protocol versions.
371#[derive(Debug, Copy, Clone, PartialEq, Eq)]
372pub struct SslProtocol(SSLProtocol);
373
374impl SslProtocol {
375    /// All supported TLS/SSL versions are accepted.
376    pub const ALL: Self = Self(kSSLProtocolAll);
377    /// The `DTLSv1` protocol is preferred.
378    pub const DTLS1: Self = Self(kDTLSProtocol1);
379    /// Only the SSL 2.0 protocol is accepted.
380    pub const SSL2: Self = Self(kSSLProtocol2);
381    /// The SSL 3.0 protocol is preferred, though SSL 2.0 may be used if the peer does not support
382    /// SSL 3.0.
383    pub const SSL3: Self = Self(kSSLProtocol3);
384    /// Only the SSL 3.0 protocol is accepted.
385    pub const SSL3_ONLY: Self = Self(kSSLProtocol3Only);
386    /// The TLS 1.0 protocol is preferred, though lower versions may be used
387    /// if the peer does not support TLS 1.0.
388    pub const TLS1: Self = Self(kTLSProtocol1);
389    /// The TLS 1.1 protocol is preferred, though lower versions may be used
390    /// if the peer does not support TLS 1.1.
391    pub const TLS11: Self = Self(kTLSProtocol11);
392    /// The TLS 1.2 protocol is preferred, though lower versions may be used
393    /// if the peer does not support TLS 1.2.
394    pub const TLS12: Self = Self(kTLSProtocol12);
395    /// The TLS 1.3 protocol is preferred, though lower versions may be used
396    /// if the peer does not support TLS 1.3.
397    pub const TLS13: Self = Self(kTLSProtocol13);
398    /// Only the TLS 1.0 protocol is accepted.
399    pub const TLS1_ONLY: Self = Self(kTLSProtocol1Only);
400    /// No protocol has been or should be negotiated or specified; use the default.
401    pub const UNKNOWN: Self = Self(kSSLProtocolUnknown);
402}
403
404declare_TCFType! {
405    /// A Secure Transport SSL/TLS context object.
406    SslContext, SSLContextRef
407}
408
409impl_TCFType!(SslContext, SSLContextRef, SSLContextGetTypeID);
410
411impl fmt::Debug for SslContext {
412    #[cold]
413    fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
414        let mut builder = fmt.debug_struct("SslContext");
415        if let Ok(state) = self.state() {
416            builder.field("state", &state);
417        }
418        builder.finish()
419    }
420}
421
422unsafe impl Sync for SslContext {}
423unsafe impl Send for SslContext {}
424
425#[cfg(target_os = "macos")]
426impl SslContext {
427    pub(crate) fn as_inner(&self) -> SSLContextRef {
428        self.0
429    }
430}
431
432macro_rules! impl_options {
433    ($($(#[$a:meta])* const $opt:ident: $get:ident & $set:ident,)*) => {
434        $(
435            #[allow(deprecated)]
436            $(#[$a])*
437            #[inline(always)]
438            pub fn $set(&mut self, value: bool) -> Result<()> {
439                unsafe { cvt(SSLSetSessionOption(self.0, $opt, Boolean::from(value))) }
440            }
441
442            #[allow(deprecated)]
443            $(#[$a])*
444            #[inline]
445            pub fn $get(&self) -> Result<bool> {
446                let mut value = 0;
447                unsafe { cvt(SSLGetSessionOption(self.0, $opt, &mut value))?; }
448                Ok(value != 0)
449            }
450        )*
451    }
452}
453
454impl SslContext {
455    /// Creates a new `SslContext` for the specified side and type of SSL
456    /// connection.
457    #[inline]
458    pub fn new(side: SslProtocolSide, type_: SslConnectionType) -> Result<Self> {
459        unsafe {
460            let ctx = SSLCreateContext(kCFAllocatorDefault, side.0, type_.0);
461            Ok(Self(ctx))
462        }
463    }
464
465    /// Sets the fully qualified domain name of the peer.
466    ///
467    /// This will be used on the client side of a session to validate the
468    /// common name field of the server's certificate. It has no effect if
469    /// called on a server-side `SslContext`.
470    ///
471    /// It is *highly* recommended to call this method before starting the
472    /// handshake process.
473    #[inline]
474    pub fn set_peer_domain_name(&mut self, peer_name: &str) -> Result<()> {
475        unsafe {
476            // SSLSetPeerDomainName doesn't need a null terminated string
477            cvt(SSLSetPeerDomainName(self.0, peer_name.as_ptr().cast(), peer_name.len()))
478        }
479    }
480
481    /// Returns the peer domain name set by `set_peer_domain_name`.
482    pub fn peer_domain_name(&self) -> Result<String> {
483        unsafe {
484            let mut len = 0;
485            cvt(SSLGetPeerDomainNameLength(self.0, &mut len))?;
486            let mut buf = vec![0; len];
487            cvt(SSLGetPeerDomainName(self.0, buf.as_mut_ptr().cast(), &mut len))?;
488            String::from_utf8(buf).map_err(|_| Error::from_code(-1))
489        }
490    }
491
492    /// Sets the certificate to be used by this side of the SSL session.
493    ///
494    /// This must be called before the handshake for server-side connections,
495    /// and can be used on the client-side to specify a client certificate.
496    ///
497    /// The `identity` corresponds to the leaf certificate and private
498    /// key, and the `certs` correspond to extra certificates in the chain.
499    pub fn set_certificate(
500        &mut self,
501        identity: &SecIdentity,
502        certs: &[SecCertificate],
503    ) -> Result<()> {
504        let mut arr = vec![identity.as_CFType()];
505        arr.extend(certs.iter().map(|c| c.as_CFType()));
506        let certs = CFArray::from_CFTypes(&arr);
507
508        unsafe { cvt(SSLSetCertificate(self.0, certs.as_concrete_TypeRef())) }
509    }
510
511    /// Sets the peer ID of this session.
512    ///
513    /// A peer ID is an opaque sequence of bytes that will be used by Secure
514    /// Transport to identify the peer of an SSL session. If the peer ID of
515    /// this session matches that of a previously terminated session, the
516    /// previous session can be resumed without requiring a full handshake.
517    #[inline]
518    pub fn set_peer_id(&mut self, peer_id: &[u8]) -> Result<()> {
519        unsafe { cvt(SSLSetPeerID(self.0, peer_id.as_ptr().cast(), peer_id.len())) }
520    }
521
522    /// Returns the peer ID of this session.
523    pub fn peer_id(&self) -> Result<Option<&[u8]>> {
524        unsafe {
525            let mut ptr = ptr::null();
526            let mut len = 0;
527            cvt(SSLGetPeerID(self.0, &mut ptr, &mut len))?;
528            if ptr.is_null() {
529                Ok(None)
530            } else {
531                Ok(Some(slice::from_raw_parts(ptr.cast(), len)))
532            }
533        }
534    }
535
536    /// Returns the list of ciphers that are supported by Secure Transport.
537    pub fn supported_ciphers(&self) -> Result<Vec<CipherSuite>> {
538        unsafe {
539            let mut num_ciphers = 0;
540            cvt(SSLGetNumberSupportedCiphers(self.0, &mut num_ciphers))?;
541            let mut ciphers = vec![0; num_ciphers];
542            cvt(SSLGetSupportedCiphers(
543                self.0,
544                ciphers.as_mut_ptr(),
545                &mut num_ciphers,
546            ))?;
547            Ok(ciphers.iter().map(|c| CipherSuite::from_raw(*c)).collect())
548        }
549    }
550
551    /// Returns the list of ciphers that are eligible to be used for
552    /// negotiation.
553    pub fn enabled_ciphers(&self) -> Result<Vec<CipherSuite>> {
554        unsafe {
555            let mut num_ciphers = 0;
556            cvt(SSLGetNumberEnabledCiphers(self.0, &mut num_ciphers))?;
557            let mut ciphers = vec![0; num_ciphers];
558            cvt(SSLGetEnabledCiphers(
559                self.0,
560                ciphers.as_mut_ptr(),
561                &mut num_ciphers,
562            ))?;
563            Ok(ciphers.iter().map(|c| CipherSuite::from_raw(*c)).collect())
564        }
565    }
566
567    /// Sets the list of ciphers that are eligible to be used for negotiation.
568    pub fn set_enabled_ciphers(&mut self, ciphers: &[CipherSuite]) -> Result<()> {
569        let ciphers = ciphers.iter().map(|c| c.to_raw()).collect::<Vec<_>>();
570        unsafe {
571            cvt(SSLSetEnabledCiphers(
572                self.0,
573                ciphers.as_ptr(),
574                ciphers.len(),
575            ))
576        }
577    }
578
579    /// Returns the cipher being used by the session.
580    #[inline]
581    pub fn negotiated_cipher(&self) -> Result<CipherSuite> {
582        unsafe {
583            let mut cipher = 0;
584            cvt(SSLGetNegotiatedCipher(self.0, &mut cipher))?;
585            Ok(CipherSuite::from_raw(cipher))
586        }
587    }
588
589    /// Sets the requirements for client certificates.
590    ///
591    /// Should only be called on server-side sessions.
592    #[inline]
593    pub fn set_client_side_authenticate(&mut self, auth: SslAuthenticate) -> Result<()> {
594        unsafe { cvt(SSLSetClientSideAuthenticate(self.0, auth.0)) }
595    }
596
597    /// Returns the state of client certificate processing.
598    #[inline]
599    pub fn client_certificate_state(&self) -> Result<SslClientCertificateState> {
600        let mut state = 0;
601
602        unsafe {
603            cvt(SSLGetClientCertificateState(self.0, &mut state))?;
604        }
605        Ok(SslClientCertificateState(state))
606    }
607
608    /// Returns the `SecTrust` object corresponding to the peer.
609    ///
610    /// This can be used in conjunction with `set_break_on_server_auth` to
611    /// validate certificates which do not have roots in the default set.
612    pub fn peer_trust2(&self) -> Result<Option<SecTrust>> {
613        // Calling SSLCopyPeerTrust on an idle connection does not seem to be well defined,
614        // so explicitly check for that
615        if self.state()? == SessionState::IDLE {
616            return Err(Error::from_code(errSecBadReq));
617        }
618
619        unsafe {
620            let mut trust = ptr::null_mut();
621            cvt(SSLCopyPeerTrust(self.0, &mut trust))?;
622            if trust.is_null() {
623                Ok(None)
624            } else {
625                Ok(Some(SecTrust::wrap_under_create_rule(trust)))
626            }
627        }
628    }
629
630    /// Returns the state of the session.
631    #[inline]
632    pub fn state(&self) -> Result<SessionState> {
633        unsafe {
634            let mut state = 0;
635            cvt(SSLGetSessionState(self.0, &mut state))?;
636            Ok(SessionState(state))
637        }
638    }
639
640    /// Returns the protocol version being used by the session.
641    #[inline]
642    pub fn negotiated_protocol_version(&self) -> Result<SslProtocol> {
643        unsafe {
644            let mut version = 0;
645            cvt(SSLGetNegotiatedProtocolVersion(self.0, &mut version))?;
646            Ok(SslProtocol(version))
647        }
648    }
649
650    /// Returns the maximum protocol version allowed by the session.
651    #[inline]
652    pub fn protocol_version_max(&self) -> Result<SslProtocol> {
653        unsafe {
654            let mut version = 0;
655            cvt(SSLGetProtocolVersionMax(self.0, &mut version))?;
656            Ok(SslProtocol(version))
657        }
658    }
659
660    /// Sets the maximum protocol version allowed by the session.
661    #[inline]
662    pub fn set_protocol_version_max(&mut self, max_version: SslProtocol) -> Result<()> {
663        unsafe { cvt(SSLSetProtocolVersionMax(self.0, max_version.0)) }
664    }
665
666    /// Returns the minimum protocol version allowed by the session.
667    #[inline]
668    pub fn protocol_version_min(&self) -> Result<SslProtocol> {
669        unsafe {
670            let mut version = 0;
671            cvt(SSLGetProtocolVersionMin(self.0, &mut version))?;
672            Ok(SslProtocol(version))
673        }
674    }
675
676    /// Sets the minimum protocol version allowed by the session.
677    #[inline]
678    pub fn set_protocol_version_min(&mut self, min_version: SslProtocol) -> Result<()> {
679        unsafe { cvt(SSLSetProtocolVersionMin(self.0, min_version.0)) }
680    }
681
682    /// Returns the set of protocols selected via ALPN if it succeeded.
683    pub fn alpn_protocols(&self) -> Result<Vec<String>> {
684        let mut array: CFArrayRef = ptr::null();
685        unsafe {
686            cvt(SSLCopyALPNProtocols(self.0, &mut array))?;
687
688            if array.is_null() {
689                return Ok(vec![]);
690            }
691
692            let array = CFArray::<CFString>::wrap_under_create_rule(array);
693            Ok(array.into_iter().map(|p| p.to_string()).collect())
694        }
695    }
696
697    /// Configures the set of protocols use for ALPN.
698    ///
699    /// This is only used for client-side connections.
700    pub fn set_alpn_protocols(&mut self, protocols: &[impl AsRef<str>]) -> Result<()> {
701        // When CFMutableArray is added to core-foundation and IntoIterator trait
702        // is implemented for CFMutableArray, the code below should directly collect
703        // into a CFMutableArray.
704        let protocols = CFArray::from_CFTypes(
705            &protocols
706                .iter()
707                .map(|proto| CFString::new(proto.as_ref()))
708                .collect::<Vec<_>>(),
709        );
710
711        unsafe { cvt(SSLSetALPNProtocols(self.0, protocols.as_concrete_TypeRef())) }
712    }
713
714    /// Sets whether the client sends the `SessionTicket` extension in its `ClientHello`.
715    ///
716    /// On its own, this will just cause the client to send an empty `SessionTicket` extension on
717    /// every connection. [`SslContext::set_peer_id`] must also be used to key the session
718    /// ticket returned by the server.
719    ///
720    /// [`SslContext::set_peer_id`]: #method.set_peer_id
721    pub fn set_session_tickets_enabled(&mut self, enabled: bool) -> Result<()> {
722        unsafe { cvt(SSLSetSessionTicketsEnabled(self.0, Boolean::from(enabled))) }
723    }
724
725    /// Returns the number of bytes which can be read without triggering a
726    /// `read` call in the underlying stream.
727    #[inline]
728    pub fn buffered_read_size(&self) -> Result<usize> {
729        unsafe {
730            let mut size = 0;
731            cvt(SSLGetBufferedReadSize(self.0, &mut size))?;
732            Ok(size)
733        }
734    }
735
736    impl_options! {
737        /// If enabled, the handshake process will pause and return instead of
738        /// automatically validating a server's certificate.
739        const kSSLSessionOptionBreakOnServerAuth: break_on_server_auth & set_break_on_server_auth,
740        /// If enabled, the handshake process will pause and return after
741        /// the server requests a certificate from the client.
742        const kSSLSessionOptionBreakOnCertRequested: break_on_cert_requested & set_break_on_cert_requested,
743        /// If enabled, the handshake process will pause and return instead of
744        /// automatically validating a client's certificate.
745        const kSSLSessionOptionBreakOnClientAuth: break_on_client_auth & set_break_on_client_auth,
746        /// If enabled, TLS false start will be performed if an appropriate
747        /// cipher suite is negotiated.
748        ///
749        const kSSLSessionOptionFalseStart: false_start & set_false_start,
750        /// If enabled, 1/n-1 record splitting will be enabled for TLS 1.0
751        /// connections using block ciphers to mitigate the BEAST attack.
752        ///
753        const kSSLSessionOptionSendOneByteRecord: send_one_byte_record & set_send_one_byte_record,
754    }
755
756    fn into_stream<S>(self, stream: S) -> Result<SslStream<S>>
757    where S: Read + Write {
758        unsafe {
759            let ret = SSLSetIOFuncs(self.0, read_func::<S>, write_func::<S>);
760            if ret != errSecSuccess {
761                return Err(Error::from_code(ret));
762            }
763
764            let stream = Connection { stream, err: None, panic: None };
765            let stream = Box::into_raw(Box::new(stream));
766            let ret = SSLSetConnection(self.0, stream.cast());
767            if ret != errSecSuccess {
768                let _conn = Box::from_raw(stream);
769                return Err(Error::from_code(ret));
770            }
771
772            Ok(SslStream { ctx: self, _m: PhantomData })
773        }
774    }
775
776    /// Performs the SSL/TLS handshake.
777    pub fn handshake<S>(self, stream: S) -> result::Result<SslStream<S>, HandshakeError<S>>
778    where
779        S: Read + Write,
780    {
781        self.into_stream(stream)
782            .map_err(HandshakeError::Failure)
783            .and_then(SslStream::handshake)
784    }
785}
786
787struct Connection<S> {
788    stream: S,
789    err: Option<io::Error>,
790    panic: Option<Box<dyn Any + Send>>,
791}
792
793// the logic here is based off of libcurl's
794#[cold]
795fn translate_err(e: &io::Error) -> OSStatus {
796    match e.kind() {
797        io::ErrorKind::NotFound => errSSLClosedGraceful,
798        io::ErrorKind::ConnectionReset => errSSLClosedAbort,
799        io::ErrorKind::WouldBlock |
800        io::ErrorKind::NotConnected => errSSLWouldBlock,
801        _ => errSecIO,
802    }
803}
804
805unsafe extern "C" fn read_func<S>(
806    connection: SSLConnectionRef,
807    data: *mut c_void,
808    data_length: *mut usize,
809) -> OSStatus
810where S: Read {
811    if data.is_null() || data_length.is_null() || connection.is_null() {
812        return errSecParam;
813    }
814
815    let conn: &mut Connection<S> = unsafe { &mut *(connection as *mut _) };
816    let data = unsafe { slice::from_raw_parts_mut(data.cast::<u8>(), *data_length) };
817    let mut read = 0;
818
819    let ret = panic::catch_unwind(AssertUnwindSafe(|| {
820        let mut data = data;
821        while !data.is_empty() {
822            match conn.stream.read(data) {
823                Ok(0) => return errSSLClosedNoNotify,
824                Ok(len) => {
825                    let Some(rest) = data.get_mut(len..) else {
826                        return errSecIO;
827                    };
828                    data = rest;
829                    read += len;
830                },
831                Err(e) => {
832                    let ret = translate_err(&e);
833                    conn.err = Some(e);
834                    return ret;
835                },
836            }
837        }
838        errSecSuccess
839    }))
840    .unwrap_or_else(|e| {
841        conn.panic = Some(e);
842        errSecIO
843    });
844
845    unsafe {
846        *data_length = read;
847    }
848    ret
849}
850
851unsafe extern "C" fn write_func<S>(
852    connection: SSLConnectionRef,
853    data: *const c_void,
854    data_length: *mut usize,
855) -> OSStatus
856where S: Write {
857    if data.is_null() || data_length.is_null() || connection.is_null() {
858        return errSecParam;
859    }
860
861    let conn: &mut Connection<S> = unsafe { &mut *(connection as *mut _) };
862    let mut written = 0;
863    let mut data = unsafe {
864        slice::from_raw_parts(data.cast::<u8>(), *data_length)
865    };
866
867    let ret = panic::catch_unwind(AssertUnwindSafe(|| {
868        while !data.is_empty() {
869            match conn.stream.write(data) {
870                Ok(0) => return errSSLClosedNoNotify,
871                Ok(len) => {
872                    let Some(rest) = data.get(len..) else {
873                        return errSecIO;
874                    };
875                    data = rest;
876                    written += len;
877                },
878                Err(e) => {
879                    let ret = translate_err(&e);
880                    conn.err = Some(e);
881                    return ret;
882                },
883            }
884        }
885        // Need to flush during the handshake so that the handshake doesn't stall on buffered
886        // write streams. It would be better if we only flushed automatically during the
887        // handshake, and not for the remainder of the stream.
888        if let Err(e) = conn.stream.flush() {
889            let ret = translate_err(&e);
890            conn.err = Some(e);
891            return ret;
892        }
893        errSecSuccess
894    }))
895    .unwrap_or_else(|e| {
896        conn.panic = Some(e);
897        errSecIO
898    });
899
900    unsafe { *data_length = written };
901    ret
902}
903
904/// A type implementing SSL/TLS encryption over an underlying stream.
905pub struct SslStream<S> {
906    ctx: SslContext,
907    _m: PhantomData<S>,
908}
909
910impl<S: fmt::Debug> fmt::Debug for SslStream<S> {
911    #[cold]
912    fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
913        fmt.debug_struct("SslStream")
914            .field("context", &self.ctx)
915            .field("stream", self.get_ref())
916            .finish()
917    }
918}
919
920impl<S> Drop for SslStream<S> {
921    fn drop(&mut self) {
922        unsafe {
923            let mut conn = ptr::null();
924            let ret = SSLGetConnection(self.ctx.0, &mut conn);
925            assert!(ret == errSecSuccess);
926            let _ = Box::<Connection<S>>::from_raw(conn as *mut _);
927        }
928    }
929}
930
931impl<S> SslStream<S> {
932    fn handshake(mut self) -> result::Result<Self, HandshakeError<S>> {
933        match unsafe { SSLHandshake(self.ctx.0) } {
934            errSecSuccess => Ok(self),
935            reason @ (errSSLPeerAuthCompleted
936            | errSSLClientCertRequested
937            | errSSLWouldBlock
938            | errSSLClientHelloReceived) => {
939                Err(HandshakeError::Interrupted(MidHandshakeSslStream {
940                    stream: self,
941                    error: Error::from_code(reason),
942                }))
943            },
944            err => {
945                self.check_panic();
946                Err(HandshakeError::Failure(Error::from_code(err)))
947            },
948        }
949    }
950
951    /// Returns a shared reference to the inner stream.
952    #[inline(always)]
953    #[must_use]
954    pub fn get_ref(&self) -> &S {
955        &self.connection().stream
956    }
957
958    /// Returns a mutable reference to the underlying stream.
959    #[inline(always)]
960    pub fn get_mut(&mut self) -> &mut S {
961        &mut self.connection_mut().stream
962    }
963
964    /// Returns a shared reference to the `SslContext` of the stream.
965    #[inline(always)]
966    #[must_use]
967    pub fn context(&self) -> &SslContext {
968        &self.ctx
969    }
970
971    /// Returns a mutable reference to the `SslContext` of the stream.
972    #[inline(always)]
973    pub fn context_mut(&mut self) -> &mut SslContext {
974        &mut self.ctx
975    }
976
977    /// Shuts down the connection.
978    pub fn close(&mut self) -> result::Result<(), io::Error> {
979        unsafe {
980            let ret = SSLClose(self.ctx.0);
981            if ret == errSecSuccess {
982                Ok(())
983            } else {
984                Err(self.get_error(ret))
985            }
986        }
987    }
988
989    fn connection(&self) -> &Connection<S> {
990        unsafe {
991            let mut conn = ptr::null();
992            let ret = SSLGetConnection(self.ctx.0, &mut conn);
993            assert!(ret == errSecSuccess);
994
995            &mut *(conn as *mut Connection<S>)
996        }
997    }
998
999    fn connection_mut(&mut self) -> &mut Connection<S> {
1000        unsafe {
1001            let mut conn = ptr::null();
1002            let ret = SSLGetConnection(self.ctx.0, &mut conn);
1003            assert!(ret == errSecSuccess);
1004
1005            &mut *(conn as *mut Connection<S>)
1006        }
1007    }
1008
1009    #[cold]
1010    fn check_panic(&mut self) {
1011        let conn = self.connection_mut();
1012        if let Some(err) = conn.panic.take() {
1013            panic::resume_unwind(err);
1014        }
1015    }
1016
1017    #[cold]
1018    fn get_error(&mut self, ret: OSStatus) -> io::Error {
1019        self.check_panic();
1020
1021        if let Some(err) = self.connection_mut().err.take() {
1022            err
1023        } else {
1024            io::Error::new(io::ErrorKind::Other, Error::from_code(ret))
1025        }
1026    }
1027}
1028
1029impl<S: Read + Write> Read for SslStream<S> {
1030    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
1031        // Below we base our return value off the amount of data read, so a
1032        // zero-length buffer might cause us to erroneously interpret this
1033        // request as an error. Instead short-circuit that logic and return
1034        // `Ok(0)` instead.
1035        if buf.is_empty() {
1036            return Ok(0);
1037        }
1038
1039        // If some data was buffered but not enough to fill `buf`, SSLRead
1040        // will try to read a new packet. This is bad because there may be
1041        // no more data but the socket is remaining open (e.g HTTPS with
1042        // Connection: keep-alive).
1043        let buffered = self.context().buffered_read_size().unwrap_or(0);
1044        let to_read = if buffered > 0 {
1045            cmp::min(buffered, buf.len())
1046        } else {
1047            buf.len()
1048        };
1049
1050        unsafe {
1051            let mut nread = 0;
1052            let ret = SSLRead(self.ctx.0, buf.as_mut_ptr().cast(), to_read, &mut nread);
1053            // SSLRead can return an error at the same time it returns the last
1054            // chunk of data (!)
1055            if nread > 0 {
1056                return Ok(nread);
1057            }
1058
1059            match ret {
1060                errSSLClosedGraceful | errSSLClosedAbort | errSSLClosedNoNotify => Ok(0),
1061                // this error isn't fatal
1062                errSSLPeerAuthCompleted => self.read(buf),
1063                _ => Err(self.get_error(ret)),
1064            }
1065        }
1066    }
1067}
1068
1069impl<S: Read + Write> Write for SslStream<S> {
1070    fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
1071        // Like above in read, short circuit a 0-length write
1072        if buf.is_empty() {
1073            return Ok(0);
1074        }
1075        unsafe {
1076            let mut nwritten = 0;
1077            let ret = SSLWrite(
1078                self.ctx.0,
1079                buf.as_ptr().cast(),
1080                buf.len(),
1081                &mut nwritten,
1082            );
1083            // just to be safe, base success off of nwritten rather than ret
1084            // for the same reason as in read
1085            if nwritten > 0 {
1086                Ok(nwritten)
1087            } else {
1088                Err(self.get_error(ret))
1089            }
1090        }
1091    }
1092
1093    fn flush(&mut self) -> io::Result<()> {
1094        self.connection_mut().stream.flush()
1095    }
1096}
1097
1098/// A builder type to simplify the creation of client side `SslStream`s.
1099#[derive(Debug)]
1100pub struct ClientBuilder {
1101    identity: Option<SecIdentity>,
1102    certs: Vec<SecCertificate>,
1103    chain: Vec<SecCertificate>,
1104    protocol_min: Option<SslProtocol>,
1105    protocol_max: Option<SslProtocol>,
1106    trust_certs_only: bool,
1107    use_sni: bool,
1108    danger_accept_invalid_certs: bool,
1109    danger_accept_invalid_hostnames: bool,
1110    whitelisted_ciphers: Vec<CipherSuite>,
1111    blacklisted_ciphers: Vec<CipherSuite>,
1112    alpn: Option<Vec<Box<str>>>,
1113    enable_session_tickets: bool,
1114}
1115
1116impl Default for ClientBuilder {
1117    #[inline(always)]
1118    fn default() -> Self {
1119        Self::new()
1120    }
1121}
1122
1123impl ClientBuilder {
1124    /// Creates a new builder with default options.
1125    #[inline]
1126    #[must_use]
1127    pub fn new() -> Self {
1128        Self {
1129            identity: None,
1130            certs: Vec::new(),
1131            chain: Vec::new(),
1132            protocol_min: None,
1133            protocol_max: None,
1134            trust_certs_only: false,
1135            use_sni: true,
1136            danger_accept_invalid_certs: false,
1137            danger_accept_invalid_hostnames: false,
1138            whitelisted_ciphers: Vec::new(),
1139            blacklisted_ciphers: Vec::new(),
1140            alpn: None,
1141            enable_session_tickets: false,
1142        }
1143    }
1144
1145    /// Specifies the set of root certificates to trust when
1146    /// verifying the server's certificate.
1147    #[inline]
1148    pub fn anchor_certificates(&mut self, certs: &[SecCertificate]) -> &mut Self {
1149        certs.clone_into(&mut self.certs);
1150        self
1151    }
1152
1153    /// Add the certificate the set of root certificates to trust
1154    /// when verifying the server's certificate.
1155    #[inline]
1156    pub fn add_anchor_certificate(&mut self, certs: &SecCertificate) -> &mut Self {
1157        self.certs.push(certs.to_owned());
1158        self
1159    }
1160
1161    /// Specifies whether to trust the built-in certificates in addition
1162    /// to specified anchor certificates.
1163    #[inline(always)]
1164    pub fn trust_anchor_certificates_only(&mut self, only: bool) -> &mut Self {
1165        self.trust_certs_only = only;
1166        self
1167    }
1168
1169    /// Specifies whether to trust invalid certificates.
1170    ///
1171    /// # Warning
1172    ///
1173    /// You should think very carefully before using this method. If invalid
1174    /// certificates are trusted, *any* certificate for *any* site will be
1175    /// trusted for use. This includes expired certificates. This introduces
1176    /// significant vulnerabilities, and should only be used as a last resort.
1177    #[inline(always)]
1178    pub fn danger_accept_invalid_certs(&mut self, noverify: bool) -> &mut Self {
1179        self.danger_accept_invalid_certs = noverify;
1180        self
1181    }
1182
1183    /// Specifies whether to use Server Name Indication (SNI).
1184    #[inline(always)]
1185    pub fn use_sni(&mut self, use_sni: bool) -> &mut Self {
1186        self.use_sni = use_sni;
1187        self
1188    }
1189
1190    /// Specifies whether to verify that the server's hostname matches its certificate.
1191    ///
1192    /// # Warning
1193    ///
1194    /// You should think very carefully before using this method. If hostnames are not verified,
1195    /// *any* valid certificate for *any* site will be trusted for use. This introduces significant
1196    /// vulnerabilities, and should only be used as a last resort.
1197    #[inline(always)]
1198    pub fn danger_accept_invalid_hostnames(
1199        &mut self,
1200        danger_accept_invalid_hostnames: bool,
1201    ) -> &mut Self {
1202        self.danger_accept_invalid_hostnames = danger_accept_invalid_hostnames;
1203        self
1204    }
1205
1206    /// Set a whitelist of enabled ciphers. Any ciphers not whitelisted will be disabled.
1207    pub fn whitelist_ciphers(&mut self, whitelisted_ciphers: &[CipherSuite]) -> &mut Self {
1208        whitelisted_ciphers.clone_into(&mut self.whitelisted_ciphers);
1209        self
1210    }
1211
1212    /// Set a blacklist of disabled ciphers. Blacklisted ciphers will be disabled.
1213    pub fn blacklist_ciphers(&mut self, blacklisted_ciphers: &[CipherSuite]) -> &mut Self {
1214        blacklisted_ciphers.clone_into(&mut self.blacklisted_ciphers);
1215        self
1216    }
1217
1218    /// Use the specified identity as a SSL/TLS client certificate.
1219    pub fn identity(&mut self, identity: &SecIdentity, chain: &[SecCertificate]) -> &mut Self {
1220        self.identity = Some(identity.clone());
1221        chain.clone_into(&mut self.chain);
1222        self
1223    }
1224
1225    /// Configure the minimum protocol that this client will support.
1226    #[inline(always)]
1227    pub fn protocol_min(&mut self, min: SslProtocol) -> &mut Self {
1228        self.protocol_min = Some(min);
1229        self
1230    }
1231
1232    /// Configure the minimum protocol that this client will support.
1233    #[inline(always)]
1234    pub fn protocol_max(&mut self, max: SslProtocol) -> &mut Self {
1235        self.protocol_max = Some(max);
1236        self
1237    }
1238
1239    /// Configures the set of protocols used for ALPN.
1240    pub fn alpn_protocols(&mut self, protocols: &[&str]) -> &mut Self {
1241        self.alpn = Some(protocols.iter().copied().map(Box::from).collect());
1242        self
1243    }
1244
1245    /// Configures the use of the RFC 5077 `SessionTicket` extension.
1246    ///
1247    /// Defaults to `false`.
1248    #[inline(always)]
1249    pub fn enable_session_tickets(&mut self, enable: bool) -> &mut Self {
1250        self.enable_session_tickets = enable;
1251        self
1252    }
1253
1254    /// Initiates a new SSL/TLS session over a stream connected to the specified domain.
1255    ///
1256    /// If both SNI and hostname verification are disabled, the value of `domain` will be ignored.
1257    pub fn handshake<S>(
1258        &self,
1259        domain: &str,
1260        stream: S,
1261    ) -> result::Result<SslStream<S>, ClientHandshakeError<S>>
1262    where
1263        S: Read + Write,
1264    {
1265        // the logic for trust validation is in MidHandshakeClientBuilder::connect, so run all
1266        // of the handshake logic through that.
1267        let stream = MidHandshakeSslStream {
1268            stream: self.ctx_into_stream(domain, stream)?,
1269            error: Error::from(errSecSuccess),
1270        };
1271
1272        let certs = self.certs.clone();
1273        let stream = MidHandshakeClientBuilder {
1274            stream,
1275            domain: if self.danger_accept_invalid_hostnames {
1276                None
1277            } else {
1278                Some(domain.to_string())
1279            },
1280            certs,
1281            trust_certs_only: self.trust_certs_only,
1282            danger_accept_invalid_certs: self.danger_accept_invalid_certs,
1283        };
1284        stream.handshake()
1285    }
1286
1287    fn ctx_into_stream<S>(&self, domain: &str, stream: S) -> Result<SslStream<S>>
1288    where S: Read + Write {
1289        let mut ctx = SslContext::new(SslProtocolSide::CLIENT, SslConnectionType::STREAM)?;
1290
1291        if self.use_sni {
1292            ctx.set_peer_domain_name(domain)?;
1293        }
1294        if let Some(identity) = &self.identity {
1295            ctx.set_certificate(identity, &self.chain)?;
1296        }
1297        if let Some(alpn) = &self.alpn {
1298            ctx.set_alpn_protocols(alpn)?;
1299        }
1300        if self.enable_session_tickets {
1301            // We must use the domain here to ensure that we go through certificate validation
1302            // again rather than resuming the session if the domain changes.
1303            ctx.set_peer_id(domain.as_bytes())?;
1304            ctx.set_session_tickets_enabled(true)?;
1305        }
1306        ctx.set_break_on_server_auth(true)?;
1307        self.configure_protocols(&mut ctx)?;
1308        self.configure_ciphers(&mut ctx)?;
1309
1310        ctx.into_stream(stream)
1311    }
1312
1313    fn configure_protocols(&self, ctx: &mut SslContext) -> Result<()> {
1314        if let Some(min) = self.protocol_min {
1315            ctx.set_protocol_version_min(min)?;
1316        }
1317        if let Some(max) = self.protocol_max {
1318            ctx.set_protocol_version_max(max)?;
1319        }
1320        Ok(())
1321    }
1322
1323    fn configure_ciphers(&self, ctx: &mut SslContext) -> Result<()> {
1324        let mut ciphers = if self.whitelisted_ciphers.is_empty() {
1325            ctx.enabled_ciphers()?
1326        } else {
1327            self.whitelisted_ciphers.clone()
1328        };
1329
1330        if !self.blacklisted_ciphers.is_empty() {
1331            ciphers.retain(|cipher| !self.blacklisted_ciphers.contains(cipher));
1332        }
1333
1334        ctx.set_enabled_ciphers(&ciphers)?;
1335        Ok(())
1336    }
1337}
1338
1339/// A builder type to simplify the creation of server-side `SslStream`s.
1340#[derive(Debug)]
1341pub struct ServerBuilder {
1342    identity: SecIdentity,
1343    certs: Vec<SecCertificate>,
1344}
1345
1346impl ServerBuilder {
1347    /// Creates a new `ServerBuilder` which will use the specified identity
1348    /// and certificate chain for handshakes.
1349    #[must_use]
1350    pub fn new(identity: &SecIdentity, certs: &[SecCertificate]) -> Self {
1351        Self {
1352            identity: identity.clone(),
1353            certs: certs.to_owned(),
1354        }
1355    }
1356
1357    /// Creates a new `ServerBuilder` which will use the identity
1358    /// from the given PKCS #12 data.
1359    ///
1360    /// This operation fails if PKCS #12 file contains zero or more than one identity.
1361    ///
1362    /// This is a shortcut for the most common operation.
1363    pub fn from_pkcs12(pkcs12_der: &[u8], passphrase: &str) -> Result<Self> {
1364        let mut identities: Vec<(SecIdentity, Vec<SecCertificate>)> = Pkcs12ImportOptions::new()
1365            .passphrase(passphrase)
1366            .import(pkcs12_der)?
1367            .into_iter()
1368            .filter_map(|idendity| {
1369                Some((idendity.identity?, idendity.cert_chain.unwrap_or_default()))
1370            })
1371            .take(2)
1372            .collect();
1373        if identities.len() == 1 {
1374            let (identity, certs) = identities.pop().unwrap();
1375            Ok(Self { identity, certs })
1376        } else {
1377            // This error code is not really helpful
1378            Err(Error::from_code(errSecParam))
1379        }
1380    }
1381
1382    /// Create a SSL context for lower-level stream initialization.
1383    pub fn new_ssl_context(&self) -> Result<SslContext> {
1384        let mut ctx = SslContext::new(SslProtocolSide::SERVER, SslConnectionType::STREAM)?;
1385        ctx.set_certificate(&self.identity, &self.certs)?;
1386        Ok(ctx)
1387    }
1388
1389    /// Initiates a new SSL/TLS session over a stream.
1390    pub fn handshake<S>(&self, stream: S) -> Result<SslStream<S>>
1391    where S: Read + Write {
1392        match self.new_ssl_context()?.handshake(stream) {
1393            Ok(stream) => Ok(stream),
1394            Err(HandshakeError::Interrupted(stream)) => Err(*stream.error()),
1395            Err(HandshakeError::Failure(err)) => Err(err),
1396        }
1397    }
1398}
1399
1400#[cfg(test)]
1401mod test {
1402    use std::io::prelude::*;
1403    use std::net::TcpStream;
1404
1405    use super::*;
1406
1407    #[test]
1408    fn server_builder_from_pkcs12() {
1409        let pkcs12_der = include_bytes!("../test/server.p12");
1410        ServerBuilder::from_pkcs12(pkcs12_der, "password123").unwrap();
1411    }
1412
1413    #[test]
1414    fn connect() {
1415        let mut ctx = p!(SslContext::new(SslProtocolSide::CLIENT, SslConnectionType::STREAM));
1416        p!(ctx.set_peer_domain_name("google.com"));
1417        let stream = p!(TcpStream::connect("google.com:443"));
1418        p!(ctx.handshake(stream));
1419    }
1420
1421    #[test]
1422    fn connect_bad_domain() {
1423        let mut ctx = p!(SslContext::new(SslProtocolSide::CLIENT, SslConnectionType::STREAM));
1424        p!(ctx.set_peer_domain_name("foobar.com"));
1425        let stream = p!(TcpStream::connect("google.com:443"));
1426        ctx.handshake(stream).expect_err("expected failure");
1427    }
1428
1429    #[test]
1430    fn connect_buffered_stream() {
1431        use std::io::BufWriter;
1432
1433        /// Small wrapper around a `TcpStream` to provide buffered writes.
1434        #[derive(Debug)]
1435        struct BufferedTcpStream {
1436            reader: TcpStream,
1437            writer: BufWriter<TcpStream>,
1438        }
1439
1440        impl BufferedTcpStream {
1441            fn new(tcp: TcpStream) -> std::io::Result<Self> {
1442                Ok(Self {
1443                    writer: BufWriter::with_capacity(500, tcp.try_clone()?),
1444                    reader: tcp,
1445                })
1446            }
1447        }
1448
1449        impl Read for BufferedTcpStream {
1450            fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
1451                self.reader.read(buf)
1452            }
1453        }
1454
1455        impl Write for BufferedTcpStream {
1456            fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
1457                self.writer.write(buf)
1458            }
1459
1460            fn flush(&mut self) -> std::io::Result<()> {
1461                self.writer.flush()
1462            }
1463        }
1464
1465        let mut ctx = p!(SslContext::new(SslProtocolSide::CLIENT, SslConnectionType::STREAM));
1466        p!(ctx.set_peer_domain_name("google.com"));
1467        let stream = p!(TcpStream::connect("google.com:443"));
1468        let stream = p!(BufferedTcpStream::new(stream));
1469        p!(ctx.handshake(stream));
1470    }
1471
1472    #[test]
1473    fn load_page() {
1474        let mut ctx = p!(SslContext::new(SslProtocolSide::CLIENT, SslConnectionType::STREAM));
1475        p!(ctx.set_peer_domain_name("google.com"));
1476        let stream = p!(TcpStream::connect("google.com:443"));
1477        let mut stream = p!(ctx.handshake(stream));
1478        p!(stream.write_all(b"GET / HTTP/1.0\r\n\r\n"));
1479        p!(stream.flush());
1480        let mut buf = vec![];
1481        p!(stream.read_to_end(&mut buf));
1482        println!("{}", String::from_utf8_lossy(&buf));
1483    }
1484
1485    #[test]
1486    fn client_no_session_ticket_resumption() {
1487        for _ in 0..2 {
1488            let stream = p!(TcpStream::connect("google.com:443"));
1489
1490            // Manually handshake here.
1491            let stream = MidHandshakeSslStream {
1492                stream: ClientBuilder::new()
1493                    .ctx_into_stream("google.com", stream)
1494                    .unwrap(),
1495                error: Error::from(errSecSuccess),
1496            };
1497
1498            let mut result = stream.handshake();
1499
1500            if let Err(HandshakeError::Interrupted(stream)) = result {
1501                assert!(stream.server_auth_completed());
1502                result = stream.handshake();
1503            } else {
1504                panic!("Unexpectedly skipped server auth");
1505            }
1506
1507            assert!(result.is_ok());
1508        }
1509    }
1510
1511    #[test]
1512    fn client_session_ticket_resumption() {
1513        // The first time through this loop, we should do a full handshake. The second time, we
1514        // should immediately finish the handshake without breaking on server auth.
1515        for i in 0..2 {
1516            let stream = p!(TcpStream::connect("google.com:443"));
1517            let mut builder = ClientBuilder::new();
1518            builder.enable_session_tickets(true);
1519
1520            // Manually handshake here.
1521            let stream = MidHandshakeSslStream {
1522                stream: builder.ctx_into_stream("google.com", stream).unwrap(),
1523                error: Error::from(errSecSuccess),
1524            };
1525
1526            let mut result = stream.handshake();
1527
1528            if let Err(HandshakeError::Interrupted(stream)) = result {
1529                assert!(stream.server_auth_completed());
1530                assert_eq!(i, 0, "Session ticket resumption did not work, server auth was not skipped");
1531                result = stream.handshake();
1532            } else {
1533                assert_eq!(i, 1, "Unexpectedly skipped server auth");
1534            }
1535
1536            assert!(result.is_ok());
1537        }
1538    }
1539
1540    #[test]
1541    fn client_alpn_accept() {
1542        let mut ctx = p!(SslContext::new(SslProtocolSide::CLIENT, SslConnectionType::STREAM));
1543        p!(ctx.set_peer_domain_name("google.com"));
1544        p!(ctx.set_alpn_protocols(&["h2"]));
1545        let stream = p!(TcpStream::connect("google.com:443"));
1546        let stream = ctx.handshake(stream).unwrap();
1547        assert_eq!(vec!["h2"], stream.context().alpn_protocols().unwrap());
1548    }
1549
1550    #[test]
1551    fn client_alpn_reject() {
1552        let mut ctx = p!(SslContext::new(SslProtocolSide::CLIENT, SslConnectionType::STREAM));
1553        p!(ctx.set_peer_domain_name("google.com"));
1554        p!(ctx.set_alpn_protocols(&["h2c"]));
1555        let stream = p!(TcpStream::connect("google.com:443"));
1556        let stream = ctx.handshake(stream).unwrap();
1557        assert!(stream.context().alpn_protocols().is_err());
1558    }
1559
1560    #[test]
1561    fn client_no_anchor_certs() {
1562        let stream = p!(TcpStream::connect("google.com:443"));
1563        assert!(ClientBuilder::new()
1564            .trust_anchor_certificates_only(true)
1565            .handshake("google.com", stream)
1566            .is_err());
1567    }
1568
1569    #[test]
1570    fn client_bad_domain() {
1571        let stream = p!(TcpStream::connect("google.com:443"));
1572        assert!(ClientBuilder::new()
1573            .handshake("foobar.com", stream)
1574            .is_err());
1575    }
1576
1577    #[test]
1578    fn client_bad_domain_ignored() {
1579        let stream = p!(TcpStream::connect("google.com:443"));
1580        ClientBuilder::new()
1581            .danger_accept_invalid_hostnames(true)
1582            .handshake("foobar.com", stream)
1583            .unwrap();
1584    }
1585
1586    #[test]
1587    fn connect_no_verify_ssl() {
1588        let stream = p!(TcpStream::connect("expired.badssl.com:443"));
1589        let mut builder = ClientBuilder::new();
1590        builder.danger_accept_invalid_certs(true);
1591        builder.handshake("expired.badssl.com", stream).unwrap();
1592    }
1593
1594    #[test]
1595    fn load_page_client() {
1596        let stream = p!(TcpStream::connect("google.com:443"));
1597        let mut stream = p!(ClientBuilder::new().handshake("google.com", stream));
1598        p!(stream.write_all(b"GET / HTTP/1.0\r\n\r\n"));
1599        p!(stream.flush());
1600        let mut buf = vec![];
1601        p!(stream.read_to_end(&mut buf));
1602        println!("{}", String::from_utf8_lossy(&buf));
1603    }
1604
1605    #[test]
1606    #[cfg_attr(any(target_os = "ios", target_os = "tvos", target_os = "watchos", target_os = "visionos"), ignore)] // FIXME what's going on with ios?
1607    fn cipher_configuration() {
1608        let mut ctx = p!(SslContext::new(SslProtocolSide::SERVER, SslConnectionType::STREAM));
1609        let ciphers = p!(ctx.enabled_ciphers());
1610        let ciphers = ciphers
1611            .iter()
1612            .enumerate()
1613            .filter_map(|(i, c)| if i % 2 == 0 { Some(*c) } else { None })
1614            .collect::<Vec<_>>();
1615        p!(ctx.set_enabled_ciphers(&ciphers));
1616        assert_eq!(ciphers, p!(ctx.enabled_ciphers()));
1617    }
1618
1619    #[test]
1620    fn test_builder_whitelist_ciphers() {
1621        let stream = p!(TcpStream::connect("google.com:443"));
1622
1623        let ctx = p!(SslContext::new(SslProtocolSide::CLIENT, SslConnectionType::STREAM));
1624        assert!(p!(ctx.enabled_ciphers()).len() > 1);
1625
1626        let ciphers = p!(ctx.enabled_ciphers());
1627        let cipher = ciphers.first().unwrap();
1628        let stream = p!(ClientBuilder::new()
1629            .whitelist_ciphers(&[*cipher])
1630            .ctx_into_stream("google.com", stream));
1631
1632        assert_eq!(1, p!(stream.context().enabled_ciphers()).len());
1633    }
1634
1635    #[test]
1636    #[cfg_attr(any(target_os = "ios", target_os = "tvos", target_os = "watchos", target_os = "visionos"), ignore)] // FIXME same issue as cipher_configuration
1637    fn test_builder_blacklist_ciphers() {
1638        let stream = p!(TcpStream::connect("google.com:443"));
1639
1640        let ctx = p!(SslContext::new(SslProtocolSide::CLIENT, SslConnectionType::STREAM));
1641        let num = p!(ctx.enabled_ciphers()).len();
1642        assert!(num > 1);
1643
1644        let ciphers = p!(ctx.enabled_ciphers());
1645        let cipher = ciphers.first().unwrap();
1646        let stream = p!(ClientBuilder::new()
1647            .blacklist_ciphers(&[*cipher])
1648            .ctx_into_stream("google.com", stream));
1649
1650        assert_eq!(num - 1, p!(stream.context().enabled_ciphers()).len());
1651    }
1652
1653    #[test]
1654    fn idle_context_peer_trust() {
1655        let ctx = p!(SslContext::new(SslProtocolSide::SERVER, SslConnectionType::STREAM));
1656        assert!(ctx.peer_trust2().is_err());
1657    }
1658
1659    #[test]
1660    fn peer_id() {
1661        let mut ctx = p!(SslContext::new(SslProtocolSide::SERVER, SslConnectionType::STREAM));
1662        assert!(p!(ctx.peer_id()).is_none());
1663        p!(ctx.set_peer_id(b"foobar"));
1664        assert_eq!(p!(ctx.peer_id()), Some(&b"foobar"[..]));
1665    }
1666
1667    #[test]
1668    fn peer_domain_name() {
1669        let mut ctx = p!(SslContext::new(SslProtocolSide::CLIENT, SslConnectionType::STREAM));
1670        assert_eq!("", p!(ctx.peer_domain_name()));
1671        p!(ctx.set_peer_domain_name("foobar.com"));
1672        assert_eq!("foobar.com", p!(ctx.peer_domain_name()));
1673    }
1674
1675    #[test]
1676    #[should_panic(expected = "blammo")]
1677    fn write_panic() {
1678        struct ExplodingStream(TcpStream);
1679
1680        impl Read for ExplodingStream {
1681            fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
1682                self.0.read(buf)
1683            }
1684        }
1685
1686        impl Write for ExplodingStream {
1687            fn write(&mut self, _: &[u8]) -> io::Result<usize> {
1688                panic!("blammo");
1689            }
1690
1691            fn flush(&mut self) -> io::Result<()> {
1692                self.0.flush()
1693            }
1694        }
1695
1696        let mut ctx = p!(SslContext::new(SslProtocolSide::CLIENT, SslConnectionType::STREAM));
1697        p!(ctx.set_peer_domain_name("google.com"));
1698        let stream = p!(TcpStream::connect("google.com:443"));
1699        let _ = ctx.handshake(ExplodingStream(stream));
1700    }
1701
1702    #[test]
1703    #[should_panic(expected = "blammo")]
1704    fn read_panic() {
1705        struct ExplodingStream(TcpStream);
1706
1707        impl Read for ExplodingStream {
1708            fn read(&mut self, _: &mut [u8]) -> io::Result<usize> {
1709                panic!("blammo");
1710            }
1711        }
1712
1713        impl Write for ExplodingStream {
1714            fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
1715                self.0.write(buf)
1716            }
1717
1718            fn flush(&mut self) -> io::Result<()> {
1719                self.0.flush()
1720            }
1721        }
1722
1723        let mut ctx = p!(SslContext::new(SslProtocolSide::CLIENT, SslConnectionType::STREAM));
1724        p!(ctx.set_peer_domain_name("google.com"));
1725        let stream = p!(TcpStream::connect("google.com:443"));
1726        let _ = ctx.handshake(ExplodingStream(stream));
1727    }
1728
1729    #[test]
1730    fn zero_length_buffers() {
1731        let mut ctx = p!(SslContext::new(SslProtocolSide::CLIENT, SslConnectionType::STREAM));
1732        p!(ctx.set_peer_domain_name("google.com"));
1733        let stream = p!(TcpStream::connect("google.com:443"));
1734        let mut stream = ctx.handshake(stream).unwrap();
1735        assert_eq!(stream.write(b"").unwrap(), 0);
1736        assert_eq!(stream.read(&mut []).unwrap(), 0);
1737    }
1738}