Skip to main content

tcp_stream/
lib.rs

1#![deny(missing_docs, missing_debug_implementations, unsafe_code)]
2#![warn(unreachable_pub, unused_qualifications, unused_lifetimes)]
3#![warn(
4    clippy::must_use_candidate,
5    clippy::unwrap_in_result,
6    clippy::panic_in_result_fn
7)]
8#![allow(clippy::large_enum_variant, clippy::result_large_err)]
9
10//! # std::net::TCP stream on steroids
11//!
12//! tcp-stream is a library aiming at providing TLS support to std::net::TcpStream
13//!
14//! # Examples
15//!
16//! To connect to a remote server:
17//!
18//! ```rust
19//! use tcp_stream::{HandshakeError, TcpStream, TLSConfig};
20//!
21//! use std::io::{self, Read, Write};
22//!
23//! let mut stream = TcpStream::connect("www.rust-lang.org:443").unwrap();
24//! stream.set_nonblocking(true).unwrap();
25//!
26//! loop {
27//!     if stream.try_connect().unwrap() {
28//!         break;
29//!     }
30//! }
31//!
32//! let mut stream = stream.into_tls("www.rust-lang.org", TLSConfig::default());
33//!
34//! while let Err(HandshakeError::WouldBlock(mid_handshake)) = stream {
35//!     stream = mid_handshake.handshake();
36//! }
37//!
38//! let mut stream = stream.unwrap();
39//!
40//! while let Err(err) = stream.write_all(b"GET / HTTP/1.0\r\n\r\n") {
41//!     if err.kind() != io::ErrorKind::WouldBlock {
42//!         panic!("error: {:?}", err);
43//!     }
44//! }
45//!
46//! while let Err(err) = stream.flush() {
47//!     if err.kind() != io::ErrorKind::WouldBlock {
48//!         panic!("error: {:?}", err);
49//!     }
50//! }
51//!
52//! let mut res = vec![];
53//! while let Err(err) = stream.read_to_end(&mut res) {
54//!     if err.kind() != io::ErrorKind::WouldBlock {
55//!         panic!("stream error: {:?}", err);
56//!     }
57//! }
58//! println!("{}", String::from_utf8_lossy(&res));
59//! ```
60
61use cfg_if::cfg_if;
62use std::{
63    convert::TryFrom,
64    error::Error,
65    fmt,
66    io::{self, IoSlice, IoSliceMut, Read, Write},
67    net::{TcpStream as StdTcpStream, ToSocketAddrs},
68    ops::{Deref, DerefMut},
69    time::Duration,
70};
71
72#[cfg(feature = "rustls")]
73mod rustls_impl;
74#[cfg(feature = "rustls")]
75pub use rustls_impl::*;
76
77#[cfg(feature = "native-tls")]
78mod native_tls_impl;
79#[cfg(feature = "native-tls")]
80pub use native_tls_impl::*;
81
82#[cfg(feature = "openssl")]
83mod openssl_impl;
84#[cfg(feature = "openssl")]
85pub use openssl_impl::*;
86
87#[cfg(feature = "futures")]
88mod futures;
89#[cfg(feature = "futures")]
90pub use futures::*;
91
92/// Wrapper around plain or TLS TCP streams
93#[non_exhaustive]
94pub enum TcpStream {
95    /// Wrapper around std::net::TcpStream
96    Plain(StdTcpStream),
97    #[cfg(feature = "native-tls")]
98    /// Wrapper around a TLS stream hanled by native-tls
99    NativeTls(NativeTlsStream),
100    #[cfg(feature = "openssl")]
101    /// Wrapper around a TLS stream hanled by openssl
102    Openssl(OpensslStream),
103    #[cfg(feature = "rustls")]
104    /// Wrapper around a TLS stream hanled by rustls
105    Rustls(RustlsStream),
106}
107
108/// Holds extra TLS configuration
109#[derive(Default, Debug, PartialEq)]
110pub struct TLSConfig<'data, 'key, 'chain> {
111    /// Use for client certificate authentication
112    pub identity: Option<Identity<'data, 'key>>,
113    /// The custom certificates chain in PEM format
114    pub cert_chain: Option<&'chain str>,
115}
116
117/// Holds extra TLS configuration
118#[derive(Clone, Default, Debug, PartialEq)]
119pub struct OwnedTLSConfig {
120    /// Use for client certificate authentication
121    pub identity: Option<OwnedIdentity>,
122    /// The custom certificates chain in PEM format
123    pub cert_chain: Option<String>,
124}
125
126impl OwnedTLSConfig {
127    /// Get the ephemeral `TLSConfig` corresponding to the `OwnedTLSConfig`
128    #[must_use]
129    pub fn as_ref(&self) -> TLSConfig<'_, '_, '_> {
130        TLSConfig {
131            identity: self.identity.as_ref().map(OwnedIdentity::as_ref),
132            cert_chain: self.cert_chain.as_deref(),
133        }
134    }
135}
136
137/// Holds one of:
138/// - PKCS#12 DER-encoded identity and decryption password
139/// - PKCS#8 PEM-encoded certificate and key (without decryption password)
140#[derive(Debug, PartialEq)]
141pub enum Identity<'data, 'key> {
142    /// PKCS#12 DER-encoded identity with decryption password
143    PKCS12 {
144        /// PKCS#12 DER-encoded identity
145        der: &'data [u8],
146        /// Decryption password
147        password: &'key str,
148    },
149    /// PEM encoded DER private key with PEM encoded certificate
150    PKCS8 {
151        /// PEM-encoded certificate
152        pem: &'data [u8],
153        /// PEM-encoded key
154        key: &'key [u8],
155    },
156}
157
158/// Holds one of:
159/// - PKCS#12 DER-encoded identity and decryption password
160/// - PKCS#8 PEM-encoded certificate and key (without decryption password)
161#[derive(Clone, Debug, PartialEq)]
162pub enum OwnedIdentity {
163    /// PKCS#12 DER-encoded identity with decryption password
164    PKCS12 {
165        /// PKCS#12 DER-encoded identity
166        der: Vec<u8>,
167        /// Decryption password
168        password: String,
169    },
170    /// PKCS#8 encoded DER private key with PEM encoded certificate
171    PKCS8 {
172        /// PEM-encoded certificate
173        pem: Vec<u8>,
174        /// PEM-encoded key
175        key: Vec<u8>,
176    },
177}
178
179impl OwnedIdentity {
180    /// Get the ephemeral `Identity` corresponding to the `OwnedIdentity`
181    #[must_use]
182    pub fn as_ref(&self) -> Identity<'_, '_> {
183        match self {
184            Self::PKCS8 { pem, key } => Identity::PKCS8 { pem, key },
185            Self::PKCS12 { der, password } => Identity::PKCS12 { der, password },
186        }
187    }
188}
189
190/// Holds either the TLS `TcpStream` result or the current handshake state
191pub type HandshakeResult = Result<TcpStream, HandshakeError>;
192
193impl TcpStream {
194    /// Wrapper around `std::net::TcpStream::connect`
195    pub fn connect<A: ToSocketAddrs>(addr: A) -> io::Result<Self> {
196        connect_std(addr, None).and_then(Self::try_from)
197    }
198
199    /// Wrapper around `std::net::TcpStream::connect_timeout`
200    pub fn connect_timeout<A: ToSocketAddrs>(addr: A, timeout: Duration) -> io::Result<Self> {
201        connect_std(addr, Some(timeout)).and_then(Self::try_from)
202    }
203
204    /// Convert from a `std::net::TcpStream`
205    pub fn from_std(stream: StdTcpStream) -> io::Result<Self> {
206        Self::try_from(stream)
207    }
208
209    /// Attempt reading from underlying stream, returning Ok(()) if the stream is readable
210    pub fn is_readable(&self) -> io::Result<()> {
211        self.deref().read(&mut []).map(|_| ())
212    }
213
214    /// Attempt writing to underlying stream, returning Ok(()) if the stream is writable
215    pub fn is_writable(&self) -> io::Result<()> {
216        is_writable(self.deref())
217    }
218
219    /// Retry the connection. Returns:
220    /// - Ok(true) if connected
221    /// - Ok(false) if connecting
222    /// - Err(_) if an error is encountered
223    pub fn try_connect(&mut self) -> io::Result<bool> {
224        try_connect(self)
225    }
226
227    /// Enable TLS
228    pub fn into_tls(
229        self,
230        domain: &str,
231        config: TLSConfig<'_, '_, '_>,
232    ) -> Result<Self, HandshakeError> {
233        into_tls_impl(self, domain, config)
234    }
235
236    #[cfg(feature = "native-tls")]
237    /// Enable TLS using native-tls
238    pub fn into_native_tls(
239        self,
240        connector: &NativeTlsConnector,
241        domain: &str,
242    ) -> Result<Self, HandshakeError> {
243        Ok(connector.connect(domain, self.into_plain()?)?.into())
244    }
245
246    #[cfg(feature = "openssl")]
247    /// Enable TLS using openssl
248    pub fn into_openssl(
249        self,
250        connector: &OpensslConnector,
251        domain: &str,
252    ) -> Result<Self, HandshakeError> {
253        Ok(connector.connect(domain, self.into_plain()?)?.into())
254    }
255
256    #[cfg(feature = "rustls")]
257    /// Enable TLS using rustls
258    pub fn into_rustls(
259        self,
260        connector: &RustlsConnector,
261        domain: &str,
262    ) -> Result<Self, HandshakeError> {
263        Ok(connector.connect(domain, self.into_plain()?)?.into())
264    }
265
266    #[allow(irrefutable_let_patterns)]
267    fn into_plain(self) -> Result<StdTcpStream, io::Error> {
268        if let Self::Plain(plain) = self {
269            Ok(plain)
270        } else {
271            Err(io::Error::new(
272                io::ErrorKind::AlreadyExists,
273                "already a TLS stream",
274            ))
275        }
276    }
277}
278
279fn connect_std<A: ToSocketAddrs>(addr: A, timeout: Option<Duration>) -> io::Result<StdTcpStream> {
280    if let Some(timeout) = timeout {
281        let addrs = addr.to_socket_addrs()?;
282        let mut err = None;
283        for addr in addrs {
284            match StdTcpStream::connect_timeout(&addr, timeout) {
285                Ok(stream) => return Ok(stream),
286                Err(error) => err = Some(error),
287            }
288        }
289        Err(err.unwrap_or_else(|| {
290            io::Error::new(io::ErrorKind::AddrNotAvailable, "couldn't resolve host")
291        }))
292    } else {
293        StdTcpStream::connect(addr)
294    }
295}
296
297fn try_connect(stream: &mut StdTcpStream) -> io::Result<bool> {
298    match is_writable(stream) {
299        Ok(()) => Ok(true),
300        Err(err)
301            if [io::ErrorKind::WouldBlock, io::ErrorKind::NotConnected].contains(&err.kind()) =>
302        {
303            Ok(false)
304        }
305        Err(err) => Err(err),
306    }
307}
308
309fn is_writable(mut stream: &StdTcpStream) -> io::Result<()> {
310    stream.write(&[]).map(|_| ())
311}
312
313fn into_tls_impl(s: TcpStream, domain: &str, config: TLSConfig<'_, '_, '_>) -> HandshakeResult {
314    cfg_if! {
315        if #[cfg(feature = "rustls-platform-verifier")] {
316            into_rustls_impl(s, RustlsConnectorConfig::new_with_platform_verifier(), domain, config)
317        } else if #[cfg(feature = "rustls-native-certs")] {
318            into_rustls_impl(s, RustlsConnectorConfig::new_with_native_certs()?, domain, config)
319        } else if #[cfg(feature = "rustls-webpki-roots-certs")] {
320            into_rustls_impl(s, RustlsConnectorConfig::new_with_webpki_root_certs(), domain, config)
321        } else if #[cfg(feature = "rustls")] {
322            into_rustls_impl(s, RustlsConnectorConfig::default(), domain, config)
323        } else if #[cfg(feature = "openssl")] {
324            into_openssl_impl(s, domain, config)
325        } else if #[cfg(feature = "native-tls")] {
326            into_native_tls_impl(s, domain, config)
327        } else {
328            let _ = (domain, config);
329            Ok(TcpStream::Plain(s.into_plain()?))
330        }
331    }
332}
333
334impl TryFrom<StdTcpStream> for TcpStream {
335    type Error = io::Error;
336
337    fn try_from(s: StdTcpStream) -> io::Result<Self> {
338        s.set_nodelay(true)?;
339        let mut this = Self::Plain(s);
340        this.try_connect()?;
341        Ok(this)
342    }
343}
344
345impl Deref for TcpStream {
346    type Target = StdTcpStream;
347
348    fn deref(&self) -> &Self::Target {
349        match self {
350            Self::Plain(plain) => plain,
351            #[cfg(feature = "native-tls")]
352            Self::NativeTls(tls) => tls.get_ref(),
353            #[cfg(feature = "openssl")]
354            Self::Openssl(tls) => tls.get_ref(),
355            #[cfg(feature = "rustls")]
356            Self::Rustls(tls) => tls.get_ref(),
357        }
358    }
359}
360
361impl DerefMut for TcpStream {
362    fn deref_mut(&mut self) -> &mut Self::Target {
363        match self {
364            Self::Plain(plain) => plain,
365            #[cfg(feature = "native-tls")]
366            Self::NativeTls(tls) => tls.get_mut(),
367            #[cfg(feature = "openssl")]
368            Self::Openssl(tls) => tls.get_mut(),
369            #[cfg(feature = "rustls")]
370            Self::Rustls(tls) => tls.get_mut(),
371        }
372    }
373}
374
375macro_rules! fwd_impl {
376    ($self:ident, $method:ident, $($args:expr),*) => {
377        match $self {
378            Self::Plain(plain) => plain.$method($($args),*),
379            #[cfg(feature = "native-tls")]
380            Self::NativeTls(tls) => tls.$method($($args),*),
381            #[cfg(feature = "openssl")]
382            Self::Openssl(tls) => tls.$method($($args),*),
383            #[cfg(feature = "rustls")]
384            Self::Rustls(tls) => tls.$method($($args),*),
385        }
386    };
387}
388
389impl Read for TcpStream {
390    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
391        fwd_impl!(self, read, buf)
392    }
393
394    fn read_vectored(&mut self, bufs: &mut [IoSliceMut<'_>]) -> io::Result<usize> {
395        fwd_impl!(self, read_vectored, bufs)
396    }
397
398    fn read_to_end(&mut self, buf: &mut Vec<u8>) -> io::Result<usize> {
399        fwd_impl!(self, read_to_end, buf)
400    }
401
402    fn read_to_string(&mut self, buf: &mut String) -> io::Result<usize> {
403        fwd_impl!(self, read_to_string, buf)
404    }
405
406    fn read_exact(&mut self, buf: &mut [u8]) -> io::Result<()> {
407        fwd_impl!(self, read_exact, buf)
408    }
409}
410
411impl Write for TcpStream {
412    fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
413        fwd_impl!(self, write, buf)
414    }
415
416    fn flush(&mut self) -> io::Result<()> {
417        fwd_impl!(self, flush,)
418    }
419
420    fn write_vectored(&mut self, bufs: &[IoSlice<'_>]) -> io::Result<usize> {
421        fwd_impl!(self, write_vectored, bufs)
422    }
423
424    fn write_all(&mut self, buf: &[u8]) -> io::Result<()> {
425        fwd_impl!(self, write_all, buf)
426    }
427
428    fn write_fmt(&mut self, fmt: fmt::Arguments<'_>) -> io::Result<()> {
429        fwd_impl!(self, write_fmt, fmt)
430    }
431}
432
433impl fmt::Debug for TcpStream {
434    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
435        f.debug_struct("TcpStream")
436            .field("inner", self.deref())
437            .finish()
438    }
439}
440
441/// A TLS stream which has been interrupted during the handshake
442#[derive(Debug)]
443pub enum MidHandshakeTlsStream {
444    /// Not a TLS stream
445    Plain(TcpStream),
446    #[cfg(feature = "native-tls")]
447    /// A native-tls MidHandshakeTlsStream
448    NativeTls(NativeTlsMidHandshakeTlsStream),
449    #[cfg(feature = "openssl")]
450    /// An openssl MidHandshakeTlsStream
451    Openssl(OpensslMidHandshakeTlsStream),
452    #[cfg(feature = "rustls")]
453    /// A rustls-connector MidHandshakeTlsStream
454    Rustls(RustlsMidHandshakeTlsStream),
455}
456
457impl MidHandshakeTlsStream {
458    /// Get a reference to the inner stream
459    #[must_use]
460    pub fn get_ref(&self) -> &StdTcpStream {
461        match self {
462            Self::Plain(mid) => mid,
463            #[cfg(feature = "native-tls")]
464            Self::NativeTls(mid) => mid.get_ref(),
465            #[cfg(feature = "openssl")]
466            Self::Openssl(mid) => mid.get_ref(),
467            #[cfg(feature = "rustls")]
468            Self::Rustls(mid) => mid.get_ref(),
469        }
470    }
471
472    /// Get a mutable reference to the inner stream
473    #[must_use]
474    pub fn get_mut(&mut self) -> &mut StdTcpStream {
475        match self {
476            Self::Plain(mid) => mid,
477            #[cfg(feature = "native-tls")]
478            Self::NativeTls(mid) => mid.get_mut(),
479            #[cfg(feature = "openssl")]
480            Self::Openssl(mid) => mid.get_mut(),
481            #[cfg(feature = "rustls")]
482            Self::Rustls(mid) => mid.get_mut(),
483        }
484    }
485
486    /// Retry the handshake
487    pub fn handshake(mut self) -> HandshakeResult {
488        if !try_connect(self.get_mut())? {
489            return Err(HandshakeError::WouldBlock(self));
490        }
491
492        Ok(match self {
493            Self::Plain(mid) => mid,
494            #[cfg(feature = "native-tls")]
495            Self::NativeTls(mid) => mid.handshake()?.into(),
496            #[cfg(feature = "openssl")]
497            Self::Openssl(mid) => mid.handshake()?.into(),
498            #[cfg(feature = "rustls")]
499            Self::Rustls(mid) => mid.handshake()?.into(),
500        })
501    }
502}
503
504impl From<TcpStream> for MidHandshakeTlsStream {
505    fn from(mid: TcpStream) -> Self {
506        Self::Plain(mid)
507    }
508}
509
510impl fmt::Display for MidHandshakeTlsStream {
511    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
512        f.write_str("MidHandshakeTlsStream")
513    }
514}
515
516/// An error returned while performing the handshake
517#[derive(Debug)]
518pub enum HandshakeError {
519    /// We hit WouldBlock during handshake
520    WouldBlock(MidHandshakeTlsStream),
521    /// We hit a critical failure
522    Failure(io::Error),
523}
524
525impl HandshakeError {
526    /// Try and get the inner mid handshake TLS stream from this error
527    pub fn into_mid_handshake_tls_stream(self) -> io::Result<MidHandshakeTlsStream> {
528        match self {
529            Self::WouldBlock(mid) => Ok(mid),
530            Self::Failure(error) => Err(error),
531        }
532    }
533}
534
535impl fmt::Display for HandshakeError {
536    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
537        match self {
538            Self::WouldBlock(_) => f.write_str("WouldBlock hit during handshake"),
539            Self::Failure(err) => f.write_fmt(format_args!("IO error: {err}")),
540        }
541    }
542}
543
544impl Error for HandshakeError {
545    fn source(&self) -> Option<&(dyn Error + 'static)> {
546        match self {
547            Self::Failure(err) => Some(err),
548            _ => None,
549        }
550    }
551}
552
553impl From<io::Error> for HandshakeError {
554    fn from(err: io::Error) -> Self {
555        Self::Failure(err)
556    }
557}
558
559mod sys;