Skip to main content

tcp_stream/
lib.rs

1#![deny(missing_docs, missing_debug_implementations, unsafe_code)]
2#![warn(unreachable_pub, unused_qualifications, unused_lifetimes)]
3#![warn(
4    clippy::must_use_candidate,
5    clippy::unwrap_in_result,
6    clippy::panic_in_result_fn
7)]
8#![allow(clippy::large_enum_variant, clippy::result_large_err)]
9
10//! A [`TcpStream`] with pluggable TLS support.
11//!
12//! Wraps `std::net::TcpStream` in a [`TcpStream`] enum that can be upgraded
13//! to a TLS-encrypted stream via [`TcpStream::into_tls`]. Supported backends
14//! are rustls (default), native-tls, and OpenSSL, selected through feature
15//! flags. Async variants of the encrypted stream are provided via the
16//! `futures-io` traits when the `futures` feature is enabled.
17//!
18//! # Feature flags
19//!
20//! ## Async runtime (pick exactly one)
21//!
22//! | Flag | Notes |
23//! |------|-------|
24//! | `tokio` *(default)* | Requires a running Tokio runtime |
25//! | `smol` | Uses the smol executor |
26//! | `async-global-executor` | Uses async-global-executor |
27//!
28//! ## TLS backend (pick at most one; `rustls` is the default)
29//!
30//! | Flag | Notes |
31//! |------|-------|
32//! | `rustls` *(default)* | TLS via rustls |
33//! | `native-tls` | TLS via the platform's native library |
34//! | `openssl` | TLS via OpenSSL |
35//!
36//! ## Rustls certificate store (only when `rustls` is active)
37//!
38//! | Flag | Notes |
39//! |------|-------|
40//! | `rustls-platform-verifier` *(default)* | Uses the platform trust store |
41//! | `rustls-native-certs` | Loads native root certificates |
42//! | `rustls-webpki-roots-certs` | Uses the webpki bundled root set |
43//!
44//! ## Rustls crypto provider (at least one must be enabled)
45//!
46//! | Flag | Notes |
47//! |------|-------|
48//! | `rustls--aws_lc_rs` *(default)* | Uses aws-lc-rs |
49//! | `rustls--ring` | Uses ring (more portable) |
50//!
51//! ## Miscellaneous
52//!
53//! | Flag | Notes |
54//! |------|-------|
55//! | `futures` | Enable `futures-io` async trait impls on the encrypted stream |
56//! | `vendored-openssl` | Build a vendored OpenSSL (requires `openssl` feature) |
57//!
58//! # Example
59//!
60//! To connect to a remote server:
61//!
62//! ```rust
63//! use tcp_stream::{HandshakeError, TcpStream, TLSConfig};
64//!
65//! use std::io::{self, Read, Write};
66//!
67//! let mut stream = TcpStream::connect("www.rust-lang.org:443").unwrap();
68//! stream.set_nonblocking(true).unwrap();
69//!
70//! loop {
71//!     if stream.try_connect().unwrap() {
72//!         break;
73//!     }
74//! }
75//!
76//! let mut stream = stream.into_tls("www.rust-lang.org", TLSConfig::default());
77//!
78//! while let Err(HandshakeError::WouldBlock(mid_handshake)) = stream {
79//!     stream = mid_handshake.handshake();
80//! }
81//!
82//! let mut stream = stream.unwrap();
83//!
84//! while let Err(err) = stream.write_all(b"GET / HTTP/1.0\r\n\r\n") {
85//!     if err.kind() != io::ErrorKind::WouldBlock {
86//!         panic!("error: {:?}", err);
87//!     }
88//! }
89//!
90//! while let Err(err) = stream.flush() {
91//!     if err.kind() != io::ErrorKind::WouldBlock {
92//!         panic!("error: {:?}", err);
93//!     }
94//! }
95//!
96//! let mut res = vec![];
97//! while let Err(err) = stream.read_to_end(&mut res) {
98//!     if err.kind() != io::ErrorKind::WouldBlock {
99//!         panic!("stream error: {:?}", err);
100//!     }
101//! }
102//! println!("{}", String::from_utf8_lossy(&res));
103//! ```
104
105use cfg_if::cfg_if;
106use std::{
107    convert::TryFrom,
108    error::Error,
109    fmt,
110    io::{self, IoSlice, IoSliceMut, Read, Write},
111    net::{TcpStream as StdTcpStream, ToSocketAddrs},
112    ops::{Deref, DerefMut},
113    time::Duration,
114};
115
116#[cfg(feature = "rustls")]
117mod rustls_impl;
118#[cfg(feature = "rustls")]
119pub use rustls_impl::*;
120
121#[cfg(feature = "native-tls")]
122mod native_tls_impl;
123#[cfg(feature = "native-tls")]
124pub use native_tls_impl::*;
125
126#[cfg(feature = "openssl")]
127mod openssl_impl;
128#[cfg(feature = "openssl")]
129pub use openssl_impl::*;
130
131macro_rules! fwd_impl {
132    ($self:ident, $method:ident, $($args:expr),*) => {
133        match $self {
134            Self::Plain(plain) => plain.$method($($args),*),
135            #[cfg(feature = "native-tls")]
136            Self::NativeTls(tls) => tls.$method($($args),*),
137            #[cfg(feature = "openssl")]
138            Self::Openssl(tls) => tls.$method($($args),*),
139            #[cfg(feature = "rustls")]
140            Self::Rustls(tls) => tls.$method($($args),*),
141        }
142    };
143}
144
145macro_rules! fwd_pin_impl {
146    ($self:ident, $method:ident, $($args:expr),*) => {
147        match $self.get_mut() {
148            Self::Plain(plain) => Pin::new(plain).$method($($args),*),
149            #[cfg(feature = "native-tls-futures")]
150            Self::NativeTls(tls) => Pin::new(tls).$method($($args),*),
151            #[cfg(feature = "openssl-futures")]
152            Self::Openssl(tls) => Pin::new(tls).$method($($args),*),
153            #[cfg(feature = "rustls-futures")]
154            Self::Rustls(tls) => Pin::new(tls).$method($($args),*),
155        }
156    };
157}
158
159#[cfg(feature = "futures")]
160mod futures;
161#[cfg(feature = "futures")]
162pub use futures::*;
163
164/// Wrapper around plain or TLS TCP streams
165#[non_exhaustive]
166pub enum TcpStream {
167    /// Wrapper around std::net::TcpStream
168    Plain(StdTcpStream),
169    #[cfg(feature = "native-tls")]
170    /// Wrapper around a TLS stream handled by native-tls
171    NativeTls(NativeTlsStream),
172    #[cfg(feature = "openssl")]
173    /// Wrapper around a TLS stream handled by openssl
174    Openssl(OpensslStream),
175    #[cfg(feature = "rustls")]
176    /// Wrapper around a TLS stream handled by rustls
177    Rustls(RustlsStream),
178}
179
180/// Holds extra TLS configuration
181#[derive(Default, Debug, PartialEq)]
182pub struct TLSConfig<'data, 'key, 'chain> {
183    /// Use for client certificate authentication
184    pub identity: Option<Identity<'data, 'key>>,
185    /// The custom certificates chain in PEM format
186    pub cert_chain: Option<&'chain str>,
187}
188
189/// Holds extra TLS configuration
190#[derive(Clone, Default, Debug, PartialEq)]
191pub struct OwnedTLSConfig {
192    /// Use for client certificate authentication
193    pub identity: Option<OwnedIdentity>,
194    /// The custom certificates chain in PEM format
195    pub cert_chain: Option<String>,
196}
197
198impl OwnedTLSConfig {
199    /// Get the ephemeral `TLSConfig` corresponding to the `OwnedTLSConfig`
200    #[must_use]
201    pub fn as_ref(&self) -> TLSConfig<'_, '_, '_> {
202        TLSConfig {
203            identity: self.identity.as_ref().map(OwnedIdentity::as_ref),
204            cert_chain: self.cert_chain.as_deref(),
205        }
206    }
207}
208
209/// Holds one of:
210/// - PKCS#12 DER-encoded identity and decryption password
211/// - PKCS#8 PEM-encoded certificate and key (without decryption password)
212#[derive(Debug, PartialEq)]
213pub enum Identity<'data, 'key> {
214    /// PKCS#12 DER-encoded identity with decryption password
215    PKCS12 {
216        /// PKCS#12 DER-encoded identity
217        der: &'data [u8],
218        /// Decryption password
219        password: &'key str,
220    },
221    /// PEM encoded DER private key with PEM encoded certificate
222    PKCS8 {
223        /// PEM-encoded certificate
224        pem: &'data [u8],
225        /// PEM-encoded key
226        key: &'key [u8],
227    },
228}
229
230/// Holds one of:
231/// - PKCS#12 DER-encoded identity and decryption password
232/// - PKCS#8 PEM-encoded certificate and key (without decryption password)
233#[derive(Clone, Debug, PartialEq)]
234pub enum OwnedIdentity {
235    /// PKCS#12 DER-encoded identity with decryption password
236    PKCS12 {
237        /// PKCS#12 DER-encoded identity
238        der: Vec<u8>,
239        /// Decryption password
240        password: String,
241    },
242    /// PKCS#8 encoded DER private key with PEM encoded certificate
243    PKCS8 {
244        /// PEM-encoded certificate
245        pem: Vec<u8>,
246        /// PEM-encoded key
247        key: Vec<u8>,
248    },
249}
250
251impl OwnedIdentity {
252    /// Get the ephemeral `Identity` corresponding to the `OwnedIdentity`
253    #[must_use]
254    pub fn as_ref(&self) -> Identity<'_, '_> {
255        match self {
256            Self::PKCS8 { pem, key } => Identity::PKCS8 { pem, key },
257            Self::PKCS12 { der, password } => Identity::PKCS12 { der, password },
258        }
259    }
260}
261
262/// Holds either the TLS `TcpStream` result or the current handshake state
263pub type HandshakeResult = Result<TcpStream, HandshakeError>;
264
265impl TcpStream {
266    /// Wrapper around `std::net::TcpStream::connect`
267    pub fn connect<A: ToSocketAddrs>(addr: A) -> io::Result<Self> {
268        connect_std(addr, None).and_then(Self::try_from)
269    }
270
271    /// Wrapper around `std::net::TcpStream::connect_timeout`
272    pub fn connect_timeout<A: ToSocketAddrs>(addr: A, timeout: Duration) -> io::Result<Self> {
273        connect_std(addr, Some(timeout)).and_then(Self::try_from)
274    }
275
276    /// Convert from a `std::net::TcpStream`
277    pub fn from_std(stream: StdTcpStream) -> io::Result<Self> {
278        Self::try_from(stream)
279    }
280
281    /// Attempt reading from underlying stream, returning Ok(()) if the stream is readable
282    pub fn is_readable(&self) -> io::Result<()> {
283        self.deref().read(&mut []).map(|_| ())
284    }
285
286    /// Attempt writing to underlying stream, returning Ok(()) if the stream is writable
287    pub fn is_writable(&self) -> io::Result<()> {
288        is_writable(self.deref())
289    }
290
291    /// Retry the connection. Returns:
292    /// - Ok(true) if connected
293    /// - Ok(false) if connecting
294    /// - Err(_) if an error is encountered
295    pub fn try_connect(&mut self) -> io::Result<bool> {
296        try_connect(self)
297    }
298
299    /// Upgrade this plain stream to TLS, returning either a completed [`TcpStream`] or a
300    /// [`HandshakeError::WouldBlock`] mid-handshake value that can be retried.
301    ///
302    /// The TLS backend is selected by feature flag (rustls by default; native-tls or openssl
303    /// as alternatives). Use [`into_rustls`](Self::into_rustls),
304    /// [`into_native_tls`](Self::into_native_tls), or [`into_openssl`](Self::into_openssl) to
305    /// target a specific backend explicitly.
306    pub fn into_tls(
307        self,
308        domain: &str,
309        config: TLSConfig<'_, '_, '_>,
310    ) -> Result<Self, HandshakeError> {
311        into_tls_impl(self, domain, config)
312    }
313
314    #[cfg(feature = "native-tls")]
315    /// Enable TLS using native-tls
316    pub fn into_native_tls(
317        self,
318        connector: &NativeTlsConnector,
319        domain: &str,
320    ) -> Result<Self, HandshakeError> {
321        Ok(connector.connect(domain, self.into_plain()?)?.into())
322    }
323
324    #[cfg(feature = "openssl")]
325    /// Enable TLS using openssl
326    pub fn into_openssl(
327        self,
328        connector: &OpensslConnector,
329        domain: &str,
330    ) -> Result<Self, HandshakeError> {
331        Ok(connector.connect(domain, self.into_plain()?)?.into())
332    }
333
334    #[cfg(feature = "rustls")]
335    /// Enable TLS using rustls
336    pub fn into_rustls(
337        self,
338        connector: &RustlsConnector,
339        domain: &str,
340    ) -> Result<Self, HandshakeError> {
341        Ok(connector.connect(domain, self.into_plain()?)?.into())
342    }
343
344    #[allow(irrefutable_let_patterns)]
345    fn into_plain(self) -> Result<StdTcpStream, io::Error> {
346        if let Self::Plain(plain) = self {
347            Ok(plain)
348        } else {
349            Err(io::Error::new(
350                io::ErrorKind::AlreadyExists,
351                "already a TLS stream",
352            ))
353        }
354    }
355}
356
357fn connect_std<A: ToSocketAddrs>(addr: A, timeout: Option<Duration>) -> io::Result<StdTcpStream> {
358    if let Some(timeout) = timeout {
359        let addrs = addr.to_socket_addrs()?;
360        let mut err = None;
361        for addr in addrs {
362            match StdTcpStream::connect_timeout(&addr, timeout) {
363                Ok(stream) => return Ok(stream),
364                Err(error) => err = Some(error),
365            }
366        }
367        Err(err.unwrap_or_else(|| {
368            io::Error::new(io::ErrorKind::AddrNotAvailable, "couldn't resolve host")
369        }))
370    } else {
371        StdTcpStream::connect(addr)
372    }
373}
374
375fn try_connect(stream: &mut StdTcpStream) -> io::Result<bool> {
376    match is_writable(stream) {
377        Ok(()) => Ok(true),
378        Err(err)
379            if [io::ErrorKind::WouldBlock, io::ErrorKind::NotConnected].contains(&err.kind()) =>
380        {
381            Ok(false)
382        }
383        Err(err) => Err(err),
384    }
385}
386
387fn is_writable(mut stream: &StdTcpStream) -> io::Result<()> {
388    stream.write(&[]).map(|_| ())
389}
390
391fn into_tls_impl(s: TcpStream, domain: &str, config: TLSConfig<'_, '_, '_>) -> HandshakeResult {
392    cfg_if! {
393        if #[cfg(feature = "rustls")] {
394            into_rustls_impl(s, domain, config)
395        } else if #[cfg(feature = "openssl")] {
396            into_openssl_impl(s, domain, config)
397        } else if #[cfg(feature = "native-tls")] {
398            into_native_tls_impl(s, domain, config)
399        } else {
400            let _ = (domain, config);
401            Ok(TcpStream::Plain(s.into_plain()?))
402        }
403    }
404}
405
406impl TryFrom<StdTcpStream> for TcpStream {
407    type Error = io::Error;
408
409    fn try_from(s: StdTcpStream) -> io::Result<Self> {
410        s.set_nodelay(true)?;
411        let mut this = Self::Plain(s);
412        this.try_connect()?;
413        Ok(this)
414    }
415}
416
417impl Deref for TcpStream {
418    type Target = StdTcpStream;
419
420    fn deref(&self) -> &Self::Target {
421        match self {
422            Self::Plain(plain) => plain,
423            #[cfg(feature = "native-tls")]
424            Self::NativeTls(tls) => tls.get_ref(),
425            #[cfg(feature = "openssl")]
426            Self::Openssl(tls) => tls.get_ref(),
427            #[cfg(feature = "rustls")]
428            Self::Rustls(tls) => tls.get_ref(),
429        }
430    }
431}
432
433impl DerefMut for TcpStream {
434    fn deref_mut(&mut self) -> &mut Self::Target {
435        match self {
436            Self::Plain(plain) => plain,
437            #[cfg(feature = "native-tls")]
438            Self::NativeTls(tls) => tls.get_mut(),
439            #[cfg(feature = "openssl")]
440            Self::Openssl(tls) => tls.get_mut(),
441            #[cfg(feature = "rustls")]
442            Self::Rustls(tls) => tls.get_mut(),
443        }
444    }
445}
446
447impl Read for TcpStream {
448    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
449        fwd_impl!(self, read, buf)
450    }
451
452    fn read_vectored(&mut self, bufs: &mut [IoSliceMut<'_>]) -> io::Result<usize> {
453        fwd_impl!(self, read_vectored, bufs)
454    }
455
456    fn read_to_end(&mut self, buf: &mut Vec<u8>) -> io::Result<usize> {
457        fwd_impl!(self, read_to_end, buf)
458    }
459
460    fn read_to_string(&mut self, buf: &mut String) -> io::Result<usize> {
461        fwd_impl!(self, read_to_string, buf)
462    }
463
464    fn read_exact(&mut self, buf: &mut [u8]) -> io::Result<()> {
465        fwd_impl!(self, read_exact, buf)
466    }
467}
468
469impl Write for TcpStream {
470    fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
471        fwd_impl!(self, write, buf)
472    }
473
474    fn flush(&mut self) -> io::Result<()> {
475        fwd_impl!(self, flush,)
476    }
477
478    fn write_vectored(&mut self, bufs: &[IoSlice<'_>]) -> io::Result<usize> {
479        fwd_impl!(self, write_vectored, bufs)
480    }
481
482    fn write_all(&mut self, buf: &[u8]) -> io::Result<()> {
483        fwd_impl!(self, write_all, buf)
484    }
485
486    fn write_fmt(&mut self, fmt: fmt::Arguments<'_>) -> io::Result<()> {
487        fwd_impl!(self, write_fmt, fmt)
488    }
489}
490
491impl fmt::Debug for TcpStream {
492    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
493        f.debug_struct("TcpStream")
494            .field("inner", self.deref())
495            .finish()
496    }
497}
498
499/// A TLS stream which has been interrupted during the handshake
500#[derive(Debug)]
501pub enum MidHandshakeTlsStream {
502    /// Not a TLS stream
503    Plain(TcpStream),
504    #[cfg(feature = "native-tls")]
505    /// A native-tls MidHandshakeTlsStream
506    NativeTls(NativeTlsMidHandshakeTlsStream),
507    #[cfg(feature = "openssl")]
508    /// An openssl MidHandshakeTlsStream
509    Openssl(OpensslMidHandshakeTlsStream),
510    #[cfg(feature = "rustls")]
511    /// A rustls-connector MidHandshakeTlsStream
512    Rustls(RustlsMidHandshakeTlsStream),
513}
514
515impl MidHandshakeTlsStream {
516    /// Get a reference to the inner stream
517    #[must_use]
518    pub fn get_ref(&self) -> &StdTcpStream {
519        match self {
520            Self::Plain(mid) => mid,
521            #[cfg(feature = "native-tls")]
522            Self::NativeTls(mid) => mid.get_ref(),
523            #[cfg(feature = "openssl")]
524            Self::Openssl(mid) => mid.get_ref(),
525            #[cfg(feature = "rustls")]
526            Self::Rustls(mid) => mid.get_ref(),
527        }
528    }
529
530    /// Get a mutable reference to the inner stream
531    #[must_use]
532    pub fn get_mut(&mut self) -> &mut StdTcpStream {
533        match self {
534            Self::Plain(mid) => mid,
535            #[cfg(feature = "native-tls")]
536            Self::NativeTls(mid) => mid.get_mut(),
537            #[cfg(feature = "openssl")]
538            Self::Openssl(mid) => mid.get_mut(),
539            #[cfg(feature = "rustls")]
540            Self::Rustls(mid) => mid.get_mut(),
541        }
542    }
543
544    /// Retry a non-blocking TLS handshake that previously returned [`HandshakeError::WouldBlock`].
545    ///
546    /// Returns the completed [`TcpStream`] on success, or a new [`HandshakeError`] if the
547    /// handshake would block again or fails permanently.
548    pub fn handshake(mut self) -> HandshakeResult {
549        if !try_connect(self.get_mut())? {
550            return Err(HandshakeError::WouldBlock(self));
551        }
552
553        Ok(match self {
554            Self::Plain(mid) => mid,
555            #[cfg(feature = "native-tls")]
556            Self::NativeTls(mid) => mid.handshake()?.into(),
557            #[cfg(feature = "openssl")]
558            Self::Openssl(mid) => mid.handshake()?.into(),
559            #[cfg(feature = "rustls")]
560            Self::Rustls(mid) => mid.handshake()?.into(),
561        })
562    }
563}
564
565impl From<TcpStream> for MidHandshakeTlsStream {
566    fn from(mid: TcpStream) -> Self {
567        Self::Plain(mid)
568    }
569}
570
571impl fmt::Display for MidHandshakeTlsStream {
572    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
573        f.write_str("MidHandshakeTlsStream")
574    }
575}
576
577/// An error returned while performing the handshake
578#[derive(Debug)]
579pub enum HandshakeError {
580    /// We hit WouldBlock during handshake
581    WouldBlock(MidHandshakeTlsStream),
582    /// We hit a critical failure
583    Failure(io::Error),
584}
585
586impl HandshakeError {
587    /// Try and get the inner mid handshake TLS stream from this error
588    pub fn into_mid_handshake_tls_stream(self) -> io::Result<MidHandshakeTlsStream> {
589        match self {
590            Self::WouldBlock(mid) => Ok(mid),
591            Self::Failure(error) => Err(error),
592        }
593    }
594}
595
596impl fmt::Display for HandshakeError {
597    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
598        match self {
599            Self::WouldBlock(_) => f.write_str("WouldBlock hit during handshake"),
600            Self::Failure(err) => f.write_fmt(format_args!("IO error: {err}")),
601        }
602    }
603}
604
605impl Error for HandshakeError {
606    fn source(&self) -> Option<&(dyn Error + 'static)> {
607        match self {
608            Self::Failure(err) => Some(err),
609            _ => None,
610        }
611    }
612}
613
614impl From<io::Error> for HandshakeError {
615    fn from(err: io::Error) -> Self {
616        Self::Failure(err)
617    }
618}
619
620mod sys;