Skip to main content

tcp_stream/
lib.rs

1#![deny(missing_docs, missing_debug_implementations, unsafe_code)]
2#![warn(unreachable_pub, unused_qualifications, unused_lifetimes)]
3#![warn(
4    clippy::must_use_candidate,
5    clippy::unwrap_in_result,
6    clippy::panic_in_result_fn
7)]
8#![allow(clippy::large_enum_variant, clippy::result_large_err)]
9
10//! A [`TcpStream`] with pluggable TLS support.
11//!
12//! Wraps `std::net::TcpStream` in a [`TcpStream`] enum that can be upgraded
13//! to a TLS-encrypted stream via [`TcpStream::into_tls`]. Supported backends
14//! are rustls (default), native-tls, and OpenSSL, selected through feature
15//! flags. Async variants of the encrypted stream are provided via the
16//! `futures-io` traits when the `futures` feature is enabled.
17//!
18//! # Feature flags
19//!
20//! ## Async runtime (pick exactly one)
21//!
22//! | Flag | Notes |
23//! |------|-------|
24//! | `tokio` *(default)* | Requires a running Tokio runtime |
25//! | `smol` | Uses the smol executor |
26//! | `async-global-executor` | Uses async-global-executor |
27//!
28//! ## TLS backend (pick at most one; `rustls` is the default)
29//!
30//! | Flag | Notes |
31//! |------|-------|
32//! | `rustls` *(default)* | TLS via rustls |
33//! | `native-tls` | TLS via the platform's native library |
34//! | `openssl` | TLS via OpenSSL |
35//!
36//! ## Rustls certificate store (only when `rustls` is active)
37//!
38//! | Flag | Notes |
39//! |------|-------|
40//! | `rustls-platform-verifier` *(default)* | Uses the platform trust store |
41//! | `rustls-native-certs` | Loads native root certificates |
42//! | `rustls-webpki-roots-certs` | Uses the webpki bundled root set |
43//!
44//! ## Rustls crypto provider (at least one must be enabled)
45//!
46//! | Flag | Notes |
47//! |------|-------|
48//! | `rustls--aws_lc_rs` *(default)* | Uses aws-lc-rs |
49//! | `rustls--ring` | Uses ring (more portable) |
50//!
51//! ## Miscellaneous
52//!
53//! | Flag | Notes |
54//! |------|-------|
55//! | `futures` | Enable `futures-io` async trait impls on the encrypted stream |
56//! | `vendored-openssl` | Build a vendored OpenSSL (requires `openssl` feature) |
57//!
58//! # Example
59//!
60//! To connect to a remote server:
61//!
62//! ```rust
63//! use tcp_stream::{HandshakeError, TcpStream, TLSConfig};
64//!
65//! use std::io::{self, Read, Write};
66//!
67//! let mut stream = TcpStream::connect("www.rust-lang.org:443").unwrap();
68//! stream.set_nonblocking(true).unwrap();
69//!
70//! loop {
71//!     if stream.try_connect().unwrap() {
72//!         break;
73//!     }
74//! }
75//!
76//! let mut stream = stream.into_tls("www.rust-lang.org", TLSConfig::default());
77//!
78//! while let Err(HandshakeError::WouldBlock(mid_handshake)) = stream {
79//!     stream = mid_handshake.handshake();
80//! }
81//!
82//! let mut stream = stream.unwrap();
83//!
84//! while let Err(err) = stream.write_all(b"GET / HTTP/1.0\r\n\r\n") {
85//!     if err.kind() != io::ErrorKind::WouldBlock {
86//!         panic!("error: {:?}", err);
87//!     }
88//! }
89//!
90//! while let Err(err) = stream.flush() {
91//!     if err.kind() != io::ErrorKind::WouldBlock {
92//!         panic!("error: {:?}", err);
93//!     }
94//! }
95//!
96//! let mut res = vec![];
97//! while let Err(err) = stream.read_to_end(&mut res) {
98//!     if err.kind() != io::ErrorKind::WouldBlock {
99//!         panic!("stream error: {:?}", err);
100//!     }
101//! }
102//! println!("{}", String::from_utf8_lossy(&res));
103//! ```
104
105use cfg_if::cfg_if;
106use std::{
107    convert::TryFrom,
108    error::Error,
109    fmt,
110    io::{self, IoSlice, IoSliceMut, Read, Write},
111    net::{TcpStream as StdTcpStream, ToSocketAddrs},
112    ops::{Deref, DerefMut},
113    time::Duration,
114};
115
116#[cfg(feature = "rustls")]
117mod rustls_impl;
118#[cfg(feature = "rustls")]
119pub use rustls_impl::*;
120
121#[cfg(feature = "native-tls")]
122mod native_tls_impl;
123#[cfg(feature = "native-tls")]
124pub use native_tls_impl::*;
125
126#[cfg(feature = "openssl")]
127mod openssl_impl;
128#[cfg(feature = "openssl")]
129pub use openssl_impl::*;
130
131macro_rules! fwd_impl {
132    ($self:ident, $method:ident, $($args:expr),*) => {
133        match $self {
134            Self::Plain(plain) => plain.$method($($args),*),
135            #[cfg(feature = "native-tls")]
136            Self::NativeTls(tls) => tls.$method($($args),*),
137            #[cfg(feature = "openssl")]
138            Self::Openssl(tls) => tls.$method($($args),*),
139            #[cfg(feature = "rustls")]
140            Self::Rustls(tls) => tls.$method($($args),*),
141        }
142    };
143}
144
145#[cfg(feature = "futures")]
146macro_rules! fwd_pin_impl {
147    ($self:ident, $method:ident, $($args:expr),*) => {
148        match $self.get_mut() {
149            Self::Plain(plain) => Pin::new(plain).$method($($args),*),
150            #[cfg(feature = "native-tls-futures")]
151            Self::NativeTls(tls) => Pin::new(tls).$method($($args),*),
152            #[cfg(feature = "openssl-futures")]
153            Self::Openssl(tls) => Pin::new(tls).$method($($args),*),
154            #[cfg(feature = "rustls-futures")]
155            Self::Rustls(tls) => Pin::new(tls).$method($($args),*),
156        }
157    };
158}
159
160#[cfg(feature = "futures")]
161mod futures;
162#[cfg(feature = "futures")]
163pub use futures::*;
164
165/// Wrapper around plain or TLS TCP streams
166#[non_exhaustive]
167pub enum TcpStream {
168    /// Wrapper around std::net::TcpStream
169    Plain(StdTcpStream),
170    #[cfg(feature = "native-tls")]
171    /// Wrapper around a TLS stream handled by native-tls
172    NativeTls(NativeTlsStream),
173    #[cfg(feature = "openssl")]
174    /// Wrapper around a TLS stream handled by openssl
175    Openssl(OpensslStream),
176    #[cfg(feature = "rustls")]
177    /// Wrapper around a TLS stream handled by rustls
178    Rustls(RustlsStream),
179}
180
181/// Holds extra TLS configuration
182#[derive(Default, Debug, PartialEq)]
183pub struct TLSConfig<'data, 'key, 'chain> {
184    /// Use for client certificate authentication
185    pub identity: Option<Identity<'data, 'key>>,
186    /// The custom certificates chain in PEM format
187    pub cert_chain: Option<&'chain str>,
188}
189
190/// Holds extra TLS configuration
191#[derive(Clone, Default, Debug, PartialEq)]
192pub struct OwnedTLSConfig {
193    /// Use for client certificate authentication
194    pub identity: Option<OwnedIdentity>,
195    /// The custom certificates chain in PEM format
196    pub cert_chain: Option<String>,
197}
198
199impl OwnedTLSConfig {
200    /// Get the ephemeral `TLSConfig` corresponding to the `OwnedTLSConfig`
201    #[must_use]
202    pub fn as_ref(&self) -> TLSConfig<'_, '_, '_> {
203        TLSConfig {
204            identity: self.identity.as_ref().map(OwnedIdentity::as_ref),
205            cert_chain: self.cert_chain.as_deref(),
206        }
207    }
208}
209
210/// Holds one of:
211/// - PKCS#12 DER-encoded identity and decryption password
212/// - PKCS#8 PEM-encoded certificate and key (without decryption password)
213#[derive(Debug, PartialEq)]
214pub enum Identity<'data, 'key> {
215    /// PKCS#12 DER-encoded identity with decryption password
216    PKCS12 {
217        /// PKCS#12 DER-encoded identity
218        der: &'data [u8],
219        /// Decryption password
220        password: &'key str,
221    },
222    /// PEM encoded DER private key with PEM encoded certificate
223    PKCS8 {
224        /// PEM-encoded certificate
225        pem: &'data [u8],
226        /// PEM-encoded key
227        key: &'key [u8],
228    },
229}
230
231/// Holds one of:
232/// - PKCS#12 DER-encoded identity and decryption password
233/// - PKCS#8 PEM-encoded certificate and key (without decryption password)
234#[derive(Clone, Debug, PartialEq)]
235pub enum OwnedIdentity {
236    /// PKCS#12 DER-encoded identity with decryption password
237    PKCS12 {
238        /// PKCS#12 DER-encoded identity
239        der: Vec<u8>,
240        /// Decryption password
241        password: String,
242    },
243    /// PKCS#8 encoded DER private key with PEM encoded certificate
244    PKCS8 {
245        /// PEM-encoded certificate
246        pem: Vec<u8>,
247        /// PEM-encoded key
248        key: Vec<u8>,
249    },
250}
251
252impl OwnedIdentity {
253    /// Get the ephemeral `Identity` corresponding to the `OwnedIdentity`
254    #[must_use]
255    pub fn as_ref(&self) -> Identity<'_, '_> {
256        match self {
257            Self::PKCS8 { pem, key } => Identity::PKCS8 { pem, key },
258            Self::PKCS12 { der, password } => Identity::PKCS12 { der, password },
259        }
260    }
261}
262
263/// Holds either the TLS `TcpStream` result or the current handshake state
264pub type HandshakeResult = Result<TcpStream, HandshakeError>;
265
266impl TcpStream {
267    /// Wrapper around `std::net::TcpStream::connect`
268    pub fn connect<A: ToSocketAddrs>(addr: A) -> io::Result<Self> {
269        connect_std(addr, None).and_then(Self::try_from)
270    }
271
272    /// Wrapper around `std::net::TcpStream::connect_timeout`
273    pub fn connect_timeout<A: ToSocketAddrs>(addr: A, timeout: Duration) -> io::Result<Self> {
274        connect_std(addr, Some(timeout)).and_then(Self::try_from)
275    }
276
277    /// Convert from a `std::net::TcpStream`
278    pub fn from_std(stream: StdTcpStream) -> io::Result<Self> {
279        Self::try_from(stream)
280    }
281
282    /// Attempt reading from underlying stream, returning Ok(()) if the stream is readable
283    pub fn is_readable(&self) -> io::Result<()> {
284        self.deref().read(&mut []).map(|_| ())
285    }
286
287    /// Attempt writing to underlying stream, returning Ok(()) if the stream is writable
288    pub fn is_writable(&self) -> io::Result<()> {
289        is_writable(self.deref())
290    }
291
292    /// Retry the connection. Returns:
293    /// - Ok(true) if connected
294    /// - Ok(false) if connecting
295    /// - Err(_) if an error is encountered
296    pub fn try_connect(&mut self) -> io::Result<bool> {
297        try_connect(self)
298    }
299
300    /// Upgrade this plain stream to TLS, returning either a completed [`TcpStream`] or a
301    /// [`HandshakeError::WouldBlock`] mid-handshake value that can be retried.
302    ///
303    /// The TLS backend is selected by feature flag (rustls by default; native-tls or openssl
304    /// as alternatives). Use [`into_rustls`](Self::into_rustls),
305    /// `into_native_tls`, or `into_openssl` to
306    /// target a specific backend explicitly.
307    pub fn into_tls(
308        self,
309        domain: &str,
310        config: TLSConfig<'_, '_, '_>,
311    ) -> Result<Self, HandshakeError> {
312        into_tls_impl(self, domain, config)
313    }
314
315    #[cfg(feature = "native-tls")]
316    /// Enable TLS using native-tls
317    pub fn into_native_tls(
318        self,
319        connector: &NativeTlsConnector,
320        domain: &str,
321    ) -> Result<Self, HandshakeError> {
322        Ok(connector.connect(domain, self.into_plain()?)?.into())
323    }
324
325    #[cfg(feature = "openssl")]
326    /// Enable TLS using openssl
327    pub fn into_openssl(
328        self,
329        connector: &OpensslConnector,
330        domain: &str,
331    ) -> Result<Self, HandshakeError> {
332        Ok(connector.connect(domain, self.into_plain()?)?.into())
333    }
334
335    #[cfg(feature = "rustls")]
336    /// Enable TLS using rustls
337    pub fn into_rustls(
338        self,
339        connector: &RustlsConnector,
340        domain: &str,
341    ) -> Result<Self, HandshakeError> {
342        Ok(connector.connect(domain, self.into_plain()?)?.into())
343    }
344
345    #[allow(irrefutable_let_patterns)]
346    fn into_plain(self) -> Result<StdTcpStream, io::Error> {
347        if let Self::Plain(plain) = self {
348            Ok(plain)
349        } else {
350            Err(io::Error::new(
351                io::ErrorKind::AlreadyExists,
352                "already a TLS stream",
353            ))
354        }
355    }
356}
357
358fn connect_std<A: ToSocketAddrs>(addr: A, timeout: Option<Duration>) -> io::Result<StdTcpStream> {
359    if let Some(timeout) = timeout {
360        let addrs = addr.to_socket_addrs()?;
361        let mut err = None;
362        for addr in addrs {
363            match StdTcpStream::connect_timeout(&addr, timeout) {
364                Ok(stream) => return Ok(stream),
365                Err(error) => err = Some(error),
366            }
367        }
368        Err(err.unwrap_or_else(|| {
369            io::Error::new(io::ErrorKind::AddrNotAvailable, "couldn't resolve host")
370        }))
371    } else {
372        StdTcpStream::connect(addr)
373    }
374}
375
376fn try_connect(stream: &mut StdTcpStream) -> io::Result<bool> {
377    match is_writable(stream) {
378        Ok(()) => Ok(true),
379        Err(err)
380            if [io::ErrorKind::WouldBlock, io::ErrorKind::NotConnected].contains(&err.kind()) =>
381        {
382            Ok(false)
383        }
384        Err(err) => Err(err),
385    }
386}
387
388fn is_writable(mut stream: &StdTcpStream) -> io::Result<()> {
389    stream.write(&[]).map(|_| ())
390}
391
392fn into_tls_impl(s: TcpStream, domain: &str, config: TLSConfig<'_, '_, '_>) -> HandshakeResult {
393    cfg_if! {
394        if #[cfg(feature = "rustls")] {
395            into_rustls_impl(s, domain, config)
396        } else if #[cfg(feature = "openssl")] {
397            into_openssl_impl(s, domain, config)
398        } else if #[cfg(feature = "native-tls")] {
399            into_native_tls_impl(s, domain, config)
400        } else {
401            let _ = (domain, config);
402            Ok(TcpStream::Plain(s.into_plain()?))
403        }
404    }
405}
406
407impl TryFrom<StdTcpStream> for TcpStream {
408    type Error = io::Error;
409
410    fn try_from(s: StdTcpStream) -> io::Result<Self> {
411        s.set_nodelay(true)?;
412        let mut this = Self::Plain(s);
413        this.try_connect()?;
414        Ok(this)
415    }
416}
417
418impl Deref for TcpStream {
419    type Target = StdTcpStream;
420
421    fn deref(&self) -> &Self::Target {
422        match self {
423            Self::Plain(plain) => plain,
424            #[cfg(feature = "native-tls")]
425            Self::NativeTls(tls) => tls.get_ref(),
426            #[cfg(feature = "openssl")]
427            Self::Openssl(tls) => tls.get_ref(),
428            #[cfg(feature = "rustls")]
429            Self::Rustls(tls) => tls.get_ref(),
430        }
431    }
432}
433
434impl DerefMut for TcpStream {
435    fn deref_mut(&mut self) -> &mut Self::Target {
436        match self {
437            Self::Plain(plain) => plain,
438            #[cfg(feature = "native-tls")]
439            Self::NativeTls(tls) => tls.get_mut(),
440            #[cfg(feature = "openssl")]
441            Self::Openssl(tls) => tls.get_mut(),
442            #[cfg(feature = "rustls")]
443            Self::Rustls(tls) => tls.get_mut(),
444        }
445    }
446}
447
448impl Read for TcpStream {
449    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
450        fwd_impl!(self, read, buf)
451    }
452
453    fn read_vectored(&mut self, bufs: &mut [IoSliceMut<'_>]) -> io::Result<usize> {
454        fwd_impl!(self, read_vectored, bufs)
455    }
456
457    fn read_to_end(&mut self, buf: &mut Vec<u8>) -> io::Result<usize> {
458        fwd_impl!(self, read_to_end, buf)
459    }
460
461    fn read_to_string(&mut self, buf: &mut String) -> io::Result<usize> {
462        fwd_impl!(self, read_to_string, buf)
463    }
464
465    fn read_exact(&mut self, buf: &mut [u8]) -> io::Result<()> {
466        fwd_impl!(self, read_exact, buf)
467    }
468}
469
470impl Write for TcpStream {
471    fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
472        fwd_impl!(self, write, buf)
473    }
474
475    fn flush(&mut self) -> io::Result<()> {
476        fwd_impl!(self, flush,)
477    }
478
479    fn write_vectored(&mut self, bufs: &[IoSlice<'_>]) -> io::Result<usize> {
480        fwd_impl!(self, write_vectored, bufs)
481    }
482
483    fn write_all(&mut self, buf: &[u8]) -> io::Result<()> {
484        fwd_impl!(self, write_all, buf)
485    }
486
487    fn write_fmt(&mut self, fmt: fmt::Arguments<'_>) -> io::Result<()> {
488        fwd_impl!(self, write_fmt, fmt)
489    }
490}
491
492impl fmt::Debug for TcpStream {
493    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
494        f.debug_struct("TcpStream")
495            .field("inner", self.deref())
496            .finish()
497    }
498}
499
500/// A TLS stream which has been interrupted during the handshake
501#[derive(Debug)]
502pub enum MidHandshakeTlsStream {
503    /// Not a TLS stream
504    Plain(TcpStream),
505    #[cfg(feature = "native-tls")]
506    /// A native-tls MidHandshakeTlsStream
507    NativeTls(NativeTlsMidHandshakeTlsStream),
508    #[cfg(feature = "openssl")]
509    /// An openssl MidHandshakeTlsStream
510    Openssl(OpensslMidHandshakeTlsStream),
511    #[cfg(feature = "rustls")]
512    /// A rustls-connector MidHandshakeTlsStream
513    Rustls(RustlsMidHandshakeTlsStream),
514}
515
516impl MidHandshakeTlsStream {
517    /// Get a reference to the inner stream
518    #[must_use]
519    pub fn get_ref(&self) -> &StdTcpStream {
520        match self {
521            Self::Plain(mid) => mid,
522            #[cfg(feature = "native-tls")]
523            Self::NativeTls(mid) => mid.get_ref(),
524            #[cfg(feature = "openssl")]
525            Self::Openssl(mid) => mid.get_ref(),
526            #[cfg(feature = "rustls")]
527            Self::Rustls(mid) => mid.get_ref(),
528        }
529    }
530
531    /// Get a mutable reference to the inner stream
532    #[must_use]
533    pub fn get_mut(&mut self) -> &mut StdTcpStream {
534        match self {
535            Self::Plain(mid) => mid,
536            #[cfg(feature = "native-tls")]
537            Self::NativeTls(mid) => mid.get_mut(),
538            #[cfg(feature = "openssl")]
539            Self::Openssl(mid) => mid.get_mut(),
540            #[cfg(feature = "rustls")]
541            Self::Rustls(mid) => mid.get_mut(),
542        }
543    }
544
545    /// Retry a non-blocking TLS handshake that previously returned [`HandshakeError::WouldBlock`].
546    ///
547    /// Returns the completed [`TcpStream`] on success, or a new [`HandshakeError`] if the
548    /// handshake would block again or fails permanently.
549    pub fn handshake(mut self) -> HandshakeResult {
550        if !try_connect(self.get_mut())? {
551            return Err(HandshakeError::WouldBlock(self));
552        }
553
554        Ok(match self {
555            Self::Plain(mid) => mid,
556            #[cfg(feature = "native-tls")]
557            Self::NativeTls(mid) => mid.handshake()?.into(),
558            #[cfg(feature = "openssl")]
559            Self::Openssl(mid) => mid.handshake()?.into(),
560            #[cfg(feature = "rustls")]
561            Self::Rustls(mid) => mid.handshake()?.into(),
562        })
563    }
564}
565
566impl From<TcpStream> for MidHandshakeTlsStream {
567    fn from(mid: TcpStream) -> Self {
568        Self::Plain(mid)
569    }
570}
571
572impl fmt::Display for MidHandshakeTlsStream {
573    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
574        f.write_str("MidHandshakeTlsStream")
575    }
576}
577
578/// An error returned while performing the handshake
579#[derive(Debug)]
580pub enum HandshakeError {
581    /// We hit WouldBlock during handshake
582    WouldBlock(MidHandshakeTlsStream),
583    /// We hit a critical failure
584    Failure(io::Error),
585}
586
587impl HandshakeError {
588    /// Try and get the inner mid handshake TLS stream from this error
589    pub fn into_mid_handshake_tls_stream(self) -> io::Result<MidHandshakeTlsStream> {
590        match self {
591            Self::WouldBlock(mid) => Ok(mid),
592            Self::Failure(error) => Err(error),
593        }
594    }
595}
596
597impl fmt::Display for HandshakeError {
598    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
599        match self {
600            Self::WouldBlock(_) => f.write_str("WouldBlock hit during handshake"),
601            Self::Failure(err) => f.write_fmt(format_args!("IO error: {err}")),
602        }
603    }
604}
605
606impl Error for HandshakeError {
607    fn source(&self) -> Option<&(dyn Error + 'static)> {
608        match self {
609            Self::Failure(err) => Some(err),
610            _ => None,
611        }
612    }
613}
614
615impl From<io::Error> for HandshakeError {
616    fn from(err: io::Error) -> Self {
617        Self::Failure(err)
618    }
619}
620
621mod sys;