1use crate::TLSConfig;
2
3use async_rs::traits::*;
4use cfg_if::cfg_if;
5use futures_io::{AsyncRead, AsyncWrite};
6use std::{
7 io::{self, IoSlice, IoSliceMut},
8 pin::Pin,
9 task::{Context, Poll},
10};
11
12#[cfg(feature = "native-tls-futures")]
13use crate::{NativeTlsAsyncStream, NativeTlsConnectorBuilder};
14#[cfg(feature = "openssl-futures")]
15use crate::{OpensslAsyncStream, OpensslConnector};
16#[cfg(feature = "rustls-futures")]
17use crate::{RustlsAsyncStream, RustlsConnector, RustlsConnectorConfig};
18
19#[non_exhaustive]
21pub enum AsyncTcpStream<S: AsyncRead + AsyncWrite + Send + Unpin + 'static> {
22 Plain(S),
24 #[cfg(feature = "native-tls-futures")]
25 NativeTls(NativeTlsAsyncStream<S>),
27 #[cfg(feature = "openssl-futures")]
28 Openssl(OpensslAsyncStream<S>),
30 #[cfg(feature = "rustls-futures")]
31 Rustls(RustlsAsyncStream<S>),
33}
34
35impl<S: AsyncRead + AsyncWrite + Send + Unpin + 'static> AsyncTcpStream<S> {
36 pub async fn connect<R: Reactor<TcpStream = S> + Sync, A: AsyncToSocketAddrs + Send>(
38 reactor: &R,
39 addr: A,
40 ) -> io::Result<Self> {
41 Ok(Self::Plain(reactor.tcp_connect(addr).await?))
42 }
43
44 pub async fn into_tls(self, domain: &str, config: TLSConfig<'_, '_, '_>) -> io::Result<Self> {
46 into_tls_impl(self, domain, config).await
47 }
48
49 #[cfg(feature = "native-tls-futures")]
50 pub async fn into_native_tls(
52 self,
53 connector: NativeTlsConnectorBuilder,
54 domain: &str,
55 ) -> io::Result<Self> {
56 Ok(Self::NativeTls(
57 async_native_tls::TlsConnector::from(connector)
58 .connect(domain, self.into_plain()?)
59 .await
60 .map_err(io::Error::other)?,
61 ))
62 }
63
64 #[cfg(feature = "openssl-futures")]
65 pub async fn into_openssl(
67 self,
68 connector: &OpensslConnector,
69 domain: &str,
70 ) -> io::Result<Self> {
71 let mut stream = async_openssl::SslStream::new(
72 connector.configure()?.into_ssl(domain)?,
73 self.into_plain()?,
74 )?;
75 Pin::new(&mut stream)
76 .connect()
77 .await
78 .map_err(io::Error::other)?;
79 Ok(Self::Openssl(stream))
80 }
81
82 #[cfg(feature = "rustls-futures")]
83 pub async fn into_rustls(self, connector: &RustlsConnector, domain: &str) -> io::Result<Self> {
85 Ok(Self::Rustls(
86 connector.connect_async(domain, self.into_plain()?).await?,
87 ))
88 }
89
90 #[allow(irrefutable_let_patterns, dead_code)]
91 fn into_plain(self) -> io::Result<S> {
92 if let Self::Plain(plain) = self {
93 Ok(plain)
94 } else {
95 Err(io::Error::new(
96 io::ErrorKind::AlreadyExists,
97 "already a TLS stream",
98 ))
99 }
100 }
101}
102
103async fn into_tls_impl<S: AsyncRead + AsyncWrite + Send + Unpin + 'static>(
104 s: AsyncTcpStream<S>,
105 domain: &str,
106 config: TLSConfig<'_, '_, '_>,
107) -> io::Result<AsyncTcpStream<S>> {
108 cfg_if! {
109 if #[cfg(all(feature = "rustls-futures", feature = "rustls-native-certs"))] {
110 crate::into_rustls_impl_async(s, RustlsConnectorConfig::new_with_native_certs()?, domain, config).await
111 } else if #[cfg(all(feature = "rustls-futures", feature = "rustls-webpki-roots-certs"))] {
112 crate::into_rustls_impl_async(s, RustlsConnectorConfig::new_with_webpki_roots_certs(), domain, config).await
113 } else if #[cfg(feature = "rustls-futures")] {
114 crate::into_rustls_impl_async(s, RustlsConnectorConfig::default(), domain, config).await
115 } else if #[cfg(feature = "openssl-futures")] {
116 crate::into_openssl_impl_async(s, domain, config).await
117 } else if #[cfg(feature = "native-tls-futures")] {
118 crate::into_native_tls_impl_async(s, domain, config).await
119 } else {
120 let _ = (domain, config);
121 Ok(AsyncTcpStream::Plain(s.into_plain()?))
122 }
123 }
124}
125
126macro_rules! fwd_impl {
127 ($self:ident, $method:ident, $($args:expr),*) => {
128 match $self.get_mut() {
129 Self::Plain(plain) => Pin::new(plain).$method($($args),*),
130 #[cfg(feature = "native-tls-futures")]
131 Self::NativeTls(tls) => Pin::new(tls).$method($($args),*),
132 #[cfg(feature = "openssl-futures")]
133 Self::Openssl(tls) => Pin::new(tls).$method($($args),*),
134 #[cfg(feature = "rustls-futures")]
135 Self::Rustls(tls) => Pin::new(tls).$method($($args),*),
136 }
137 };
138}
139
140impl<S: AsyncRead + AsyncWrite + Send + Unpin + 'static> AsyncRead for AsyncTcpStream<S> {
141 fn poll_read(
142 self: Pin<&mut Self>,
143 cx: &mut Context<'_>,
144 buf: &mut [u8],
145 ) -> Poll<io::Result<usize>> {
146 fwd_impl!(self, poll_read, cx, buf)
147 }
148
149 fn poll_read_vectored(
150 self: Pin<&mut Self>,
151 cx: &mut Context<'_>,
152 bufs: &mut [IoSliceMut<'_>],
153 ) -> Poll<io::Result<usize>> {
154 fwd_impl!(self, poll_read_vectored, cx, bufs)
155 }
156}
157
158impl<S: AsyncRead + AsyncWrite + Send + Unpin + 'static> AsyncWrite for AsyncTcpStream<S> {
159 fn poll_write(
160 self: Pin<&mut Self>,
161 cx: &mut Context<'_>,
162 buf: &[u8],
163 ) -> Poll<io::Result<usize>> {
164 fwd_impl!(self, poll_write, cx, buf)
165 }
166
167 fn poll_write_vectored(
168 self: Pin<&mut Self>,
169 cx: &mut Context<'_>,
170 bufs: &[IoSlice<'_>],
171 ) -> Poll<io::Result<usize>> {
172 fwd_impl!(self, poll_write_vectored, cx, bufs)
173 }
174
175 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
176 fwd_impl!(self, poll_flush, cx)
177 }
178
179 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
180 fwd_impl!(self, poll_close, cx)
181 }
182}