Skip to main content

tokio_postgres_rustls/
lib.rs

1#![doc = include_str!("../README.md")]
2#![forbid(rust_2018_idioms)]
3#![forbid(missing_docs, unsafe_code)]
4#![warn(clippy::all, clippy::pedantic)]
5
6use std::sync::Arc;
7
8use rustls::ClientConfig;
9use tokio::io::{AsyncRead, AsyncWrite};
10use tokio_postgres::tls::MakeTlsConnect;
11
12mod private {
13    use std::{
14        convert::TryFrom,
15        future::Future,
16        io,
17        pin::Pin,
18        task::{Context, Poll},
19    };
20
21    use rustls::pki_types::ServerName;
22    use sha2::{Digest, Sha256, Sha384, Sha512};
23    use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
24    use tokio_postgres::tls::{ChannelBinding, TlsConnect};
25    use tokio_rustls::{TlsConnector, client::TlsStream};
26    use x509_cert::der::oid::db::rfc5912::{
27        ECDSA_WITH_SHA_256, ECDSA_WITH_SHA_384, ID_SHA_1, ID_SHA_256, ID_SHA_384, ID_SHA_512,
28        SHA_1_WITH_RSA_ENCRYPTION, SHA_256_WITH_RSA_ENCRYPTION, SHA_384_WITH_RSA_ENCRYPTION,
29        SHA_512_WITH_RSA_ENCRYPTION,
30    };
31    use x509_cert::{Certificate, der::Decode, der::oid::ObjectIdentifier};
32
33    pub enum TlsConnectFuture<S> {
34        Connect(Box<tokio_rustls::Connect<S>>),
35        Error(Option<io::Error>),
36    }
37
38    impl<S> Future for TlsConnectFuture<S>
39    where
40        S: AsyncRead + AsyncWrite + Unpin,
41    {
42        type Output = io::Result<RustlsStream<S>>;
43
44        fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
45            match &mut *self {
46                Self::Connect(inner) => Pin::new(inner.as_mut()).poll(cx).map_ok(RustlsStream),
47                Self::Error(error) => Poll::Ready(Err(error
48                    .take()
49                    .expect("TlsConnectFuture polled after completion"))),
50            }
51        }
52    }
53
54    pub struct RustlsConnect(pub RustlsConnectData);
55
56    pub struct RustlsConnectData {
57        pub hostname: String,
58        pub connector: TlsConnector,
59    }
60
61    impl<S> TlsConnect<S> for RustlsConnect
62    where
63        S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
64    {
65        type Stream = RustlsStream<S>;
66        type Error = io::Error;
67        type Future = TlsConnectFuture<S>;
68
69        fn connect(self, stream: S) -> Self::Future {
70            match ServerName::try_from(self.0.hostname) {
71                Ok(hostname) => {
72                    TlsConnectFuture::Connect(Box::new(self.0.connector.connect(hostname, stream)))
73                }
74                Err(error) => TlsConnectFuture::Error(Some(io::Error::new(
75                    io::ErrorKind::InvalidInput,
76                    error,
77                ))),
78            }
79        }
80    }
81
82    pub struct RustlsStream<S>(TlsStream<S>);
83
84    pub(super) enum ChannelBindingDigest {
85        Sha256,
86        Sha384,
87        Sha512,
88    }
89
90    impl ChannelBindingDigest {
91        pub(super) fn digest(&self, data: &[u8]) -> Vec<u8> {
92            match self {
93                Self::Sha256 => Sha256::digest(data).to_vec(),
94                Self::Sha384 => Sha384::digest(data).to_vec(),
95                Self::Sha512 => Sha512::digest(data).to_vec(),
96            }
97        }
98
99        #[cfg(test)]
100        pub(super) fn output_len(&self) -> usize {
101            match self {
102                Self::Sha256 => 32,
103                Self::Sha384 => 48,
104                Self::Sha512 => 64,
105            }
106        }
107    }
108
109    pub(super) fn channel_binding_digest(
110        signature_algorithm: ObjectIdentifier,
111    ) -> Option<ChannelBindingDigest> {
112        match signature_algorithm {
113            // Note: SHA1 is upgraded to SHA256 as per https://datatracker.ietf.org/doc/html/rfc5929#section-4.1
114            ID_SHA_1
115            | ID_SHA_256
116            | SHA_1_WITH_RSA_ENCRYPTION
117            | SHA_256_WITH_RSA_ENCRYPTION
118            | ECDSA_WITH_SHA_256 => Some(ChannelBindingDigest::Sha256),
119            ID_SHA_384 | SHA_384_WITH_RSA_ENCRYPTION | ECDSA_WITH_SHA_384 => {
120                Some(ChannelBindingDigest::Sha384)
121            }
122            ID_SHA_512 | SHA_512_WITH_RSA_ENCRYPTION => Some(ChannelBindingDigest::Sha512),
123            // Unsupported algorithms, including pure signature algorithms like Ed25519, have no
124            // digest to use for tls-server-end-point channel binding.
125            _ => None,
126        }
127    }
128
129    impl<S> tokio_postgres::tls::TlsStream for RustlsStream<S>
130    where
131        S: AsyncRead + AsyncWrite + Unpin,
132    {
133        fn channel_binding(&self) -> ChannelBinding {
134            let (_, session) = self.0.get_ref();
135            match session.peer_certificates() {
136                Some(certs) if !certs.is_empty() => Certificate::from_der(&certs[0])
137                    .ok()
138                    .and_then(|cert| channel_binding_digest(cert.signature_algorithm.oid))
139                    .map_or_else(ChannelBinding::none, |algorithm| {
140                        ChannelBinding::tls_server_end_point(algorithm.digest(certs[0].as_ref()))
141                    }),
142                _ => ChannelBinding::none(),
143            }
144        }
145    }
146
147    impl<S> AsyncRead for RustlsStream<S>
148    where
149        S: AsyncRead + AsyncWrite + Unpin,
150    {
151        fn poll_read(
152            mut self: Pin<&mut Self>,
153            cx: &mut Context<'_>,
154            buf: &mut ReadBuf<'_>,
155        ) -> Poll<tokio::io::Result<()>> {
156            Pin::new(&mut self.0).poll_read(cx, buf)
157        }
158    }
159
160    impl<S> AsyncWrite for RustlsStream<S>
161    where
162        S: AsyncRead + AsyncWrite + Unpin,
163    {
164        fn poll_write(
165            mut self: Pin<&mut Self>,
166            cx: &mut Context<'_>,
167            buf: &[u8],
168        ) -> Poll<tokio::io::Result<usize>> {
169            Pin::new(&mut self.0).poll_write(cx, buf)
170        }
171
172        fn poll_flush(
173            mut self: Pin<&mut Self>,
174            cx: &mut Context<'_>,
175        ) -> Poll<tokio::io::Result<()>> {
176            Pin::new(&mut self.0).poll_flush(cx)
177        }
178
179        fn poll_shutdown(
180            mut self: Pin<&mut Self>,
181            cx: &mut Context<'_>,
182        ) -> Poll<tokio::io::Result<()>> {
183            Pin::new(&mut self.0).poll_shutdown(cx)
184        }
185    }
186}
187
188/// A `MakeTlsConnect` implementation using `rustls`.
189///
190/// That way you can connect to PostgreSQL using `rustls` as the TLS stack.
191#[derive(Clone)]
192pub struct MakeRustlsConnect {
193    config: Arc<ClientConfig>,
194}
195
196impl MakeRustlsConnect {
197    /// Creates a new `MakeRustlsConnect` from the provided `ClientConfig`.
198    #[must_use]
199    pub fn new(config: ClientConfig) -> Self {
200        Self {
201            config: Arc::new(config),
202        }
203    }
204
205    #[cfg(any(feature = "native-certs", feature = "webpki-roots"))]
206    fn from_root_certificates(roots: rustls::RootCertStore) -> Self {
207        Self::new(
208            ClientConfig::builder()
209                .with_root_certificates(roots)
210                .with_no_client_auth(),
211        )
212    }
213
214    /// Creates a new `MakeRustlsConnect` using the Mozilla roots from `webpki-roots`.
215    ///
216    /// This uses rustls' process-level default crypto provider, so the application
217    /// must install or otherwise configure a process-default provider before use.
218    #[cfg(feature = "webpki-roots")]
219    #[must_use]
220    pub fn with_webpki_roots() -> Self {
221        Self::from_root_certificates(rustls::RootCertStore {
222            roots: webpki_roots::TLS_SERVER_ROOTS.to_vec(),
223        })
224    }
225
226    /// Creates a new `MakeRustlsConnect` using certificates from the platform's native store.
227    ///
228    /// This uses rustls' process-level default crypto provider, so the application
229    /// must install or otherwise configure a process-default provider before use.
230    ///
231    /// Returns the connector and any errors reported while loading the native
232    /// certificate store. If no certificates could be loaded, returns the
233    /// reported errors instead.
234    #[cfg(feature = "native-certs")]
235    pub fn with_native_certs()
236    -> Result<(Self, Vec<rustls_native_certs::Error>), Vec<rustls_native_certs::Error>> {
237        let result = rustls_native_certs::load_native_certs();
238        if !result.certs.is_empty() {
239            let mut roots = rustls::RootCertStore::empty();
240            roots.add_parsable_certificates(result.certs);
241            Ok((Self::from_root_certificates(roots), result.errors))
242        } else {
243            Err(result.errors)
244        }
245    }
246}
247
248impl<S> MakeTlsConnect<S> for MakeRustlsConnect
249where
250    S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
251{
252    type Stream = private::RustlsStream<S>;
253    type TlsConnect = private::RustlsConnect;
254    type Error = std::convert::Infallible;
255
256    fn make_tls_connect(&mut self, hostname: &str) -> Result<Self::TlsConnect, Self::Error> {
257        Ok(private::RustlsConnect(private::RustlsConnectData {
258            hostname: hostname.to_owned(),
259            connector: Arc::clone(&self.config).into(),
260        }))
261    }
262}
263
264#[cfg(test)]
265mod tests {
266    use super::*;
267    #[cfg(any(feature = "aws-lc-rs", feature = "ring"))]
268    use rustls::pki_types::{CertificateDer, UnixTime};
269    #[cfg(any(feature = "aws-lc-rs", feature = "ring"))]
270    use rustls::{
271        Error, SignatureScheme,
272        client::danger::ServerCertVerifier,
273        client::danger::{HandshakeSignatureValid, ServerCertVerified},
274        pki_types::ServerName,
275    };
276    #[cfg(any(feature = "aws-lc-rs", feature = "ring"))]
277    use tokio::io::DuplexStream;
278    use x509_cert::der::oid::db::{rfc5912::SHA_512_WITH_RSA_ENCRYPTION, rfc8410::ID_ED_25519};
279
280    #[cfg(any(feature = "aws-lc-rs", feature = "ring"))]
281    fn client_config_with_provider(
282        provider: rustls::crypto::CryptoProvider,
283    ) -> rustls::ClientConfig {
284        rustls::ClientConfig::builder_with_provider(provider.into())
285            .with_safe_default_protocol_versions()
286            .expect("default protocol versions")
287            .with_root_certificates(rustls::RootCertStore::empty())
288            .with_no_client_auth()
289    }
290
291    #[cfg(feature = "aws-lc-rs")]
292    fn aws_lc_rs_client_config() -> rustls::ClientConfig {
293        client_config_with_provider(rustls::crypto::aws_lc_rs::default_provider())
294    }
295
296    #[cfg(feature = "ring")]
297    fn ring_client_config() -> rustls::ClientConfig {
298        client_config_with_provider(rustls::crypto::ring::default_provider())
299    }
300
301    #[cfg(any(feature = "aws-lc-rs", feature = "ring"))]
302    #[derive(Debug)]
303    struct AcceptAllVerifier {}
304
305    #[cfg(any(feature = "aws-lc-rs", feature = "ring"))]
306    impl ServerCertVerifier for AcceptAllVerifier {
307        fn verify_server_cert(
308            &self,
309            _end_entity: &CertificateDer<'_>,
310            _intermediates: &[CertificateDer<'_>],
311            _server_name: &ServerName<'_>,
312            _ocsp_response: &[u8],
313            _now: UnixTime,
314        ) -> Result<ServerCertVerified, Error> {
315            Ok(ServerCertVerified::assertion())
316        }
317
318        fn verify_tls12_signature(
319            &self,
320            _message: &[u8],
321            _cert: &CertificateDer<'_>,
322            _dss: &rustls::DigitallySignedStruct,
323        ) -> Result<rustls::client::danger::HandshakeSignatureValid, Error> {
324            Ok(HandshakeSignatureValid::assertion())
325        }
326
327        fn verify_tls13_signature(
328            &self,
329            _message: &[u8],
330            _cert: &CertificateDer<'_>,
331            _dss: &rustls::DigitallySignedStruct,
332        ) -> Result<rustls::client::danger::HandshakeSignatureValid, Error> {
333            Ok(HandshakeSignatureValid::assertion())
334        }
335
336        fn supported_verify_schemes(&self) -> Vec<SignatureScheme> {
337            vec![
338                SignatureScheme::ECDSA_NISTP384_SHA384,
339                SignatureScheme::ECDSA_NISTP256_SHA256,
340                SignatureScheme::RSA_PSS_SHA512,
341                SignatureScheme::RSA_PSS_SHA384,
342                SignatureScheme::RSA_PSS_SHA256,
343                SignatureScheme::ED25519,
344            ]
345        }
346    }
347
348    #[cfg(any(feature = "aws-lc-rs", feature = "ring"))]
349    async fn connect_works(mut config: rustls::ClientConfig) {
350        env_logger::builder().is_test(true).try_init().unwrap();
351
352        config
353            .dangerous()
354            .set_certificate_verifier(Arc::new(AcceptAllVerifier {}));
355        let tls = super::MakeRustlsConnect::new(config);
356        let (client, conn) = tokio_postgres::connect(
357            "sslmode=require host=localhost port=5432 user=postgres",
358            tls,
359        )
360        .await
361        .expect("connect");
362        tokio::spawn(async move { conn.await.map_err(|e| panic!("{e:?}")) });
363        let stmt = client.prepare("SELECT 1").await.expect("prepare");
364        let _ = client.query(&stmt, &[]).await.expect("query");
365    }
366
367    #[cfg(feature = "aws-lc-rs")]
368    #[tokio::test]
369    async fn it_works_with_aws_lc_rs() {
370        connect_works(aws_lc_rs_client_config()).await;
371    }
372
373    #[cfg(feature = "ring")]
374    #[tokio::test]
375    async fn it_works_with_ring() {
376        connect_works(ring_client_config()).await;
377    }
378
379    #[cfg(any(feature = "aws-lc-rs", feature = "ring"))]
380    fn accepts_unix_socket_hostname_before_tls_is_used(config: rustls::ClientConfig) {
381        let mut tls = super::MakeRustlsConnect::new(config);
382
383        let tls_connect =
384            <super::MakeRustlsConnect as MakeTlsConnect<DuplexStream>>::make_tls_connect(
385                &mut tls,
386                "/var/run/postgresql",
387            );
388
389        assert!(tls_connect.is_ok());
390    }
391
392    #[cfg(feature = "aws-lc-rs")]
393    #[test]
394    fn accepts_unix_socket_hostname_before_tls_is_used_with_aws_lc_rs() {
395        accepts_unix_socket_hostname_before_tls_is_used(aws_lc_rs_client_config());
396    }
397
398    #[cfg(feature = "ring")]
399    #[test]
400    fn accepts_unix_socket_hostname_before_tls_is_used_with_ring() {
401        accepts_unix_socket_hostname_before_tls_is_used(ring_client_config());
402    }
403
404    #[test]
405    fn ed25519_has_no_channel_binding_digest() {
406        assert!(private::channel_binding_digest(ID_ED_25519).is_none());
407    }
408
409    #[test]
410    fn sha512_with_rsa_has_channel_binding_digest() {
411        let algorithm = private::channel_binding_digest(SHA_512_WITH_RSA_ENCRYPTION)
412            .expect("SHA-512 signature algorithm should map to a digest");
413
414        assert_eq!(algorithm.output_len(), 64);
415    }
416}