1use crate::TLSConfig;
2
3use async_rs::traits::*;
4use cfg_if::cfg_if;
5use futures_io::{AsyncRead, AsyncWrite};
6use std::{
7 fmt,
8 io::{self, IoSlice, IoSliceMut},
9 pin::Pin,
10 task::{Context, Poll},
11};
12
13#[cfg(feature = "native-tls-futures")]
14use crate::{NativeTlsAsyncStream, NativeTlsConnectorBuilder};
15#[cfg(feature = "openssl-futures")]
16use crate::{OpensslAsyncStream, OpensslConnector};
17#[cfg(feature = "rustls-futures")]
18use crate::{RustlsAsyncStream, RustlsConnector};
19
20#[non_exhaustive]
22pub enum AsyncTcpStream<S: AsyncRead + AsyncWrite + Send + Unpin + 'static> {
23 Plain(S),
25 #[cfg(feature = "native-tls-futures")]
26 NativeTls(NativeTlsAsyncStream<S>),
28 #[cfg(feature = "openssl-futures")]
29 Openssl(OpensslAsyncStream<S>),
31 #[cfg(feature = "rustls-futures")]
32 Rustls(RustlsAsyncStream<S>),
34}
35
36impl<S: AsyncRead + AsyncWrite + fmt::Debug + Send + Unpin + 'static> fmt::Debug
37 for AsyncTcpStream<S>
38{
39 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
40 f.debug_struct("AsyncTcpStream").finish_non_exhaustive()
41 }
42}
43
44impl<S: AsyncRead + AsyncWrite + Send + Unpin + 'static> AsyncTcpStream<S> {
45 pub async fn connect<R: Reactor<TcpStream = S> + Sync, A: AsyncToSocketAddrs + Send>(
47 reactor: &R,
48 addr: A,
49 ) -> io::Result<Self> {
50 Ok(Self::Plain(reactor.tcp_connect(addr).await?))
51 }
52
53 pub async fn into_tls(self, domain: &str, config: TLSConfig<'_, '_, '_>) -> io::Result<Self> {
59 into_tls_impl(self, domain, config).await
60 }
61
62 #[cfg(feature = "native-tls-futures")]
63 pub async fn into_native_tls(
65 self,
66 connector: NativeTlsConnectorBuilder,
67 domain: &str,
68 ) -> io::Result<Self> {
69 Ok(Self::NativeTls(
70 async_native_tls::TlsConnector::from(connector)
71 .connect(domain, self.into_plain()?)
72 .await
73 .map_err(io::Error::other)?,
74 ))
75 }
76
77 #[cfg(feature = "openssl-futures")]
78 pub async fn into_openssl(
80 self,
81 connector: &OpensslConnector,
82 domain: &str,
83 ) -> io::Result<Self> {
84 let mut stream = async_openssl::SslStream::new(
85 connector.configure()?.into_ssl(domain)?,
86 self.into_plain()?,
87 )?;
88 Pin::new(&mut stream)
89 .connect()
90 .await
91 .map_err(io::Error::other)?;
92 Ok(Self::Openssl(stream))
93 }
94
95 #[cfg(feature = "rustls-futures")]
96 pub async fn into_rustls(self, connector: &RustlsConnector, domain: &str) -> io::Result<Self> {
98 Ok(Self::Rustls(
99 connector.connect_async(domain, self.into_plain()?).await?,
100 ))
101 }
102
103 #[allow(irrefutable_let_patterns, dead_code)]
104 fn into_plain(self) -> io::Result<S> {
105 if let Self::Plain(plain) = self {
106 Ok(plain)
107 } else {
108 Err(io::Error::new(
109 io::ErrorKind::AlreadyExists,
110 "already a TLS stream",
111 ))
112 }
113 }
114}
115
116async fn into_tls_impl<S: AsyncRead + AsyncWrite + Send + Unpin + 'static>(
117 s: AsyncTcpStream<S>,
118 domain: &str,
119 config: TLSConfig<'_, '_, '_>,
120) -> io::Result<AsyncTcpStream<S>> {
121 cfg_if! {
122 if #[cfg(feature = "rustls-futures")] {
123 crate::into_rustls_impl_async(s, domain, config).await
124 } else if #[cfg(feature = "openssl-futures")] {
125 crate::into_openssl_impl_async(s, domain, config).await
126 } else if #[cfg(feature = "native-tls-futures")] {
127 crate::into_native_tls_impl_async(s, domain, config).await
128 } else {
129 let _ = (domain, config);
130 Ok(AsyncTcpStream::Plain(s.into_plain()?))
131 }
132 }
133}
134
135impl<S: AsyncRead + AsyncWrite + Send + Unpin + 'static> AsyncRead for AsyncTcpStream<S> {
136 fn poll_read(
137 self: Pin<&mut Self>,
138 cx: &mut Context<'_>,
139 buf: &mut [u8],
140 ) -> Poll<io::Result<usize>> {
141 fwd_pin_impl!(self, poll_read, cx, buf)
142 }
143
144 fn poll_read_vectored(
145 self: Pin<&mut Self>,
146 cx: &mut Context<'_>,
147 bufs: &mut [IoSliceMut<'_>],
148 ) -> Poll<io::Result<usize>> {
149 fwd_pin_impl!(self, poll_read_vectored, cx, bufs)
150 }
151}
152
153impl<S: AsyncRead + AsyncWrite + Send + Unpin + 'static> AsyncWrite for AsyncTcpStream<S> {
154 fn poll_write(
155 self: Pin<&mut Self>,
156 cx: &mut Context<'_>,
157 buf: &[u8],
158 ) -> Poll<io::Result<usize>> {
159 fwd_pin_impl!(self, poll_write, cx, buf)
160 }
161
162 fn poll_write_vectored(
163 self: Pin<&mut Self>,
164 cx: &mut Context<'_>,
165 bufs: &[IoSlice<'_>],
166 ) -> Poll<io::Result<usize>> {
167 fwd_pin_impl!(self, poll_write_vectored, cx, bufs)
168 }
169
170 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
171 fwd_pin_impl!(self, poll_flush, cx)
172 }
173
174 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
175 fwd_pin_impl!(self, poll_close, cx)
176 }
177}