sqlx_core_guts/net/tls/
mod.rs

1#![allow(dead_code)]
2
3use std::io;
4use std::ops::{Deref, DerefMut};
5use std::path::PathBuf;
6use std::pin::Pin;
7use std::task::{Context, Poll};
8
9use sqlx_rt::{AsyncRead, AsyncWrite, TlsStream};
10
11use crate::error::Error;
12use std::mem::replace;
13
14/// X.509 Certificate input, either a file path or a PEM encoded inline certificate(s).
15#[derive(Clone, Debug)]
16pub enum CertificateInput {
17    /// PEM encoded certificate(s)
18    Inline(Vec<u8>),
19    /// Path to a file containing PEM encoded certificate(s)
20    File(PathBuf),
21}
22
23impl From<String> for CertificateInput {
24    fn from(value: String) -> Self {
25        let trimmed = value.trim();
26        // Some heuristics according to https://tools.ietf.org/html/rfc7468
27        if trimmed.starts_with("-----BEGIN CERTIFICATE-----")
28            && trimmed.contains("-----END CERTIFICATE-----")
29        {
30            CertificateInput::Inline(value.as_bytes().to_vec())
31        } else {
32            CertificateInput::File(PathBuf::from(value))
33        }
34    }
35}
36
37impl CertificateInput {
38    async fn data(&self) -> Result<Vec<u8>, std::io::Error> {
39        use sqlx_rt::fs;
40        match self {
41            CertificateInput::Inline(v) => Ok(v.clone()),
42            CertificateInput::File(path) => fs::read(path).await,
43        }
44    }
45}
46
47impl std::fmt::Display for CertificateInput {
48    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
49        match self {
50            CertificateInput::Inline(v) => write!(f, "{}", String::from_utf8_lossy(v.as_slice())),
51            CertificateInput::File(path) => write!(f, "file: {}", path.display()),
52        }
53    }
54}
55
56#[cfg(feature = "_tls-rustls")]
57mod rustls;
58
59pub enum MaybeTlsStream<S>
60where
61    S: AsyncRead + AsyncWrite + Unpin,
62{
63    Raw(S),
64    Tls(TlsStream<S>),
65    Upgrading,
66}
67
68impl<S> MaybeTlsStream<S>
69where
70    S: AsyncRead + AsyncWrite + Unpin,
71{
72    #[inline]
73    pub fn is_tls(&self) -> bool {
74        matches!(self, Self::Tls(_))
75    }
76
77    pub async fn upgrade(
78        &mut self,
79        host: &str,
80        accept_invalid_certs: bool,
81        accept_invalid_hostnames: bool,
82        root_cert_path: Option<&CertificateInput>,
83    ) -> Result<(), Error> {
84        let connector = configure_tls_connector(
85            accept_invalid_certs,
86            accept_invalid_hostnames,
87            root_cert_path,
88        )
89        .await?;
90
91        let stream = match replace(self, MaybeTlsStream::Upgrading) {
92            MaybeTlsStream::Raw(stream) => stream,
93
94            MaybeTlsStream::Tls(_) => {
95                // ignore upgrade, we are already a TLS connection
96                return Ok(());
97            }
98
99            MaybeTlsStream::Upgrading => {
100                // we previously failed to upgrade and now hold no connection
101                // this should only happen from an internal misuse of this method
102                return Err(Error::Io(io::ErrorKind::ConnectionAborted.into()));
103            }
104        };
105
106        #[cfg(feature = "_tls-rustls")]
107        let host = ::rustls::ServerName::try_from(host).map_err(|err| Error::Tls(err.into()))?;
108
109        *self = MaybeTlsStream::Tls(connector.connect(host, stream).await?);
110
111        Ok(())
112    }
113}
114
115#[cfg(feature = "_tls-native-tls")]
116async fn configure_tls_connector(
117    accept_invalid_certs: bool,
118    accept_invalid_hostnames: bool,
119    root_cert_path: Option<&CertificateInput>,
120) -> Result<sqlx_rt::TlsConnector, Error> {
121    use sqlx_rt::native_tls::{Certificate, TlsConnector};
122
123    let mut builder = TlsConnector::builder();
124    builder
125        .danger_accept_invalid_certs(accept_invalid_certs)
126        .danger_accept_invalid_hostnames(accept_invalid_hostnames);
127
128    if !accept_invalid_certs {
129        if let Some(ca) = root_cert_path {
130            let data = ca.data().await?;
131            let cert = Certificate::from_pem(&data)?;
132
133            builder.add_root_certificate(cert);
134        }
135    }
136
137    #[cfg(not(feature = "_rt-async-std"))]
138    let connector = builder.build()?.into();
139
140    #[cfg(feature = "_rt-async-std")]
141    let connector = builder.into();
142
143    Ok(connector)
144}
145
146#[cfg(feature = "_tls-rustls")]
147use self::rustls::configure_tls_connector;
148
149impl<S> AsyncRead for MaybeTlsStream<S>
150where
151    S: Unpin + AsyncWrite + AsyncRead,
152{
153    fn poll_read(
154        mut self: Pin<&mut Self>,
155        cx: &mut Context<'_>,
156        buf: &mut super::PollReadBuf<'_>,
157    ) -> Poll<io::Result<super::PollReadOut>> {
158        match &mut *self {
159            MaybeTlsStream::Raw(s) => Pin::new(s).poll_read(cx, buf),
160            MaybeTlsStream::Tls(s) => Pin::new(s).poll_read(cx, buf),
161
162            MaybeTlsStream::Upgrading => Poll::Ready(Err(io::ErrorKind::ConnectionAborted.into())),
163        }
164    }
165}
166
167impl<S> AsyncWrite for MaybeTlsStream<S>
168where
169    S: Unpin + AsyncWrite + AsyncRead,
170{
171    fn poll_write(
172        mut self: Pin<&mut Self>,
173        cx: &mut Context<'_>,
174        buf: &[u8],
175    ) -> Poll<io::Result<usize>> {
176        match &mut *self {
177            MaybeTlsStream::Raw(s) => Pin::new(s).poll_write(cx, buf),
178            MaybeTlsStream::Tls(s) => Pin::new(s).poll_write(cx, buf),
179
180            MaybeTlsStream::Upgrading => Poll::Ready(Err(io::ErrorKind::ConnectionAborted.into())),
181        }
182    }
183
184    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
185        match &mut *self {
186            MaybeTlsStream::Raw(s) => Pin::new(s).poll_flush(cx),
187            MaybeTlsStream::Tls(s) => Pin::new(s).poll_flush(cx),
188
189            MaybeTlsStream::Upgrading => Poll::Ready(Err(io::ErrorKind::ConnectionAborted.into())),
190        }
191    }
192
193    #[cfg(any(feature = "_rt-actix", feature = "_rt-tokio"))]
194    fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
195        match &mut *self {
196            MaybeTlsStream::Raw(s) => Pin::new(s).poll_shutdown(cx),
197            MaybeTlsStream::Tls(s) => Pin::new(s).poll_shutdown(cx),
198
199            MaybeTlsStream::Upgrading => Poll::Ready(Err(io::ErrorKind::ConnectionAborted.into())),
200        }
201    }
202
203    #[cfg(feature = "_rt-async-std")]
204    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
205        match &mut *self {
206            MaybeTlsStream::Raw(s) => Pin::new(s).poll_close(cx),
207            MaybeTlsStream::Tls(s) => Pin::new(s).poll_close(cx),
208
209            MaybeTlsStream::Upgrading => Poll::Ready(Err(io::ErrorKind::ConnectionAborted.into())),
210        }
211    }
212}
213
214impl<S> Deref for MaybeTlsStream<S>
215where
216    S: Unpin + AsyncWrite + AsyncRead,
217{
218    type Target = S;
219
220    fn deref(&self) -> &Self::Target {
221        match self {
222            MaybeTlsStream::Raw(s) => s,
223
224            #[cfg(feature = "_tls-rustls")]
225            MaybeTlsStream::Tls(s) => s.get_ref().0,
226
227            #[cfg(all(feature = "_rt-async-std", feature = "_tls-native-tls"))]
228            MaybeTlsStream::Tls(s) => s.get_ref(),
229
230            #[cfg(all(not(feature = "_rt-async-std"), feature = "_tls-native-tls"))]
231            MaybeTlsStream::Tls(s) => s.get_ref().get_ref().get_ref(),
232
233            MaybeTlsStream::Upgrading => {
234                panic!("{}", io::Error::from(io::ErrorKind::ConnectionAborted))
235            }
236        }
237    }
238}
239
240impl<S> DerefMut for MaybeTlsStream<S>
241where
242    S: Unpin + AsyncWrite + AsyncRead,
243{
244    fn deref_mut(&mut self) -> &mut Self::Target {
245        match self {
246            MaybeTlsStream::Raw(s) => s,
247
248            #[cfg(feature = "_tls-rustls")]
249            MaybeTlsStream::Tls(s) => s.get_mut().0,
250
251            #[cfg(all(feature = "_rt-async-std", feature = "_tls-native-tls"))]
252            MaybeTlsStream::Tls(s) => s.get_mut(),
253
254            #[cfg(all(not(feature = "_rt-async-std"), feature = "_tls-native-tls"))]
255            MaybeTlsStream::Tls(s) => s.get_mut().get_mut().get_mut(),
256
257            MaybeTlsStream::Upgrading => {
258                panic!("{}", io::Error::from(io::ErrorKind::ConnectionAborted))
259            }
260        }
261    }
262}