1use crate::TLSConfig;
2
3use cfg_if::cfg_if;
4use futures_io::{AsyncRead, AsyncWrite};
5use reactor_trait::{AsyncIOHandle, AsyncToSocketAddrs, TcpReactor};
6use std::{
7 io::{self, IoSlice, IoSliceMut},
8 ops::Deref,
9 pin::{Pin, pin},
10 task::{Context, Poll},
11};
12
13#[cfg(feature = "native-tls-futures")]
14use crate::NativeTlsConnectorBuilder;
15#[cfg(feature = "openssl-futures")]
16use crate::OpenSslConnector;
17#[cfg(feature = "rustls-futures")]
18use crate::{RustlsConnector, RustlsConnectorConfig};
19
20type AsyncStream = Pin<Box<dyn AsyncIOHandle + Send>>;
21
22pub enum AsyncTcpStream {
24 Plain(AsyncStream),
26 TLS(AsyncStream),
28}
29
30impl AsyncTcpStream {
31 pub async fn connect<R: Deref, A: AsyncToSocketAddrs>(reactor: R, addr: A) -> io::Result<Self>
33 where
34 R::Target: TcpReactor,
35 {
36 let addrs = addr.to_socket_addrs().await?;
37 let mut err = None;
38 for addr in addrs {
39 match reactor.connect(addr).await {
40 Ok(stream) => return Ok(Self::Plain(stream.into())),
41 Err(e) => err = Some(e),
42 }
43 }
44 Err(err.unwrap_or_else(|| {
45 io::Error::new(io::ErrorKind::AddrNotAvailable, "couldn't resolve host")
46 }))
47 }
48
49 pub async fn into_tls(self, domain: &str, config: TLSConfig<'_, '_, '_>) -> io::Result<Self> {
51 into_tls_impl(self, domain, config).await
52 }
53
54 #[cfg(feature = "native-tls-futures")]
55 pub async fn into_native_tls(
57 self,
58 connector: NativeTlsConnectorBuilder,
59 domain: &str,
60 ) -> io::Result<Self> {
61 Ok(Self::TLS(Box::pin(
62 async_native_tls::TlsConnector::from(connector)
63 .connect(domain, self.into_plain()?)
64 .await
65 .map_err(io::Error::other)?,
66 )))
67 }
68
69 #[cfg(feature = "openssl-futures")]
70 pub async fn into_openssl(
72 self,
73 connector: &OpenSslConnector,
74 domain: &str,
75 ) -> io::Result<Self> {
76 let mut stream = async_openssl::SslStream::new(
77 connector.configure()?.into_ssl(domain)?,
78 self.into_plain()?,
79 )?;
80 Pin::new(&mut stream)
81 .connect()
82 .await
83 .map_err(io::Error::other)?;
84 Ok(Self::TLS(Box::pin(stream)))
85 }
86
87 #[cfg(feature = "rustls-futures")]
88 pub async fn into_rustls(self, connector: &RustlsConnector, domain: &str) -> io::Result<Self> {
90 Ok(Self::TLS(Box::pin(
91 connector.connect_async(domain, self.into_plain()?).await?,
92 )))
93 }
94
95 #[allow(irrefutable_let_patterns, dead_code)]
96 fn into_plain(self) -> io::Result<AsyncStream> {
97 if let AsyncTcpStream::Plain(plain) = self {
98 Ok(plain)
99 } else {
100 Err(io::Error::new(
101 io::ErrorKind::AlreadyExists,
102 "already a TLS stream",
103 ))
104 }
105 }
106}
107
108cfg_if! {
109 if #[cfg(all(feature = "rustls-futures", feature = "rustls-native-certs"))] {
110 async fn into_tls_impl(s: AsyncTcpStream, domain: &str, config: TLSConfig<'_, '_, '_>) -> io::Result<AsyncTcpStream> {
111 crate::into_rustls_impl_async(s, RustlsConnectorConfig::new_with_native_certs()?, domain, config).await
112 }
113 } else if #[cfg(all(feature = "rustls-futures", feature = "rustls-webpki-roots-certs"))] {
114 async fn into_tls_impl(s: AsyncTcpStream, domain: &str, config: TLSConfig<'_, '_, '_>) -> io::Result<AsyncTcpStream> {
115 crate::into_rustls_impl_async(s, RustlsConnectorConfig::new_with_webpki_roots_certs(), domain, config).await
116 }
117 } else if #[cfg(feature = "rustls-futures")] {
118 async fn into_tls_impl(s: AsyncTcpStream, domain: &str, config: TLSConfig<'_, '_, '_>) -> io::Result<AsyncTcpStream> {
119 crate::into_rustls_impl_async(s, RustlsConnectorConfig::default(), domain, config).await
120 }
121 } else if #[cfg(feature = "openssl-futures")] {
122 async fn into_tls_impl(s: AsyncTcpStream, domain: &str, config: TLSConfig<'_, '_, '_>) -> io::Result<AsyncTcpStream> {
123 crate::into_openssl_impl_async(s, domain, config).await
124 }
125 } else if #[cfg(feature = "native-tls-futures")] {
126 async fn into_tls_impl(s: AsyncTcpStream, domain: &str, config: TLSConfig<'_, '_, '_>) -> io::Result<AsyncTcpStream> {
127 crate::into_native_tls_impl_async(s, domain, config).await
128 }
129 } else {
130 async fn into_tls_impl(s: AsyncTcpStream, _domain: &str, _: TLSConfig<'_, '_, '_>) -> io::Result<AsyncTcpStream> {
131 Ok(AsyncTcpStream::Plain(s.into_plain()?))
132 }
133 }
134}
135
136impl AsyncRead for AsyncTcpStream {
137 fn poll_read(
138 self: Pin<&mut Self>,
139 cx: &mut Context<'_>,
140 buf: &mut [u8],
141 ) -> Poll<io::Result<usize>> {
142 match self.get_mut() {
143 Self::Plain(plain) => pin!(plain).poll_read(cx, buf),
144 Self::TLS(tls) => pin!(tls).poll_read(cx, buf),
145 }
146 }
147
148 fn poll_read_vectored(
149 self: Pin<&mut Self>,
150 cx: &mut Context<'_>,
151 bufs: &mut [IoSliceMut<'_>],
152 ) -> Poll<io::Result<usize>> {
153 match self.get_mut() {
154 Self::Plain(plain) => pin!(plain).poll_read_vectored(cx, bufs),
155 Self::TLS(tls) => pin!(tls).poll_read_vectored(cx, bufs),
156 }
157 }
158}
159
160impl AsyncWrite for AsyncTcpStream {
161 fn poll_write(
162 self: Pin<&mut Self>,
163 cx: &mut Context<'_>,
164 buf: &[u8],
165 ) -> Poll<io::Result<usize>> {
166 match self.get_mut() {
167 Self::Plain(plain) => pin!(plain).poll_write(cx, buf),
168 Self::TLS(tls) => pin!(tls).poll_write(cx, buf),
169 }
170 }
171
172 fn poll_write_vectored(
173 self: Pin<&mut Self>,
174 cx: &mut Context<'_>,
175 bufs: &[IoSlice<'_>],
176 ) -> Poll<io::Result<usize>> {
177 match self.get_mut() {
178 Self::Plain(plain) => pin!(plain).poll_write_vectored(cx, bufs),
179 Self::TLS(tls) => pin!(tls).poll_write_vectored(cx, bufs),
180 }
181 }
182
183 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
184 match self.get_mut() {
185 Self::Plain(plain) => pin!(plain).poll_flush(cx),
186 Self::TLS(tls) => pin!(tls).poll_flush(cx),
187 }
188 }
189
190 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
191 match self.get_mut() {
192 Self::Plain(plain) => pin!(plain).poll_close(cx),
193 Self::TLS(tls) => pin!(tls).poll_close(cx),
194 }
195 }
196}