tokio_postgres_rustls_improved/
lib.rs1#![doc = include_str!("../README.md")]
2#![forbid(rust_2018_idioms)]
3#![forbid(missing_docs, unsafe_code, unused)]
4#![deny(
5 clippy::all,
6 clippy::pedantic,
7 clippy::unwrap_used,
8 clippy::expect_used,
9 clippy::nursery,
10 clippy::dbg_macro,
11 clippy::todo
12)]
13
14use std::{convert::TryFrom, sync::Arc};
15
16use rustls::{pki_types::ServerName, ClientConfig};
17use tokio::io::{AsyncRead, AsyncWrite};
18use tokio_postgres::tls::MakeTlsConnect;
19
20mod private {
21 use std::{
22 future::Future,
23 io,
24 pin::Pin,
25 task::{Context, Poll},
26 };
27
28 use rustls::pki_types::ServerName;
29 use sha2::digest::const_oid::db::{
30 rfc5912::{
31 ECDSA_WITH_SHA_256, ECDSA_WITH_SHA_384, ID_SHA_1, ID_SHA_256, ID_SHA_384, ID_SHA_512,
32 SHA_1_WITH_RSA_ENCRYPTION, SHA_256_WITH_RSA_ENCRYPTION, SHA_384_WITH_RSA_ENCRYPTION,
33 SHA_512_WITH_RSA_ENCRYPTION,
34 },
35 rfc8410::ID_ED_25519,
36 };
37 use sha2::{Digest, Sha256, Sha384, Sha512};
38 use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
39 use tokio_postgres::tls::{ChannelBinding, TlsConnect};
40 use tokio_rustls::{client::TlsStream, TlsConnector};
41 use x509_cert::{der::Decode, Certificate};
42
43 pub struct TlsConnectFuture<S> {
44 inner: tokio_rustls::Connect<S>,
45 }
46
47 impl<S> Future for TlsConnectFuture<S>
48 where
49 S: AsyncRead + AsyncWrite + Unpin,
50 {
51 type Output = io::Result<RustlsStream<S>>;
52
53 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
54 Pin::new(&mut self.inner).poll(cx).map_ok(RustlsStream)
55 }
56 }
57
58 pub struct RustlsConnect(pub RustlsConnectData);
59
60 pub struct RustlsConnectData {
61 pub hostname: ServerName<'static>,
62 pub connector: TlsConnector,
63 }
64
65 impl<S> TlsConnect<S> for RustlsConnect
66 where
67 S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
68 {
69 type Stream = RustlsStream<S>;
70 type Error = io::Error;
71 type Future = TlsConnectFuture<S>;
72
73 fn connect(self, stream: S) -> Self::Future {
74 TlsConnectFuture {
75 inner: self.0.connector.connect(self.0.hostname, stream),
76 }
77 }
78 }
79
80 pub struct RustlsStream<S>(TlsStream<S>);
81
82 impl<S> tokio_postgres::tls::TlsStream for RustlsStream<S>
83 where
84 S: AsyncRead + AsyncWrite + Unpin,
85 {
86 fn channel_binding(&self) -> ChannelBinding {
87 let (_, session) = self.0.get_ref();
88 match session.peer_certificates() {
89 Some(certs) if !certs.is_empty() => Certificate::from_der(&certs[0]).map_or_else(
90 |_| ChannelBinding::none(),
91 |cert| {
92 match cert.signature_algorithm.oid {
93 ID_SHA_1
95 | ID_SHA_256
96 | SHA_1_WITH_RSA_ENCRYPTION
97 | SHA_256_WITH_RSA_ENCRYPTION
98 | ECDSA_WITH_SHA_256 => ChannelBinding::tls_server_end_point(
99 Sha256::digest(certs[0].as_ref()).to_vec(),
100 ),
101 ID_SHA_384 | SHA_384_WITH_RSA_ENCRYPTION | ECDSA_WITH_SHA_384 => {
102 ChannelBinding::tls_server_end_point(
103 Sha384::digest(certs[0].as_ref()).to_vec(),
104 )
105 }
106 ID_SHA_512 | SHA_512_WITH_RSA_ENCRYPTION | ID_ED_25519 => {
107 ChannelBinding::tls_server_end_point(
108 Sha512::digest(certs[0].as_ref()).to_vec(),
109 )
110 }
111 _ => ChannelBinding::none(),
112 }
113 },
114 ),
115 _ => ChannelBinding::none(),
116 }
117 }
118 }
119
120 impl<S> AsyncRead for RustlsStream<S>
121 where
122 S: AsyncRead + AsyncWrite + Unpin,
123 {
124 fn poll_read(
125 mut self: Pin<&mut Self>,
126 cx: &mut Context<'_>,
127 buf: &mut ReadBuf<'_>,
128 ) -> Poll<tokio::io::Result<()>> {
129 Pin::new(&mut self.0).poll_read(cx, buf)
130 }
131 }
132
133 impl<S> AsyncWrite for RustlsStream<S>
134 where
135 S: AsyncRead + AsyncWrite + Unpin,
136 {
137 fn poll_write(
138 mut self: Pin<&mut Self>,
139 cx: &mut Context<'_>,
140 buf: &[u8],
141 ) -> Poll<tokio::io::Result<usize>> {
142 Pin::new(&mut self.0).poll_write(cx, buf)
143 }
144
145 fn poll_flush(
146 mut self: Pin<&mut Self>,
147 cx: &mut Context<'_>,
148 ) -> Poll<tokio::io::Result<()>> {
149 Pin::new(&mut self.0).poll_flush(cx)
150 }
151
152 fn poll_shutdown(
153 mut self: Pin<&mut Self>,
154 cx: &mut Context<'_>,
155 ) -> Poll<tokio::io::Result<()>> {
156 Pin::new(&mut self.0).poll_shutdown(cx)
157 }
158 }
159}
160
161#[derive(Clone)]
167pub struct MakeRustlsConnect {
168 config: Arc<ClientConfig>,
169}
170
171impl MakeRustlsConnect {
172 #[must_use]
174 pub fn new(config: ClientConfig) -> Self {
175 Self {
176 config: Arc::new(config),
177 }
178 }
179}
180
181impl<S> MakeTlsConnect<S> for MakeRustlsConnect
182where
183 S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
184{
185 type Stream = private::RustlsStream<S>;
186 type TlsConnect = private::RustlsConnect;
187 type Error = rustls::pki_types::InvalidDnsNameError;
188
189 fn make_tls_connect(&mut self, hostname: &str) -> Result<Self::TlsConnect, Self::Error> {
204 ServerName::try_from(hostname).map(|dns_name| {
205 private::RustlsConnect(private::RustlsConnectData {
206 hostname: dns_name.to_owned(),
207 connector: Arc::clone(&self.config).into(),
208 })
209 })
210 }
211}