tcp_stream/
lib.rs

1#![deny(missing_docs)]
2#![allow(clippy::result_large_err)]
3
4//! # std::net::TCP stream on steroids
5//!
6//! tcp-stream is a library aiming at providing TLS support to std::net::TcpStream
7//!
8//! # Examples
9//!
10//! To connect to a remote server:
11//!
12//! ```rust
13//! use tcp_stream::{HandshakeError, TcpStream, TLSConfig};
14//!
15//! use std::io::{self, Read, Write};
16//!
17//! let mut stream = TcpStream::connect("www.rust-lang.org:443").unwrap();
18//! stream.set_nonblocking(true).unwrap();
19//!
20//! while !stream.is_connected() {
21//!     if stream.try_connect().unwrap() {
22//!         break;
23//!     }
24//! }
25//!
26//! let mut stream = stream.into_tls("www.rust-lang.org", TLSConfig::default());
27//!
28//! while let Err(HandshakeError::WouldBlock(mid_handshake)) = stream {
29//!     stream = mid_handshake.handshake();
30//! }
31//!
32//! let mut stream = stream.unwrap();
33//!
34//! while let Err(err) = stream.write_all(b"GET / HTTP/1.0\r\n\r\n") {
35//!     if err.kind() != io::ErrorKind::WouldBlock {
36//!         panic!("error: {:?}", err);
37//!     }
38//! }
39//! stream.flush().unwrap();
40//! let mut res = vec![];
41//! while let Err(err) = stream.read_to_end(&mut res) {
42//!     if err.kind() != io::ErrorKind::WouldBlock {
43//!         panic!("stream error: {:?}", err);
44//!     }
45//! }
46//! println!("{}", String::from_utf8_lossy(&res));
47//! ```
48
49use cfg_if::cfg_if;
50use std::{
51    convert::TryFrom,
52    error::Error,
53    fmt,
54    io::{self, IoSlice, IoSliceMut, Read, Write},
55    net::{TcpStream as StdTcpStream, ToSocketAddrs},
56    ops::{Deref, DerefMut},
57    time::Duration,
58};
59
60#[cfg(feature = "native-tls")]
61/// Reexport native-tls's `TlsConnector`
62pub use native_tls::TlsConnector as NativeTlsConnector;
63
64#[cfg(feature = "native-tls")]
65/// A `TcpStream` wrapped by native-tls
66pub type NativeTlsStream = native_tls::TlsStream<TcpStream>;
67
68#[cfg(feature = "native-tls")]
69/// A `MidHandshakeTlsStream` from native-tls
70pub type NativeTlsMidHandshakeTlsStream = native_tls::MidHandshakeTlsStream<TcpStream>;
71
72#[cfg(feature = "native-tls")]
73/// A `HandshakeError` from native-tls
74pub type NativeTlsHandshakeError = native_tls::HandshakeError<TcpStream>;
75
76#[cfg(feature = "openssl")]
77/// Reexport openssl's `TlsConnector`
78pub use openssl::ssl::{SslConnector as OpenSslConnector, SslMethod as OpenSslMethod};
79
80#[cfg(feature = "openssl")]
81/// A `TcpStream` wrapped by openssl
82pub type OpenSslStream = openssl::ssl::SslStream<TcpStream>;
83
84#[cfg(feature = "openssl")]
85/// A `MidHandshakeTlsStream` from openssl
86pub type OpenSslMidHandshakeTlsStream = openssl::ssl::MidHandshakeSslStream<TcpStream>;
87
88#[cfg(feature = "openssl")]
89/// A `HandshakeError` from openssl
90pub type OpenSslHandshakeError = openssl::ssl::HandshakeError<TcpStream>;
91
92#[cfg(feature = "openssl")]
93/// An `ErrorStack` from openssl
94pub type OpenSslErrorStack = openssl::error::ErrorStack;
95
96#[cfg(feature = "rustls-common")]
97/// Reexport rustls-connector's `TlsConnector`
98pub use rustls_connector::{RustlsConnector, RustlsConnectorConfig};
99
100#[cfg(feature = "rustls-common")]
101/// A `TcpStream` wrapped by rustls
102pub type RustlsStream = rustls_connector::TlsStream<TcpStream>;
103
104#[cfg(feature = "rustls-common")]
105/// A `MidHandshakeTlsStream` from rustls-connector
106pub type RustlsMidHandshakeTlsStream = rustls_connector::MidHandshakeTlsStream<TcpStream>;
107
108#[cfg(feature = "rustls-common")]
109/// A `HandshakeError` from rustls-connector
110pub type RustlsHandshakeError = rustls_connector::HandshakeError<TcpStream>;
111
112/// Wrapper around plain or TLS TCP streams
113pub enum TcpStream {
114    /// Wrapper around std::net::TcpStream
115    Plain(StdTcpStream, bool),
116    #[cfg(feature = "native-tls")]
117    /// Wrapper around a TLS stream hanled by native-tls
118    NativeTls(Box<NativeTlsStream>),
119    #[cfg(feature = "openssl")]
120    /// Wrapper around a TLS stream hanled by openssl
121    OpenSsl(Box<OpenSslStream>),
122    #[cfg(feature = "rustls-common")]
123    /// Wrapper around a TLS stream hanled by rustls
124    Rustls(Box<RustlsStream>),
125}
126
127/// Holds extra TLS configuration
128#[derive(Default, Debug, PartialEq)]
129pub struct TLSConfig<'data, 'key, 'chain> {
130    /// Use for client certificate authentication
131    pub identity: Option<Identity<'data, 'key>>,
132    /// The custom certificates chain in PEM format
133    pub cert_chain: Option<&'chain str>,
134}
135
136/// Holds extra TLS configuration
137#[derive(Default, Debug, PartialEq)]
138pub struct OwnedTLSConfig {
139    /// Use for client certificate authentication
140    pub identity: Option<OwnedIdentity>,
141    /// The custom certificates chain in PEM format
142    pub cert_chain: Option<String>,
143}
144
145impl OwnedTLSConfig {
146    /// Get the ephemeral `TLSConfig` corresponding to the `OwnedTLSConfig`
147    #[must_use]
148    pub fn as_ref(&self) -> TLSConfig<'_, '_, '_> {
149        TLSConfig {
150            identity: self.identity.as_ref().map(OwnedIdentity::as_ref),
151            cert_chain: self.cert_chain.as_deref(),
152        }
153    }
154}
155
156/// Holds one of:
157/// - PKCS#12 DER-encoded identity and decryption password
158/// - PKCS#8 PEM-encoded certificate and key (without decryption password)
159#[derive(Debug, PartialEq)]
160pub enum Identity<'data, 'key> {
161    /// PKCS#12 DER-encoded identity with decryption password
162    PKCS12 {
163        /// PKCS#12 DER-encoded identity
164        der: &'data [u8],
165        /// Decryption password
166        password: &'key str,
167    },
168    /// PEM encoded DER private key with PEM encoded certificate
169    PKCS8 {
170        /// PEM-encoded certificate
171        pem: &'data [u8],
172        /// PEM-encoded key
173        key: &'key [u8],
174    },
175}
176
177/// Holds one of:
178/// - PKCS#12 DER-encoded identity and decryption password
179/// - PKCS#8 PEM-encoded certificate and key (without decryption password)
180#[derive(Debug, PartialEq)]
181pub enum OwnedIdentity {
182    /// PKCS#12 DER-encoded identity with decryption password
183    PKCS12 {
184        /// PKCS#12 DER-encoded identity
185        der: Vec<u8>,
186        /// Decryption password
187        password: String,
188    },
189    /// PKCS#8 encoded DER private key with PEM encoded certificate
190    PKCS8 {
191        /// PEM-encoded certificate
192        pem: Vec<u8>,
193        /// PEM-encoded key
194        key: Vec<u8>,
195    },
196}
197
198impl OwnedIdentity {
199    /// Get the ephemeral `Identity` corresponding to the `OwnedIdentity`
200    #[must_use]
201    pub fn as_ref(&self) -> Identity<'_, '_> {
202        match self {
203            Self::PKCS8 { pem, key } => Identity::PKCS8 { pem, key },
204            Self::PKCS12 { der, password } => Identity::PKCS12 { der, password },
205        }
206    }
207}
208
209/// Holds either the TLS `TcpStream` result or the current handshake state
210pub type HandshakeResult = Result<TcpStream, HandshakeError>;
211
212impl TcpStream {
213    /// Wrapper around `std::net::TcpStream::connect`
214    pub fn connect<A: ToSocketAddrs>(addr: A) -> io::Result<Self> {
215        connect_std(addr, None).and_then(Self::try_from)
216    }
217
218    /// Wrapper around `std::net::TcpStream::connect_timeout`
219    pub fn connect_timeout<A: ToSocketAddrs>(addr: A, timeout: Duration) -> io::Result<Self> {
220        connect_std(addr, Some(timeout)).and_then(Self::try_from)
221    }
222
223    /// Convert from a `std::net::TcpStream`
224    pub fn from_std(stream: StdTcpStream) -> io::Result<Self> {
225        Self::try_from(stream)
226    }
227
228    /// Check whether the stream is connected or not
229    #[must_use]
230    #[allow(irrefutable_let_patterns)]
231    pub fn is_connected(&self) -> bool {
232        if let Self::Plain(_, connected) = self {
233            *connected
234        } else {
235            true
236        }
237    }
238
239    /// Retry the connection. Returns:
240    /// - Ok(true) if connected
241    /// - Ok(false) if connecting
242    /// - Err(_) if an error is encountered
243    #[allow(irrefutable_let_patterns)]
244    pub fn try_connect(&mut self) -> io::Result<bool> {
245        if self.is_connected() {
246            return Ok(true);
247        }
248        match self.is_writable() {
249            Ok(()) => {
250                if let Self::Plain(_, connected) = self {
251                    *connected = true;
252                }
253                Ok(true)
254            }
255            Err(err)
256                if [io::ErrorKind::WouldBlock, io::ErrorKind::NotConnected]
257                    .contains(&err.kind()) =>
258            {
259                Ok(false)
260            }
261            Err(err) => Err(err),
262        }
263    }
264
265    /// Enable TLS
266    pub fn into_tls(
267        self,
268        domain: &str,
269        config: TLSConfig<'_, '_, '_>,
270    ) -> Result<Self, HandshakeError> {
271        into_tls_impl(self, domain, config)
272    }
273
274    #[cfg(feature = "native-tls")]
275    /// Enable TLS using native-tls
276    pub fn into_native_tls(
277        self,
278        connector: &NativeTlsConnector,
279        domain: &str,
280    ) -> Result<Self, HandshakeError> {
281        Ok(connector.connect(domain, self.into_plain()?)?.into())
282    }
283
284    #[cfg(feature = "openssl")]
285    /// Enable TLS using openssl
286    pub fn into_openssl(
287        self,
288        connector: &OpenSslConnector,
289        domain: &str,
290    ) -> Result<Self, HandshakeError> {
291        Ok(connector.connect(domain, self.into_plain()?)?.into())
292    }
293
294    #[cfg(feature = "rustls-common")]
295    /// Enable TLS using rustls
296    pub fn into_rustls(
297        self,
298        connector: &RustlsConnector,
299        domain: &str,
300    ) -> Result<Self, HandshakeError> {
301        Ok(connector.connect(domain, self.into_plain()?)?.into())
302    }
303
304    #[allow(irrefutable_let_patterns)]
305    fn into_plain(self) -> Result<TcpStream, io::Error> {
306        if let TcpStream::Plain(plain, connected) = self {
307            Ok(TcpStream::Plain(plain, connected))
308        } else {
309            Err(io::Error::new(
310                io::ErrorKind::AlreadyExists,
311                "already a TLS stream",
312            ))
313        }
314    }
315}
316
317fn connect_std<A: ToSocketAddrs>(addr: A, timeout: Option<Duration>) -> io::Result<StdTcpStream> {
318    let stream = connect_std_raw(addr, timeout)?;
319    stream.set_nodelay(true)?;
320    Ok(stream)
321}
322
323fn connect_std_raw<A: ToSocketAddrs>(
324    addr: A,
325    timeout: Option<Duration>,
326) -> io::Result<StdTcpStream> {
327    if let Some(timeout) = timeout {
328        let addrs = addr.to_socket_addrs()?;
329        let mut err = None;
330        for addr in addrs {
331            match StdTcpStream::connect_timeout(&addr, timeout) {
332                Ok(stream) => return Ok(stream),
333                Err(error) => err = Some(error),
334            }
335        }
336        Err(err.unwrap_or_else(|| {
337            io::Error::new(io::ErrorKind::AddrNotAvailable, "couldn't resolve host")
338        }))
339    } else {
340        StdTcpStream::connect(addr)
341    }
342}
343
344#[cfg(feature = "rustls-common")]
345fn into_rustls_common(
346    s: TcpStream,
347    mut c: RustlsConnectorConfig,
348    domain: &str,
349    config: TLSConfig<'_, '_, '_>,
350) -> HandshakeResult {
351    use rustls_connector::rustls_pki_types::{
352        CertificateDer, PrivateKeyDer, PrivatePkcs8KeyDer, pem::PemObject,
353    };
354
355    if let Some(cert_chain) = config.cert_chain {
356        let mut cert_chain = std::io::BufReader::new(cert_chain.as_bytes());
357        let certs = rustls_pemfile::certs(&mut cert_chain)
358            .collect::<Result<Vec<_>, _>>()
359            .map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err))?;
360        c.add_parsable_certificates(certs);
361    }
362    let connector = if let Some(identity) = config.identity {
363        let (certs, key) = match identity {
364            Identity::PKCS12 { der, password } => {
365                let pfx =
366                    p12_keystore::KeyStore::from_pkcs12(der, password).map_err(io::Error::other)?;
367                let Some((_, keychain)) = pfx.private_key_chain() else {
368                    return Err(io::Error::other("No private key in pkcs12 DER").into());
369                };
370                let certs = keychain
371                    .chain()
372                    .iter()
373                    .map(|cert| CertificateDer::from(cert.as_der().to_vec()))
374                    .collect();
375                (
376                    certs,
377                    PrivateKeyDer::from(PrivatePkcs8KeyDer::from(keychain.key().to_vec())),
378                )
379            }
380            Identity::PKCS8 { pem, key } => {
381                let mut cert_reader = std::io::BufReader::new(pem);
382                let certs = rustls_pemfile::certs(&mut cert_reader)
383                    .collect::<Result<Vec<_>, _>>()
384                    .map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err))?;
385                (
386                    certs,
387                    PrivateKeyDer::from_pem_slice(key)
388                        .map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err))?,
389                )
390            }
391        };
392        c.connector_with_single_cert(certs, key)
393            .map_err(io::Error::other)?
394    } else {
395        c.connector_with_no_client_auth()
396    };
397    s.into_rustls(&connector, domain)
398}
399
400cfg_if! {
401    if #[cfg(feature = "rustls-native-certs")] {
402        fn into_tls_impl(s: TcpStream, domain: &str, config: TLSConfig<'_, '_, '_>) -> HandshakeResult {
403            into_rustls_common(s, RustlsConnectorConfig::new_with_native_certs()?, domain, config)
404        }
405    } else if #[cfg(feature = "rustls-webpki-roots-certs")] {
406        fn into_tls_impl(s: TcpStream, domain: &str, config: TLSConfig<'_, '_, '_>) -> HandshakeResult {
407            into_rustls_common(s, RustlsConnectorConfig::new_with_webpki_roots_certs(), domain, config)
408        }
409    } else if #[cfg(feature = "rustls-common")] {
410        fn into_tls_impl(s: TcpStream, domain: &str, config: TLSConfig<'_, '_, '_>) -> HandshakeResult {
411            into_rustls_common(s, RustlsConnectorConfig::default(), domain, config)
412        }
413    } else if #[cfg(feature = "openssl")] {
414        fn into_tls_impl(s: TcpStream, domain: &str, config: TLSConfig<'_, '_, '_>) -> HandshakeResult {
415            use openssl::x509::X509;
416
417            let mut builder = OpenSslConnector::builder(OpenSslMethod::tls())?;
418            if let Some(identity) = config.identity {
419                let (cert, pkey, chain) = match identity {
420                    Identity::PKCS8 { pem, key } => {
421                        let pkey = openssl::pkey::PKey::private_key_from_pem(key)?;
422                        let mut chain = openssl::x509::X509::stack_from_pem(pem)?.into_iter();
423                        let cert = chain.next();
424                        (cert, Some(pkey), Some(chain.collect()))
425                    }
426                    Identity::PKCS12 { der, password } => {
427                        let mut openssl_identity = openssl::pkcs12::Pkcs12::from_der(der)?.parse2(password)?;
428                        (openssl_identity.cert, openssl_identity.pkey, openssl_identity.ca.take().map(|stack| stack.into_iter().collect::<Vec<_>>()))
429                    },
430                };
431                if let Some(cert) = cert.as_ref() {
432                    builder.set_certificate(cert)?;
433                }
434                if let Some(pkey) = pkey.as_ref() {
435                    builder.set_private_key(pkey)?;
436                }
437                if let Some(chain) = chain.as_ref() {
438                    for cert in chain.iter().rev() {
439                        builder.add_extra_chain_cert(cert.to_owned())?;
440                    }
441                }
442            }
443            if let Some(cert_chain) = config.cert_chain.as_ref() {
444                for cert in X509::stack_from_pem(cert_chain.as_bytes())?.drain(..).rev() {
445                    builder.cert_store_mut().add_cert(cert)?;
446                }
447            }
448            s.into_openssl(&builder.build(), domain)
449        }
450    } else if #[cfg(feature = "native-tls")] {
451        fn into_tls_impl(s: TcpStream, domain: &str, config: TLSConfig<'_, '_, '_>) -> HandshakeResult {
452            use native_tls::Certificate;
453
454            let mut builder = NativeTlsConnector::builder();
455            if let Some(identity) = config.identity {
456                let native_identity = match identity {
457                    Identity::PKCS8 { pem, key } => native_tls::Identity::from_pkcs8(pem, key),
458                    Identity::PKCS12 { der, password } => native_tls::Identity::from_pkcs12(der, password),
459                };
460                builder.identity(native_identity.map_err(|e| io::Error::new(io::ErrorKind::Other, e))?);
461            }
462            if let Some(cert_chain) = config.cert_chain {
463                let mut cert_chain = std::io::BufReader::new(cert_chain.as_bytes());
464                for cert in rustls_pemfile::certs(&mut cert_chain).collect::<Result<Vec<_>, _>>()? {
465                    builder.add_root_certificate(Certificate::from_der(&cert[..]).map_err(|e| io::Error::new(io::ErrorKind::Other, e))?);
466                }
467            }
468            s.into_native_tls(&builder.build().map_err(|e| io::Error::new(io::ErrorKind::Other, e))?, domain)
469        }
470    } else {
471        fn into_tls_impl(s: TcpStream, _domain: &str, _: TLSConfig<'_, '_, '_>) -> HandshakeResult {
472            Ok(s.into_plain()?)
473        }
474    }
475}
476
477impl TryFrom<StdTcpStream> for TcpStream {
478    type Error = io::Error;
479
480    fn try_from(s: StdTcpStream) -> io::Result<Self> {
481        let mut this = TcpStream::Plain(s, false);
482        this.try_connect()?;
483        Ok(this)
484    }
485}
486
487#[cfg(feature = "native-tls")]
488impl From<NativeTlsStream> for TcpStream {
489    fn from(s: NativeTlsStream) -> Self {
490        TcpStream::NativeTls(Box::new(s))
491    }
492}
493
494#[cfg(feature = "openssl")]
495impl From<OpenSslStream> for TcpStream {
496    fn from(s: OpenSslStream) -> Self {
497        TcpStream::OpenSsl(Box::new(s))
498    }
499}
500
501#[cfg(feature = "rustls-common")]
502impl From<RustlsStream> for TcpStream {
503    fn from(s: RustlsStream) -> Self {
504        TcpStream::Rustls(Box::new(s))
505    }
506}
507
508impl TcpStream {
509    /// Attempt reading from underlying stream, returning Ok(()) if the stream is readable
510    pub fn is_readable(&self) -> io::Result<()> {
511        self.deref().read(&mut []).map(|_| ())
512    }
513
514    /// Attempt writing to underlying stream, returning Ok(()) if the stream is writable
515    pub fn is_writable(&self) -> io::Result<()> {
516        self.deref().write(&[]).map(|_| ())
517    }
518}
519
520impl Deref for TcpStream {
521    type Target = StdTcpStream;
522
523    fn deref(&self) -> &Self::Target {
524        match self {
525            TcpStream::Plain(plain, _) => plain,
526            #[cfg(feature = "native-tls")]
527            TcpStream::NativeTls(tls) => tls.get_ref(),
528            #[cfg(feature = "openssl")]
529            TcpStream::OpenSsl(tls) => tls.get_ref(),
530            #[cfg(feature = "rustls-common")]
531            TcpStream::Rustls(tls) => tls.get_ref(),
532        }
533    }
534}
535
536impl DerefMut for TcpStream {
537    fn deref_mut(&mut self) -> &mut Self::Target {
538        match self {
539            TcpStream::Plain(plain, _) => plain,
540            #[cfg(feature = "native-tls")]
541            TcpStream::NativeTls(tls) => tls.get_mut(),
542            #[cfg(feature = "openssl")]
543            TcpStream::OpenSsl(tls) => tls.get_mut(),
544            #[cfg(feature = "rustls-common")]
545            TcpStream::Rustls(tls) => tls.get_mut(),
546        }
547    }
548}
549
550macro_rules! fwd_impl {
551    ($self:ident, $method:ident, $($args:expr),*) => {
552        match $self {
553            TcpStream::Plain(plain, _) => plain.$method($($args),*),
554            #[cfg(feature = "native-tls")]
555            TcpStream::NativeTls(tls) => tls.$method($($args),*),
556            #[cfg(feature = "openssl")]
557            TcpStream::OpenSsl(tls) => tls.$method($($args),*),
558            #[cfg(feature = "rustls-common")]
559            TcpStream::Rustls(tls) => tls.$method($($args),*),
560        }
561    };
562}
563
564impl Read for TcpStream {
565    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
566        fwd_impl!(self, read, buf)
567    }
568
569    fn read_vectored(&mut self, bufs: &mut [IoSliceMut<'_>]) -> io::Result<usize> {
570        fwd_impl!(self, read_vectored, bufs)
571    }
572
573    fn read_to_end(&mut self, buf: &mut Vec<u8>) -> io::Result<usize> {
574        fwd_impl!(self, read_to_end, buf)
575    }
576
577    fn read_to_string(&mut self, buf: &mut String) -> io::Result<usize> {
578        fwd_impl!(self, read_to_string, buf)
579    }
580
581    fn read_exact(&mut self, buf: &mut [u8]) -> io::Result<()> {
582        fwd_impl!(self, read_exact, buf)
583    }
584}
585
586impl Write for TcpStream {
587    fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
588        fwd_impl!(self, write, buf)
589    }
590
591    fn flush(&mut self) -> io::Result<()> {
592        fwd_impl!(self, flush,)
593    }
594
595    fn write_vectored(&mut self, bufs: &[IoSlice<'_>]) -> io::Result<usize> {
596        fwd_impl!(self, write_vectored, bufs)
597    }
598
599    fn write_all(&mut self, buf: &[u8]) -> io::Result<()> {
600        fwd_impl!(self, write_all, buf)
601    }
602
603    fn write_fmt(&mut self, fmt: fmt::Arguments<'_>) -> io::Result<()> {
604        fwd_impl!(self, write_fmt, fmt)
605    }
606}
607
608impl fmt::Debug for TcpStream {
609    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
610        f.debug_struct("TcpStream")
611            .field("inner", self.deref())
612            .finish()
613    }
614}
615
616/// A TLS stream which has been interrupted during the handshake
617#[allow(clippy::large_enum_variant)]
618#[derive(Debug)]
619pub enum MidHandshakeTlsStream {
620    /// Not a TLS stream
621    Plain(TcpStream),
622    #[cfg(feature = "native-tls")]
623    /// A native-tls MidHandshakeTlsStream
624    NativeTls(NativeTlsMidHandshakeTlsStream),
625    #[cfg(feature = "openssl")]
626    /// An openssl MidHandshakeTlsStream
627    Openssl(OpenSslMidHandshakeTlsStream),
628    #[cfg(feature = "rustls-common")]
629    /// A rustls-connector MidHandshakeTlsStream
630    Rustls(RustlsMidHandshakeTlsStream),
631}
632
633impl MidHandshakeTlsStream {
634    /// Get a reference to the inner stream
635    #[must_use]
636    pub fn get_ref(&self) -> &TcpStream {
637        match self {
638            MidHandshakeTlsStream::Plain(mid) => mid,
639            #[cfg(feature = "native-tls")]
640            MidHandshakeTlsStream::NativeTls(mid) => mid.get_ref(),
641            #[cfg(feature = "openssl")]
642            MidHandshakeTlsStream::Openssl(mid) => mid.get_ref(),
643            #[cfg(feature = "rustls-common")]
644            MidHandshakeTlsStream::Rustls(mid) => mid.get_ref(),
645        }
646    }
647
648    /// Get a mutable reference to the inner stream
649    #[must_use]
650    pub fn get_mut(&mut self) -> &mut TcpStream {
651        match self {
652            MidHandshakeTlsStream::Plain(mid) => mid,
653            #[cfg(feature = "native-tls")]
654            MidHandshakeTlsStream::NativeTls(mid) => mid.get_mut(),
655            #[cfg(feature = "openssl")]
656            MidHandshakeTlsStream::Openssl(mid) => mid.get_mut(),
657            #[cfg(feature = "rustls-common")]
658            MidHandshakeTlsStream::Rustls(mid) => mid.get_mut(),
659        }
660    }
661
662    /// Retry the handshake
663    pub fn handshake(self) -> HandshakeResult {
664        Ok(match self {
665            MidHandshakeTlsStream::Plain(mut mid) => {
666                if !mid.try_connect()? {
667                    return Err(HandshakeError::WouldBlock(mid.into()));
668                }
669                mid
670            }
671            #[cfg(feature = "native-tls")]
672            MidHandshakeTlsStream::NativeTls(mut mid) => {
673                if !mid.get_mut().try_connect()? {
674                    return Err(HandshakeError::WouldBlock(mid.into()));
675                }
676                mid.handshake()?.into()
677            }
678            #[cfg(feature = "openssl")]
679            MidHandshakeTlsStream::Openssl(mut mid) => {
680                if !mid.get_mut().try_connect()? {
681                    return Err(HandshakeError::WouldBlock(mid.into()));
682                }
683                mid.handshake()?.into()
684            }
685            #[cfg(feature = "rustls-common")]
686            MidHandshakeTlsStream::Rustls(mut mid) => {
687                if !mid.get_mut().try_connect()? {
688                    return Err(HandshakeError::WouldBlock(mid.into()));
689                }
690                mid.handshake()?.into()
691            }
692        })
693    }
694}
695
696impl From<TcpStream> for MidHandshakeTlsStream {
697    fn from(mid: TcpStream) -> Self {
698        MidHandshakeTlsStream::Plain(mid)
699    }
700}
701
702#[cfg(feature = "native-tls")]
703impl From<NativeTlsMidHandshakeTlsStream> for MidHandshakeTlsStream {
704    fn from(mid: NativeTlsMidHandshakeTlsStream) -> Self {
705        MidHandshakeTlsStream::NativeTls(mid)
706    }
707}
708
709#[cfg(feature = "openssl")]
710impl From<OpenSslMidHandshakeTlsStream> for MidHandshakeTlsStream {
711    fn from(mid: OpenSslMidHandshakeTlsStream) -> Self {
712        MidHandshakeTlsStream::Openssl(mid)
713    }
714}
715
716#[cfg(feature = "rustls-common")]
717impl From<RustlsMidHandshakeTlsStream> for MidHandshakeTlsStream {
718    fn from(mid: RustlsMidHandshakeTlsStream) -> Self {
719        MidHandshakeTlsStream::Rustls(mid)
720    }
721}
722
723impl fmt::Display for MidHandshakeTlsStream {
724    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
725        f.write_str("MidHandshakeTlsStream")
726    }
727}
728
729/// An error returned while performing the handshake
730#[allow(clippy::large_enum_variant)]
731#[derive(Debug)]
732pub enum HandshakeError {
733    /// We hit WouldBlock during handshake
734    WouldBlock(MidHandshakeTlsStream),
735    /// We hit a critical failure
736    Failure(io::Error),
737}
738
739impl HandshakeError {
740    /// Try and get the inner mid handshake TLS stream from this error
741    pub fn into_mid_handshake_tls_stream(self) -> io::Result<MidHandshakeTlsStream> {
742        match self {
743            Self::WouldBlock(mid) => Ok(mid),
744            Self::Failure(error) => Err(error),
745        }
746    }
747}
748
749impl fmt::Display for HandshakeError {
750    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
751        match self {
752            HandshakeError::WouldBlock(_) => f.write_str("WouldBlock hit during handshake"),
753            HandshakeError::Failure(err) => f.write_fmt(format_args!("IO error: {err}")),
754        }
755    }
756}
757
758impl Error for HandshakeError {
759    fn source(&self) -> Option<&(dyn Error + 'static)> {
760        match self {
761            HandshakeError::Failure(err) => Some(err),
762            _ => None,
763        }
764    }
765}
766
767#[cfg(feature = "native-tls")]
768impl From<NativeTlsHandshakeError> for HandshakeError {
769    fn from(error: NativeTlsHandshakeError) -> Self {
770        match error {
771            native_tls::HandshakeError::WouldBlock(mid) => HandshakeError::WouldBlock(mid.into()),
772            native_tls::HandshakeError::Failure(failure) => {
773                HandshakeError::Failure(io::Error::new(io::ErrorKind::Other, failure))
774            }
775        }
776    }
777}
778
779#[cfg(feature = "openssl")]
780impl From<OpenSslHandshakeError> for HandshakeError {
781    fn from(error: OpenSslHandshakeError) -> Self {
782        match error {
783            openssl::ssl::HandshakeError::WouldBlock(mid) => HandshakeError::WouldBlock(mid.into()),
784            openssl::ssl::HandshakeError::Failure(failure) => {
785                HandshakeError::Failure(io::Error::new(io::ErrorKind::Other, failure.into_error()))
786            }
787            openssl::ssl::HandshakeError::SetupFailure(failure) => failure.into(),
788        }
789    }
790}
791
792#[cfg(feature = "openssl")]
793impl From<OpenSslErrorStack> for HandshakeError {
794    fn from(error: OpenSslErrorStack) -> Self {
795        Self::Failure(error.into())
796    }
797}
798
799#[cfg(feature = "rustls-common")]
800impl From<RustlsHandshakeError> for HandshakeError {
801    fn from(error: RustlsHandshakeError) -> Self {
802        match error {
803            rustls_connector::HandshakeError::WouldBlock(mid) => {
804                HandshakeError::WouldBlock((*mid).into())
805            }
806            rustls_connector::HandshakeError::Failure(failure) => HandshakeError::Failure(failure),
807        }
808    }
809}
810
811impl From<io::Error> for HandshakeError {
812    fn from(err: io::Error) -> Self {
813        HandshakeError::Failure(err)
814    }
815}
816
817#[cfg(unix)]
818mod sys {
819    use crate::TcpStream;
820    use std::{
821        net::TcpStream as StdTcpStream,
822        os::unix::io::{AsFd, AsRawFd, BorrowedFd, FromRawFd, RawFd},
823    };
824
825    impl AsFd for TcpStream {
826        fn as_fd(&self) -> BorrowedFd<'_> {
827            <StdTcpStream as AsFd>::as_fd(self)
828        }
829    }
830
831    impl AsRawFd for TcpStream {
832        fn as_raw_fd(&self) -> RawFd {
833            <StdTcpStream as AsRawFd>::as_raw_fd(self)
834        }
835    }
836
837    impl AsRawFd for &TcpStream {
838        fn as_raw_fd(&self) -> RawFd {
839            <StdTcpStream as AsRawFd>::as_raw_fd(self)
840        }
841    }
842
843    impl FromRawFd for TcpStream {
844        unsafe fn from_raw_fd(fd: RawFd) -> Self {
845            Self::Plain(unsafe { StdTcpStream::from_raw_fd(fd) }, false)
846        }
847    }
848}
849
850#[cfg(windows)]
851mod sys {
852    use crate::TcpStream;
853    use std::{
854        net::TcpStream as StdTcpStream,
855        os::windows::io::{AsRawSocket, AsSocket, BorrowedSocket, FromRawSocket, RawSocket},
856    };
857
858    impl AsSocket for TcpStream {
859        fn as_socket(&self) -> BorrowedSocket<'_> {
860            <StdTcpStream as AsSocket>::as_socket(self)
861        }
862    }
863
864    impl AsRawSocket for TcpStream {
865        fn as_raw_socket(&self) -> RawSocket {
866            <StdTcpStream as AsRawSocket>::as_raw_socket(self)
867        }
868    }
869
870    impl AsRawSocket for &TcpStream {
871        fn as_raw_socket(&self) -> RawSocket {
872            <StdTcpStream as AsRawSocket>::as_raw_socket(self)
873        }
874    }
875
876    impl FromRawSocket for TcpStream {
877        unsafe fn from_raw_socket(socket: RawSocket) -> Self {
878            Self::Plain(unsafe { StdTcpStream::from_raw_socket(socket) }, false)
879        }
880    }
881}