tcp_stream/
lib.rs

1#![deny(missing_docs)]
2#![allow(clippy::result_large_err)]
3
4//! # std::net::TCP stream on steroids
5//!
6//! tcp-stream is a library aiming at providing TLS support to std::net::TcpStream
7//!
8//! # Examples
9//!
10//! To connect to a remote server:
11//!
12//! ```rust
13//! use tcp_stream::{HandshakeError, TcpStream, TLSConfig};
14//!
15//! use std::io::{self, Read, Write};
16//!
17//! let mut stream = TcpStream::connect("www.rust-lang.org:443").unwrap();
18//! stream.set_nonblocking(true).unwrap();
19//!
20//! while !stream.is_connected() {
21//!     if stream.try_connect().unwrap() {
22//!         break;
23//!     }
24//! }
25//!
26//! let mut stream = stream.into_tls("www.rust-lang.org", TLSConfig::default());
27//!
28//! while let Err(HandshakeError::WouldBlock(mid_handshake)) = stream {
29//!     stream = mid_handshake.handshake();
30//! }
31//!
32//! let mut stream = stream.unwrap();
33//!
34//! while let Err(err) = stream.write_all(b"GET / HTTP/1.0\r\n\r\n") {
35//!     if err.kind() != io::ErrorKind::WouldBlock {
36//!         panic!("error: {:?}", err);
37//!     }
38//! }
39//! stream.flush().unwrap();
40//! let mut res = vec![];
41//! while let Err(err) = stream.read_to_end(&mut res) {
42//!     if err.kind() != io::ErrorKind::WouldBlock {
43//!         panic!("stream error: {:?}", err);
44//!     }
45//! }
46//! println!("{}", String::from_utf8_lossy(&res));
47//! ```
48
49use cfg_if::cfg_if;
50use std::{
51    convert::TryFrom,
52    error::Error,
53    fmt,
54    io::{self, IoSlice, IoSliceMut, Read, Write},
55    net::{TcpStream as StdTcpStream, ToSocketAddrs},
56    ops::{Deref, DerefMut},
57    time::Duration,
58};
59
60#[cfg(feature = "rustls-common")]
61mod rustls_impl;
62#[cfg(feature = "rustls-common")]
63pub use rustls_impl::*;
64
65#[cfg(feature = "native-tls")]
66mod native_tls_impl;
67#[cfg(feature = "native-tls")]
68pub use native_tls_impl::*;
69
70#[cfg(feature = "openssl")]
71mod openssl_impl;
72#[cfg(feature = "openssl")]
73pub use openssl_impl::*;
74
75/// Wrapper around plain or TLS TCP streams
76pub enum TcpStream {
77    /// Wrapper around std::net::TcpStream
78    Plain(StdTcpStream, bool),
79    #[cfg(feature = "native-tls")]
80    /// Wrapper around a TLS stream hanled by native-tls
81    NativeTls(Box<NativeTlsStream>),
82    #[cfg(feature = "openssl")]
83    /// Wrapper around a TLS stream hanled by openssl
84    OpenSsl(Box<OpenSslStream>),
85    #[cfg(feature = "rustls-common")]
86    /// Wrapper around a TLS stream hanled by rustls
87    Rustls(Box<RustlsStream>),
88}
89
90/// Holds extra TLS configuration
91#[derive(Default, Debug, PartialEq)]
92pub struct TLSConfig<'data, 'key, 'chain> {
93    /// Use for client certificate authentication
94    pub identity: Option<Identity<'data, 'key>>,
95    /// The custom certificates chain in PEM format
96    pub cert_chain: Option<&'chain str>,
97}
98
99/// Holds extra TLS configuration
100#[derive(Default, Debug, PartialEq)]
101pub struct OwnedTLSConfig {
102    /// Use for client certificate authentication
103    pub identity: Option<OwnedIdentity>,
104    /// The custom certificates chain in PEM format
105    pub cert_chain: Option<String>,
106}
107
108impl OwnedTLSConfig {
109    /// Get the ephemeral `TLSConfig` corresponding to the `OwnedTLSConfig`
110    #[must_use]
111    pub fn as_ref(&self) -> TLSConfig<'_, '_, '_> {
112        TLSConfig {
113            identity: self.identity.as_ref().map(OwnedIdentity::as_ref),
114            cert_chain: self.cert_chain.as_deref(),
115        }
116    }
117}
118
119/// Holds one of:
120/// - PKCS#12 DER-encoded identity and decryption password
121/// - PKCS#8 PEM-encoded certificate and key (without decryption password)
122#[derive(Debug, PartialEq)]
123pub enum Identity<'data, 'key> {
124    /// PKCS#12 DER-encoded identity with decryption password
125    PKCS12 {
126        /// PKCS#12 DER-encoded identity
127        der: &'data [u8],
128        /// Decryption password
129        password: &'key str,
130    },
131    /// PEM encoded DER private key with PEM encoded certificate
132    PKCS8 {
133        /// PEM-encoded certificate
134        pem: &'data [u8],
135        /// PEM-encoded key
136        key: &'key [u8],
137    },
138}
139
140/// Holds one of:
141/// - PKCS#12 DER-encoded identity and decryption password
142/// - PKCS#8 PEM-encoded certificate and key (without decryption password)
143#[derive(Debug, PartialEq)]
144pub enum OwnedIdentity {
145    /// PKCS#12 DER-encoded identity with decryption password
146    PKCS12 {
147        /// PKCS#12 DER-encoded identity
148        der: Vec<u8>,
149        /// Decryption password
150        password: String,
151    },
152    /// PKCS#8 encoded DER private key with PEM encoded certificate
153    PKCS8 {
154        /// PEM-encoded certificate
155        pem: Vec<u8>,
156        /// PEM-encoded key
157        key: Vec<u8>,
158    },
159}
160
161impl OwnedIdentity {
162    /// Get the ephemeral `Identity` corresponding to the `OwnedIdentity`
163    #[must_use]
164    pub fn as_ref(&self) -> Identity<'_, '_> {
165        match self {
166            Self::PKCS8 { pem, key } => Identity::PKCS8 { pem, key },
167            Self::PKCS12 { der, password } => Identity::PKCS12 { der, password },
168        }
169    }
170}
171
172/// Holds either the TLS `TcpStream` result or the current handshake state
173pub type HandshakeResult = Result<TcpStream, HandshakeError>;
174
175impl TcpStream {
176    /// Wrapper around `std::net::TcpStream::connect`
177    pub fn connect<A: ToSocketAddrs>(addr: A) -> io::Result<Self> {
178        connect_std(addr, None).and_then(Self::try_from)
179    }
180
181    /// Wrapper around `std::net::TcpStream::connect_timeout`
182    pub fn connect_timeout<A: ToSocketAddrs>(addr: A, timeout: Duration) -> io::Result<Self> {
183        connect_std(addr, Some(timeout)).and_then(Self::try_from)
184    }
185
186    /// Convert from a `std::net::TcpStream`
187    pub fn from_std(stream: StdTcpStream) -> io::Result<Self> {
188        Self::try_from(stream)
189    }
190
191    /// Check whether the stream is connected or not
192    #[must_use]
193    #[allow(irrefutable_let_patterns)]
194    pub fn is_connected(&self) -> bool {
195        if let Self::Plain(_, connected) = self {
196            *connected
197        } else {
198            true
199        }
200    }
201
202    /// Attempt reading from underlying stream, returning Ok(()) if the stream is readable
203    pub fn is_readable(&self) -> io::Result<()> {
204        self.deref().read(&mut []).map(|_| ())
205    }
206
207    /// Attempt writing to underlying stream, returning Ok(()) if the stream is writable
208    pub fn is_writable(&self) -> io::Result<()> {
209        self.deref().write(&[]).map(|_| ())
210    }
211
212    /// Retry the connection. Returns:
213    /// - Ok(true) if connected
214    /// - Ok(false) if connecting
215    /// - Err(_) if an error is encountered
216    #[allow(irrefutable_let_patterns)]
217    pub fn try_connect(&mut self) -> io::Result<bool> {
218        if self.is_connected() {
219            return Ok(true);
220        }
221        match self.is_writable() {
222            Ok(()) => {
223                if let Self::Plain(_, connected) = self {
224                    *connected = true;
225                }
226                Ok(true)
227            }
228            Err(err)
229                if [io::ErrorKind::WouldBlock, io::ErrorKind::NotConnected]
230                    .contains(&err.kind()) =>
231            {
232                Ok(false)
233            }
234            Err(err) => Err(err),
235        }
236    }
237
238    /// Enable TLS
239    pub fn into_tls(
240        self,
241        domain: &str,
242        config: TLSConfig<'_, '_, '_>,
243    ) -> Result<Self, HandshakeError> {
244        into_tls_impl(self, domain, config)
245    }
246
247    #[cfg(feature = "native-tls")]
248    /// Enable TLS using native-tls
249    pub fn into_native_tls(
250        self,
251        connector: &NativeTlsConnector,
252        domain: &str,
253    ) -> Result<Self, HandshakeError> {
254        Ok(connector.connect(domain, self.into_plain()?)?.into())
255    }
256
257    #[cfg(feature = "openssl")]
258    /// Enable TLS using openssl
259    pub fn into_openssl(
260        self,
261        connector: &OpenSslConnector,
262        domain: &str,
263    ) -> Result<Self, HandshakeError> {
264        Ok(connector.connect(domain, self.into_plain()?)?.into())
265    }
266
267    #[cfg(feature = "rustls-common")]
268    /// Enable TLS using rustls
269    pub fn into_rustls(
270        self,
271        connector: &RustlsConnector,
272        domain: &str,
273    ) -> Result<Self, HandshakeError> {
274        Ok(connector.connect(domain, self.into_plain()?)?.into())
275    }
276
277    #[allow(irrefutable_let_patterns)]
278    fn into_plain(self) -> Result<TcpStream, io::Error> {
279        if let TcpStream::Plain(plain, connected) = self {
280            Ok(TcpStream::Plain(plain, connected))
281        } else {
282            Err(io::Error::new(
283                io::ErrorKind::AlreadyExists,
284                "already a TLS stream",
285            ))
286        }
287    }
288}
289
290fn connect_std<A: ToSocketAddrs>(addr: A, timeout: Option<Duration>) -> io::Result<StdTcpStream> {
291    let stream = connect_std_raw(addr, timeout)?;
292    stream.set_nodelay(true)?;
293    Ok(stream)
294}
295
296fn connect_std_raw<A: ToSocketAddrs>(
297    addr: A,
298    timeout: Option<Duration>,
299) -> io::Result<StdTcpStream> {
300    if let Some(timeout) = timeout {
301        let addrs = addr.to_socket_addrs()?;
302        let mut err = None;
303        for addr in addrs {
304            match StdTcpStream::connect_timeout(&addr, timeout) {
305                Ok(stream) => return Ok(stream),
306                Err(error) => err = Some(error),
307            }
308        }
309        Err(err.unwrap_or_else(|| {
310            io::Error::new(io::ErrorKind::AddrNotAvailable, "couldn't resolve host")
311        }))
312    } else {
313        StdTcpStream::connect(addr)
314    }
315}
316
317cfg_if! {
318    if #[cfg(feature = "rustls-native-certs")] {
319        fn into_tls_impl(s: TcpStream, domain: &str, config: TLSConfig<'_, '_, '_>) -> HandshakeResult {
320            into_rustls_impl(s, RustlsConnectorConfig::new_with_native_certs()?, domain, config)
321        }
322    } else if #[cfg(feature = "rustls-webpki-roots-certs")] {
323        fn into_tls_impl(s: TcpStream, domain: &str, config: TLSConfig<'_, '_, '_>) -> HandshakeResult {
324            into_rustls_impl(s, RustlsConnectorConfig::new_with_webpki_roots_certs(), domain, config)
325        }
326    } else if #[cfg(feature = "rustls-common")] {
327        fn into_tls_impl(s: TcpStream, domain: &str, config: TLSConfig<'_, '_, '_>) -> HandshakeResult {
328            into_rustls_impl(s, RustlsConnectorConfig::default(), domain, config)
329        }
330    } else if #[cfg(feature = "openssl")] {
331        fn into_tls_impl(s: TcpStream, domain: &str, config: TLSConfig<'_, '_, '_>) -> HandshakeResult {
332            into_openssl_impl(s, domain, config)
333        }
334    } else if #[cfg(feature = "native-tls")] {
335        fn into_tls_impl(s: TcpStream, domain: &str, config: TLSConfig<'_, '_, '_>) -> HandshakeResult {
336            into_native_tls_impl(s, domain, config)
337        }
338    } else {
339        fn into_tls_impl(s: TcpStream, _domain: &str, _: TLSConfig<'_, '_, '_>) -> HandshakeResult {
340            Ok(s.into_plain()?)
341        }
342    }
343}
344
345impl TryFrom<StdTcpStream> for TcpStream {
346    type Error = io::Error;
347
348    fn try_from(s: StdTcpStream) -> io::Result<Self> {
349        let mut this = TcpStream::Plain(s, false);
350        this.try_connect()?;
351        Ok(this)
352    }
353}
354
355impl Deref for TcpStream {
356    type Target = StdTcpStream;
357
358    fn deref(&self) -> &Self::Target {
359        match self {
360            TcpStream::Plain(plain, _) => plain,
361            #[cfg(feature = "native-tls")]
362            TcpStream::NativeTls(tls) => tls.get_ref(),
363            #[cfg(feature = "openssl")]
364            TcpStream::OpenSsl(tls) => tls.get_ref(),
365            #[cfg(feature = "rustls-common")]
366            TcpStream::Rustls(tls) => tls.get_ref(),
367        }
368    }
369}
370
371impl DerefMut for TcpStream {
372    fn deref_mut(&mut self) -> &mut Self::Target {
373        match self {
374            TcpStream::Plain(plain, _) => plain,
375            #[cfg(feature = "native-tls")]
376            TcpStream::NativeTls(tls) => tls.get_mut(),
377            #[cfg(feature = "openssl")]
378            TcpStream::OpenSsl(tls) => tls.get_mut(),
379            #[cfg(feature = "rustls-common")]
380            TcpStream::Rustls(tls) => tls.get_mut(),
381        }
382    }
383}
384
385macro_rules! fwd_impl {
386    ($self:ident, $method:ident, $($args:expr),*) => {
387        match $self {
388            TcpStream::Plain(plain, _) => plain.$method($($args),*),
389            #[cfg(feature = "native-tls")]
390            TcpStream::NativeTls(tls) => tls.$method($($args),*),
391            #[cfg(feature = "openssl")]
392            TcpStream::OpenSsl(tls) => tls.$method($($args),*),
393            #[cfg(feature = "rustls-common")]
394            TcpStream::Rustls(tls) => tls.$method($($args),*),
395        }
396    };
397}
398
399impl Read for TcpStream {
400    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
401        fwd_impl!(self, read, buf)
402    }
403
404    fn read_vectored(&mut self, bufs: &mut [IoSliceMut<'_>]) -> io::Result<usize> {
405        fwd_impl!(self, read_vectored, bufs)
406    }
407
408    fn read_to_end(&mut self, buf: &mut Vec<u8>) -> io::Result<usize> {
409        fwd_impl!(self, read_to_end, buf)
410    }
411
412    fn read_to_string(&mut self, buf: &mut String) -> io::Result<usize> {
413        fwd_impl!(self, read_to_string, buf)
414    }
415
416    fn read_exact(&mut self, buf: &mut [u8]) -> io::Result<()> {
417        fwd_impl!(self, read_exact, buf)
418    }
419}
420
421impl Write for TcpStream {
422    fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
423        fwd_impl!(self, write, buf)
424    }
425
426    fn flush(&mut self) -> io::Result<()> {
427        fwd_impl!(self, flush,)
428    }
429
430    fn write_vectored(&mut self, bufs: &[IoSlice<'_>]) -> io::Result<usize> {
431        fwd_impl!(self, write_vectored, bufs)
432    }
433
434    fn write_all(&mut self, buf: &[u8]) -> io::Result<()> {
435        fwd_impl!(self, write_all, buf)
436    }
437
438    fn write_fmt(&mut self, fmt: fmt::Arguments<'_>) -> io::Result<()> {
439        fwd_impl!(self, write_fmt, fmt)
440    }
441}
442
443impl fmt::Debug for TcpStream {
444    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
445        f.debug_struct("TcpStream")
446            .field("inner", self.deref())
447            .finish()
448    }
449}
450
451/// A TLS stream which has been interrupted during the handshake
452#[allow(clippy::large_enum_variant)]
453#[derive(Debug)]
454pub enum MidHandshakeTlsStream {
455    /// Not a TLS stream
456    Plain(TcpStream),
457    #[cfg(feature = "native-tls")]
458    /// A native-tls MidHandshakeTlsStream
459    NativeTls(NativeTlsMidHandshakeTlsStream),
460    #[cfg(feature = "openssl")]
461    /// An openssl MidHandshakeTlsStream
462    Openssl(OpenSslMidHandshakeTlsStream),
463    #[cfg(feature = "rustls-common")]
464    /// A rustls-connector MidHandshakeTlsStream
465    Rustls(RustlsMidHandshakeTlsStream),
466}
467
468impl MidHandshakeTlsStream {
469    /// Get a reference to the inner stream
470    #[must_use]
471    pub fn get_ref(&self) -> &TcpStream {
472        match self {
473            MidHandshakeTlsStream::Plain(mid) => mid,
474            #[cfg(feature = "native-tls")]
475            MidHandshakeTlsStream::NativeTls(mid) => mid.get_ref(),
476            #[cfg(feature = "openssl")]
477            MidHandshakeTlsStream::Openssl(mid) => mid.get_ref(),
478            #[cfg(feature = "rustls-common")]
479            MidHandshakeTlsStream::Rustls(mid) => mid.get_ref(),
480        }
481    }
482
483    /// Get a mutable reference to the inner stream
484    #[must_use]
485    pub fn get_mut(&mut self) -> &mut TcpStream {
486        match self {
487            MidHandshakeTlsStream::Plain(mid) => mid,
488            #[cfg(feature = "native-tls")]
489            MidHandshakeTlsStream::NativeTls(mid) => mid.get_mut(),
490            #[cfg(feature = "openssl")]
491            MidHandshakeTlsStream::Openssl(mid) => mid.get_mut(),
492            #[cfg(feature = "rustls-common")]
493            MidHandshakeTlsStream::Rustls(mid) => mid.get_mut(),
494        }
495    }
496
497    /// Retry the handshake
498    pub fn handshake(self) -> HandshakeResult {
499        Ok(match self {
500            MidHandshakeTlsStream::Plain(mut mid) => {
501                if !mid.try_connect()? {
502                    return Err(HandshakeError::WouldBlock(mid.into()));
503                }
504                mid
505            }
506            #[cfg(feature = "native-tls")]
507            MidHandshakeTlsStream::NativeTls(mut mid) => {
508                if !mid.get_mut().try_connect()? {
509                    return Err(HandshakeError::WouldBlock(mid.into()));
510                }
511                mid.handshake()?.into()
512            }
513            #[cfg(feature = "openssl")]
514            MidHandshakeTlsStream::Openssl(mut mid) => {
515                if !mid.get_mut().try_connect()? {
516                    return Err(HandshakeError::WouldBlock(mid.into()));
517                }
518                mid.handshake()?.into()
519            }
520            #[cfg(feature = "rustls-common")]
521            MidHandshakeTlsStream::Rustls(mut mid) => {
522                if !mid.get_mut().try_connect()? {
523                    return Err(HandshakeError::WouldBlock(mid.into()));
524                }
525                mid.handshake()?.into()
526            }
527        })
528    }
529}
530
531impl From<TcpStream> for MidHandshakeTlsStream {
532    fn from(mid: TcpStream) -> Self {
533        MidHandshakeTlsStream::Plain(mid)
534    }
535}
536
537impl fmt::Display for MidHandshakeTlsStream {
538    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
539        f.write_str("MidHandshakeTlsStream")
540    }
541}
542
543/// An error returned while performing the handshake
544#[allow(clippy::large_enum_variant)]
545#[derive(Debug)]
546pub enum HandshakeError {
547    /// We hit WouldBlock during handshake
548    WouldBlock(MidHandshakeTlsStream),
549    /// We hit a critical failure
550    Failure(io::Error),
551}
552
553impl HandshakeError {
554    /// Try and get the inner mid handshake TLS stream from this error
555    pub fn into_mid_handshake_tls_stream(self) -> io::Result<MidHandshakeTlsStream> {
556        match self {
557            Self::WouldBlock(mid) => Ok(mid),
558            Self::Failure(error) => Err(error),
559        }
560    }
561}
562
563impl fmt::Display for HandshakeError {
564    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
565        match self {
566            HandshakeError::WouldBlock(_) => f.write_str("WouldBlock hit during handshake"),
567            HandshakeError::Failure(err) => f.write_fmt(format_args!("IO error: {err}")),
568        }
569    }
570}
571
572impl Error for HandshakeError {
573    fn source(&self) -> Option<&(dyn Error + 'static)> {
574        match self {
575            HandshakeError::Failure(err) => Some(err),
576            _ => None,
577        }
578    }
579}
580
581impl From<io::Error> for HandshakeError {
582    fn from(err: io::Error) -> Self {
583        HandshakeError::Failure(err)
584    }
585}
586
587mod sys;