tcp_stream/
futures.rs

1use crate::TLSConfig;
2
3use cfg_if::cfg_if;
4use futures_io::{AsyncRead, AsyncWrite};
5use reactor_trait::{AsyncIOHandle, AsyncToSocketAddrs, TcpReactor};
6use std::{
7    io::{self, IoSlice, IoSliceMut},
8    ops::Deref,
9    pin::{Pin, pin},
10    task::{Context, Poll},
11};
12
13#[cfg(feature = "native-tls-futures")]
14use crate::NativeTlsConnectorBuilder;
15#[cfg(feature = "openssl-futures")]
16use crate::OpenSslConnector;
17#[cfg(feature = "rustls-futures")]
18use crate::{RustlsConnector, RustlsConnectorConfig};
19
20type AsyncStream = Pin<Box<dyn AsyncIOHandle + Send>>;
21
22/// Wrapper around plain or TLS async TCP streams
23pub enum AsyncTcpStream {
24    /// Wrapper around plain async TCP stream
25    Plain(AsyncStream),
26    /// Wrapper around a TLS async TCP stream
27    TLS(AsyncStream),
28}
29
30impl AsyncTcpStream {
31    /// Wrapper around `reactor_trait::TcpReactor::connect`
32    pub async fn connect<R: Deref, A: AsyncToSocketAddrs>(reactor: R, addr: A) -> io::Result<Self>
33    where
34        R::Target: TcpReactor,
35    {
36        let addrs = addr.to_socket_addrs().await?;
37        let mut err = None;
38        for addr in addrs {
39            match reactor.connect(addr).await {
40                Ok(stream) => return Ok(Self::Plain(stream.into())),
41                Err(e) => err = Some(e),
42            }
43        }
44        Err(err.unwrap_or_else(|| {
45            io::Error::new(io::ErrorKind::AddrNotAvailable, "couldn't resolve host")
46        }))
47    }
48
49    /// Enable TLS
50    pub async fn into_tls(self, domain: &str, config: TLSConfig<'_, '_, '_>) -> io::Result<Self> {
51        into_tls_impl(self, domain, config).await
52    }
53
54    #[cfg(feature = "native-tls-futures")]
55    /// Enable TLS using native-tls
56    pub async fn into_native_tls(
57        self,
58        connector: NativeTlsConnectorBuilder,
59        domain: &str,
60    ) -> io::Result<Self> {
61        Ok(Self::TLS(Box::pin(
62            async_native_tls::TlsConnector::from(connector)
63                .connect(domain, self.into_plain()?)
64                .await
65                .map_err(io::Error::other)?,
66        )))
67    }
68
69    #[cfg(feature = "openssl-futures")]
70    /// Enable TLS using openssl
71    pub async fn into_openssl(
72        self,
73        connector: &OpenSslConnector,
74        domain: &str,
75    ) -> io::Result<Self> {
76        let mut stream = async_openssl::SslStream::new(
77            connector.configure()?.into_ssl(domain)?,
78            self.into_plain()?,
79        )?;
80        Pin::new(&mut stream)
81            .connect()
82            .await
83            .map_err(io::Error::other)?;
84        Ok(Self::TLS(Box::pin(stream)))
85    }
86
87    #[cfg(feature = "rustls-futures")]
88    /// Enable TLS using rustls
89    pub async fn into_rustls(self, connector: &RustlsConnector, domain: &str) -> io::Result<Self> {
90        Ok(Self::TLS(Box::pin(
91            connector.connect_async(domain, self.into_plain()?).await?,
92        )))
93    }
94
95    #[allow(irrefutable_let_patterns, dead_code)]
96    fn into_plain(self) -> io::Result<AsyncStream> {
97        if let AsyncTcpStream::Plain(plain) = self {
98            Ok(plain)
99        } else {
100            Err(io::Error::new(
101                io::ErrorKind::AlreadyExists,
102                "already a TLS stream",
103            ))
104        }
105    }
106}
107
108cfg_if! {
109    if #[cfg(all(feature = "rustls-futures", feature = "rustls-native-certs"))] {
110        async fn into_tls_impl(s: AsyncTcpStream, domain: &str, config: TLSConfig<'_, '_, '_>) -> io::Result<AsyncTcpStream> {
111            crate::into_rustls_impl_async(s, RustlsConnectorConfig::new_with_native_certs()?, domain, config).await
112        }
113    } else if #[cfg(all(feature = "rustls-futures", feature = "rustls-webpki-roots-certs"))] {
114        async fn into_tls_impl(s: AsyncTcpStream, domain: &str, config: TLSConfig<'_, '_, '_>) -> io::Result<AsyncTcpStream> {
115            crate::into_rustls_impl_async(s, RustlsConnectorConfig::new_with_webpki_roots_certs(), domain, config).await
116        }
117    } else if #[cfg(feature = "rustls-futures")] {
118        async fn into_tls_impl(s: AsyncTcpStream, domain: &str, config: TLSConfig<'_, '_, '_>) -> io::Result<AsyncTcpStream> {
119            crate::into_rustls_impl_async(s, RustlsConnectorConfig::default(), domain, config).await
120        }
121    } else if #[cfg(feature = "openssl-futures")] {
122        async fn into_tls_impl(s: AsyncTcpStream, domain: &str, config: TLSConfig<'_, '_, '_>) -> io::Result<AsyncTcpStream> {
123            crate::into_openssl_impl_async(s, domain, config).await
124        }
125    } else if #[cfg(feature = "native-tls-futures")] {
126        async fn into_tls_impl(s: AsyncTcpStream, domain: &str, config: TLSConfig<'_, '_, '_>) -> io::Result<AsyncTcpStream> {
127            crate::into_native_tls_impl_async(s, domain, config).await
128        }
129    } else {
130        async fn into_tls_impl(s: AsyncTcpStream, _domain: &str, _: TLSConfig<'_, '_, '_>) -> io::Result<AsyncTcpStream> {
131            Ok(AsyncTcpStream::Plain(s.into_plain()?))
132        }
133    }
134}
135
136impl AsyncRead for AsyncTcpStream {
137    fn poll_read(
138        self: Pin<&mut Self>,
139        cx: &mut Context<'_>,
140        buf: &mut [u8],
141    ) -> Poll<io::Result<usize>> {
142        match self.get_mut() {
143            Self::Plain(plain) => pin!(plain).poll_read(cx, buf),
144            Self::TLS(tls) => pin!(tls).poll_read(cx, buf),
145        }
146    }
147
148    fn poll_read_vectored(
149        self: Pin<&mut Self>,
150        cx: &mut Context<'_>,
151        bufs: &mut [IoSliceMut<'_>],
152    ) -> Poll<io::Result<usize>> {
153        match self.get_mut() {
154            Self::Plain(plain) => pin!(plain).poll_read_vectored(cx, bufs),
155            Self::TLS(tls) => pin!(tls).poll_read_vectored(cx, bufs),
156        }
157    }
158}
159
160impl AsyncWrite for AsyncTcpStream {
161    fn poll_write(
162        self: Pin<&mut Self>,
163        cx: &mut Context<'_>,
164        buf: &[u8],
165    ) -> Poll<io::Result<usize>> {
166        match self.get_mut() {
167            Self::Plain(plain) => pin!(plain).poll_write(cx, buf),
168            Self::TLS(tls) => pin!(tls).poll_write(cx, buf),
169        }
170    }
171
172    fn poll_write_vectored(
173        self: Pin<&mut Self>,
174        cx: &mut Context<'_>,
175        bufs: &[IoSlice<'_>],
176    ) -> Poll<io::Result<usize>> {
177        match self.get_mut() {
178            Self::Plain(plain) => pin!(plain).poll_write_vectored(cx, bufs),
179            Self::TLS(tls) => pin!(tls).poll_write_vectored(cx, bufs),
180        }
181    }
182
183    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
184        match self.get_mut() {
185            Self::Plain(plain) => pin!(plain).poll_flush(cx),
186            Self::TLS(tls) => pin!(tls).poll_flush(cx),
187        }
188    }
189
190    fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
191        match self.get_mut() {
192            Self::Plain(plain) => pin!(plain).poll_close(cx),
193            Self::TLS(tls) => pin!(tls).poll_close(cx),
194        }
195    }
196}