tcp_stream/
futures.rs

1use crate::TLSConfig;
2
3use async_rs::traits::*;
4use cfg_if::cfg_if;
5use futures_io::{AsyncRead, AsyncWrite};
6use std::{
7    io::{self, IoSlice, IoSliceMut},
8    pin::Pin,
9    task::{Context, Poll},
10};
11
12#[cfg(feature = "native-tls-futures")]
13use crate::{NativeTlsAsyncStream, NativeTlsConnectorBuilder};
14#[cfg(feature = "openssl-futures")]
15use crate::{OpensslAsyncStream, OpensslConnector};
16#[cfg(feature = "rustls-futures")]
17use crate::{RustlsAsyncStream, RustlsConnector, RustlsConnectorConfig};
18
19/// Wrapper around plain or TLS async TCP streams
20#[non_exhaustive]
21pub enum AsyncTcpStream<S: AsyncRead + AsyncWrite + Send + Unpin + 'static> {
22    /// Wrapper around plain async TCP stream
23    Plain(S),
24    #[cfg(feature = "native-tls-futures")]
25    /// Wrapper around a TLS async stream hanled by native-tls
26    NativeTls(NativeTlsAsyncStream<S>),
27    #[cfg(feature = "openssl-futures")]
28    /// Wrapper around a TLS async stream hanled by openssl
29    Openssl(OpensslAsyncStream<S>),
30    #[cfg(feature = "rustls-futures")]
31    /// Wrapper around a TLS async stream hanled by rustls
32    Rustls(RustlsAsyncStream<S>),
33}
34
35impl<S: AsyncRead + AsyncWrite + Send + Unpin + 'static> AsyncTcpStream<S> {
36    /// Wrapper around `reactor_trait::TcpReactor::connect`
37    pub async fn connect<R: Reactor<TcpStream = S> + Sync, A: AsyncToSocketAddrs + Send>(
38        reactor: &R,
39        addr: A,
40    ) -> io::Result<Self> {
41        Ok(Self::Plain(reactor.tcp_connect(addr).await?))
42    }
43
44    /// Enable TLS
45    pub async fn into_tls(self, domain: &str, config: TLSConfig<'_, '_, '_>) -> io::Result<Self> {
46        into_tls_impl(self, domain, config).await
47    }
48
49    #[cfg(feature = "native-tls-futures")]
50    /// Enable TLS using native-tls
51    pub async fn into_native_tls(
52        self,
53        connector: NativeTlsConnectorBuilder,
54        domain: &str,
55    ) -> io::Result<Self> {
56        Ok(Self::NativeTls(
57            async_native_tls::TlsConnector::from(connector)
58                .connect(domain, self.into_plain()?)
59                .await
60                .map_err(io::Error::other)?,
61        ))
62    }
63
64    #[cfg(feature = "openssl-futures")]
65    /// Enable TLS using openssl
66    pub async fn into_openssl(
67        self,
68        connector: &OpensslConnector,
69        domain: &str,
70    ) -> io::Result<Self> {
71        let mut stream = async_openssl::SslStream::new(
72            connector.configure()?.into_ssl(domain)?,
73            self.into_plain()?,
74        )?;
75        Pin::new(&mut stream)
76            .connect()
77            .await
78            .map_err(io::Error::other)?;
79        Ok(Self::Openssl(stream))
80    }
81
82    #[cfg(feature = "rustls-futures")]
83    /// Enable TLS using rustls
84    pub async fn into_rustls(self, connector: &RustlsConnector, domain: &str) -> io::Result<Self> {
85        Ok(Self::Rustls(
86            connector.connect_async(domain, self.into_plain()?).await?,
87        ))
88    }
89
90    #[allow(irrefutable_let_patterns, dead_code)]
91    fn into_plain(self) -> io::Result<S> {
92        if let Self::Plain(plain) = self {
93            Ok(plain)
94        } else {
95            Err(io::Error::new(
96                io::ErrorKind::AlreadyExists,
97                "already a TLS stream",
98            ))
99        }
100    }
101}
102
103async fn into_tls_impl<S: AsyncRead + AsyncWrite + Send + Unpin + 'static>(
104    s: AsyncTcpStream<S>,
105    domain: &str,
106    config: TLSConfig<'_, '_, '_>,
107) -> io::Result<AsyncTcpStream<S>> {
108    cfg_if! {
109        if #[cfg(all(feature = "rustls-futures", feature = "rustls-native-certs"))] {
110            crate::into_rustls_impl_async(s, RustlsConnectorConfig::new_with_native_certs()?, domain, config).await
111        } else if #[cfg(all(feature = "rustls-futures", feature = "rustls-webpki-roots-certs"))] {
112            crate::into_rustls_impl_async(s, RustlsConnectorConfig::new_with_webpki_roots_certs(), domain, config).await
113        } else if #[cfg(feature = "rustls-futures")] {
114            crate::into_rustls_impl_async(s, RustlsConnectorConfig::default(), domain, config).await
115        } else if #[cfg(feature = "openssl-futures")] {
116            crate::into_openssl_impl_async(s, domain, config).await
117        } else if #[cfg(feature = "native-tls-futures")] {
118            crate::into_native_tls_impl_async(s, domain, config).await
119        } else {
120            let _ = (domain, config);
121            Ok(AsyncTcpStream::Plain(s.into_plain()?))
122        }
123    }
124}
125
126macro_rules! fwd_impl {
127    ($self:ident, $method:ident, $($args:expr),*) => {
128        match $self.get_mut() {
129            Self::Plain(plain) => Pin::new(plain).$method($($args),*),
130            #[cfg(feature = "native-tls-futures")]
131            Self::NativeTls(tls) => Pin::new(tls).$method($($args),*),
132            #[cfg(feature = "openssl-futures")]
133            Self::Openssl(tls) => Pin::new(tls).$method($($args),*),
134            #[cfg(feature = "rustls-futures")]
135            Self::Rustls(tls) => Pin::new(tls).$method($($args),*),
136        }
137    };
138}
139
140impl<S: AsyncRead + AsyncWrite + Send + Unpin + 'static> AsyncRead for AsyncTcpStream<S> {
141    fn poll_read(
142        self: Pin<&mut Self>,
143        cx: &mut Context<'_>,
144        buf: &mut [u8],
145    ) -> Poll<io::Result<usize>> {
146        fwd_impl!(self, poll_read, cx, buf)
147    }
148
149    fn poll_read_vectored(
150        self: Pin<&mut Self>,
151        cx: &mut Context<'_>,
152        bufs: &mut [IoSliceMut<'_>],
153    ) -> Poll<io::Result<usize>> {
154        fwd_impl!(self, poll_read_vectored, cx, bufs)
155    }
156}
157
158impl<S: AsyncRead + AsyncWrite + Send + Unpin + 'static> AsyncWrite for AsyncTcpStream<S> {
159    fn poll_write(
160        self: Pin<&mut Self>,
161        cx: &mut Context<'_>,
162        buf: &[u8],
163    ) -> Poll<io::Result<usize>> {
164        fwd_impl!(self, poll_write, cx, buf)
165    }
166
167    fn poll_write_vectored(
168        self: Pin<&mut Self>,
169        cx: &mut Context<'_>,
170        bufs: &[IoSlice<'_>],
171    ) -> Poll<io::Result<usize>> {
172        fwd_impl!(self, poll_write_vectored, cx, bufs)
173    }
174
175    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
176        fwd_impl!(self, poll_flush, cx)
177    }
178
179    fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
180        fwd_impl!(self, poll_close, cx)
181    }
182}