tcp_stream/
lib.rs

1#![deny(missing_docs)]
2#![allow(clippy::result_large_err)]
3
4//! # std::net::TCP stream on steroids
5//!
6//! tcp-stream is a library aiming at providing TLS support to std::net::TcpStream
7//!
8//! # Examples
9//!
10//! To connect to a remote server:
11//!
12//! ```rust
13//! use tcp_stream::{HandshakeError, TcpStream, TLSConfig};
14//!
15//! use std::io::{self, Read, Write};
16//!
17//! let mut stream = TcpStream::connect("www.rust-lang.org:443").unwrap();
18//! stream.set_nonblocking(true).unwrap();
19//!
20//! while !stream.is_connected() {
21//!     if stream.try_connect().unwrap() {
22//!         break;
23//!     }
24//! }
25//!
26//! let mut stream = stream.into_tls("www.rust-lang.org", TLSConfig::default());
27//!
28//! while let Err(HandshakeError::WouldBlock(mid_handshake)) = stream {
29//!     stream = mid_handshake.handshake();
30//! }
31//!
32//! let mut stream = stream.unwrap();
33//!
34//! while let Err(err) = stream.write_all(b"GET / HTTP/1.0\r\n\r\n") {
35//!     if err.kind() != io::ErrorKind::WouldBlock {
36//!         panic!("error: {:?}", err);
37//!     }
38//! }
39//! stream.flush().unwrap();
40//! let mut res = vec![];
41//! while let Err(err) = stream.read_to_end(&mut res) {
42//!     if err.kind() != io::ErrorKind::WouldBlock {
43//!         panic!("stream error: {:?}", err);
44//!     }
45//! }
46//! println!("{}", String::from_utf8_lossy(&res));
47//! ```
48
49use cfg_if::cfg_if;
50use std::{
51    convert::TryFrom,
52    error::Error,
53    fmt,
54    io::{self, IoSlice, IoSliceMut, Read, Write},
55    net::{TcpStream as StdTcpStream, ToSocketAddrs},
56    ops::{Deref, DerefMut},
57    time::Duration,
58};
59
60#[cfg(feature = "rustls-common")]
61mod rustls_impl;
62#[cfg(feature = "rustls-common")]
63pub use rustls_impl::*;
64
65#[cfg(feature = "native-tls")]
66mod native_tls_impl;
67#[cfg(feature = "native-tls")]
68pub use native_tls_impl::*;
69
70#[cfg(feature = "openssl")]
71mod openssl_impl;
72#[cfg(feature = "openssl")]
73pub use openssl_impl::*;
74
75#[cfg(feature = "futures")]
76mod futures;
77#[cfg(feature = "futures")]
78pub use futures::*;
79
80/// Wrapper around plain or TLS TCP streams
81pub enum TcpStream {
82    /// Wrapper around std::net::TcpStream
83    Plain(StdTcpStream, bool),
84    #[cfg(feature = "native-tls")]
85    /// Wrapper around a TLS stream hanled by native-tls
86    NativeTls(Box<NativeTlsStream>),
87    #[cfg(feature = "openssl")]
88    /// Wrapper around a TLS stream hanled by openssl
89    OpenSsl(Box<OpenSslStream>),
90    #[cfg(feature = "rustls-common")]
91    /// Wrapper around a TLS stream hanled by rustls
92    Rustls(Box<RustlsStream>),
93}
94
95/// Holds extra TLS configuration
96#[derive(Default, Debug, PartialEq)]
97pub struct TLSConfig<'data, 'key, 'chain> {
98    /// Use for client certificate authentication
99    pub identity: Option<Identity<'data, 'key>>,
100    /// The custom certificates chain in PEM format
101    pub cert_chain: Option<&'chain str>,
102}
103
104/// Holds extra TLS configuration
105#[derive(Default, Debug, PartialEq)]
106pub struct OwnedTLSConfig {
107    /// Use for client certificate authentication
108    pub identity: Option<OwnedIdentity>,
109    /// The custom certificates chain in PEM format
110    pub cert_chain: Option<String>,
111}
112
113impl OwnedTLSConfig {
114    /// Get the ephemeral `TLSConfig` corresponding to the `OwnedTLSConfig`
115    #[must_use]
116    pub fn as_ref(&self) -> TLSConfig<'_, '_, '_> {
117        TLSConfig {
118            identity: self.identity.as_ref().map(OwnedIdentity::as_ref),
119            cert_chain: self.cert_chain.as_deref(),
120        }
121    }
122}
123
124/// Holds one of:
125/// - PKCS#12 DER-encoded identity and decryption password
126/// - PKCS#8 PEM-encoded certificate and key (without decryption password)
127#[derive(Debug, PartialEq)]
128pub enum Identity<'data, 'key> {
129    /// PKCS#12 DER-encoded identity with decryption password
130    PKCS12 {
131        /// PKCS#12 DER-encoded identity
132        der: &'data [u8],
133        /// Decryption password
134        password: &'key str,
135    },
136    /// PEM encoded DER private key with PEM encoded certificate
137    PKCS8 {
138        /// PEM-encoded certificate
139        pem: &'data [u8],
140        /// PEM-encoded key
141        key: &'key [u8],
142    },
143}
144
145/// Holds one of:
146/// - PKCS#12 DER-encoded identity and decryption password
147/// - PKCS#8 PEM-encoded certificate and key (without decryption password)
148#[derive(Debug, PartialEq)]
149pub enum OwnedIdentity {
150    /// PKCS#12 DER-encoded identity with decryption password
151    PKCS12 {
152        /// PKCS#12 DER-encoded identity
153        der: Vec<u8>,
154        /// Decryption password
155        password: String,
156    },
157    /// PKCS#8 encoded DER private key with PEM encoded certificate
158    PKCS8 {
159        /// PEM-encoded certificate
160        pem: Vec<u8>,
161        /// PEM-encoded key
162        key: Vec<u8>,
163    },
164}
165
166impl OwnedIdentity {
167    /// Get the ephemeral `Identity` corresponding to the `OwnedIdentity`
168    #[must_use]
169    pub fn as_ref(&self) -> Identity<'_, '_> {
170        match self {
171            Self::PKCS8 { pem, key } => Identity::PKCS8 { pem, key },
172            Self::PKCS12 { der, password } => Identity::PKCS12 { der, password },
173        }
174    }
175}
176
177/// Holds either the TLS `TcpStream` result or the current handshake state
178pub type HandshakeResult = Result<TcpStream, HandshakeError>;
179
180impl TcpStream {
181    /// Wrapper around `std::net::TcpStream::connect`
182    pub fn connect<A: ToSocketAddrs>(addr: A) -> io::Result<Self> {
183        connect_std(addr, None).and_then(Self::try_from)
184    }
185
186    /// Wrapper around `std::net::TcpStream::connect_timeout`
187    pub fn connect_timeout<A: ToSocketAddrs>(addr: A, timeout: Duration) -> io::Result<Self> {
188        connect_std(addr, Some(timeout)).and_then(Self::try_from)
189    }
190
191    /// Convert from a `std::net::TcpStream`
192    pub fn from_std(stream: StdTcpStream) -> io::Result<Self> {
193        Self::try_from(stream)
194    }
195
196    /// Check whether the stream is connected or not
197    #[must_use]
198    #[allow(irrefutable_let_patterns)]
199    pub fn is_connected(&self) -> bool {
200        if let Self::Plain(_, connected) = self {
201            *connected
202        } else {
203            true
204        }
205    }
206
207    /// Attempt reading from underlying stream, returning Ok(()) if the stream is readable
208    pub fn is_readable(&self) -> io::Result<()> {
209        self.deref().read(&mut []).map(|_| ())
210    }
211
212    /// Attempt writing to underlying stream, returning Ok(()) if the stream is writable
213    pub fn is_writable(&self) -> io::Result<()> {
214        self.deref().write(&[]).map(|_| ())
215    }
216
217    /// Retry the connection. Returns:
218    /// - Ok(true) if connected
219    /// - Ok(false) if connecting
220    /// - Err(_) if an error is encountered
221    #[allow(irrefutable_let_patterns)]
222    pub fn try_connect(&mut self) -> io::Result<bool> {
223        if self.is_connected() {
224            return Ok(true);
225        }
226        match self.is_writable() {
227            Ok(()) => {
228                if let Self::Plain(_, connected) = self {
229                    *connected = true;
230                }
231                Ok(true)
232            }
233            Err(err)
234                if [io::ErrorKind::WouldBlock, io::ErrorKind::NotConnected]
235                    .contains(&err.kind()) =>
236            {
237                Ok(false)
238            }
239            Err(err) => Err(err),
240        }
241    }
242
243    /// Enable TLS
244    pub fn into_tls(
245        self,
246        domain: &str,
247        config: TLSConfig<'_, '_, '_>,
248    ) -> Result<Self, HandshakeError> {
249        into_tls_impl(self, domain, config)
250    }
251
252    #[cfg(feature = "native-tls")]
253    /// Enable TLS using native-tls
254    pub fn into_native_tls(
255        self,
256        connector: &NativeTlsConnector,
257        domain: &str,
258    ) -> Result<Self, HandshakeError> {
259        Ok(connector.connect(domain, self.into_plain()?)?.into())
260    }
261
262    #[cfg(feature = "openssl")]
263    /// Enable TLS using openssl
264    pub fn into_openssl(
265        self,
266        connector: &OpenSslConnector,
267        domain: &str,
268    ) -> Result<Self, HandshakeError> {
269        Ok(connector.connect(domain, self.into_plain()?)?.into())
270    }
271
272    #[cfg(feature = "rustls-common")]
273    /// Enable TLS using rustls
274    pub fn into_rustls(
275        self,
276        connector: &RustlsConnector,
277        domain: &str,
278    ) -> Result<Self, HandshakeError> {
279        Ok(connector.connect(domain, self.into_plain()?)?.into())
280    }
281
282    #[allow(irrefutable_let_patterns)]
283    fn into_plain(self) -> Result<TcpStream, io::Error> {
284        if let TcpStream::Plain(plain, connected) = self {
285            Ok(TcpStream::Plain(plain, connected))
286        } else {
287            Err(io::Error::new(
288                io::ErrorKind::AlreadyExists,
289                "already a TLS stream",
290            ))
291        }
292    }
293}
294
295fn connect_std<A: ToSocketAddrs>(addr: A, timeout: Option<Duration>) -> io::Result<StdTcpStream> {
296    let stream = connect_std_raw(addr, timeout)?;
297    stream.set_nodelay(true)?;
298    Ok(stream)
299}
300
301fn connect_std_raw<A: ToSocketAddrs>(
302    addr: A,
303    timeout: Option<Duration>,
304) -> io::Result<StdTcpStream> {
305    if let Some(timeout) = timeout {
306        let addrs = addr.to_socket_addrs()?;
307        let mut err = None;
308        for addr in addrs {
309            match StdTcpStream::connect_timeout(&addr, timeout) {
310                Ok(stream) => return Ok(stream),
311                Err(error) => err = Some(error),
312            }
313        }
314        Err(err.unwrap_or_else(|| {
315            io::Error::new(io::ErrorKind::AddrNotAvailable, "couldn't resolve host")
316        }))
317    } else {
318        StdTcpStream::connect(addr)
319    }
320}
321
322cfg_if! {
323    if #[cfg(feature = "rustls-native-certs")] {
324        fn into_tls_impl(s: TcpStream, domain: &str, config: TLSConfig<'_, '_, '_>) -> HandshakeResult {
325            into_rustls_impl(s, RustlsConnectorConfig::new_with_native_certs()?, domain, config)
326        }
327    } else if #[cfg(feature = "rustls-webpki-roots-certs")] {
328        fn into_tls_impl(s: TcpStream, domain: &str, config: TLSConfig<'_, '_, '_>) -> HandshakeResult {
329            into_rustls_impl(s, RustlsConnectorConfig::new_with_webpki_roots_certs(), domain, config)
330        }
331    } else if #[cfg(feature = "rustls-common")] {
332        fn into_tls_impl(s: TcpStream, domain: &str, config: TLSConfig<'_, '_, '_>) -> HandshakeResult {
333            into_rustls_impl(s, RustlsConnectorConfig::default(), domain, config)
334        }
335    } else if #[cfg(feature = "openssl")] {
336        fn into_tls_impl(s: TcpStream, domain: &str, config: TLSConfig<'_, '_, '_>) -> HandshakeResult {
337            into_openssl_impl(s, domain, config)
338        }
339    } else if #[cfg(feature = "native-tls")] {
340        fn into_tls_impl(s: TcpStream, domain: &str, config: TLSConfig<'_, '_, '_>) -> HandshakeResult {
341            into_native_tls_impl(s, domain, config)
342        }
343    } else {
344        fn into_tls_impl(s: TcpStream, _domain: &str, _: TLSConfig<'_, '_, '_>) -> HandshakeResult {
345            Ok(s.into_plain()?)
346        }
347    }
348}
349
350impl TryFrom<StdTcpStream> for TcpStream {
351    type Error = io::Error;
352
353    fn try_from(s: StdTcpStream) -> io::Result<Self> {
354        let mut this = TcpStream::Plain(s, false);
355        this.try_connect()?;
356        Ok(this)
357    }
358}
359
360impl Deref for TcpStream {
361    type Target = StdTcpStream;
362
363    fn deref(&self) -> &Self::Target {
364        match self {
365            TcpStream::Plain(plain, _) => plain,
366            #[cfg(feature = "native-tls")]
367            TcpStream::NativeTls(tls) => tls.get_ref(),
368            #[cfg(feature = "openssl")]
369            TcpStream::OpenSsl(tls) => tls.get_ref(),
370            #[cfg(feature = "rustls-common")]
371            TcpStream::Rustls(tls) => tls.get_ref(),
372        }
373    }
374}
375
376impl DerefMut for TcpStream {
377    fn deref_mut(&mut self) -> &mut Self::Target {
378        match self {
379            TcpStream::Plain(plain, _) => plain,
380            #[cfg(feature = "native-tls")]
381            TcpStream::NativeTls(tls) => tls.get_mut(),
382            #[cfg(feature = "openssl")]
383            TcpStream::OpenSsl(tls) => tls.get_mut(),
384            #[cfg(feature = "rustls-common")]
385            TcpStream::Rustls(tls) => tls.get_mut(),
386        }
387    }
388}
389
390macro_rules! fwd_impl {
391    ($self:ident, $method:ident, $($args:expr),*) => {
392        match $self {
393            TcpStream::Plain(plain, _) => plain.$method($($args),*),
394            #[cfg(feature = "native-tls")]
395            TcpStream::NativeTls(tls) => tls.$method($($args),*),
396            #[cfg(feature = "openssl")]
397            TcpStream::OpenSsl(tls) => tls.$method($($args),*),
398            #[cfg(feature = "rustls-common")]
399            TcpStream::Rustls(tls) => tls.$method($($args),*),
400        }
401    };
402}
403
404impl Read for TcpStream {
405    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
406        fwd_impl!(self, read, buf)
407    }
408
409    fn read_vectored(&mut self, bufs: &mut [IoSliceMut<'_>]) -> io::Result<usize> {
410        fwd_impl!(self, read_vectored, bufs)
411    }
412
413    fn read_to_end(&mut self, buf: &mut Vec<u8>) -> io::Result<usize> {
414        fwd_impl!(self, read_to_end, buf)
415    }
416
417    fn read_to_string(&mut self, buf: &mut String) -> io::Result<usize> {
418        fwd_impl!(self, read_to_string, buf)
419    }
420
421    fn read_exact(&mut self, buf: &mut [u8]) -> io::Result<()> {
422        fwd_impl!(self, read_exact, buf)
423    }
424}
425
426impl Write for TcpStream {
427    fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
428        fwd_impl!(self, write, buf)
429    }
430
431    fn flush(&mut self) -> io::Result<()> {
432        fwd_impl!(self, flush,)
433    }
434
435    fn write_vectored(&mut self, bufs: &[IoSlice<'_>]) -> io::Result<usize> {
436        fwd_impl!(self, write_vectored, bufs)
437    }
438
439    fn write_all(&mut self, buf: &[u8]) -> io::Result<()> {
440        fwd_impl!(self, write_all, buf)
441    }
442
443    fn write_fmt(&mut self, fmt: fmt::Arguments<'_>) -> io::Result<()> {
444        fwd_impl!(self, write_fmt, fmt)
445    }
446}
447
448impl fmt::Debug for TcpStream {
449    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
450        f.debug_struct("TcpStream")
451            .field("inner", self.deref())
452            .finish()
453    }
454}
455
456/// A TLS stream which has been interrupted during the handshake
457#[allow(clippy::large_enum_variant)]
458#[derive(Debug)]
459pub enum MidHandshakeTlsStream {
460    /// Not a TLS stream
461    Plain(TcpStream),
462    #[cfg(feature = "native-tls")]
463    /// A native-tls MidHandshakeTlsStream
464    NativeTls(NativeTlsMidHandshakeTlsStream),
465    #[cfg(feature = "openssl")]
466    /// An openssl MidHandshakeTlsStream
467    Openssl(OpenSslMidHandshakeTlsStream),
468    #[cfg(feature = "rustls-common")]
469    /// A rustls-connector MidHandshakeTlsStream
470    Rustls(RustlsMidHandshakeTlsStream),
471}
472
473impl MidHandshakeTlsStream {
474    /// Get a reference to the inner stream
475    #[must_use]
476    pub fn get_ref(&self) -> &TcpStream {
477        match self {
478            MidHandshakeTlsStream::Plain(mid) => mid,
479            #[cfg(feature = "native-tls")]
480            MidHandshakeTlsStream::NativeTls(mid) => mid.get_ref(),
481            #[cfg(feature = "openssl")]
482            MidHandshakeTlsStream::Openssl(mid) => mid.get_ref(),
483            #[cfg(feature = "rustls-common")]
484            MidHandshakeTlsStream::Rustls(mid) => mid.get_ref(),
485        }
486    }
487
488    /// Get a mutable reference to the inner stream
489    #[must_use]
490    pub fn get_mut(&mut self) -> &mut TcpStream {
491        match self {
492            MidHandshakeTlsStream::Plain(mid) => mid,
493            #[cfg(feature = "native-tls")]
494            MidHandshakeTlsStream::NativeTls(mid) => mid.get_mut(),
495            #[cfg(feature = "openssl")]
496            MidHandshakeTlsStream::Openssl(mid) => mid.get_mut(),
497            #[cfg(feature = "rustls-common")]
498            MidHandshakeTlsStream::Rustls(mid) => mid.get_mut(),
499        }
500    }
501
502    /// Retry the handshake
503    pub fn handshake(self) -> HandshakeResult {
504        Ok(match self {
505            MidHandshakeTlsStream::Plain(mut mid) => {
506                if !mid.try_connect()? {
507                    return Err(HandshakeError::WouldBlock(mid.into()));
508                }
509                mid
510            }
511            #[cfg(feature = "native-tls")]
512            MidHandshakeTlsStream::NativeTls(mut mid) => {
513                if !mid.get_mut().try_connect()? {
514                    return Err(HandshakeError::WouldBlock(mid.into()));
515                }
516                mid.handshake()?.into()
517            }
518            #[cfg(feature = "openssl")]
519            MidHandshakeTlsStream::Openssl(mut mid) => {
520                if !mid.get_mut().try_connect()? {
521                    return Err(HandshakeError::WouldBlock(mid.into()));
522                }
523                mid.handshake()?.into()
524            }
525            #[cfg(feature = "rustls-common")]
526            MidHandshakeTlsStream::Rustls(mut mid) => {
527                if !mid.get_mut().try_connect()? {
528                    return Err(HandshakeError::WouldBlock(mid.into()));
529                }
530                mid.handshake()?.into()
531            }
532        })
533    }
534}
535
536impl From<TcpStream> for MidHandshakeTlsStream {
537    fn from(mid: TcpStream) -> Self {
538        MidHandshakeTlsStream::Plain(mid)
539    }
540}
541
542impl fmt::Display for MidHandshakeTlsStream {
543    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
544        f.write_str("MidHandshakeTlsStream")
545    }
546}
547
548/// An error returned while performing the handshake
549#[allow(clippy::large_enum_variant)]
550#[derive(Debug)]
551pub enum HandshakeError {
552    /// We hit WouldBlock during handshake
553    WouldBlock(MidHandshakeTlsStream),
554    /// We hit a critical failure
555    Failure(io::Error),
556}
557
558impl HandshakeError {
559    /// Try and get the inner mid handshake TLS stream from this error
560    pub fn into_mid_handshake_tls_stream(self) -> io::Result<MidHandshakeTlsStream> {
561        match self {
562            Self::WouldBlock(mid) => Ok(mid),
563            Self::Failure(error) => Err(error),
564        }
565    }
566}
567
568impl fmt::Display for HandshakeError {
569    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
570        match self {
571            HandshakeError::WouldBlock(_) => f.write_str("WouldBlock hit during handshake"),
572            HandshakeError::Failure(err) => f.write_fmt(format_args!("IO error: {err}")),
573        }
574    }
575}
576
577impl Error for HandshakeError {
578    fn source(&self) -> Option<&(dyn Error + 'static)> {
579        match self {
580            HandshakeError::Failure(err) => Some(err),
581            _ => None,
582        }
583    }
584}
585
586impl From<io::Error> for HandshakeError {
587    fn from(err: io::Error) -> Self {
588        HandshakeError::Failure(err)
589    }
590}
591
592mod sys;