1use crate::TLSConfig;
2
3use async_rs::traits::*;
4use cfg_if::cfg_if;
5use futures_io::{AsyncRead, AsyncWrite};
6use std::{
7 fmt,
8 io::{self, IoSlice, IoSliceMut},
9 pin::Pin,
10 task::{Context, Poll},
11};
12
13#[cfg(feature = "native-tls-futures")]
14use crate::{NativeTlsAsyncStream, NativeTlsConnectorBuilder};
15#[cfg(feature = "openssl-futures")]
16use crate::{OpensslAsyncStream, OpensslConnector};
17#[cfg(feature = "rustls-futures")]
18use crate::{RustlsAsyncStream, RustlsConnector, RustlsConnectorConfig};
19
20#[non_exhaustive]
22pub enum AsyncTcpStream<S: AsyncRead + AsyncWrite + Send + Unpin + 'static> {
23 Plain(S),
25 #[cfg(feature = "native-tls-futures")]
26 NativeTls(NativeTlsAsyncStream<S>),
28 #[cfg(feature = "openssl-futures")]
29 Openssl(OpensslAsyncStream<S>),
31 #[cfg(feature = "rustls-futures")]
32 Rustls(RustlsAsyncStream<S>),
34}
35
36impl<S: AsyncRead + AsyncWrite + fmt::Debug + Send + Unpin + 'static> fmt::Debug
37 for AsyncTcpStream<S>
38{
39 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
40 f.debug_struct("AsyncTcpStream").finish_non_exhaustive()
41 }
42}
43
44impl<S: AsyncRead + AsyncWrite + Send + Unpin + 'static> AsyncTcpStream<S> {
45 pub async fn connect<R: Reactor<TcpStream = S> + Sync, A: AsyncToSocketAddrs + Send>(
47 reactor: &R,
48 addr: A,
49 ) -> io::Result<Self> {
50 Ok(Self::Plain(reactor.tcp_connect(addr).await?))
51 }
52
53 pub async fn into_tls(self, domain: &str, config: TLSConfig<'_, '_, '_>) -> io::Result<Self> {
55 into_tls_impl(self, domain, config).await
56 }
57
58 #[cfg(feature = "native-tls-futures")]
59 pub async fn into_native_tls(
61 self,
62 connector: NativeTlsConnectorBuilder,
63 domain: &str,
64 ) -> io::Result<Self> {
65 Ok(Self::NativeTls(
66 async_native_tls::TlsConnector::from(connector)
67 .connect(domain, self.into_plain()?)
68 .await
69 .map_err(io::Error::other)?,
70 ))
71 }
72
73 #[cfg(feature = "openssl-futures")]
74 pub async fn into_openssl(
76 self,
77 connector: &OpensslConnector,
78 domain: &str,
79 ) -> io::Result<Self> {
80 let mut stream = async_openssl::SslStream::new(
81 connector.configure()?.into_ssl(domain)?,
82 self.into_plain()?,
83 )?;
84 Pin::new(&mut stream)
85 .connect()
86 .await
87 .map_err(io::Error::other)?;
88 Ok(Self::Openssl(stream))
89 }
90
91 #[cfg(feature = "rustls-futures")]
92 pub async fn into_rustls(self, connector: &RustlsConnector, domain: &str) -> io::Result<Self> {
94 Ok(Self::Rustls(
95 connector.connect_async(domain, self.into_plain()?).await?,
96 ))
97 }
98
99 #[allow(irrefutable_let_patterns, dead_code)]
100 fn into_plain(self) -> io::Result<S> {
101 if let Self::Plain(plain) = self {
102 Ok(plain)
103 } else {
104 Err(io::Error::new(
105 io::ErrorKind::AlreadyExists,
106 "already a TLS stream",
107 ))
108 }
109 }
110}
111
112async fn into_tls_impl<S: AsyncRead + AsyncWrite + Send + Unpin + 'static>(
113 s: AsyncTcpStream<S>,
114 domain: &str,
115 config: TLSConfig<'_, '_, '_>,
116) -> io::Result<AsyncTcpStream<S>> {
117 cfg_if! {
118 if #[cfg(all(feature = "rustls-futures", feature = "rustls-platform-verifier"))] {
119 crate::into_rustls_impl_async(s, RustlsConnectorConfig::new_with_platform_verifier(), domain, config).await
120 } else if #[cfg(all(feature = "rustls-futures", feature = "rustls-native-certs"))] {
121 crate::into_rustls_impl_async(s, RustlsConnectorConfig::new_with_native_certs()?, domain, config).await
122 } else if #[cfg(all(feature = "rustls-futures", feature = "rustls-webpki-roots-certs"))] {
123 crate::into_rustls_impl_async(s, RustlsConnectorConfig::new_with_webpki_root_certs(), domain, config).await
124 } else if #[cfg(feature = "rustls-futures")] {
125 crate::into_rustls_impl_async(s, RustlsConnectorConfig::default(), domain, config).await
126 } else if #[cfg(feature = "openssl-futures")] {
127 crate::into_openssl_impl_async(s, domain, config).await
128 } else if #[cfg(feature = "native-tls-futures")] {
129 crate::into_native_tls_impl_async(s, domain, config).await
130 } else {
131 let _ = (domain, config);
132 Ok(AsyncTcpStream::Plain(s.into_plain()?))
133 }
134 }
135}
136
137macro_rules! fwd_impl {
138 ($self:ident, $method:ident, $($args:expr),*) => {
139 match $self.get_mut() {
140 Self::Plain(plain) => Pin::new(plain).$method($($args),*),
141 #[cfg(feature = "native-tls-futures")]
142 Self::NativeTls(tls) => Pin::new(tls).$method($($args),*),
143 #[cfg(feature = "openssl-futures")]
144 Self::Openssl(tls) => Pin::new(tls).$method($($args),*),
145 #[cfg(feature = "rustls-futures")]
146 Self::Rustls(tls) => Pin::new(tls).$method($($args),*),
147 }
148 };
149}
150
151impl<S: AsyncRead + AsyncWrite + Send + Unpin + 'static> AsyncRead for AsyncTcpStream<S> {
152 fn poll_read(
153 self: Pin<&mut Self>,
154 cx: &mut Context<'_>,
155 buf: &mut [u8],
156 ) -> Poll<io::Result<usize>> {
157 fwd_impl!(self, poll_read, cx, buf)
158 }
159
160 fn poll_read_vectored(
161 self: Pin<&mut Self>,
162 cx: &mut Context<'_>,
163 bufs: &mut [IoSliceMut<'_>],
164 ) -> Poll<io::Result<usize>> {
165 fwd_impl!(self, poll_read_vectored, cx, bufs)
166 }
167}
168
169impl<S: AsyncRead + AsyncWrite + Send + Unpin + 'static> AsyncWrite for AsyncTcpStream<S> {
170 fn poll_write(
171 self: Pin<&mut Self>,
172 cx: &mut Context<'_>,
173 buf: &[u8],
174 ) -> Poll<io::Result<usize>> {
175 fwd_impl!(self, poll_write, cx, buf)
176 }
177
178 fn poll_write_vectored(
179 self: Pin<&mut Self>,
180 cx: &mut Context<'_>,
181 bufs: &[IoSlice<'_>],
182 ) -> Poll<io::Result<usize>> {
183 fwd_impl!(self, poll_write_vectored, cx, bufs)
184 }
185
186 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
187 fwd_impl!(self, poll_flush, cx)
188 }
189
190 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
191 fwd_impl!(self, poll_close, cx)
192 }
193}