tokio_postgres_rustls/
lib.rs

1#![doc = include_str!("../README.md")]
2#![forbid(rust_2018_idioms)]
3#![deny(missing_docs, unsafe_code)]
4#![warn(clippy::all, clippy::pedantic)]
5
6use std::{convert::TryFrom, sync::Arc};
7
8use rustls::{pki_types::ServerName, ClientConfig};
9use tokio::io::{AsyncRead, AsyncWrite};
10use tokio_postgres::tls::MakeTlsConnect;
11
12mod private {
13    use std::{
14        future::Future,
15        io,
16        pin::Pin,
17        task::{Context, Poll},
18    };
19
20    use const_oid::db::{
21        rfc5912::{
22            ECDSA_WITH_SHA_256, ECDSA_WITH_SHA_384, ID_SHA_1, ID_SHA_256, ID_SHA_384, ID_SHA_512,
23            SHA_1_WITH_RSA_ENCRYPTION, SHA_256_WITH_RSA_ENCRYPTION, SHA_384_WITH_RSA_ENCRYPTION,
24            SHA_512_WITH_RSA_ENCRYPTION,
25        },
26        rfc8410::ID_ED_25519,
27    };
28    use ring::digest;
29    use rustls::pki_types::ServerName;
30    use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
31    use tokio_postgres::tls::{ChannelBinding, TlsConnect};
32    use tokio_rustls::{client::TlsStream, TlsConnector};
33    use x509_cert::{der::Decode, TbsCertificate};
34
35    pub struct TlsConnectFuture<S> {
36        pub inner: tokio_rustls::Connect<S>,
37    }
38
39    impl<S> Future for TlsConnectFuture<S>
40    where
41        S: AsyncRead + AsyncWrite + Unpin,
42    {
43        type Output = io::Result<RustlsStream<S>>;
44
45        fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
46            // SAFETY: If `self` is pinned, so is `inner`.
47            #[allow(unsafe_code)]
48            let fut = unsafe { self.map_unchecked_mut(|this| &mut this.inner) };
49            fut.poll(cx).map_ok(RustlsStream)
50        }
51    }
52
53    pub struct RustlsConnect(pub RustlsConnectData);
54
55    pub struct RustlsConnectData {
56        pub hostname: ServerName<'static>,
57        pub connector: TlsConnector,
58    }
59
60    impl<S> TlsConnect<S> for RustlsConnect
61    where
62        S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
63    {
64        type Stream = RustlsStream<S>;
65        type Error = io::Error;
66        type Future = TlsConnectFuture<S>;
67
68        fn connect(self, stream: S) -> Self::Future {
69            TlsConnectFuture {
70                inner: self.0.connector.connect(self.0.hostname, stream),
71            }
72        }
73    }
74
75    pub struct RustlsStream<S>(TlsStream<S>);
76
77    impl<S> RustlsStream<S> {
78        pub fn project_stream(self: Pin<&mut Self>) -> Pin<&mut TlsStream<S>> {
79            // SAFETY: When `Self` is pinned, so is the inner `TlsStream`.
80            #[allow(unsafe_code)]
81            unsafe {
82                self.map_unchecked_mut(|this| &mut this.0)
83            }
84        }
85    }
86
87    impl<S> tokio_postgres::tls::TlsStream for RustlsStream<S>
88    where
89        S: AsyncRead + AsyncWrite + Unpin,
90    {
91        fn channel_binding(&self) -> ChannelBinding {
92            let (_, session) = self.0.get_ref();
93            match session.peer_certificates() {
94                Some(certs) if !certs.is_empty() => TbsCertificate::from_der(&certs[0])
95                    .ok()
96                    .and_then(|cert| {
97                        let digest = match cert.signature.oid {
98                            // Note: SHA1 is upgraded to SHA256 as per https://datatracker.ietf.org/doc/html/rfc5929#section-4.1
99                            ID_SHA_1
100                            | ID_SHA_256
101                            | SHA_1_WITH_RSA_ENCRYPTION
102                            | SHA_256_WITH_RSA_ENCRYPTION
103                            | ECDSA_WITH_SHA_256 => &digest::SHA256,
104                            ID_SHA_384 | SHA_384_WITH_RSA_ENCRYPTION | ECDSA_WITH_SHA_384 => {
105                                &digest::SHA384
106                            }
107                            ID_SHA_512 | SHA_512_WITH_RSA_ENCRYPTION | ID_ED_25519 => {
108                                &digest::SHA512
109                            }
110                            _ => return None,
111                        };
112
113                        Some(digest)
114                    })
115                    .map_or_else(ChannelBinding::none, |algorithm| {
116                        let hash = digest::digest(algorithm, certs[0].as_ref());
117                        ChannelBinding::tls_server_end_point(hash.as_ref().into())
118                    }),
119                _ => ChannelBinding::none(),
120            }
121        }
122    }
123
124    impl<S> AsyncRead for RustlsStream<S>
125    where
126        S: AsyncRead + AsyncWrite + Unpin,
127    {
128        fn poll_read(
129            self: Pin<&mut Self>,
130            cx: &mut Context<'_>,
131            buf: &mut ReadBuf<'_>,
132        ) -> Poll<tokio::io::Result<()>> {
133            self.project_stream().poll_read(cx, buf)
134        }
135    }
136
137    impl<S> AsyncWrite for RustlsStream<S>
138    where
139        S: AsyncRead + AsyncWrite + Unpin,
140    {
141        fn poll_write(
142            self: Pin<&mut Self>,
143            cx: &mut Context<'_>,
144            buf: &[u8],
145        ) -> Poll<tokio::io::Result<usize>> {
146            self.project_stream().poll_write(cx, buf)
147        }
148
149        fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<tokio::io::Result<()>> {
150            self.project_stream().poll_flush(cx)
151        }
152
153        fn poll_shutdown(
154            self: Pin<&mut Self>,
155            cx: &mut Context<'_>,
156        ) -> Poll<tokio::io::Result<()>> {
157            self.project_stream().poll_shutdown(cx)
158        }
159    }
160}
161
162/// A `MakeTlsConnect` implementation using `rustls`.
163///
164/// That way you can connect to PostgreSQL using `rustls` as the TLS stack.
165#[derive(Clone)]
166pub struct MakeRustlsConnect {
167    config: Arc<ClientConfig>,
168}
169
170impl MakeRustlsConnect {
171    /// Creates a new `MakeRustlsConnect` from the provided `ClientConfig`.
172    #[must_use]
173    pub fn new(config: ClientConfig) -> Self {
174        Self {
175            config: Arc::new(config),
176        }
177    }
178}
179
180impl<S> MakeTlsConnect<S> for MakeRustlsConnect
181where
182    S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
183{
184    type Stream = private::RustlsStream<S>;
185    type TlsConnect = private::RustlsConnect;
186    type Error = rustls::pki_types::InvalidDnsNameError;
187
188    fn make_tls_connect(&mut self, hostname: &str) -> Result<Self::TlsConnect, Self::Error> {
189        ServerName::try_from(hostname).map(|dns_name| {
190            private::RustlsConnect(private::RustlsConnectData {
191                hostname: dns_name.to_owned(),
192                connector: Arc::clone(&self.config).into(),
193            })
194        })
195    }
196}
197
198#[cfg(test)]
199mod tests {
200    use super::*;
201    use rustls::pki_types::{CertificateDer, UnixTime};
202    use rustls::{
203        client::danger::ServerCertVerifier,
204        client::danger::{HandshakeSignatureValid, ServerCertVerified},
205        Error, SignatureScheme,
206    };
207
208    #[derive(Debug)]
209    struct AcceptAllVerifier {}
210    impl ServerCertVerifier for AcceptAllVerifier {
211        fn verify_server_cert(
212            &self,
213            _end_entity: &CertificateDer<'_>,
214            _intermediates: &[CertificateDer<'_>],
215            _server_name: &ServerName<'_>,
216            _ocsp_response: &[u8],
217            _now: UnixTime,
218        ) -> Result<ServerCertVerified, Error> {
219            Ok(ServerCertVerified::assertion())
220        }
221
222        fn verify_tls12_signature(
223            &self,
224            _message: &[u8],
225            _cert: &CertificateDer<'_>,
226            _dss: &rustls::DigitallySignedStruct,
227        ) -> Result<rustls::client::danger::HandshakeSignatureValid, Error> {
228            Ok(HandshakeSignatureValid::assertion())
229        }
230
231        fn verify_tls13_signature(
232            &self,
233            _message: &[u8],
234            _cert: &CertificateDer<'_>,
235            _dss: &rustls::DigitallySignedStruct,
236        ) -> Result<rustls::client::danger::HandshakeSignatureValid, Error> {
237            Ok(HandshakeSignatureValid::assertion())
238        }
239
240        fn supported_verify_schemes(&self) -> Vec<SignatureScheme> {
241            vec![
242                SignatureScheme::ECDSA_NISTP384_SHA384,
243                SignatureScheme::ECDSA_NISTP256_SHA256,
244                SignatureScheme::RSA_PSS_SHA512,
245                SignatureScheme::RSA_PSS_SHA384,
246                SignatureScheme::RSA_PSS_SHA256,
247                SignatureScheme::ED25519,
248            ]
249        }
250    }
251
252    #[tokio::test]
253    async fn it_works() {
254        env_logger::builder().is_test(true).try_init().unwrap();
255
256        let mut config = rustls::ClientConfig::builder()
257            .with_root_certificates(rustls::RootCertStore::empty())
258            .with_no_client_auth();
259        config
260            .dangerous()
261            .set_certificate_verifier(Arc::new(AcceptAllVerifier {}));
262        let tls = super::MakeRustlsConnect::new(config);
263        let (client, conn) = tokio_postgres::connect(
264            "sslmode=require host=localhost port=5432 user=postgres",
265            tls,
266        )
267        .await
268        .expect("connect");
269        tokio::spawn(async move { conn.await.map_err(|e| panic!("{:?}", e)) });
270        let stmt = client.prepare("SELECT 1").await.expect("prepare");
271        let _ = client.query(&stmt, &[]).await.expect("query");
272    }
273}