1use crate::TLSConfig;
2
3use cfg_if::cfg_if;
4use futures_io::{AsyncRead, AsyncWrite};
5use reactor_trait::{AsyncIOHandle, AsyncToSocketAddrs, TcpReactor};
6use std::{
7 io::{self, IoSlice, IoSliceMut},
8 pin::{Pin, pin},
9 task::{Context, Poll},
10};
11
12#[cfg(feature = "native-tls-futures")]
13use crate::NativeTlsConnectorBuilder;
14#[cfg(feature = "openssl-futures")]
15use crate::OpenSslConnector;
16#[cfg(feature = "rustls-futures")]
17use crate::{RustlsConnector, RustlsConnectorConfig};
18
19type AsyncStream = Pin<Box<dyn AsyncIOHandle + Send>>;
20
21pub enum AsyncTcpStream {
23 Plain(AsyncStream),
25 TLS(AsyncStream),
27}
28
29impl AsyncTcpStream {
30 pub async fn connect<R: TcpReactor, A: AsyncToSocketAddrs>(
32 _reactor: R,
33 addr: A,
34 ) -> io::Result<Self> {
35 let addrs = addr.to_socket_addrs().await?;
36 let mut err = None;
37 for addr in addrs {
38 match R::connect(addr).await {
39 Ok(stream) => return Ok(Self::Plain(stream.into())),
40 Err(e) => err = Some(e),
41 }
42 }
43 Err(err.unwrap_or_else(|| {
44 io::Error::new(io::ErrorKind::AddrNotAvailable, "couldn't resolve host")
45 }))
46 }
47
48 pub async fn into_tls(self, domain: &str, config: TLSConfig<'_, '_, '_>) -> io::Result<Self> {
50 into_tls_impl(self, domain, config).await
51 }
52
53 #[cfg(feature = "native-tls-futures")]
54 pub async fn into_native_tls(
56 self,
57 connector: NativeTlsConnectorBuilder,
58 domain: &str,
59 ) -> io::Result<Self> {
60 Ok(Self::TLS(Box::pin(
61 async_native_tls::TlsConnector::from(connector)
62 .connect(domain, self.into_plain()?)
63 .await
64 .map_err(io::Error::other)?,
65 )))
66 }
67
68 #[cfg(feature = "openssl-futures")]
69 pub async fn into_openssl(
71 self,
72 connector: &OpenSslConnector,
73 domain: &str,
74 ) -> io::Result<Self> {
75 let mut stream = async_openssl::SslStream::new(
76 connector.configure()?.into_ssl(domain)?,
77 self.into_plain()?,
78 )?;
79 Pin::new(&mut stream)
80 .connect()
81 .await
82 .map_err(io::Error::other)?;
83 Ok(Self::TLS(Box::pin(stream)))
84 }
85
86 #[cfg(feature = "rustls-futures")]
87 pub async fn into_rustls(self, connector: &RustlsConnector, domain: &str) -> io::Result<Self> {
89 Ok(Self::TLS(Box::pin(
90 connector.connect_async(domain, self.into_plain()?).await?,
91 )))
92 }
93
94 #[allow(irrefutable_let_patterns, dead_code)]
95 fn into_plain(self) -> io::Result<AsyncStream> {
96 if let AsyncTcpStream::Plain(plain) = self {
97 Ok(plain)
98 } else {
99 Err(io::Error::new(
100 io::ErrorKind::AlreadyExists,
101 "already a TLS stream",
102 ))
103 }
104 }
105}
106
107cfg_if! {
108 if #[cfg(all(feature = "rustls-futures", feature = "rustls-native-certs"))] {
109 async fn into_tls_impl(s: AsyncTcpStream, domain: &str, config: TLSConfig<'_, '_, '_>) -> io::Result<AsyncTcpStream> {
110 crate::into_rustls_impl_async(s, RustlsConnectorConfig::new_with_native_certs()?, domain, config).await
111 }
112 } else if #[cfg(all(feature = "rustls-futures", feature = "rustls-webpki-roots-certs"))] {
113 async fn into_tls_impl(s: AsyncTcpStream, domain: &str, config: TLSConfig<'_, '_, '_>) -> io::Result<AsyncTcpStream> {
114 crate::into_rustls_impl_async(s, RustlsConnectorConfig::new_with_webpki_roots_certs(), domain, config).await
115 }
116 } else if #[cfg(feature = "rustls-futures")] {
117 async fn into_tls_impl(s: AsyncTcpStream, domain: &str, config: TLSConfig<'_, '_, '_>) -> io::Result<AsyncTcpStream> {
118 crate::into_rustls_impl_async(s, RustlsConnectorConfig::default(), domain, config).await
119 }
120 } else if #[cfg(feature = "openssl-futures")] {
121 async fn into_tls_impl(s: AsyncTcpStream, domain: &str, config: TLSConfig<'_, '_, '_>) -> io::Result<AsyncTcpStream> {
122 crate::into_openssl_impl_async(s, domain, config).await
123 }
124 } else if #[cfg(feature = "native-tls-futures")] {
125 async fn into_tls_impl(s: AsyncTcpStream, domain: &str, config: TLSConfig<'_, '_, '_>) -> io::Result<AsyncTcpStream> {
126 crate::into_native_tls_impl_async(s, domain, config).await
127 }
128 } else {
129 async fn into_tls_impl(s: AsyncTcpStream, _domain: &str, _: TLSConfig<'_, '_, '_>) -> io::Result<AsyncTcpStream> {
130 Ok(AsyncTcpStream::Plain(s.into_plain()?))
131 }
132 }
133}
134
135impl AsyncRead for AsyncTcpStream {
136 fn poll_read(
137 self: Pin<&mut Self>,
138 cx: &mut Context<'_>,
139 buf: &mut [u8],
140 ) -> Poll<io::Result<usize>> {
141 match self.get_mut() {
142 Self::Plain(plain) => pin!(plain).poll_read(cx, buf),
143 Self::TLS(tls) => pin!(tls).poll_read(cx, buf),
144 }
145 }
146
147 fn poll_read_vectored(
148 self: Pin<&mut Self>,
149 cx: &mut Context<'_>,
150 bufs: &mut [IoSliceMut<'_>],
151 ) -> Poll<io::Result<usize>> {
152 match self.get_mut() {
153 Self::Plain(plain) => pin!(plain).poll_read_vectored(cx, bufs),
154 Self::TLS(tls) => pin!(tls).poll_read_vectored(cx, bufs),
155 }
156 }
157}
158
159impl AsyncWrite for AsyncTcpStream {
160 fn poll_write(
161 self: Pin<&mut Self>,
162 cx: &mut Context<'_>,
163 buf: &[u8],
164 ) -> Poll<io::Result<usize>> {
165 match self.get_mut() {
166 Self::Plain(plain) => pin!(plain).poll_write(cx, buf),
167 Self::TLS(tls) => pin!(tls).poll_write(cx, buf),
168 }
169 }
170
171 fn poll_write_vectored(
172 self: Pin<&mut Self>,
173 cx: &mut Context<'_>,
174 bufs: &[IoSlice<'_>],
175 ) -> Poll<io::Result<usize>> {
176 match self.get_mut() {
177 Self::Plain(plain) => pin!(plain).poll_write_vectored(cx, bufs),
178 Self::TLS(tls) => pin!(tls).poll_write_vectored(cx, bufs),
179 }
180 }
181
182 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
183 match self.get_mut() {
184 Self::Plain(plain) => pin!(plain).poll_flush(cx),
185 Self::TLS(tls) => pin!(tls).poll_flush(cx),
186 }
187 }
188
189 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
190 match self.get_mut() {
191 Self::Plain(plain) => pin!(plain).poll_close(cx),
192 Self::TLS(tls) => pin!(tls).poll_close(cx),
193 }
194 }
195}