tokio_postgres_rustls_improved/
lib.rs1#![doc = include_str!("../README.md")]
2#![forbid(rust_2018_idioms)]
3#![forbid(missing_docs, unsafe_code)]
4#![warn(clippy::all, clippy::pedantic)]
5
6use std::{convert::TryFrom, sync::Arc};
7
8use rustls::{pki_types::ServerName, ClientConfig};
9use tokio::io::{AsyncRead, AsyncWrite};
10use tokio_postgres::tls::MakeTlsConnect;
11
12mod private {
13 use std::{
14 future::Future,
15 io,
16 pin::Pin,
17 task::{Context, Poll},
18 };
19
20 use rustls::pki_types::ServerName;
21 use sha2::digest::const_oid::db::{
22 rfc5912::{
23 ECDSA_WITH_SHA_256, ECDSA_WITH_SHA_384, ID_SHA_1, ID_SHA_256, ID_SHA_384, ID_SHA_512,
24 SHA_1_WITH_RSA_ENCRYPTION, SHA_256_WITH_RSA_ENCRYPTION, SHA_384_WITH_RSA_ENCRYPTION,
25 SHA_512_WITH_RSA_ENCRYPTION,
26 },
27 rfc8410::ID_ED_25519,
28 };
29 use sha2::{Digest, Sha256, Sha384, Sha512};
30 use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
31 use tokio_postgres::tls::{ChannelBinding, TlsConnect};
32 use tokio_rustls::{client::TlsStream, TlsConnector};
33 use x509_cert::{der::Decode, Certificate};
34
35 pub struct TlsConnectFuture<S> {
36 inner: tokio_rustls::Connect<S>,
37 }
38
39 impl<S> Future for TlsConnectFuture<S>
40 where
41 S: AsyncRead + AsyncWrite + Unpin,
42 {
43 type Output = io::Result<RustlsStream<S>>;
44
45 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
46 Pin::new(&mut self.inner).poll(cx).map_ok(RustlsStream)
47 }
48 }
49
50 pub struct RustlsConnect(pub RustlsConnectData);
51
52 pub struct RustlsConnectData {
53 pub hostname: ServerName<'static>,
54 pub connector: TlsConnector,
55 }
56
57 impl<S> TlsConnect<S> for RustlsConnect
58 where
59 S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
60 {
61 type Stream = RustlsStream<S>;
62 type Error = io::Error;
63 type Future = TlsConnectFuture<S>;
64
65 fn connect(self, stream: S) -> Self::Future {
66 TlsConnectFuture {
67 inner: self.0.connector.connect(self.0.hostname, stream),
68 }
69 }
70 }
71
72 pub struct RustlsStream<S>(TlsStream<S>);
73
74 impl<S> tokio_postgres::tls::TlsStream for RustlsStream<S>
75 where
76 S: AsyncRead + AsyncWrite + Unpin,
77 {
78 fn channel_binding(&self) -> ChannelBinding {
79 let (_, session) = self.0.get_ref();
80 match session.peer_certificates() {
81 Some(certs) if !certs.is_empty() => Certificate::from_der(&certs[0]).map_or_else(
82 |_| ChannelBinding::none(),
83 |cert| {
84 match cert.signature_algorithm.oid {
85 ID_SHA_1
87 | ID_SHA_256
88 | SHA_1_WITH_RSA_ENCRYPTION
89 | SHA_256_WITH_RSA_ENCRYPTION
90 | ECDSA_WITH_SHA_256 => ChannelBinding::tls_server_end_point(
91 Sha256::digest(certs[0].as_ref()).to_vec(),
92 ),
93 ID_SHA_384 | SHA_384_WITH_RSA_ENCRYPTION | ECDSA_WITH_SHA_384 => {
94 ChannelBinding::tls_server_end_point(
95 Sha384::digest(certs[0].as_ref()).to_vec(),
96 )
97 }
98 ID_SHA_512 | SHA_512_WITH_RSA_ENCRYPTION | ID_ED_25519 => {
99 ChannelBinding::tls_server_end_point(
100 Sha512::digest(certs[0].as_ref()).to_vec(),
101 )
102 }
103 _ => ChannelBinding::none(),
104 }
105 },
106 ),
107 _ => ChannelBinding::none(),
108 }
109 }
110 }
111
112 impl<S> AsyncRead for RustlsStream<S>
113 where
114 S: AsyncRead + AsyncWrite + Unpin,
115 {
116 fn poll_read(
117 mut self: Pin<&mut Self>,
118 cx: &mut Context<'_>,
119 buf: &mut ReadBuf<'_>,
120 ) -> Poll<tokio::io::Result<()>> {
121 Pin::new(&mut self.0).poll_read(cx, buf)
122 }
123 }
124
125 impl<S> AsyncWrite for RustlsStream<S>
126 where
127 S: AsyncRead + AsyncWrite + Unpin,
128 {
129 fn poll_write(
130 mut self: Pin<&mut Self>,
131 cx: &mut Context<'_>,
132 buf: &[u8],
133 ) -> Poll<tokio::io::Result<usize>> {
134 Pin::new(&mut self.0).poll_write(cx, buf)
135 }
136
137 fn poll_flush(
138 mut self: Pin<&mut Self>,
139 cx: &mut Context<'_>,
140 ) -> Poll<tokio::io::Result<()>> {
141 Pin::new(&mut self.0).poll_flush(cx)
142 }
143
144 fn poll_shutdown(
145 mut self: Pin<&mut Self>,
146 cx: &mut Context<'_>,
147 ) -> Poll<tokio::io::Result<()>> {
148 Pin::new(&mut self.0).poll_shutdown(cx)
149 }
150 }
151}
152
153#[derive(Clone)]
157pub struct MakeRustlsConnect {
158 config: Arc<ClientConfig>,
159}
160
161impl MakeRustlsConnect {
162 #[must_use]
164 pub fn new(config: ClientConfig) -> Self {
165 Self {
166 config: Arc::new(config),
167 }
168 }
169}
170
171impl<S> MakeTlsConnect<S> for MakeRustlsConnect
172where
173 S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
174{
175 type Stream = private::RustlsStream<S>;
176 type TlsConnect = private::RustlsConnect;
177 type Error = rustls::pki_types::InvalidDnsNameError;
178
179 fn make_tls_connect(&mut self, hostname: &str) -> Result<Self::TlsConnect, Self::Error> {
180 ServerName::try_from(hostname).map(|dns_name| {
181 private::RustlsConnect(private::RustlsConnectData {
182 hostname: dns_name.to_owned(),
183 connector: Arc::clone(&self.config).into(),
184 })
185 })
186 }
187}