1use std::{
2 convert::TryFrom,
3 future::Future,
4 io,
5 pin::Pin,
6 sync::Arc,
7 task::{Context, Poll},
8};
9
10use ring::digest;
11use rustls::{pki_types::ServerName, ClientConfig};
12use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
13use tokio_postgres::tls::{ChannelBinding, MakeTlsConnect, TlsConnect};
14use tokio_rustls::{client::TlsStream, TlsConnector};
15use x509_certificate::{DigestAlgorithm, SignatureAlgorithm, X509Certificate};
16use DigestAlgorithm::{Sha1, Sha256, Sha384, Sha512};
17use SignatureAlgorithm::{
18 EcdsaSha256, EcdsaSha384, Ed25519, NoSignature, RsaSha1, RsaSha256, RsaSha384, RsaSha512,
19};
20
21#[derive(Clone)]
22pub struct MakeRustlsConnect {
23 pub config: Arc<ClientConfig>,
24}
25
26impl MakeRustlsConnect {
27 pub fn new(config: ClientConfig) -> Self {
28 Self {
29 config: Arc::new(config),
30 }
31 }
32}
33
34impl<S> MakeTlsConnect<S> for MakeRustlsConnect
35where
36 S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
37{
38 type Stream = RustlsStream<S>;
39 type TlsConnect = RustlsConnect;
40 type Error = rustls::pki_types::InvalidDnsNameError;
41
42 fn make_tls_connect(&mut self, hostname: &str) -> Result<RustlsConnect, Self::Error> {
43 ServerName::try_from(hostname).map(|dns_name| {
44 RustlsConnect(RustlsConnectData {
45 hostname: dns_name.to_owned(),
46 connector: Arc::clone(&self.config).into(),
47 })
48 })
49 }
50}
51
52pub struct RustlsConnect(RustlsConnectData);
53
54struct RustlsConnectData {
55 hostname: ServerName<'static>,
56 connector: TlsConnector,
57}
58
59impl<S> TlsConnect<S> for RustlsConnect
60where
61 S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
62{
63 type Stream = RustlsStream<S>;
64 type Error = io::Error;
65 type Future = Pin<Box<dyn Future<Output = io::Result<RustlsStream<S>>> + Send>>;
66
67 fn connect(self, stream: S) -> Self::Future {
68 Box::pin(async move {
69 self
70 .0
71 .connector
72 .connect(self.0.hostname, stream)
73 .await
74 .map(|s| RustlsStream(Box::pin(s)))
75 })
76 }
77}
78
79pub struct RustlsStream<S>(Pin<Box<TlsStream<S>>>);
80
81impl<S> tokio_postgres::tls::TlsStream for RustlsStream<S>
82where
83 S: AsyncRead + AsyncWrite + Unpin,
84{
85 fn channel_binding(&self) -> ChannelBinding {
86 let (_, session) = self.0.get_ref();
87 match session.peer_certificates() {
88 Some(certs) if !certs.is_empty() => X509Certificate::from_der(&certs[0])
89 .ok()
90 .and_then(|cert| cert.signature_algorithm())
91 .map(|algorithm| match algorithm {
92 RsaSha1 | RsaSha256 | EcdsaSha256 => &digest::SHA256,
94 RsaSha384 | EcdsaSha384 => &digest::SHA384,
95 RsaSha512 => &digest::SHA512,
96 Ed25519 => &digest::SHA512,
97 NoSignature(algo) => match algo {
98 Sha1 | Sha256 => &digest::SHA256,
99 Sha384 => &digest::SHA384,
100 Sha512 => &digest::SHA512,
101 },
102 })
103 .map(|algorithm| {
104 let hash = digest::digest(algorithm, certs[0].as_ref());
105 ChannelBinding::tls_server_end_point(hash.as_ref().into())
106 })
107 .unwrap_or(ChannelBinding::none()),
108 _ => ChannelBinding::none(),
109 }
110 }
111}
112
113impl<S> AsyncRead for RustlsStream<S>
114where
115 S: AsyncRead + AsyncWrite + Unpin,
116{
117 fn poll_read(
118 mut self: Pin<&mut Self>,
119 cx: &mut Context,
120 buf: &mut ReadBuf<'_>,
121 ) -> Poll<tokio::io::Result<()>> {
122 self.0.as_mut().poll_read(cx, buf)
123 }
124}
125
126impl<S> AsyncWrite for RustlsStream<S>
127where
128 S: AsyncRead + AsyncWrite + Unpin,
129{
130 fn poll_write(
131 mut self: Pin<&mut Self>,
132 cx: &mut Context,
133 buf: &[u8],
134 ) -> Poll<tokio::io::Result<usize>> {
135 self.0.as_mut().poll_write(cx, buf)
136 }
137
138 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<tokio::io::Result<()>> {
139 self.0.as_mut().poll_flush(cx)
140 }
141
142 fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<tokio::io::Result<()>> {
143 self.0.as_mut().poll_shutdown(cx)
144 }
145}
146
147#[cfg(test)]
148mod tests {
149 use rustls::{
150 client::danger::{HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier},
151 pki_types::{CertificateDer, UnixTime},
152 Error, SignatureScheme,
153 };
154
155 use super::*;
156
157 #[derive(Debug)]
158 struct AcceptAllVerifier {}
159 impl ServerCertVerifier for AcceptAllVerifier {
160 fn verify_server_cert(
161 &self,
162 _end_entity: &CertificateDer<'_>,
163 _intermediates: &[CertificateDer<'_>],
164 _server_name: &ServerName<'_>,
165 _ocsp_response: &[u8],
166 _now: UnixTime,
167 ) -> Result<ServerCertVerified, Error> {
168 Ok(ServerCertVerified::assertion())
169 }
170
171 fn verify_tls12_signature(
172 &self,
173 _message: &[u8],
174 _cert: &CertificateDer<'_>,
175 _dss: &rustls::DigitallySignedStruct,
176 ) -> Result<rustls::client::danger::HandshakeSignatureValid, Error> {
177 Ok(HandshakeSignatureValid::assertion())
178 }
179
180 fn verify_tls13_signature(
181 &self,
182 _message: &[u8],
183 _cert: &CertificateDer<'_>,
184 _dss: &rustls::DigitallySignedStruct,
185 ) -> Result<rustls::client::danger::HandshakeSignatureValid, Error> {
186 Ok(HandshakeSignatureValid::assertion())
187 }
188
189 fn supported_verify_schemes(&self) -> Vec<SignatureScheme> {
190 vec![
191 SignatureScheme::ECDSA_NISTP384_SHA384,
192 SignatureScheme::ECDSA_NISTP256_SHA256,
193 SignatureScheme::RSA_PSS_SHA512,
194 SignatureScheme::RSA_PSS_SHA384,
195 SignatureScheme::RSA_PSS_SHA256,
196 SignatureScheme::ED25519,
197 ]
198 }
199 }
200
201 #[tokio::test]
202 async fn it_works() {
203 env_logger::builder().is_test(true).try_init().unwrap();
204
205 let mut config = rustls::ClientConfig::builder()
206 .with_root_certificates(rustls::RootCertStore::empty())
207 .with_no_client_auth();
208 config
209 .dangerous()
210 .set_certificate_verifier(Arc::new(AcceptAllVerifier {}));
211 let tls = super::MakeRustlsConnect::new(config);
212 let (client, conn) = tokio_postgres::connect(
213 "sslmode=require host=localhost port=5432 user=postgres",
214 tls,
215 )
216 .await
217 .expect("connect");
218 tokio::spawn(async move { conn.await.map_err(|e| panic!("{:?}", e)) });
219 let stmt = client.prepare("SELECT 1").await.expect("prepare");
220 let _ = client.query(&stmt, &[]).await.expect("query");
221 }
222}