1#![allow(dead_code)]
2
3use std::io;
4use std::ops::{Deref, DerefMut};
5use std::path::PathBuf;
6use std::pin::Pin;
7use std::task::{Context, Poll};
8
9use sqlx_rt::{AsyncRead, AsyncWrite, TlsStream};
10
11use crate::error::Error;
12use std::mem::replace;
13
14#[derive(Clone, Debug)]
16pub enum CertificateInput {
17 Inline(Vec<u8>),
19 File(PathBuf),
21}
22
23impl From<String> for CertificateInput {
24 fn from(value: String) -> Self {
25 let trimmed = value.trim();
26 if trimmed.starts_with("-----BEGIN CERTIFICATE-----")
28 && trimmed.contains("-----END CERTIFICATE-----")
29 {
30 CertificateInput::Inline(value.as_bytes().to_vec())
31 } else {
32 CertificateInput::File(PathBuf::from(value))
33 }
34 }
35}
36
37impl CertificateInput {
38 async fn data(&self) -> Result<Vec<u8>, std::io::Error> {
39 use sqlx_rt::fs;
40 match self {
41 CertificateInput::Inline(v) => Ok(v.clone()),
42 CertificateInput::File(path) => fs::read(path).await,
43 }
44 }
45}
46
47impl std::fmt::Display for CertificateInput {
48 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
49 match self {
50 CertificateInput::Inline(v) => write!(f, "{}", String::from_utf8_lossy(v.as_slice())),
51 CertificateInput::File(path) => write!(f, "file: {}", path.display()),
52 }
53 }
54}
55
56#[cfg(feature = "_tls-rustls")]
57mod rustls;
58
59pub enum MaybeTlsStream<S>
60where
61 S: AsyncRead + AsyncWrite + Unpin,
62{
63 Raw(S),
64 Tls(TlsStream<S>),
65 Upgrading,
66}
67
68impl<S> MaybeTlsStream<S>
69where
70 S: AsyncRead + AsyncWrite + Unpin,
71{
72 #[inline]
73 pub fn is_tls(&self) -> bool {
74 matches!(self, Self::Tls(_))
75 }
76
77 pub async fn upgrade(
78 &mut self,
79 host: &str,
80 accept_invalid_certs: bool,
81 accept_invalid_hostnames: bool,
82 root_cert_path: Option<&CertificateInput>,
83 ) -> Result<(), Error> {
84 let connector = configure_tls_connector(
85 accept_invalid_certs,
86 accept_invalid_hostnames,
87 root_cert_path,
88 )
89 .await?;
90
91 let stream = match replace(self, MaybeTlsStream::Upgrading) {
92 MaybeTlsStream::Raw(stream) => stream,
93
94 MaybeTlsStream::Tls(_) => {
95 return Ok(());
97 }
98
99 MaybeTlsStream::Upgrading => {
100 return Err(Error::Io(io::ErrorKind::ConnectionAborted.into()));
103 }
104 };
105
106 #[cfg(feature = "_tls-rustls")]
107 let host = ::rustls::ServerName::try_from(host).map_err(|err| Error::Tls(err.into()))?;
108
109 *self = MaybeTlsStream::Tls(connector.connect(host, stream).await?);
110
111 Ok(())
112 }
113}
114
115#[cfg(feature = "_tls-native-tls")]
116async fn configure_tls_connector(
117 accept_invalid_certs: bool,
118 accept_invalid_hostnames: bool,
119 root_cert_path: Option<&CertificateInput>,
120) -> Result<sqlx_rt::TlsConnector, Error> {
121 use sqlx_rt::native_tls::{Certificate, TlsConnector};
122
123 let mut builder = TlsConnector::builder();
124 builder
125 .danger_accept_invalid_certs(accept_invalid_certs)
126 .danger_accept_invalid_hostnames(accept_invalid_hostnames);
127
128 if !accept_invalid_certs {
129 if let Some(ca) = root_cert_path {
130 let data = ca.data().await?;
131 let cert = Certificate::from_pem(&data)?;
132
133 builder.add_root_certificate(cert);
134 }
135 }
136
137 #[cfg(not(feature = "_rt-async-std"))]
138 let connector = builder.build()?.into();
139
140 #[cfg(feature = "_rt-async-std")]
141 let connector = builder.into();
142
143 Ok(connector)
144}
145
146#[cfg(feature = "_tls-rustls")]
147use self::rustls::configure_tls_connector;
148
149impl<S> AsyncRead for MaybeTlsStream<S>
150where
151 S: Unpin + AsyncWrite + AsyncRead,
152{
153 fn poll_read(
154 mut self: Pin<&mut Self>,
155 cx: &mut Context<'_>,
156 buf: &mut super::PollReadBuf<'_>,
157 ) -> Poll<io::Result<super::PollReadOut>> {
158 match &mut *self {
159 MaybeTlsStream::Raw(s) => Pin::new(s).poll_read(cx, buf),
160 MaybeTlsStream::Tls(s) => Pin::new(s).poll_read(cx, buf),
161
162 MaybeTlsStream::Upgrading => Poll::Ready(Err(io::ErrorKind::ConnectionAborted.into())),
163 }
164 }
165}
166
167impl<S> AsyncWrite for MaybeTlsStream<S>
168where
169 S: Unpin + AsyncWrite + AsyncRead,
170{
171 fn poll_write(
172 mut self: Pin<&mut Self>,
173 cx: &mut Context<'_>,
174 buf: &[u8],
175 ) -> Poll<io::Result<usize>> {
176 match &mut *self {
177 MaybeTlsStream::Raw(s) => Pin::new(s).poll_write(cx, buf),
178 MaybeTlsStream::Tls(s) => Pin::new(s).poll_write(cx, buf),
179
180 MaybeTlsStream::Upgrading => Poll::Ready(Err(io::ErrorKind::ConnectionAborted.into())),
181 }
182 }
183
184 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
185 match &mut *self {
186 MaybeTlsStream::Raw(s) => Pin::new(s).poll_flush(cx),
187 MaybeTlsStream::Tls(s) => Pin::new(s).poll_flush(cx),
188
189 MaybeTlsStream::Upgrading => Poll::Ready(Err(io::ErrorKind::ConnectionAborted.into())),
190 }
191 }
192
193 #[cfg(any(feature = "_rt-actix", feature = "_rt-tokio"))]
194 fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
195 match &mut *self {
196 MaybeTlsStream::Raw(s) => Pin::new(s).poll_shutdown(cx),
197 MaybeTlsStream::Tls(s) => Pin::new(s).poll_shutdown(cx),
198
199 MaybeTlsStream::Upgrading => Poll::Ready(Err(io::ErrorKind::ConnectionAborted.into())),
200 }
201 }
202
203 #[cfg(feature = "_rt-async-std")]
204 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
205 match &mut *self {
206 MaybeTlsStream::Raw(s) => Pin::new(s).poll_close(cx),
207 MaybeTlsStream::Tls(s) => Pin::new(s).poll_close(cx),
208
209 MaybeTlsStream::Upgrading => Poll::Ready(Err(io::ErrorKind::ConnectionAborted.into())),
210 }
211 }
212}
213
214impl<S> Deref for MaybeTlsStream<S>
215where
216 S: Unpin + AsyncWrite + AsyncRead,
217{
218 type Target = S;
219
220 fn deref(&self) -> &Self::Target {
221 match self {
222 MaybeTlsStream::Raw(s) => s,
223
224 #[cfg(feature = "_tls-rustls")]
225 MaybeTlsStream::Tls(s) => s.get_ref().0,
226
227 #[cfg(all(feature = "_rt-async-std", feature = "_tls-native-tls"))]
228 MaybeTlsStream::Tls(s) => s.get_ref(),
229
230 #[cfg(all(not(feature = "_rt-async-std"), feature = "_tls-native-tls"))]
231 MaybeTlsStream::Tls(s) => s.get_ref().get_ref().get_ref(),
232
233 MaybeTlsStream::Upgrading => {
234 panic!("{}", io::Error::from(io::ErrorKind::ConnectionAborted))
235 }
236 }
237 }
238}
239
240impl<S> DerefMut for MaybeTlsStream<S>
241where
242 S: Unpin + AsyncWrite + AsyncRead,
243{
244 fn deref_mut(&mut self) -> &mut Self::Target {
245 match self {
246 MaybeTlsStream::Raw(s) => s,
247
248 #[cfg(feature = "_tls-rustls")]
249 MaybeTlsStream::Tls(s) => s.get_mut().0,
250
251 #[cfg(all(feature = "_rt-async-std", feature = "_tls-native-tls"))]
252 MaybeTlsStream::Tls(s) => s.get_mut(),
253
254 #[cfg(all(not(feature = "_rt-async-std"), feature = "_tls-native-tls"))]
255 MaybeTlsStream::Tls(s) => s.get_mut().get_mut().get_mut(),
256
257 MaybeTlsStream::Upgrading => {
258 panic!("{}", io::Error::from(io::ErrorKind::ConnectionAborted))
259 }
260 }
261 }
262}