tcp_stream/
lib.rs

1#![deny(missing_docs)]
2
3//! # std::net::TCP stream on steroids
4//!
5//! tcp-stream is a library aiming at providing TLS support to std::net::TcpStream
6//!
7//! # Examples
8//!
9//! To connect to a remote server:
10//!
11//! ```rust
12//! use tcp_stream::{HandshakeError, TcpStream, TLSConfig};
13//!
14//! use std::io::{self, Read, Write};
15//!
16//! let mut stream = TcpStream::connect("www.rust-lang.org:443").unwrap();
17//! stream.set_nonblocking(true).unwrap();
18//!
19//! while !stream.is_connected() {
20//!     if stream.try_connect().unwrap() {
21//!         break;
22//!     }
23//! }
24//!
25//! let mut stream = stream.into_tls("www.rust-lang.org", TLSConfig::default());
26//!
27//! while let Err(HandshakeError::WouldBlock(mid_handshake)) = stream {
28//!     stream = mid_handshake.handshake();
29//! }
30//!
31//! let mut stream = stream.unwrap();
32//!
33//! while let Err(err) = stream.write_all(b"GET / HTTP/1.0\r\n\r\n") {
34//!     if err.kind() != io::ErrorKind::WouldBlock {
35//!         panic!("error: {:?}", err);
36//!     }
37//! }
38//! stream.flush().unwrap();
39//! let mut res = vec![];
40//! while let Err(err) = stream.read_to_end(&mut res) {
41//!     if err.kind() != io::ErrorKind::WouldBlock {
42//!         panic!("stream error: {:?}", err);
43//!     }
44//! }
45//! println!("{}", String::from_utf8_lossy(&res));
46//! ```
47
48use cfg_if::cfg_if;
49use std::{
50    convert::TryFrom,
51    error::Error,
52    fmt,
53    io::{self, IoSlice, IoSliceMut, Read, Write},
54    net::{TcpStream as StdTcpStream, ToSocketAddrs},
55    ops::{Deref, DerefMut},
56    time::Duration,
57};
58
59#[cfg(feature = "native-tls")]
60/// Reexport native-tls's `TlsConnector`
61pub use native_tls::TlsConnector as NativeTlsConnector;
62
63#[cfg(feature = "native-tls")]
64/// A `TcpStream` wrapped by native-tls
65pub type NativeTlsStream = native_tls::TlsStream<TcpStream>;
66
67#[cfg(feature = "native-tls")]
68/// A `MidHandshakeTlsStream` from native-tls
69pub type NativeTlsMidHandshakeTlsStream = native_tls::MidHandshakeTlsStream<TcpStream>;
70
71#[cfg(feature = "native-tls")]
72/// A `HandshakeError` from native-tls
73pub type NativeTlsHandshakeError = native_tls::HandshakeError<TcpStream>;
74
75#[cfg(feature = "openssl")]
76/// Reexport openssl's `TlsConnector`
77pub use openssl::ssl::{SslConnector as OpenSslConnector, SslMethod as OpenSslMethod};
78
79#[cfg(feature = "openssl")]
80/// A `TcpStream` wrapped by openssl
81pub type OpenSslStream = openssl::ssl::SslStream<TcpStream>;
82
83#[cfg(feature = "openssl")]
84/// A `MidHandshakeTlsStream` from openssl
85pub type OpenSslMidHandshakeTlsStream = openssl::ssl::MidHandshakeSslStream<TcpStream>;
86
87#[cfg(feature = "openssl")]
88/// A `HandshakeError` from openssl
89pub type OpenSslHandshakeError = openssl::ssl::HandshakeError<TcpStream>;
90
91#[cfg(feature = "openssl")]
92/// An `ErrorStack` from openssl
93pub type OpenSslErrorStack = openssl::error::ErrorStack;
94
95#[cfg(feature = "rustls-common")]
96/// Reexport rustls-connector's `TlsConnector`
97pub use rustls_connector::{RustlsConnector, RustlsConnectorConfig};
98
99#[cfg(feature = "rustls-common")]
100/// A `TcpStream` wrapped by rustls
101pub type RustlsStream = rustls_connector::TlsStream<TcpStream>;
102
103#[cfg(feature = "rustls-common")]
104/// A `MidHandshakeTlsStream` from rustls-connector
105pub type RustlsMidHandshakeTlsStream = rustls_connector::MidHandshakeTlsStream<TcpStream>;
106
107#[cfg(feature = "rustls-common")]
108/// A `HandshakeError` from rustls-connector
109pub type RustlsHandshakeError = rustls_connector::HandshakeError<TcpStream>;
110
111/// Wrapper around plain or TLS TCP streams
112pub enum TcpStream {
113    /// Wrapper around std::net::TcpStream
114    Plain(StdTcpStream, bool),
115    #[cfg(feature = "native-tls")]
116    /// Wrapper around a TLS stream hanled by native-tls
117    NativeTls(Box<NativeTlsStream>),
118    #[cfg(feature = "openssl")]
119    /// Wrapper around a TLS stream hanled by openssl
120    OpenSsl(Box<OpenSslStream>),
121    #[cfg(feature = "rustls-common")]
122    /// Wrapper around a TLS stream hanled by rustls
123    Rustls(Box<RustlsStream>),
124}
125
126/// Holds extra TLS configuration
127#[derive(Default, Debug, PartialEq)]
128pub struct TLSConfig<'data, 'key, 'chain> {
129    /// Use for client certificate authentication
130    pub identity: Option<Identity<'data, 'key>>,
131    /// The custom certificates chain in PEM format
132    pub cert_chain: Option<&'chain str>,
133}
134
135/// Holds extra TLS configuration
136#[derive(Default, Debug, PartialEq)]
137pub struct OwnedTLSConfig {
138    /// Use for client certificate authentication
139    pub identity: Option<OwnedIdentity>,
140    /// The custom certificates chain in PEM format
141    pub cert_chain: Option<String>,
142}
143
144impl OwnedTLSConfig {
145    /// Get the ephemeral `TLSConfig` corresponding to the `OwnedTLSConfig`
146    #[must_use]
147    pub fn as_ref(&self) -> TLSConfig<'_, '_, '_> {
148        TLSConfig {
149            identity: self.identity.as_ref().map(OwnedIdentity::as_ref),
150            cert_chain: self.cert_chain.as_deref(),
151        }
152    }
153}
154
155/// Holds one of:
156/// - PKCS#12 DER-encoded identity and decryption password
157/// - PKCS#8 PEM-encoded certificate and key (without decryption password)
158#[derive(Debug, PartialEq)]
159pub enum Identity<'data, 'key> {
160    /// PKCS#12 DER-encoded identity with decryption password
161    PKCS12 {
162        /// PKCS#12 DER-encoded identity
163        der: &'data [u8],
164        /// Decryption password
165        password: &'key str,
166    },
167    /// PEM encoded DER private key with PEM encoded certificate
168    PKCS8 {
169        /// PEM-encoded certificate
170        pem: &'data [u8],
171        /// PEM-encoded key
172        key: &'key [u8],
173    },
174}
175
176/// Holds one of:
177/// - PKCS#12 DER-encoded identity and decryption password
178/// - PKCS#8 PEM-encoded certificate and key (without decryption password)
179#[derive(Debug, PartialEq)]
180pub enum OwnedIdentity {
181    /// PKCS#12 DER-encoded identity with decryption password
182    PKCS12 {
183        /// PKCS#12 DER-encoded identity
184        der: Vec<u8>,
185        /// Decryption password
186        password: String,
187    },
188    /// PKCS#8 encoded DER private key with PEM encoded certificate
189    PKCS8 {
190        /// PEM-encoded certificate
191        pem: Vec<u8>,
192        /// PEM-encoded key
193        key: Vec<u8>,
194    },
195}
196
197impl OwnedIdentity {
198    /// Get the ephemeral `Identity` corresponding to the `OwnedIdentity`
199    #[must_use]
200    pub fn as_ref(&self) -> Identity<'_, '_> {
201        match self {
202            Self::PKCS8 { pem, key } => Identity::PKCS8 { pem, key },
203            Self::PKCS12 { der, password } => Identity::PKCS12 { der, password },
204        }
205    }
206}
207
208/// Holds either the TLS `TcpStream` result or the current handshake state
209pub type HandshakeResult = Result<TcpStream, HandshakeError>;
210
211impl TcpStream {
212    /// Wrapper around `std::net::TcpStream::connect`
213    pub fn connect<A: ToSocketAddrs>(addr: A) -> io::Result<Self> {
214        connect_std(addr, None).and_then(Self::try_from)
215    }
216
217    /// Wrapper around `std::net::TcpStream::connect_timeout`
218    pub fn connect_timeout<A: ToSocketAddrs>(addr: A, timeout: Duration) -> io::Result<Self> {
219        connect_std(addr, Some(timeout)).and_then(Self::try_from)
220    }
221
222    /// Convert from a `std::net::TcpStream`
223    pub fn from_std(stream: StdTcpStream) -> io::Result<Self> {
224        Self::try_from(stream)
225    }
226
227    /// Check whether the stream is connected or not
228    #[must_use]
229    #[allow(irrefutable_let_patterns)]
230    pub fn is_connected(&self) -> bool {
231        if let Self::Plain(_, connected) = self {
232            *connected
233        } else {
234            true
235        }
236    }
237
238    /// Retry the connection. Returns:
239    /// - Ok(true) if connected
240    /// - Ok(false) if connecting
241    /// - Err(_) if an error is encountered
242    #[allow(irrefutable_let_patterns)]
243    pub fn try_connect(&mut self) -> io::Result<bool> {
244        if self.is_connected() {
245            return Ok(true);
246        }
247        match self.is_writable() {
248            Ok(()) => {
249                if let Self::Plain(_, connected) = self {
250                    *connected = true;
251                }
252                Ok(true)
253            }
254            Err(err)
255                if [io::ErrorKind::WouldBlock, io::ErrorKind::NotConnected]
256                    .contains(&err.kind()) =>
257            {
258                Ok(false)
259            }
260            Err(err) => Err(err),
261        }
262    }
263
264    /// Enable TLS
265    pub fn into_tls(
266        self,
267        domain: &str,
268        config: TLSConfig<'_, '_, '_>,
269    ) -> Result<Self, HandshakeError> {
270        into_tls_impl(self, domain, config)
271    }
272
273    #[cfg(feature = "native-tls")]
274    /// Enable TLS using native-tls
275    pub fn into_native_tls(
276        self,
277        connector: &NativeTlsConnector,
278        domain: &str,
279    ) -> Result<Self, HandshakeError> {
280        Ok(connector.connect(domain, self.into_plain()?)?.into())
281    }
282
283    #[cfg(feature = "openssl")]
284    /// Enable TLS using openssl
285    pub fn into_openssl(
286        self,
287        connector: &OpenSslConnector,
288        domain: &str,
289    ) -> Result<Self, HandshakeError> {
290        Ok(connector.connect(domain, self.into_plain()?)?.into())
291    }
292
293    #[cfg(feature = "rustls-common")]
294    /// Enable TLS using rustls
295    pub fn into_rustls(
296        self,
297        connector: &RustlsConnector,
298        domain: &str,
299    ) -> Result<Self, HandshakeError> {
300        Ok(connector.connect(domain, self.into_plain()?)?.into())
301    }
302
303    #[allow(irrefutable_let_patterns)]
304    fn into_plain(self) -> Result<TcpStream, io::Error> {
305        if let TcpStream::Plain(plain, connected) = self {
306            Ok(TcpStream::Plain(plain, connected))
307        } else {
308            Err(io::Error::new(
309                io::ErrorKind::AlreadyExists,
310                "already a TLS stream",
311            ))
312        }
313    }
314}
315
316fn connect_std<A: ToSocketAddrs>(addr: A, timeout: Option<Duration>) -> io::Result<StdTcpStream> {
317    let stream = connect_std_raw(addr, timeout)?;
318    stream.set_nodelay(true)?;
319    Ok(stream)
320}
321
322fn connect_std_raw<A: ToSocketAddrs>(
323    addr: A,
324    timeout: Option<Duration>,
325) -> io::Result<StdTcpStream> {
326    let mut addrs = addr.to_socket_addrs()?;
327    let mut err = None;
328    if let Some(timeout) = timeout {
329        if let Some(addr) = addrs.next() {
330            match StdTcpStream::connect_timeout(&addr, timeout) {
331                Ok(stream) => return Ok(stream),
332                Err(error) => err = Some(error),
333            }
334        }
335    }
336    for addr in addrs {
337        match StdTcpStream::connect(addr) {
338            Ok(stream) => return Ok(stream),
339            Err(error) => err = Some(error),
340        }
341    }
342    Err(err.unwrap_or_else(|| {
343        io::Error::new(io::ErrorKind::AddrNotAvailable, "couldn't resolve host")
344    }))
345}
346
347#[cfg(feature = "rustls-common")]
348fn into_rustls_common(
349    s: TcpStream,
350    mut c: RustlsConnectorConfig,
351    domain: &str,
352    config: TLSConfig<'_, '_, '_>,
353) -> HandshakeResult {
354    use rustls_connector::rustls_pki_types::{
355        pem::PemObject, CertificateDer, PrivateKeyDer, PrivatePkcs8KeyDer,
356    };
357
358    if let Some(cert_chain) = config.cert_chain {
359        let mut cert_chain = std::io::BufReader::new(cert_chain.as_bytes());
360        let certs = rustls_pemfile::certs(&mut cert_chain)
361            .collect::<Result<Vec<_>, _>>()
362            .map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err))?;
363        c.add_parsable_certificates(certs);
364    }
365    let connector = if let Some(identity) = config.identity {
366        let (certs, key) = match identity {
367            Identity::PKCS12 { der, password } => {
368                let pfx = p12_keystore::KeyStore::from_pkcs12(der, password)
369                    .map_err(|err| io::Error::new(io::ErrorKind::Other, err))?;
370                let Some((_, keychain)) = pfx.private_key_chain() else {
371                    return Err(io::Error::new(
372                        io::ErrorKind::Other,
373                        "No private key in pkcs12 DER",
374                    )
375                    .into());
376                };
377                let certs = keychain
378                    .chain()
379                    .iter()
380                    .map(|cert| CertificateDer::from(cert.as_der().to_vec()))
381                    .collect();
382                (
383                    certs,
384                    PrivateKeyDer::from(PrivatePkcs8KeyDer::from(keychain.key().to_vec())),
385                )
386            }
387            Identity::PKCS8 { pem, key } => {
388                let mut cert_reader = std::io::BufReader::new(pem);
389                let certs = rustls_pemfile::certs(&mut cert_reader)
390                    .collect::<Result<Vec<_>, _>>()
391                    .map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err))?;
392                (
393                    certs,
394                    PrivateKeyDer::from_pem_slice(key)
395                        .map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err))?,
396                )
397            }
398        };
399        c.connector_with_single_cert(certs, key)
400            .map_err(|err| io::Error::new(io::ErrorKind::Other, err))?
401    } else {
402        c.connector_with_no_client_auth()
403    };
404    s.into_rustls(&connector, domain)
405}
406
407cfg_if! {
408    if #[cfg(feature = "rustls-native-certs")] {
409        fn into_tls_impl(s: TcpStream, domain: &str, config: TLSConfig<'_, '_, '_>) -> HandshakeResult {
410            into_rustls_common(s, RustlsConnectorConfig::new_with_native_certs()?, domain, config)
411        }
412    } else if #[cfg(feature = "rustls-webpki-roots-certs")] {
413        fn into_tls_impl(s: TcpStream, domain: &str, config: TLSConfig<'_, '_, '_>) -> HandshakeResult {
414            into_rustls_common(s, RustlsConnectorConfig::new_with_webpki_roots_certs(), domain, config)
415        }
416    } else if #[cfg(feature = "rustls-common")] {
417        fn into_tls_impl(s: TcpStream, domain: &str, config: TLSConfig<'_, '_, '_>) -> HandshakeResult {
418            into_rustls_common(s, RustlsConnectorConfig::default(), domain, config)
419        }
420    } else if #[cfg(feature = "openssl")] {
421        fn into_tls_impl(s: TcpStream, domain: &str, config: TLSConfig<'_, '_, '_>) -> HandshakeResult {
422            use openssl::x509::X509;
423
424            let mut builder = OpenSslConnector::builder(OpenSslMethod::tls())?;
425            if let Some(identity) = config.identity {
426                let (cert, pkey, chain) = match identity {
427                    Identity::PKCS8 { pem, key } => {
428                        let pkey = openssl::pkey::PKey::private_key_from_pem(key)?;
429                        let mut chain = openssl::x509::X509::stack_from_pem(pem)?.into_iter();
430                        let cert = chain.next();
431                        (cert, Some(pkey), Some(chain.collect()))
432                    }
433                    Identity::PKCS12 { der, password } => {
434                        let mut openssl_identity = openssl::pkcs12::Pkcs12::from_der(der)?.parse2(password)?;
435                        (openssl_identity.cert, openssl_identity.pkey, openssl_identity.ca.take().map(|stack| stack.into_iter().collect::<Vec<_>>()))
436                    },
437                };
438                if let Some(cert) = cert.as_ref() {
439                    builder.set_certificate(cert)?;
440                }
441                if let Some(pkey) = pkey.as_ref() {
442                    builder.set_private_key(pkey)?;
443                }
444                if let Some(chain) = chain.as_ref() {
445                    for cert in chain.iter().rev() {
446                        builder.add_extra_chain_cert(cert.to_owned())?;
447                    }
448                }
449            }
450            if let Some(cert_chain) = config.cert_chain.as_ref() {
451                for cert in X509::stack_from_pem(cert_chain.as_bytes())?.drain(..).rev() {
452                    builder.cert_store_mut().add_cert(cert)?;
453                }
454            }
455            s.into_openssl(&builder.build(), domain)
456        }
457    } else if #[cfg(feature = "native-tls")] {
458        fn into_tls_impl(s: TcpStream, domain: &str, config: TLSConfig<'_, '_, '_>) -> HandshakeResult {
459            use native_tls::Certificate;
460
461            let mut builder = NativeTlsConnector::builder();
462            if let Some(identity) = config.identity {
463                let native_identity = match identity {
464                    Identity::PKCS8 { pem, key } => native_tls::Identity::from_pkcs8(pem, key),
465                    Identity::PKCS12 { der, password } => native_tls::Identity::from_pkcs12(der, password),
466                };
467                builder.identity(native_identity.map_err(|e| io::Error::new(io::ErrorKind::Other, e))?);
468            }
469            if let Some(cert_chain) = config.cert_chain {
470                let mut cert_chain = std::io::BufReader::new(cert_chain.as_bytes());
471                for cert in rustls_pemfile::certs(&mut cert_chain).collect::<Result<Vec<_>, _>>()? {
472                    builder.add_root_certificate(Certificate::from_der(&cert[..]).map_err(|e| io::Error::new(io::ErrorKind::Other, e))?);
473                }
474            }
475            s.into_native_tls(&builder.build().map_err(|e| io::Error::new(io::ErrorKind::Other, e))?, domain)
476        }
477    } else {
478        fn into_tls_impl(s: TcpStream, _domain: &str, _: TLSConfig<'_, '_, '_>) -> HandshakeResult {
479            Ok(s.into_plain()?)
480        }
481    }
482}
483
484impl TryFrom<StdTcpStream> for TcpStream {
485    type Error = io::Error;
486
487    fn try_from(s: StdTcpStream) -> io::Result<Self> {
488        let mut this = TcpStream::Plain(s, false);
489        this.try_connect()?;
490        Ok(this)
491    }
492}
493
494#[cfg(feature = "native-tls")]
495impl From<NativeTlsStream> for TcpStream {
496    fn from(s: NativeTlsStream) -> Self {
497        TcpStream::NativeTls(Box::new(s))
498    }
499}
500
501#[cfg(feature = "openssl")]
502impl From<OpenSslStream> for TcpStream {
503    fn from(s: OpenSslStream) -> Self {
504        TcpStream::OpenSsl(Box::new(s))
505    }
506}
507
508#[cfg(feature = "rustls-common")]
509impl From<RustlsStream> for TcpStream {
510    fn from(s: RustlsStream) -> Self {
511        TcpStream::Rustls(Box::new(s))
512    }
513}
514
515impl TcpStream {
516    /// Attempt reading from underlying stream, returning Ok(()) if the stream is readable
517    pub fn is_readable(&self) -> io::Result<()> {
518        self.deref().read(&mut []).map(|_| ())
519    }
520
521    /// Attempt writing to underlying stream, returning Ok(()) if the stream is writable
522    pub fn is_writable(&self) -> io::Result<()> {
523        self.deref().write(&[]).map(|_| ())
524    }
525}
526
527impl Deref for TcpStream {
528    type Target = StdTcpStream;
529
530    fn deref(&self) -> &Self::Target {
531        match self {
532            TcpStream::Plain(plain, _) => plain,
533            #[cfg(feature = "native-tls")]
534            TcpStream::NativeTls(tls) => tls.get_ref(),
535            #[cfg(feature = "openssl")]
536            TcpStream::OpenSsl(tls) => tls.get_ref(),
537            #[cfg(feature = "rustls-common")]
538            TcpStream::Rustls(tls) => tls.get_ref(),
539        }
540    }
541}
542
543impl DerefMut for TcpStream {
544    fn deref_mut(&mut self) -> &mut Self::Target {
545        match self {
546            TcpStream::Plain(plain, _) => plain,
547            #[cfg(feature = "native-tls")]
548            TcpStream::NativeTls(tls) => tls.get_mut(),
549            #[cfg(feature = "openssl")]
550            TcpStream::OpenSsl(tls) => tls.get_mut(),
551            #[cfg(feature = "rustls-common")]
552            TcpStream::Rustls(tls) => tls.get_mut(),
553        }
554    }
555}
556
557macro_rules! fwd_impl {
558    ($self:ident, $method:ident, $($args:expr),*) => {
559        match $self {
560            TcpStream::Plain(plain, _) => plain.$method($($args),*),
561            #[cfg(feature = "native-tls")]
562            TcpStream::NativeTls(tls) => tls.$method($($args),*),
563            #[cfg(feature = "openssl")]
564            TcpStream::OpenSsl(tls) => tls.$method($($args),*),
565            #[cfg(feature = "rustls-common")]
566            TcpStream::Rustls(tls) => tls.$method($($args),*),
567        }
568    };
569}
570
571impl Read for TcpStream {
572    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
573        fwd_impl!(self, read, buf)
574    }
575
576    fn read_vectored(&mut self, bufs: &mut [IoSliceMut<'_>]) -> io::Result<usize> {
577        fwd_impl!(self, read_vectored, bufs)
578    }
579
580    fn read_to_end(&mut self, buf: &mut Vec<u8>) -> io::Result<usize> {
581        fwd_impl!(self, read_to_end, buf)
582    }
583
584    fn read_to_string(&mut self, buf: &mut String) -> io::Result<usize> {
585        fwd_impl!(self, read_to_string, buf)
586    }
587
588    fn read_exact(&mut self, buf: &mut [u8]) -> io::Result<()> {
589        fwd_impl!(self, read_exact, buf)
590    }
591}
592
593impl Write for TcpStream {
594    fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
595        fwd_impl!(self, write, buf)
596    }
597
598    fn flush(&mut self) -> io::Result<()> {
599        fwd_impl!(self, flush,)
600    }
601
602    fn write_vectored(&mut self, bufs: &[IoSlice<'_>]) -> io::Result<usize> {
603        fwd_impl!(self, write_vectored, bufs)
604    }
605
606    fn write_all(&mut self, buf: &[u8]) -> io::Result<()> {
607        fwd_impl!(self, write_all, buf)
608    }
609
610    fn write_fmt(&mut self, fmt: fmt::Arguments<'_>) -> io::Result<()> {
611        fwd_impl!(self, write_fmt, fmt)
612    }
613}
614
615impl fmt::Debug for TcpStream {
616    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
617        f.debug_struct("TcpStream")
618            .field("inner", self.deref())
619            .finish()
620    }
621}
622
623/// A TLS stream which has been interrupted during the handshake
624#[allow(clippy::large_enum_variant)]
625#[derive(Debug)]
626pub enum MidHandshakeTlsStream {
627    /// Not a TLS stream
628    Plain(TcpStream),
629    #[cfg(feature = "native-tls")]
630    /// A native-tls MidHandshakeTlsStream
631    NativeTls(NativeTlsMidHandshakeTlsStream),
632    #[cfg(feature = "openssl")]
633    /// An openssl MidHandshakeTlsStream
634    Openssl(OpenSslMidHandshakeTlsStream),
635    #[cfg(feature = "rustls-common")]
636    /// A rustls-connector MidHandshakeTlsStream
637    Rustls(RustlsMidHandshakeTlsStream),
638}
639
640impl MidHandshakeTlsStream {
641    /// Get a reference to the inner stream
642    #[must_use]
643    pub fn get_ref(&self) -> &TcpStream {
644        match self {
645            MidHandshakeTlsStream::Plain(mid) => mid,
646            #[cfg(feature = "native-tls")]
647            MidHandshakeTlsStream::NativeTls(mid) => mid.get_ref(),
648            #[cfg(feature = "openssl")]
649            MidHandshakeTlsStream::Openssl(mid) => mid.get_ref(),
650            #[cfg(feature = "rustls-common")]
651            MidHandshakeTlsStream::Rustls(mid) => mid.get_ref(),
652        }
653    }
654
655    /// Get a mutable reference to the inner stream
656    #[must_use]
657    pub fn get_mut(&mut self) -> &mut TcpStream {
658        match self {
659            MidHandshakeTlsStream::Plain(mid) => mid,
660            #[cfg(feature = "native-tls")]
661            MidHandshakeTlsStream::NativeTls(mid) => mid.get_mut(),
662            #[cfg(feature = "openssl")]
663            MidHandshakeTlsStream::Openssl(mid) => mid.get_mut(),
664            #[cfg(feature = "rustls-common")]
665            MidHandshakeTlsStream::Rustls(mid) => mid.get_mut(),
666        }
667    }
668
669    /// Retry the handshake
670    pub fn handshake(self) -> HandshakeResult {
671        Ok(match self {
672            MidHandshakeTlsStream::Plain(mut mid) => {
673                if !mid.try_connect()? {
674                    return Err(HandshakeError::WouldBlock(mid.into()));
675                }
676                mid
677            }
678            #[cfg(feature = "native-tls")]
679            MidHandshakeTlsStream::NativeTls(mut mid) => {
680                if !mid.get_mut().try_connect()? {
681                    return Err(HandshakeError::WouldBlock(mid.into()));
682                }
683                mid.handshake()?.into()
684            }
685            #[cfg(feature = "openssl")]
686            MidHandshakeTlsStream::Openssl(mut mid) => {
687                if !mid.get_mut().try_connect()? {
688                    return Err(HandshakeError::WouldBlock(mid.into()));
689                }
690                mid.handshake()?.into()
691            }
692            #[cfg(feature = "rustls-common")]
693            MidHandshakeTlsStream::Rustls(mut mid) => {
694                if !mid.get_mut().try_connect()? {
695                    return Err(HandshakeError::WouldBlock(mid.into()));
696                }
697                mid.handshake()?.into()
698            }
699        })
700    }
701}
702
703impl From<TcpStream> for MidHandshakeTlsStream {
704    fn from(mid: TcpStream) -> Self {
705        MidHandshakeTlsStream::Plain(mid)
706    }
707}
708
709#[cfg(feature = "native-tls")]
710impl From<NativeTlsMidHandshakeTlsStream> for MidHandshakeTlsStream {
711    fn from(mid: NativeTlsMidHandshakeTlsStream) -> Self {
712        MidHandshakeTlsStream::NativeTls(mid)
713    }
714}
715
716#[cfg(feature = "openssl")]
717impl From<OpenSslMidHandshakeTlsStream> for MidHandshakeTlsStream {
718    fn from(mid: OpenSslMidHandshakeTlsStream) -> Self {
719        MidHandshakeTlsStream::Openssl(mid)
720    }
721}
722
723#[cfg(feature = "rustls-common")]
724impl From<RustlsMidHandshakeTlsStream> for MidHandshakeTlsStream {
725    fn from(mid: RustlsMidHandshakeTlsStream) -> Self {
726        MidHandshakeTlsStream::Rustls(mid)
727    }
728}
729
730impl fmt::Display for MidHandshakeTlsStream {
731    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
732        f.write_str("MidHandshakeTlsStream")
733    }
734}
735
736/// An error returned while performing the handshake
737#[allow(clippy::large_enum_variant)]
738#[derive(Debug)]
739pub enum HandshakeError {
740    /// We hit WouldBlock during handshake
741    WouldBlock(MidHandshakeTlsStream),
742    /// We hit a critical failure
743    Failure(io::Error),
744}
745
746impl HandshakeError {
747    /// Try and get the inner mid handshake TLS stream from this error
748    pub fn into_mid_handshake_tls_stream(self) -> io::Result<MidHandshakeTlsStream> {
749        match self {
750            Self::WouldBlock(mid) => Ok(mid),
751            Self::Failure(error) => Err(error),
752        }
753    }
754}
755
756impl fmt::Display for HandshakeError {
757    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
758        match self {
759            HandshakeError::WouldBlock(_) => f.write_str("WouldBlock hit during handshake"),
760            HandshakeError::Failure(err) => f.write_fmt(format_args!("IO error: {}", err)),
761        }
762    }
763}
764
765impl Error for HandshakeError {
766    fn source(&self) -> Option<&(dyn Error + 'static)> {
767        match self {
768            HandshakeError::Failure(err) => Some(err),
769            _ => None,
770        }
771    }
772}
773
774#[cfg(feature = "native-tls")]
775impl From<NativeTlsHandshakeError> for HandshakeError {
776    fn from(error: NativeTlsHandshakeError) -> Self {
777        match error {
778            native_tls::HandshakeError::WouldBlock(mid) => HandshakeError::WouldBlock(mid.into()),
779            native_tls::HandshakeError::Failure(failure) => {
780                HandshakeError::Failure(io::Error::new(io::ErrorKind::Other, failure))
781            }
782        }
783    }
784}
785
786#[cfg(feature = "openssl")]
787impl From<OpenSslHandshakeError> for HandshakeError {
788    fn from(error: OpenSslHandshakeError) -> Self {
789        match error {
790            openssl::ssl::HandshakeError::WouldBlock(mid) => HandshakeError::WouldBlock(mid.into()),
791            openssl::ssl::HandshakeError::Failure(failure) => {
792                HandshakeError::Failure(io::Error::new(io::ErrorKind::Other, failure.into_error()))
793            }
794            openssl::ssl::HandshakeError::SetupFailure(failure) => failure.into(),
795        }
796    }
797}
798
799#[cfg(feature = "openssl")]
800impl From<OpenSslErrorStack> for HandshakeError {
801    fn from(error: OpenSslErrorStack) -> Self {
802        Self::Failure(error.into())
803    }
804}
805
806#[cfg(feature = "rustls-common")]
807impl From<RustlsHandshakeError> for HandshakeError {
808    fn from(error: RustlsHandshakeError) -> Self {
809        match error {
810            rustls_connector::HandshakeError::WouldBlock(mid) => {
811                HandshakeError::WouldBlock((*mid).into())
812            }
813            rustls_connector::HandshakeError::Failure(failure) => HandshakeError::Failure(failure),
814        }
815    }
816}
817
818impl From<io::Error> for HandshakeError {
819    fn from(err: io::Error) -> Self {
820        HandshakeError::Failure(err)
821    }
822}
823
824#[cfg(unix)]
825mod sys {
826    use crate::TcpStream;
827    use std::{
828        net::TcpStream as StdTcpStream,
829        os::unix::io::{AsFd, AsRawFd, BorrowedFd, FromRawFd, RawFd},
830    };
831
832    impl AsFd for TcpStream {
833        fn as_fd(&self) -> BorrowedFd<'_> {
834            <StdTcpStream as AsFd>::as_fd(self)
835        }
836    }
837
838    impl AsRawFd for TcpStream {
839        fn as_raw_fd(&self) -> RawFd {
840            <StdTcpStream as AsRawFd>::as_raw_fd(self)
841        }
842    }
843
844    impl AsRawFd for &TcpStream {
845        fn as_raw_fd(&self) -> RawFd {
846            <StdTcpStream as AsRawFd>::as_raw_fd(self)
847        }
848    }
849
850    impl FromRawFd for TcpStream {
851        unsafe fn from_raw_fd(fd: RawFd) -> Self {
852            Self::Plain(unsafe { StdTcpStream::from_raw_fd(fd) }, false)
853        }
854    }
855}
856
857#[cfg(windows)]
858mod sys {
859    use crate::TcpStream;
860    use std::{
861        net::TcpStream as StdTcpStream,
862        os::windows::io::{AsRawSocket, AsSocket, BorrowedSocket, FromRawSocket, RawSocket},
863    };
864
865    impl AsSocket for TcpStream {
866        fn as_socket(&self) -> BorrowedSocket<'_> {
867            <StdTcpStream as AsSocket>::as_socket(self)
868        }
869    }
870
871    impl AsRawSocket for TcpStream {
872        fn as_raw_socket(&self) -> RawSocket {
873            <StdTcpStream as AsRawSocket>::as_raw_socket(self)
874        }
875    }
876
877    impl AsRawSocket for &TcpStream {
878        fn as_raw_socket(&self) -> RawSocket {
879            <StdTcpStream as AsRawSocket>::as_raw_socket(self)
880        }
881    }
882
883    impl FromRawSocket for TcpStream {
884        unsafe fn from_raw_socket(socket: RawSocket) -> Self {
885            Self::Plain(unsafe { StdTcpStream::from_raw_socket(socket) }, false)
886        }
887    }
888}