tcp_stream/
lib.rs

1#![deny(missing_docs)]
2#![allow(clippy::large_enum_variant, clippy::result_large_err)]
3
4//! # std::net::TCP stream on steroids
5//!
6//! tcp-stream is a library aiming at providing TLS support to std::net::TcpStream
7//!
8//! # Examples
9//!
10//! To connect to a remote server:
11//!
12//! ```rust
13//! use tcp_stream::{HandshakeError, TcpStream, TLSConfig};
14//!
15//! use std::io::{self, Read, Write};
16//!
17//! let mut stream = TcpStream::connect("www.rust-lang.org:443").unwrap();
18//! stream.set_nonblocking(true).unwrap();
19//!
20//! loop {
21//!     if stream.try_connect().unwrap() {
22//!         break;
23//!     }
24//! }
25//!
26//! let mut stream = stream.into_tls("www.rust-lang.org", TLSConfig::default());
27//!
28//! while let Err(HandshakeError::WouldBlock(mid_handshake)) = stream {
29//!     stream = mid_handshake.handshake();
30//! }
31//!
32//! let mut stream = stream.unwrap();
33//!
34//! while let Err(err) = stream.write_all(b"GET / HTTP/1.0\r\n\r\n") {
35//!     if err.kind() != io::ErrorKind::WouldBlock {
36//!         panic!("error: {:?}", err);
37//!     }
38//! }
39//! stream.flush().unwrap();
40//! let mut res = vec![];
41//! while let Err(err) = stream.read_to_end(&mut res) {
42//!     if err.kind() != io::ErrorKind::WouldBlock {
43//!         panic!("stream error: {:?}", err);
44//!     }
45//! }
46//! println!("{}", String::from_utf8_lossy(&res));
47//! ```
48
49use cfg_if::cfg_if;
50use std::{
51    convert::TryFrom,
52    error::Error,
53    fmt,
54    io::{self, IoSlice, IoSliceMut, Read, Write},
55    net::{TcpStream as StdTcpStream, ToSocketAddrs},
56    ops::{Deref, DerefMut},
57    time::Duration,
58};
59
60#[cfg(feature = "rustls")]
61mod rustls_impl;
62#[cfg(feature = "rustls")]
63pub use rustls_impl::*;
64
65#[cfg(feature = "native-tls")]
66mod native_tls_impl;
67#[cfg(feature = "native-tls")]
68pub use native_tls_impl::*;
69
70#[cfg(feature = "openssl")]
71mod openssl_impl;
72#[cfg(feature = "openssl")]
73pub use openssl_impl::*;
74
75#[cfg(feature = "futures")]
76mod futures;
77#[cfg(feature = "futures")]
78pub use futures::*;
79
80/// Wrapper around plain or TLS TCP streams
81#[non_exhaustive]
82pub enum TcpStream {
83    /// Wrapper around std::net::TcpStream
84    Plain(StdTcpStream),
85    #[cfg(feature = "native-tls")]
86    /// Wrapper around a TLS stream hanled by native-tls
87    NativeTls(NativeTlsStream),
88    #[cfg(feature = "openssl")]
89    /// Wrapper around a TLS stream hanled by openssl
90    Openssl(OpensslStream),
91    #[cfg(feature = "rustls")]
92    /// Wrapper around a TLS stream hanled by rustls
93    Rustls(RustlsStream),
94}
95
96/// Holds extra TLS configuration
97#[derive(Default, Debug, PartialEq)]
98pub struct TLSConfig<'data, 'key, 'chain> {
99    /// Use for client certificate authentication
100    pub identity: Option<Identity<'data, 'key>>,
101    /// The custom certificates chain in PEM format
102    pub cert_chain: Option<&'chain str>,
103}
104
105/// Holds extra TLS configuration
106#[derive(Default, Debug, PartialEq)]
107pub struct OwnedTLSConfig {
108    /// Use for client certificate authentication
109    pub identity: Option<OwnedIdentity>,
110    /// The custom certificates chain in PEM format
111    pub cert_chain: Option<String>,
112}
113
114impl OwnedTLSConfig {
115    /// Get the ephemeral `TLSConfig` corresponding to the `OwnedTLSConfig`
116    #[must_use]
117    pub fn as_ref(&self) -> TLSConfig<'_, '_, '_> {
118        TLSConfig {
119            identity: self.identity.as_ref().map(OwnedIdentity::as_ref),
120            cert_chain: self.cert_chain.as_deref(),
121        }
122    }
123}
124
125/// Holds one of:
126/// - PKCS#12 DER-encoded identity and decryption password
127/// - PKCS#8 PEM-encoded certificate and key (without decryption password)
128#[derive(Debug, PartialEq)]
129pub enum Identity<'data, 'key> {
130    /// PKCS#12 DER-encoded identity with decryption password
131    PKCS12 {
132        /// PKCS#12 DER-encoded identity
133        der: &'data [u8],
134        /// Decryption password
135        password: &'key str,
136    },
137    /// PEM encoded DER private key with PEM encoded certificate
138    PKCS8 {
139        /// PEM-encoded certificate
140        pem: &'data [u8],
141        /// PEM-encoded key
142        key: &'key [u8],
143    },
144}
145
146/// Holds one of:
147/// - PKCS#12 DER-encoded identity and decryption password
148/// - PKCS#8 PEM-encoded certificate and key (without decryption password)
149#[derive(Debug, PartialEq)]
150pub enum OwnedIdentity {
151    /// PKCS#12 DER-encoded identity with decryption password
152    PKCS12 {
153        /// PKCS#12 DER-encoded identity
154        der: Vec<u8>,
155        /// Decryption password
156        password: String,
157    },
158    /// PKCS#8 encoded DER private key with PEM encoded certificate
159    PKCS8 {
160        /// PEM-encoded certificate
161        pem: Vec<u8>,
162        /// PEM-encoded key
163        key: Vec<u8>,
164    },
165}
166
167impl OwnedIdentity {
168    /// Get the ephemeral `Identity` corresponding to the `OwnedIdentity`
169    #[must_use]
170    pub fn as_ref(&self) -> Identity<'_, '_> {
171        match self {
172            Self::PKCS8 { pem, key } => Identity::PKCS8 { pem, key },
173            Self::PKCS12 { der, password } => Identity::PKCS12 { der, password },
174        }
175    }
176}
177
178/// Holds either the TLS `TcpStream` result or the current handshake state
179pub type HandshakeResult = Result<TcpStream, HandshakeError>;
180
181impl TcpStream {
182    /// Wrapper around `std::net::TcpStream::connect`
183    pub fn connect<A: ToSocketAddrs>(addr: A) -> io::Result<Self> {
184        connect_std(addr, None).and_then(Self::try_from)
185    }
186
187    /// Wrapper around `std::net::TcpStream::connect_timeout`
188    pub fn connect_timeout<A: ToSocketAddrs>(addr: A, timeout: Duration) -> io::Result<Self> {
189        connect_std(addr, Some(timeout)).and_then(Self::try_from)
190    }
191
192    /// Convert from a `std::net::TcpStream`
193    pub fn from_std(stream: StdTcpStream) -> io::Result<Self> {
194        Self::try_from(stream)
195    }
196
197    /// Attempt reading from underlying stream, returning Ok(()) if the stream is readable
198    pub fn is_readable(&self) -> io::Result<()> {
199        self.deref().read(&mut []).map(|_| ())
200    }
201
202    /// Attempt writing to underlying stream, returning Ok(()) if the stream is writable
203    pub fn is_writable(&mut self) -> io::Result<()> {
204        is_writable(self.deref_mut())
205    }
206
207    /// Retry the connection. Returns:
208    /// - Ok(true) if connected
209    /// - Ok(false) if connecting
210    /// - Err(_) if an error is encountered
211    pub fn try_connect(&mut self) -> io::Result<bool> {
212        try_connect(self)
213    }
214
215    /// Enable TLS
216    pub fn into_tls(
217        self,
218        domain: &str,
219        config: TLSConfig<'_, '_, '_>,
220    ) -> Result<Self, HandshakeError> {
221        into_tls_impl(self, domain, config)
222    }
223
224    #[cfg(feature = "native-tls")]
225    /// Enable TLS using native-tls
226    pub fn into_native_tls(
227        self,
228        connector: &NativeTlsConnector,
229        domain: &str,
230    ) -> Result<Self, HandshakeError> {
231        Ok(connector.connect(domain, self.into_plain()?)?.into())
232    }
233
234    #[cfg(feature = "openssl")]
235    /// Enable TLS using openssl
236    pub fn into_openssl(
237        self,
238        connector: &OpensslConnector,
239        domain: &str,
240    ) -> Result<Self, HandshakeError> {
241        Ok(connector.connect(domain, self.into_plain()?)?.into())
242    }
243
244    #[cfg(feature = "rustls")]
245    /// Enable TLS using rustls
246    pub fn into_rustls(
247        self,
248        connector: &RustlsConnector,
249        domain: &str,
250    ) -> Result<Self, HandshakeError> {
251        Ok(connector.connect(domain, self.into_plain()?)?.into())
252    }
253
254    #[allow(irrefutable_let_patterns)]
255    fn into_plain(self) -> Result<StdTcpStream, io::Error> {
256        if let Self::Plain(plain) = self {
257            Ok(plain)
258        } else {
259            Err(io::Error::new(
260                io::ErrorKind::AlreadyExists,
261                "already a TLS stream",
262            ))
263        }
264    }
265}
266
267fn connect_std<A: ToSocketAddrs>(addr: A, timeout: Option<Duration>) -> io::Result<StdTcpStream> {
268    let stream = connect_std_raw(addr, timeout)?;
269    stream.set_nodelay(true)?;
270    Ok(stream)
271}
272
273fn connect_std_raw<A: ToSocketAddrs>(
274    addr: A,
275    timeout: Option<Duration>,
276) -> io::Result<StdTcpStream> {
277    if let Some(timeout) = timeout {
278        let addrs = addr.to_socket_addrs()?;
279        let mut err = None;
280        for addr in addrs {
281            match StdTcpStream::connect_timeout(&addr, timeout) {
282                Ok(stream) => return Ok(stream),
283                Err(error) => err = Some(error),
284            }
285        }
286        Err(err.unwrap_or_else(|| {
287            io::Error::new(io::ErrorKind::AddrNotAvailable, "couldn't resolve host")
288        }))
289    } else {
290        StdTcpStream::connect(addr)
291    }
292}
293
294fn try_connect(stream: &mut StdTcpStream) -> io::Result<bool> {
295    match is_writable(stream) {
296        Ok(()) => Ok(true),
297        Err(err)
298            if [io::ErrorKind::WouldBlock, io::ErrorKind::NotConnected].contains(&err.kind()) =>
299        {
300            Ok(false)
301        }
302        Err(err) => Err(err),
303    }
304}
305
306fn is_writable(stream: &mut StdTcpStream) -> io::Result<()> {
307    stream.write(&[]).map(|_| ())
308}
309
310fn into_tls_impl(s: TcpStream, domain: &str, config: TLSConfig<'_, '_, '_>) -> HandshakeResult {
311    cfg_if! {
312        if #[cfg(feature = "rustls-native-certs")] {
313            into_rustls_impl(s, RustlsConnectorConfig::new_with_native_certs()?, domain, config)
314        } else if #[cfg(feature = "rustls-webpki-roots-certs")] {
315            into_rustls_impl(s, RustlsConnectorConfig::new_with_webpki_roots_certs(), domain, config)
316        } else if #[cfg(feature = "rustls")] {
317            into_rustls_impl(s, RustlsConnectorConfig::default(), domain, config)
318        } else if #[cfg(feature = "openssl")] {
319            into_openssl_impl(s, domain, config)
320        } else if #[cfg(feature = "native-tls")] {
321            into_native_tls_impl(s, domain, config)
322        } else {
323            let _ = (domain, config);
324            Ok(TcpStream::Plain(s.into_plain()?))
325        }
326    }
327}
328
329impl TryFrom<StdTcpStream> for TcpStream {
330    type Error = io::Error;
331
332    fn try_from(s: StdTcpStream) -> io::Result<Self> {
333        let mut this = Self::Plain(s);
334        this.try_connect()?;
335        Ok(this)
336    }
337}
338
339impl Deref for TcpStream {
340    type Target = StdTcpStream;
341
342    fn deref(&self) -> &Self::Target {
343        match self {
344            Self::Plain(plain) => plain,
345            #[cfg(feature = "native-tls")]
346            Self::NativeTls(tls) => tls.get_ref(),
347            #[cfg(feature = "openssl")]
348            Self::Openssl(tls) => tls.get_ref(),
349            #[cfg(feature = "rustls")]
350            Self::Rustls(tls) => tls.get_ref(),
351        }
352    }
353}
354
355impl DerefMut for TcpStream {
356    fn deref_mut(&mut self) -> &mut Self::Target {
357        match self {
358            Self::Plain(plain) => plain,
359            #[cfg(feature = "native-tls")]
360            Self::NativeTls(tls) => tls.get_mut(),
361            #[cfg(feature = "openssl")]
362            Self::Openssl(tls) => tls.get_mut(),
363            #[cfg(feature = "rustls")]
364            Self::Rustls(tls) => tls.get_mut(),
365        }
366    }
367}
368
369macro_rules! fwd_impl {
370    ($self:ident, $method:ident, $($args:expr),*) => {
371        match $self {
372            Self::Plain(plain) => plain.$method($($args),*),
373            #[cfg(feature = "native-tls")]
374            Self::NativeTls(tls) => tls.$method($($args),*),
375            #[cfg(feature = "openssl")]
376            Self::Openssl(tls) => tls.$method($($args),*),
377            #[cfg(feature = "rustls")]
378            Self::Rustls(tls) => tls.$method($($args),*),
379        }
380    };
381}
382
383impl Read for TcpStream {
384    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
385        fwd_impl!(self, read, buf)
386    }
387
388    fn read_vectored(&mut self, bufs: &mut [IoSliceMut<'_>]) -> io::Result<usize> {
389        fwd_impl!(self, read_vectored, bufs)
390    }
391
392    fn read_to_end(&mut self, buf: &mut Vec<u8>) -> io::Result<usize> {
393        fwd_impl!(self, read_to_end, buf)
394    }
395
396    fn read_to_string(&mut self, buf: &mut String) -> io::Result<usize> {
397        fwd_impl!(self, read_to_string, buf)
398    }
399
400    fn read_exact(&mut self, buf: &mut [u8]) -> io::Result<()> {
401        fwd_impl!(self, read_exact, buf)
402    }
403}
404
405impl Write for TcpStream {
406    fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
407        fwd_impl!(self, write, buf)
408    }
409
410    fn flush(&mut self) -> io::Result<()> {
411        fwd_impl!(self, flush,)
412    }
413
414    fn write_vectored(&mut self, bufs: &[IoSlice<'_>]) -> io::Result<usize> {
415        fwd_impl!(self, write_vectored, bufs)
416    }
417
418    fn write_all(&mut self, buf: &[u8]) -> io::Result<()> {
419        fwd_impl!(self, write_all, buf)
420    }
421
422    fn write_fmt(&mut self, fmt: fmt::Arguments<'_>) -> io::Result<()> {
423        fwd_impl!(self, write_fmt, fmt)
424    }
425}
426
427impl fmt::Debug for TcpStream {
428    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
429        f.debug_struct("TcpStream")
430            .field("inner", self.deref())
431            .finish()
432    }
433}
434
435/// A TLS stream which has been interrupted during the handshake
436#[derive(Debug)]
437pub enum MidHandshakeTlsStream {
438    /// Not a TLS stream
439    Plain(TcpStream),
440    #[cfg(feature = "native-tls")]
441    /// A native-tls MidHandshakeTlsStream
442    NativeTls(NativeTlsMidHandshakeTlsStream),
443    #[cfg(feature = "openssl")]
444    /// An openssl MidHandshakeTlsStream
445    Openssl(OpensslMidHandshakeTlsStream),
446    #[cfg(feature = "rustls")]
447    /// A rustls-connector MidHandshakeTlsStream
448    Rustls(RustlsMidHandshakeTlsStream),
449}
450
451impl MidHandshakeTlsStream {
452    /// Get a reference to the inner stream
453    #[must_use]
454    pub fn get_ref(&self) -> &StdTcpStream {
455        match self {
456            Self::Plain(mid) => mid,
457            #[cfg(feature = "native-tls")]
458            Self::NativeTls(mid) => mid.get_ref(),
459            #[cfg(feature = "openssl")]
460            Self::Openssl(mid) => mid.get_ref(),
461            #[cfg(feature = "rustls")]
462            Self::Rustls(mid) => mid.get_ref(),
463        }
464    }
465
466    /// Get a mutable reference to the inner stream
467    #[must_use]
468    pub fn get_mut(&mut self) -> &mut StdTcpStream {
469        match self {
470            Self::Plain(mid) => mid,
471            #[cfg(feature = "native-tls")]
472            Self::NativeTls(mid) => mid.get_mut(),
473            #[cfg(feature = "openssl")]
474            Self::Openssl(mid) => mid.get_mut(),
475            #[cfg(feature = "rustls")]
476            Self::Rustls(mid) => mid.get_mut(),
477        }
478    }
479
480    /// Retry the handshake
481    pub fn handshake(mut self) -> HandshakeResult {
482        if !try_connect(self.get_mut())? {
483            return Err(HandshakeError::WouldBlock(self));
484        }
485
486        Ok(match self {
487            Self::Plain(mid) => mid,
488            #[cfg(feature = "native-tls")]
489            Self::NativeTls(mid) => mid.handshake()?.into(),
490            #[cfg(feature = "openssl")]
491            Self::Openssl(mid) => mid.handshake()?.into(),
492            #[cfg(feature = "rustls")]
493            Self::Rustls(mid) => mid.handshake()?.into(),
494        })
495    }
496}
497
498impl From<TcpStream> for MidHandshakeTlsStream {
499    fn from(mid: TcpStream) -> Self {
500        Self::Plain(mid)
501    }
502}
503
504impl fmt::Display for MidHandshakeTlsStream {
505    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
506        f.write_str("MidHandshakeTlsStream")
507    }
508}
509
510/// An error returned while performing the handshake
511#[derive(Debug)]
512pub enum HandshakeError {
513    /// We hit WouldBlock during handshake
514    WouldBlock(MidHandshakeTlsStream),
515    /// We hit a critical failure
516    Failure(io::Error),
517}
518
519impl HandshakeError {
520    /// Try and get the inner mid handshake TLS stream from this error
521    pub fn into_mid_handshake_tls_stream(self) -> io::Result<MidHandshakeTlsStream> {
522        match self {
523            Self::WouldBlock(mid) => Ok(mid),
524            Self::Failure(error) => Err(error),
525        }
526    }
527}
528
529impl fmt::Display for HandshakeError {
530    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
531        match self {
532            Self::WouldBlock(_) => f.write_str("WouldBlock hit during handshake"),
533            Self::Failure(err) => f.write_fmt(format_args!("IO error: {err}")),
534        }
535    }
536}
537
538impl Error for HandshakeError {
539    fn source(&self) -> Option<&(dyn Error + 'static)> {
540        match self {
541            Self::Failure(err) => Some(err),
542            _ => None,
543        }
544    }
545}
546
547impl From<io::Error> for HandshakeError {
548    fn from(err: io::Error) -> Self {
549        Self::Failure(err)
550    }
551}
552
553mod sys;