Skip to main content

tcp_stream/
futures.rs

1use crate::TLSConfig;
2
3use async_rs::traits::*;
4use cfg_if::cfg_if;
5use futures_io::{AsyncRead, AsyncWrite};
6use std::{
7    fmt,
8    io::{self, IoSlice, IoSliceMut},
9    pin::Pin,
10    task::{Context, Poll},
11};
12
13#[cfg(feature = "native-tls-futures")]
14use crate::{NativeTlsAsyncStream, NativeTlsConnectorBuilder};
15#[cfg(feature = "openssl-futures")]
16use crate::{OpensslAsyncStream, OpensslConnector};
17#[cfg(feature = "rustls-futures")]
18use crate::{RustlsAsyncStream, RustlsConnector, RustlsConnectorConfig};
19
20/// Wrapper around plain or TLS async TCP streams
21#[non_exhaustive]
22pub enum AsyncTcpStream<S: AsyncRead + AsyncWrite + Send + Unpin + 'static> {
23    /// Wrapper around plain async TCP stream
24    Plain(S),
25    #[cfg(feature = "native-tls-futures")]
26    /// Wrapper around a TLS async stream hanled by native-tls
27    NativeTls(NativeTlsAsyncStream<S>),
28    #[cfg(feature = "openssl-futures")]
29    /// Wrapper around a TLS async stream hanled by openssl
30    Openssl(OpensslAsyncStream<S>),
31    #[cfg(feature = "rustls-futures")]
32    /// Wrapper around a TLS async stream hanled by rustls
33    Rustls(RustlsAsyncStream<S>),
34}
35
36impl<S: AsyncRead + AsyncWrite + fmt::Debug + Send + Unpin + 'static> fmt::Debug
37    for AsyncTcpStream<S>
38{
39    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
40        f.debug_struct("AsyncTcpStream").finish_non_exhaustive()
41    }
42}
43
44impl<S: AsyncRead + AsyncWrite + Send + Unpin + 'static> AsyncTcpStream<S> {
45    /// Wrapper around `reactor_trait::TcpReactor::connect`
46    pub async fn connect<R: Reactor<TcpStream = S> + Sync, A: AsyncToSocketAddrs + Send>(
47        reactor: &R,
48        addr: A,
49    ) -> io::Result<Self> {
50        Ok(Self::Plain(reactor.tcp_connect(addr).await?))
51    }
52
53    /// Enable TLS
54    pub async fn into_tls(self, domain: &str, config: TLSConfig<'_, '_, '_>) -> io::Result<Self> {
55        into_tls_impl(self, domain, config).await
56    }
57
58    #[cfg(feature = "native-tls-futures")]
59    /// Enable TLS using native-tls
60    pub async fn into_native_tls(
61        self,
62        connector: NativeTlsConnectorBuilder,
63        domain: &str,
64    ) -> io::Result<Self> {
65        Ok(Self::NativeTls(
66            async_native_tls::TlsConnector::from(connector)
67                .connect(domain, self.into_plain()?)
68                .await
69                .map_err(io::Error::other)?,
70        ))
71    }
72
73    #[cfg(feature = "openssl-futures")]
74    /// Enable TLS using openssl
75    pub async fn into_openssl(
76        self,
77        connector: &OpensslConnector,
78        domain: &str,
79    ) -> io::Result<Self> {
80        let mut stream = async_openssl::SslStream::new(
81            connector.configure()?.into_ssl(domain)?,
82            self.into_plain()?,
83        )?;
84        Pin::new(&mut stream)
85            .connect()
86            .await
87            .map_err(io::Error::other)?;
88        Ok(Self::Openssl(stream))
89    }
90
91    #[cfg(feature = "rustls-futures")]
92    /// Enable TLS using rustls
93    pub async fn into_rustls(self, connector: &RustlsConnector, domain: &str) -> io::Result<Self> {
94        Ok(Self::Rustls(
95            connector.connect_async(domain, self.into_plain()?).await?,
96        ))
97    }
98
99    #[allow(irrefutable_let_patterns, dead_code)]
100    fn into_plain(self) -> io::Result<S> {
101        if let Self::Plain(plain) = self {
102            Ok(plain)
103        } else {
104            Err(io::Error::new(
105                io::ErrorKind::AlreadyExists,
106                "already a TLS stream",
107            ))
108        }
109    }
110}
111
112async fn into_tls_impl<S: AsyncRead + AsyncWrite + Send + Unpin + 'static>(
113    s: AsyncTcpStream<S>,
114    domain: &str,
115    config: TLSConfig<'_, '_, '_>,
116) -> io::Result<AsyncTcpStream<S>> {
117    cfg_if! {
118        if #[cfg(all(feature = "rustls-futures", feature = "rustls-platform-verifier"))] {
119            crate::into_rustls_impl_async(s, RustlsConnectorConfig::new_with_platform_verifier(), domain, config).await
120        } else if #[cfg(all(feature = "rustls-futures", feature = "rustls-native-certs"))] {
121            crate::into_rustls_impl_async(s, RustlsConnectorConfig::new_with_native_certs()?, domain, config).await
122        } else if #[cfg(all(feature = "rustls-futures", feature = "rustls-webpki-roots-certs"))] {
123            crate::into_rustls_impl_async(s, RustlsConnectorConfig::new_with_webpki_root_certs(), domain, config).await
124        } else if #[cfg(feature = "rustls-futures")] {
125            crate::into_rustls_impl_async(s, RustlsConnectorConfig::default(), domain, config).await
126        } else if #[cfg(feature = "openssl-futures")] {
127            crate::into_openssl_impl_async(s, domain, config).await
128        } else if #[cfg(feature = "native-tls-futures")] {
129            crate::into_native_tls_impl_async(s, domain, config).await
130        } else {
131            let _ = (domain, config);
132            Ok(AsyncTcpStream::Plain(s.into_plain()?))
133        }
134    }
135}
136
137macro_rules! fwd_impl {
138    ($self:ident, $method:ident, $($args:expr),*) => {
139        match $self.get_mut() {
140            Self::Plain(plain) => Pin::new(plain).$method($($args),*),
141            #[cfg(feature = "native-tls-futures")]
142            Self::NativeTls(tls) => Pin::new(tls).$method($($args),*),
143            #[cfg(feature = "openssl-futures")]
144            Self::Openssl(tls) => Pin::new(tls).$method($($args),*),
145            #[cfg(feature = "rustls-futures")]
146            Self::Rustls(tls) => Pin::new(tls).$method($($args),*),
147        }
148    };
149}
150
151impl<S: AsyncRead + AsyncWrite + Send + Unpin + 'static> AsyncRead for AsyncTcpStream<S> {
152    fn poll_read(
153        self: Pin<&mut Self>,
154        cx: &mut Context<'_>,
155        buf: &mut [u8],
156    ) -> Poll<io::Result<usize>> {
157        fwd_impl!(self, poll_read, cx, buf)
158    }
159
160    fn poll_read_vectored(
161        self: Pin<&mut Self>,
162        cx: &mut Context<'_>,
163        bufs: &mut [IoSliceMut<'_>],
164    ) -> Poll<io::Result<usize>> {
165        fwd_impl!(self, poll_read_vectored, cx, bufs)
166    }
167}
168
169impl<S: AsyncRead + AsyncWrite + Send + Unpin + 'static> AsyncWrite for AsyncTcpStream<S> {
170    fn poll_write(
171        self: Pin<&mut Self>,
172        cx: &mut Context<'_>,
173        buf: &[u8],
174    ) -> Poll<io::Result<usize>> {
175        fwd_impl!(self, poll_write, cx, buf)
176    }
177
178    fn poll_write_vectored(
179        self: Pin<&mut Self>,
180        cx: &mut Context<'_>,
181        bufs: &[IoSlice<'_>],
182    ) -> Poll<io::Result<usize>> {
183        fwd_impl!(self, poll_write_vectored, cx, bufs)
184    }
185
186    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
187        fwd_impl!(self, poll_flush, cx)
188    }
189
190    fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
191        fwd_impl!(self, poll_close, cx)
192    }
193}