Skip to main content

tcp_stream/
lib.rs

1#![deny(missing_docs)]
2#![allow(clippy::large_enum_variant, clippy::result_large_err)]
3
4//! # std::net::TCP stream on steroids
5//!
6//! tcp-stream is a library aiming at providing TLS support to std::net::TcpStream
7//!
8//! # Examples
9//!
10//! To connect to a remote server:
11//!
12//! ```rust
13//! use tcp_stream::{HandshakeError, TcpStream, TLSConfig};
14//!
15//! use std::io::{self, Read, Write};
16//!
17//! let mut stream = TcpStream::connect("www.rust-lang.org:443").unwrap();
18//! stream.set_nonblocking(true).unwrap();
19//!
20//! loop {
21//!     if stream.try_connect().unwrap() {
22//!         break;
23//!     }
24//! }
25//!
26//! let mut stream = stream.into_tls("www.rust-lang.org", TLSConfig::default());
27//!
28//! while let Err(HandshakeError::WouldBlock(mid_handshake)) = stream {
29//!     stream = mid_handshake.handshake();
30//! }
31//!
32//! let mut stream = stream.unwrap();
33//!
34//! while let Err(err) = stream.write_all(b"GET / HTTP/1.0\r\n\r\n") {
35//!     if err.kind() != io::ErrorKind::WouldBlock {
36//!         panic!("error: {:?}", err);
37//!     }
38//! }
39//!
40//! while let Err(err) = stream.flush() {
41//!     if err.kind() != io::ErrorKind::WouldBlock {
42//!         panic!("error: {:?}", err);
43//!     }
44//! }
45//!
46//! let mut res = vec![];
47//! while let Err(err) = stream.read_to_end(&mut res) {
48//!     if err.kind() != io::ErrorKind::WouldBlock {
49//!         panic!("stream error: {:?}", err);
50//!     }
51//! }
52//! println!("{}", String::from_utf8_lossy(&res));
53//! ```
54
55use cfg_if::cfg_if;
56use std::{
57    convert::TryFrom,
58    error::Error,
59    fmt,
60    io::{self, IoSlice, IoSliceMut, Read, Write},
61    net::{TcpStream as StdTcpStream, ToSocketAddrs},
62    ops::{Deref, DerefMut},
63    time::Duration,
64};
65
66#[cfg(feature = "rustls")]
67mod rustls_impl;
68#[cfg(feature = "rustls")]
69pub use rustls_impl::*;
70
71#[cfg(feature = "native-tls")]
72mod native_tls_impl;
73#[cfg(feature = "native-tls")]
74pub use native_tls_impl::*;
75
76#[cfg(feature = "openssl")]
77mod openssl_impl;
78#[cfg(feature = "openssl")]
79pub use openssl_impl::*;
80
81#[cfg(feature = "futures")]
82mod futures;
83#[cfg(feature = "futures")]
84pub use futures::*;
85
86/// Wrapper around plain or TLS TCP streams
87#[non_exhaustive]
88pub enum TcpStream {
89    /// Wrapper around std::net::TcpStream
90    Plain(StdTcpStream),
91    #[cfg(feature = "native-tls")]
92    /// Wrapper around a TLS stream hanled by native-tls
93    NativeTls(NativeTlsStream),
94    #[cfg(feature = "openssl")]
95    /// Wrapper around a TLS stream hanled by openssl
96    Openssl(OpensslStream),
97    #[cfg(feature = "rustls")]
98    /// Wrapper around a TLS stream hanled by rustls
99    Rustls(RustlsStream),
100}
101
102/// Holds extra TLS configuration
103#[derive(Default, Debug, PartialEq)]
104pub struct TLSConfig<'data, 'key, 'chain> {
105    /// Use for client certificate authentication
106    pub identity: Option<Identity<'data, 'key>>,
107    /// The custom certificates chain in PEM format
108    pub cert_chain: Option<&'chain str>,
109}
110
111/// Holds extra TLS configuration
112#[derive(Clone, Default, Debug, PartialEq)]
113pub struct OwnedTLSConfig {
114    /// Use for client certificate authentication
115    pub identity: Option<OwnedIdentity>,
116    /// The custom certificates chain in PEM format
117    pub cert_chain: Option<String>,
118}
119
120impl OwnedTLSConfig {
121    /// Get the ephemeral `TLSConfig` corresponding to the `OwnedTLSConfig`
122    #[must_use]
123    pub fn as_ref(&self) -> TLSConfig<'_, '_, '_> {
124        TLSConfig {
125            identity: self.identity.as_ref().map(OwnedIdentity::as_ref),
126            cert_chain: self.cert_chain.as_deref(),
127        }
128    }
129}
130
131/// Holds one of:
132/// - PKCS#12 DER-encoded identity and decryption password
133/// - PKCS#8 PEM-encoded certificate and key (without decryption password)
134#[derive(Debug, PartialEq)]
135pub enum Identity<'data, 'key> {
136    /// PKCS#12 DER-encoded identity with decryption password
137    PKCS12 {
138        /// PKCS#12 DER-encoded identity
139        der: &'data [u8],
140        /// Decryption password
141        password: &'key str,
142    },
143    /// PEM encoded DER private key with PEM encoded certificate
144    PKCS8 {
145        /// PEM-encoded certificate
146        pem: &'data [u8],
147        /// PEM-encoded key
148        key: &'key [u8],
149    },
150}
151
152/// Holds one of:
153/// - PKCS#12 DER-encoded identity and decryption password
154/// - PKCS#8 PEM-encoded certificate and key (without decryption password)
155#[derive(Clone, Debug, PartialEq)]
156pub enum OwnedIdentity {
157    /// PKCS#12 DER-encoded identity with decryption password
158    PKCS12 {
159        /// PKCS#12 DER-encoded identity
160        der: Vec<u8>,
161        /// Decryption password
162        password: String,
163    },
164    /// PKCS#8 encoded DER private key with PEM encoded certificate
165    PKCS8 {
166        /// PEM-encoded certificate
167        pem: Vec<u8>,
168        /// PEM-encoded key
169        key: Vec<u8>,
170    },
171}
172
173impl OwnedIdentity {
174    /// Get the ephemeral `Identity` corresponding to the `OwnedIdentity`
175    #[must_use]
176    pub fn as_ref(&self) -> Identity<'_, '_> {
177        match self {
178            Self::PKCS8 { pem, key } => Identity::PKCS8 { pem, key },
179            Self::PKCS12 { der, password } => Identity::PKCS12 { der, password },
180        }
181    }
182}
183
184/// Holds either the TLS `TcpStream` result or the current handshake state
185pub type HandshakeResult = Result<TcpStream, HandshakeError>;
186
187impl TcpStream {
188    /// Wrapper around `std::net::TcpStream::connect`
189    pub fn connect<A: ToSocketAddrs>(addr: A) -> io::Result<Self> {
190        connect_std(addr, None).and_then(Self::try_from)
191    }
192
193    /// Wrapper around `std::net::TcpStream::connect_timeout`
194    pub fn connect_timeout<A: ToSocketAddrs>(addr: A, timeout: Duration) -> io::Result<Self> {
195        connect_std(addr, Some(timeout)).and_then(Self::try_from)
196    }
197
198    /// Convert from a `std::net::TcpStream`
199    pub fn from_std(stream: StdTcpStream) -> io::Result<Self> {
200        Self::try_from(stream)
201    }
202
203    /// Attempt reading from underlying stream, returning Ok(()) if the stream is readable
204    pub fn is_readable(&self) -> io::Result<()> {
205        self.deref().read(&mut []).map(|_| ())
206    }
207
208    /// Attempt writing to underlying stream, returning Ok(()) if the stream is writable
209    pub fn is_writable(&mut self) -> io::Result<()> {
210        is_writable(self.deref_mut())
211    }
212
213    /// Retry the connection. Returns:
214    /// - Ok(true) if connected
215    /// - Ok(false) if connecting
216    /// - Err(_) if an error is encountered
217    pub fn try_connect(&mut self) -> io::Result<bool> {
218        try_connect(self)
219    }
220
221    /// Enable TLS
222    pub fn into_tls(
223        self,
224        domain: &str,
225        config: TLSConfig<'_, '_, '_>,
226    ) -> Result<Self, HandshakeError> {
227        into_tls_impl(self, domain, config)
228    }
229
230    #[cfg(feature = "native-tls")]
231    /// Enable TLS using native-tls
232    pub fn into_native_tls(
233        self,
234        connector: &NativeTlsConnector,
235        domain: &str,
236    ) -> Result<Self, HandshakeError> {
237        Ok(connector.connect(domain, self.into_plain()?)?.into())
238    }
239
240    #[cfg(feature = "openssl")]
241    /// Enable TLS using openssl
242    pub fn into_openssl(
243        self,
244        connector: &OpensslConnector,
245        domain: &str,
246    ) -> Result<Self, HandshakeError> {
247        Ok(connector.connect(domain, self.into_plain()?)?.into())
248    }
249
250    #[cfg(feature = "rustls")]
251    /// Enable TLS using rustls
252    pub fn into_rustls(
253        self,
254        connector: &RustlsConnector,
255        domain: &str,
256    ) -> Result<Self, HandshakeError> {
257        Ok(connector.connect(domain, self.into_plain()?)?.into())
258    }
259
260    #[allow(irrefutable_let_patterns)]
261    fn into_plain(self) -> Result<StdTcpStream, io::Error> {
262        if let Self::Plain(plain) = self {
263            Ok(plain)
264        } else {
265            Err(io::Error::new(
266                io::ErrorKind::AlreadyExists,
267                "already a TLS stream",
268            ))
269        }
270    }
271}
272
273fn connect_std<A: ToSocketAddrs>(addr: A, timeout: Option<Duration>) -> io::Result<StdTcpStream> {
274    let stream = connect_std_raw(addr, timeout)?;
275    stream.set_nodelay(true)?;
276    Ok(stream)
277}
278
279fn connect_std_raw<A: ToSocketAddrs>(
280    addr: A,
281    timeout: Option<Duration>,
282) -> io::Result<StdTcpStream> {
283    if let Some(timeout) = timeout {
284        let addrs = addr.to_socket_addrs()?;
285        let mut err = None;
286        for addr in addrs {
287            match StdTcpStream::connect_timeout(&addr, timeout) {
288                Ok(stream) => return Ok(stream),
289                Err(error) => err = Some(error),
290            }
291        }
292        Err(err.unwrap_or_else(|| {
293            io::Error::new(io::ErrorKind::AddrNotAvailable, "couldn't resolve host")
294        }))
295    } else {
296        StdTcpStream::connect(addr)
297    }
298}
299
300fn try_connect(stream: &mut StdTcpStream) -> io::Result<bool> {
301    match is_writable(stream) {
302        Ok(()) => Ok(true),
303        Err(err)
304            if [io::ErrorKind::WouldBlock, io::ErrorKind::NotConnected].contains(&err.kind()) =>
305        {
306            Ok(false)
307        }
308        Err(err) => Err(err),
309    }
310}
311
312fn is_writable(stream: &mut StdTcpStream) -> io::Result<()> {
313    stream.write(&[]).map(|_| ())
314}
315
316fn into_tls_impl(s: TcpStream, domain: &str, config: TLSConfig<'_, '_, '_>) -> HandshakeResult {
317    cfg_if! {
318        if #[cfg(feature = "rustls-platform-verifier")] {
319            into_rustls_impl(s, RustlsConnectorConfig::new_with_platform_verifier(), domain, config)
320        } else if #[cfg(feature = "rustls-native-certs")] {
321            into_rustls_impl(s, RustlsConnectorConfig::new_with_native_certs()?, domain, config)
322        } else if #[cfg(feature = "rustls-webpki-roots-certs")] {
323            into_rustls_impl(s, RustlsConnectorConfig::new_with_webpki_roots_certs(), domain, config)
324        } else if #[cfg(feature = "rustls")] {
325            into_rustls_impl(s, RustlsConnectorConfig::default(), domain, config)
326        } else if #[cfg(feature = "openssl")] {
327            into_openssl_impl(s, domain, config)
328        } else if #[cfg(feature = "native-tls")] {
329            into_native_tls_impl(s, domain, config)
330        } else {
331            let _ = (domain, config);
332            Ok(TcpStream::Plain(s.into_plain()?))
333        }
334    }
335}
336
337impl TryFrom<StdTcpStream> for TcpStream {
338    type Error = io::Error;
339
340    fn try_from(s: StdTcpStream) -> io::Result<Self> {
341        let mut this = Self::Plain(s);
342        this.try_connect()?;
343        Ok(this)
344    }
345}
346
347impl Deref for TcpStream {
348    type Target = StdTcpStream;
349
350    fn deref(&self) -> &Self::Target {
351        match self {
352            Self::Plain(plain) => plain,
353            #[cfg(feature = "native-tls")]
354            Self::NativeTls(tls) => tls.get_ref(),
355            #[cfg(feature = "openssl")]
356            Self::Openssl(tls) => tls.get_ref(),
357            #[cfg(feature = "rustls")]
358            Self::Rustls(tls) => tls.get_ref(),
359        }
360    }
361}
362
363impl DerefMut for TcpStream {
364    fn deref_mut(&mut self) -> &mut Self::Target {
365        match self {
366            Self::Plain(plain) => plain,
367            #[cfg(feature = "native-tls")]
368            Self::NativeTls(tls) => tls.get_mut(),
369            #[cfg(feature = "openssl")]
370            Self::Openssl(tls) => tls.get_mut(),
371            #[cfg(feature = "rustls")]
372            Self::Rustls(tls) => tls.get_mut(),
373        }
374    }
375}
376
377macro_rules! fwd_impl {
378    ($self:ident, $method:ident, $($args:expr),*) => {
379        match $self {
380            Self::Plain(plain) => plain.$method($($args),*),
381            #[cfg(feature = "native-tls")]
382            Self::NativeTls(tls) => tls.$method($($args),*),
383            #[cfg(feature = "openssl")]
384            Self::Openssl(tls) => tls.$method($($args),*),
385            #[cfg(feature = "rustls")]
386            Self::Rustls(tls) => tls.$method($($args),*),
387        }
388    };
389}
390
391impl Read for TcpStream {
392    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
393        fwd_impl!(self, read, buf)
394    }
395
396    fn read_vectored(&mut self, bufs: &mut [IoSliceMut<'_>]) -> io::Result<usize> {
397        fwd_impl!(self, read_vectored, bufs)
398    }
399
400    fn read_to_end(&mut self, buf: &mut Vec<u8>) -> io::Result<usize> {
401        fwd_impl!(self, read_to_end, buf)
402    }
403
404    fn read_to_string(&mut self, buf: &mut String) -> io::Result<usize> {
405        fwd_impl!(self, read_to_string, buf)
406    }
407
408    fn read_exact(&mut self, buf: &mut [u8]) -> io::Result<()> {
409        fwd_impl!(self, read_exact, buf)
410    }
411}
412
413impl Write for TcpStream {
414    fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
415        fwd_impl!(self, write, buf)
416    }
417
418    fn flush(&mut self) -> io::Result<()> {
419        fwd_impl!(self, flush,)
420    }
421
422    fn write_vectored(&mut self, bufs: &[IoSlice<'_>]) -> io::Result<usize> {
423        fwd_impl!(self, write_vectored, bufs)
424    }
425
426    fn write_all(&mut self, buf: &[u8]) -> io::Result<()> {
427        fwd_impl!(self, write_all, buf)
428    }
429
430    fn write_fmt(&mut self, fmt: fmt::Arguments<'_>) -> io::Result<()> {
431        fwd_impl!(self, write_fmt, fmt)
432    }
433}
434
435impl fmt::Debug for TcpStream {
436    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
437        f.debug_struct("TcpStream")
438            .field("inner", self.deref())
439            .finish()
440    }
441}
442
443/// A TLS stream which has been interrupted during the handshake
444#[derive(Debug)]
445pub enum MidHandshakeTlsStream {
446    /// Not a TLS stream
447    Plain(TcpStream),
448    #[cfg(feature = "native-tls")]
449    /// A native-tls MidHandshakeTlsStream
450    NativeTls(NativeTlsMidHandshakeTlsStream),
451    #[cfg(feature = "openssl")]
452    /// An openssl MidHandshakeTlsStream
453    Openssl(OpensslMidHandshakeTlsStream),
454    #[cfg(feature = "rustls")]
455    /// A rustls-connector MidHandshakeTlsStream
456    Rustls(RustlsMidHandshakeTlsStream),
457}
458
459impl MidHandshakeTlsStream {
460    /// Get a reference to the inner stream
461    #[must_use]
462    pub fn get_ref(&self) -> &StdTcpStream {
463        match self {
464            Self::Plain(mid) => mid,
465            #[cfg(feature = "native-tls")]
466            Self::NativeTls(mid) => mid.get_ref(),
467            #[cfg(feature = "openssl")]
468            Self::Openssl(mid) => mid.get_ref(),
469            #[cfg(feature = "rustls")]
470            Self::Rustls(mid) => mid.get_ref(),
471        }
472    }
473
474    /// Get a mutable reference to the inner stream
475    #[must_use]
476    pub fn get_mut(&mut self) -> &mut StdTcpStream {
477        match self {
478            Self::Plain(mid) => mid,
479            #[cfg(feature = "native-tls")]
480            Self::NativeTls(mid) => mid.get_mut(),
481            #[cfg(feature = "openssl")]
482            Self::Openssl(mid) => mid.get_mut(),
483            #[cfg(feature = "rustls")]
484            Self::Rustls(mid) => mid.get_mut(),
485        }
486    }
487
488    /// Retry the handshake
489    pub fn handshake(mut self) -> HandshakeResult {
490        if !try_connect(self.get_mut())? {
491            return Err(HandshakeError::WouldBlock(self));
492        }
493
494        Ok(match self {
495            Self::Plain(mid) => mid,
496            #[cfg(feature = "native-tls")]
497            Self::NativeTls(mid) => mid.handshake()?.into(),
498            #[cfg(feature = "openssl")]
499            Self::Openssl(mid) => mid.handshake()?.into(),
500            #[cfg(feature = "rustls")]
501            Self::Rustls(mid) => mid.handshake()?.into(),
502        })
503    }
504}
505
506impl From<TcpStream> for MidHandshakeTlsStream {
507    fn from(mid: TcpStream) -> Self {
508        Self::Plain(mid)
509    }
510}
511
512impl fmt::Display for MidHandshakeTlsStream {
513    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
514        f.write_str("MidHandshakeTlsStream")
515    }
516}
517
518/// An error returned while performing the handshake
519#[derive(Debug)]
520pub enum HandshakeError {
521    /// We hit WouldBlock during handshake
522    WouldBlock(MidHandshakeTlsStream),
523    /// We hit a critical failure
524    Failure(io::Error),
525}
526
527impl HandshakeError {
528    /// Try and get the inner mid handshake TLS stream from this error
529    pub fn into_mid_handshake_tls_stream(self) -> io::Result<MidHandshakeTlsStream> {
530        match self {
531            Self::WouldBlock(mid) => Ok(mid),
532            Self::Failure(error) => Err(error),
533        }
534    }
535}
536
537impl fmt::Display for HandshakeError {
538    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
539        match self {
540            Self::WouldBlock(_) => f.write_str("WouldBlock hit during handshake"),
541            Self::Failure(err) => f.write_fmt(format_args!("IO error: {err}")),
542        }
543    }
544}
545
546impl Error for HandshakeError {
547    fn source(&self) -> Option<&(dyn Error + 'static)> {
548        match self {
549            Self::Failure(err) => Some(err),
550            _ => None,
551        }
552    }
553}
554
555impl From<io::Error> for HandshakeError {
556    fn from(err: io::Error) -> Self {
557        Self::Failure(err)
558    }
559}
560
561mod sys;