tokio_postgres_generic_rustls/lib.rs
1//! # tokio-postgres-generic-rustls
2//! _An impelementation of TLS based on rustls for tokio-postgres_
3//!
4//! This crate allows users to select a crypto backend, or bring their own, rather than relying on
5//! primitives provided by `ring` directly. This is done through the use of x509-cert for
6//! certificate parsing rather than X509-certificate, while also adding an abstraction for
7//! computing digests.
8//!
9//! By default, tokio-postgres-generic-rustls does not provide a digest implementation, but one or
10//! more are provided behind crate features.
11//!
12//! | Feature | Impelementation |
13//! | ------------ | ------------------ |
14//! | `aws-lc-rs` | `AwsLcRsDigest` |
15//! | `graviola` | `GraviolaDigest` |
16//! | `ring` | `RingDigest` |
17//! | `rustcrypto` | `RustcryptoDigest` |
18//!
19//! ## Usage
20//! Using this crate is fairly straightforward. First, select your digest impelementation via crate
21//! features (or provide your own), then construct rustls connector for tokio-postgres with your
22//! rustls client configuration.
23//!
24//! The following example demonstrates providing a custom digest backend.
25//!
26//! ```rust
27//! use tokio_postgres_generic_rustls::{DigestImplementation, DigestAlgorithm, MakeRustlsConnect};
28//!
29//! #[derive(Clone)]
30//! struct DemoDigest;
31//!
32//! impl DigestImplementation for DemoDigest {
33//! fn digest(&self, algorithm: DigestAlgorithm, bytes: &[u8]) -> Vec<u8> {
34//! todo!("digest it")
35//! }
36//! }
37//!
38//! let cert_store = rustls::RootCertStore::empty();
39//!
40//! let config = rustls::ClientConfig::builder()
41//! .with_root_certificates(cert_store)
42//! .with_no_client_auth();
43//!
44//! let tls = MakeRustlsConnect::new(config, DemoDigest);
45//!
46//! let connect_future = tokio_postgres::connect("postgres://username:password@localhost:5432/db", tls);
47//!
48//! // connect_future.await;
49//! ```
50//!
51//! ## License
52//! This project is licensed under either of
53//!
54//! - Apache License, Version 2.0, (LICENSE-APACHE or http://www.apache.org/licenses/LICENSE-2.0)
55//! - MIT license (LICENSE-MIT or http://opensource.org/licenses/MIT)
56//!
57//! at your option.
58
59#![deny(unsafe_code)]
60#![deny(missing_docs)]
61
62use std::{future::Future, pin::Pin, sync::Arc};
63
64use rustls::{pki_types::ServerName, ClientConfig};
65use tokio::io::{AsyncRead, AsyncWrite};
66use tokio_postgres::tls::{ChannelBinding, MakeTlsConnect, TlsConnect};
67use tokio_rustls::{client::TlsStream, Connect, TlsConnector};
68use x509_cert::{
69 der::{
70 oid::db::rfc5912::{
71 ECDSA_WITH_SHA_256, ECDSA_WITH_SHA_384, SHA_1_WITH_RSA_ENCRYPTION,
72 SHA_256_WITH_RSA_ENCRYPTION, SHA_384_WITH_RSA_ENCRYPTION, SHA_512_WITH_RSA_ENCRYPTION,
73 },
74 Decode,
75 },
76 spki::ObjectIdentifier,
77 Certificate,
78};
79
80/// Trait used to provide a custom digest backend to tokio_postgres_generic_rustls
81///
82/// This trait is implementated for three types by default, each behind it's own feature flag. The
83/// provided backends are aws-lc-rs, ring, and rustcrypto.
84pub trait DigestImplementation {
85 /// Hash the provided bytes with the provided algorithm
86 ///
87 /// ```rust
88 /// use tokio_postgres_generic_rustls::{DigestImplementation, DigestAlgorithm};
89 ///
90 /// struct CustomAwsLcRsDigest;
91 ///
92 /// impl DigestImplementation for CustomAwsLcRsDigest {
93 /// fn digest(&self, digest_algorithm: DigestAlgorithm, bytes: &[u8]) -> Vec<u8> {
94 /// let digest_alg = match digest_algorithm {
95 /// // Note: SHA1 is upgraded to SHA256 as per https://datatracker.ietf.org/doc/html/rfc5929#section-4.1
96 /// DigestAlgorithm::Sha1 | DigestAlgorithm::Sha256 => &aws_lc_rs::digest::SHA256,
97 /// DigestAlgorithm::Sha384 => &aws_lc_rs::digest::SHA384,
98 /// DigestAlgorithm::Sha512 => &aws_lc_rs::digest::SHA512,
99 /// };
100 ///
101 /// aws_lc_rs::digest::digest(digest_alg, bytes).as_ref().into()
102 /// }
103 /// }
104 /// ```
105 fn digest(&self, algorithm: DigestAlgorithm, bytes: &[u8]) -> Vec<u8>;
106}
107
108#[cfg(feature = "aws-lc-rs")]
109pub use aws_lc_rs_backend::AwsLcRsDigest;
110
111#[cfg(feature = "graviola")]
112pub use graviola_backend::GraviolaDigest;
113
114#[cfg(feature = "ring")]
115pub use ring_backend::RingDigest;
116
117#[cfg(feature = "rustcrypto")]
118pub use rustcrypto_backend::RustcryptoDigest;
119
120#[cfg(feature = "aws-lc-rs")]
121mod aws_lc_rs_backend {
122 use super::{DigestAlgorithm, DigestImplementation};
123
124 #[derive(Clone, Copy, Debug, Default, PartialEq, Eq, PartialOrd, Ord, Hash)]
125 /// A digest backend provided by aws-lc-rs
126 ///
127 /// Usage:
128 /// ```rust
129 /// # let cert_store = rustls::RootCertStore::empty();
130 /// #
131 /// # let rustls_config = rustls::ClientConfig::builder()
132 /// # .with_root_certificates(cert_store)
133 /// # .with_no_client_auth();
134 /// #
135 /// use tokio_postgres_generic_rustls::{MakeRustlsConnect, AwsLcRsDigest};
136 ///
137 /// let tls = MakeRustlsConnect::new(rustls_config, AwsLcRsDigest);
138 ///
139 /// let connect_future = tokio_postgres::connect("postgres://username:password@localhost:5432/db", tls);
140 /// ```
141 pub struct AwsLcRsDigest;
142
143 impl DigestImplementation for AwsLcRsDigest {
144 fn digest(&self, algorithm: DigestAlgorithm, bytes: &[u8]) -> Vec<u8> {
145 let digest_alg = match algorithm {
146 // Note: SHA1 is upgraded to SHA256 as per https://datatracker.ietf.org/doc/html/rfc5929#section-4.1
147 DigestAlgorithm::Sha1 | DigestAlgorithm::Sha256 => &aws_lc_rs::digest::SHA256,
148 DigestAlgorithm::Sha384 => &aws_lc_rs::digest::SHA384,
149 DigestAlgorithm::Sha512 => &aws_lc_rs::digest::SHA512,
150 };
151
152 aws_lc_rs::digest::digest(digest_alg, bytes).as_ref().into()
153 }
154 }
155}
156
157#[cfg(feature = "graviola")]
158mod graviola_backend {
159 use super::{DigestAlgorithm, DigestImplementation};
160
161 #[derive(Clone, Copy, Debug, Default, PartialEq, Eq, PartialOrd, Ord, Hash)]
162 /// A digest backend provided by aws-lc-rs
163 ///
164 /// Usage:
165 /// ```rust
166 /// # let cert_store = rustls::RootCertStore::empty();
167 /// #
168 /// # let rustls_config = rustls::ClientConfig::builder()
169 /// # .with_root_certificates(cert_store)
170 /// # .with_no_client_auth();
171 /// #
172 /// use tokio_postgres_generic_rustls::{MakeRustlsConnect, GraviolaDigest};
173 ///
174 /// let tls = MakeRustlsConnect::new(rustls_config, GraviolaDigest);
175 ///
176 /// let connect_future = tokio_postgres::connect("postgres://username:password@localhost:5432/db", tls);
177 /// ```
178 pub struct GraviolaDigest;
179
180 impl DigestImplementation for GraviolaDigest {
181 fn digest(&self, algorithm: DigestAlgorithm, bytes: &[u8]) -> Vec<u8> {
182 use graviola::hashing::Hash;
183
184 match algorithm {
185 // Note: SHA1 is upgraded to SHA256 as per https://datatracker.ietf.org/doc/html/rfc5929#section-4.1
186 DigestAlgorithm::Sha1 | DigestAlgorithm::Sha256 => {
187 graviola::hashing::Sha256::hash(bytes).as_ref().to_vec()
188 }
189 DigestAlgorithm::Sha384 => graviola::hashing::Sha384::hash(bytes).as_ref().to_vec(),
190 DigestAlgorithm::Sha512 => graviola::hashing::Sha512::hash(bytes).as_ref().to_vec(),
191 }
192 }
193 }
194}
195
196#[cfg(feature = "ring")]
197mod ring_backend {
198 use super::{DigestAlgorithm, DigestImplementation};
199
200 #[derive(Clone, Copy, Debug, Default, PartialEq, Eq, PartialOrd, Ord, Hash)]
201 /// a digest backend provided by ring
202 ///
203 /// Usage:
204 /// ```rust
205 /// # let cert_store = rustls::RootCertStore::empty();
206 /// #
207 /// # let rustls_config = rustls::ClientConfig::builder()
208 /// # .with_root_certificates(cert_store)
209 /// # .with_no_client_auth();
210 /// #
211 /// use tokio_postgres_generic_rustls::{MakeRustlsConnect, RingDigest};
212 ///
213 /// let tls = MakeRustlsConnect::new(rustls_config, RingDigest);
214 ///
215 /// let connect_future = tokio_postgres::connect("postgres://username:password@localhost:5432/db", tls);
216 /// ```
217 pub struct RingDigest;
218
219 impl DigestImplementation for RingDigest {
220 fn digest(&self, algorithm: DigestAlgorithm, bytes: &[u8]) -> Vec<u8> {
221 let digest_alg = match algorithm {
222 // Note: SHA1 is upgraded to SHA256 as per https://datatracker.ietf.org/doc/html/rfc5929#section-4.1
223 DigestAlgorithm::Sha1 | DigestAlgorithm::Sha256 => &ring::digest::SHA256,
224 DigestAlgorithm::Sha384 => &ring::digest::SHA384,
225 DigestAlgorithm::Sha512 => &ring::digest::SHA512,
226 };
227
228 ring::digest::digest(digest_alg, bytes).as_ref().into()
229 }
230 }
231}
232
233#[cfg(feature = "rustcrypto")]
234mod rustcrypto_backend {
235 use super::{DigestAlgorithm, DigestImplementation};
236 use sha2::{Digest, Sha256, Sha384, Sha512};
237
238 #[derive(Clone, Copy, Debug, Default, PartialEq, Eq, PartialOrd, Ord, Hash)]
239 /// a digest backend provided by rustcrypto
240 ///
241 /// Usage:
242 /// ```rust
243 /// # let cert_store = rustls::RootCertStore::empty();
244 /// #
245 /// # let rustls_config = rustls::ClientConfig::builder()
246 /// # .with_root_certificates(cert_store)
247 /// # .with_no_client_auth();
248 /// #
249 /// use tokio_postgres_generic_rustls::{MakeRustlsConnect, RustcryptoDigest};
250 ///
251 /// let tls = MakeRustlsConnect::new(rustls_config, RustcryptoDigest);
252 ///
253 /// let connect_future = tokio_postgres::connect("postgres://username:password@localhost:5432/db", tls);
254 pub struct RustcryptoDigest;
255
256 impl DigestImplementation for RustcryptoDigest {
257 fn digest(&self, algorithm: DigestAlgorithm, bytes: &[u8]) -> Vec<u8> {
258 match algorithm {
259 // Note: SHA1 is upgraded to SHA256 as per https://datatracker.ietf.org/doc/html/rfc5929#section-4.1
260 DigestAlgorithm::Sha1 | DigestAlgorithm::Sha256 => {
261 Sha256::digest(bytes).as_slice().into()
262 }
263 DigestAlgorithm::Sha384 => Sha384::digest(bytes).as_slice().into(),
264 DigestAlgorithm::Sha512 => Sha512::digest(bytes).as_slice().into(),
265 }
266 }
267 }
268}
269
270#[derive(Clone)]
271/// The primary interface for consumers of this crate
272///
273/// This type can be provided to tokio-postgres' `connect` method in order to add TLS to a postgres
274/// connection.
275/// ```rust
276/// # use tokio_postgres_generic_rustls::{DigestImplementation, DigestAlgorithm};
277/// #
278/// # #[derive(Clone)]
279/// # struct DemoDigest;
280/// #
281/// # impl DigestImplementation for DemoDigest {
282/// # fn digest(&self, algorithm: DigestAlgorithm, bytes: &[u8]) -> Vec<u8> {
283/// # todo!("digest it")
284/// # }
285/// # }
286/// #
287/// # let cert_store = rustls::RootCertStore::empty();
288/// #
289/// # let rustls_config = rustls::ClientConfig::builder()
290/// # .with_root_certificates(cert_store)
291/// # .with_no_client_auth();
292/// #
293/// use tokio_postgres_generic_rustls::MakeRustlsConnect;
294///
295/// let tls = MakeRustlsConnect::new(rustls_config, DemoDigest);
296///
297/// let connect_future = tokio_postgres::connect("postgres://username:password@localhost:5432/db", tls);
298/// ```
299pub struct MakeRustlsConnect<D> {
300 config: Arc<ClientConfig>,
301 digest_impl: D,
302}
303
304impl<D> MakeRustlsConnect<D>
305where
306 D: DigestImplementation,
307{
308 /// Create a new MakeRustlsConnect instance
309 pub fn new(config: ClientConfig, digest_impl: D) -> Self {
310 Self {
311 config: Arc::new(config),
312 digest_impl,
313 }
314 }
315}
316
317impl<D, S> MakeTlsConnect<S> for MakeRustlsConnect<D>
318where
319 D: DigestImplementation + Clone + Unpin,
320 S: AsyncRead + AsyncWrite + Unpin + Send,
321{
322 type Stream = RustlsStream<D, S>;
323 type TlsConnect = RustlsConnect<D>;
324 type Error = rustls::pki_types::InvalidDnsNameError;
325
326 fn make_tls_connect(&mut self, domain: &str) -> Result<Self::TlsConnect, Self::Error> {
327 ServerName::try_from(domain).map(|dns_name| RustlsConnect {
328 dns_name: dns_name.to_owned(),
329 connector: Arc::clone(&self.config).into(),
330 digest_impl: self.digest_impl.clone(),
331 })
332 }
333}
334
335#[doc(hidden)]
336pub struct RustlsConnect<D> {
337 dns_name: ServerName<'static>,
338 connector: TlsConnector,
339 digest_impl: D,
340}
341
342#[doc(hidden)]
343pub struct ConnectFuture<D, S> {
344 connect: Connect<S>,
345 digest_impl: D,
346}
347
348impl<D, S> Future for ConnectFuture<D, S>
349where
350 D: DigestImplementation + Clone + Unpin,
351 S: AsyncRead + AsyncWrite + Unpin,
352{
353 type Output = std::io::Result<RustlsStream<D, S>>;
354
355 fn poll(
356 self: Pin<&mut Self>,
357 cx: &mut std::task::Context<'_>,
358 ) -> std::task::Poll<Self::Output> {
359 let this = self.get_mut();
360
361 let res = std::task::ready!(Pin::new(&mut this.connect).poll(cx));
362
363 std::task::Poll::Ready(res.map(|io| RustlsStream {
364 io,
365 digest_impl: this.digest_impl.clone(),
366 }))
367 }
368}
369
370impl<D, S> TlsConnect<S> for RustlsConnect<D>
371where
372 D: DigestImplementation + Clone + Unpin,
373 S: AsyncRead + AsyncWrite + Unpin + Send,
374{
375 type Stream = RustlsStream<D, S>;
376 type Error = std::io::Error;
377 type Future = ConnectFuture<D, S>;
378
379 fn connect(self, stream: S) -> Self::Future {
380 ConnectFuture {
381 connect: self.connector.connect(self.dns_name, stream),
382 digest_impl: self.digest_impl.clone(),
383 }
384 }
385}
386
387enum SignatureAlgorithm {
388 // 1.2.840.113549.1.1.5
389 Sha1Rsa,
390
391 // 1.2.840.113549.1.1.11
392 Sha256Rsa,
393
394 // 1.2.840.113549.1.1.12
395 Sha384Rsa,
396
397 // 1.2.840.113549.1.1.13
398 Sha512Rsa,
399
400 // 1.2.840.10045.4.3.2
401 EcdsaSha256,
402
403 // 1.2.840.10045.4.3.3
404 EcdsaSha384,
405}
406
407impl SignatureAlgorithm {
408 fn try_from_identifier(oid: &ObjectIdentifier) -> Option<Self> {
409 if oid == &SHA_1_WITH_RSA_ENCRYPTION {
410 Some(Self::Sha1Rsa)
411 } else if oid == &SHA_256_WITH_RSA_ENCRYPTION {
412 Some(Self::Sha256Rsa)
413 } else if oid == &SHA_384_WITH_RSA_ENCRYPTION {
414 Some(Self::Sha384Rsa)
415 } else if oid == &SHA_512_WITH_RSA_ENCRYPTION {
416 Some(Self::Sha512Rsa)
417 } else if oid == &ECDSA_WITH_SHA_256 {
418 Some(Self::EcdsaSha256)
419 } else if oid == &ECDSA_WITH_SHA_384 {
420 Some(Self::EcdsaSha384)
421 } else {
422 None
423 }
424 }
425
426 fn digest_algorithm(self) -> DigestAlgorithm {
427 match self {
428 Self::Sha1Rsa => DigestAlgorithm::Sha1,
429 Self::Sha256Rsa | Self::EcdsaSha256 => DigestAlgorithm::Sha256,
430 Self::Sha384Rsa | Self::EcdsaSha384 => DigestAlgorithm::Sha384,
431 Self::Sha512Rsa => DigestAlgorithm::Sha512,
432 }
433 }
434}
435
436/// Digest algorithms that can be used in tls-server-end-point channel bindings.
437///
438/// This type is only useful in the context of defining a custom digest impelementation, and
439/// otherwise can be ignored.
440pub enum DigestAlgorithm {
441 /// The provided certificate requests Sha1 hasing. as per
442 /// <https://datatracker.ietf.org/doc/html/rfc5929#section-4.1> sha256 hashing should be used
443 /// instead
444 Sha1,
445
446 /// The provided certificate requests Sha256 hasing
447 Sha256,
448
449 /// The provided certificate requests sha284 hashing
450 Sha384,
451
452 /// The provided certificate requests Sha512 hashing
453 Sha512,
454}
455
456#[doc(hidden)]
457pub struct RustlsStream<D, S> {
458 io: TlsStream<S>,
459 digest_impl: D,
460}
461
462impl<D, S> tokio_postgres::tls::TlsStream for RustlsStream<D, S>
463where
464 D: DigestImplementation + Unpin,
465 S: AsyncRead + AsyncWrite + Unpin,
466{
467 fn channel_binding(&self) -> tokio_postgres::tls::ChannelBinding {
468 let (_, session) = self.io.get_ref();
469
470 match session.peer_certificates() {
471 Some(certs) if !certs.is_empty() => Certificate::from_der(&certs[0])
472 .ok()
473 .and_then(|cert| {
474 SignatureAlgorithm::try_from_identifier(&cert.signature_algorithm.oid)
475 })
476 .map(|signature_algorithm| {
477 let digest = self
478 .digest_impl
479 .digest(signature_algorithm.digest_algorithm(), &certs[0]);
480
481 ChannelBinding::tls_server_end_point(digest)
482 })
483 .unwrap_or_else(ChannelBinding::none),
484 _ => ChannelBinding::none(),
485 }
486 }
487}
488
489impl<D, S> AsyncRead for RustlsStream<D, S>
490where
491 D: Unpin,
492 S: AsyncRead + AsyncWrite + Unpin,
493{
494 fn poll_read(
495 self: Pin<&mut Self>,
496 cx: &mut std::task::Context<'_>,
497 buf: &mut tokio::io::ReadBuf<'_>,
498 ) -> std::task::Poll<std::io::Result<()>> {
499 Pin::new(&mut self.get_mut().io).poll_read(cx, buf)
500 }
501}
502
503impl<D, S> AsyncWrite for RustlsStream<D, S>
504where
505 D: Unpin,
506 S: AsyncRead + AsyncWrite + Unpin,
507{
508 fn poll_write(
509 self: Pin<&mut Self>,
510 cx: &mut std::task::Context<'_>,
511 buf: &[u8],
512 ) -> std::task::Poll<Result<usize, std::io::Error>> {
513 Pin::new(&mut self.get_mut().io).poll_write(cx, buf)
514 }
515
516 fn poll_flush(
517 self: Pin<&mut Self>,
518 cx: &mut std::task::Context<'_>,
519 ) -> std::task::Poll<Result<(), std::io::Error>> {
520 Pin::new(&mut self.get_mut().io).poll_flush(cx)
521 }
522
523 fn poll_shutdown(
524 self: Pin<&mut Self>,
525 cx: &mut std::task::Context<'_>,
526 ) -> std::task::Poll<Result<(), std::io::Error>> {
527 Pin::new(&mut self.get_mut().io).poll_shutdown(cx)
528 }
529
530 fn is_write_vectored(&self) -> bool {
531 self.io.is_write_vectored()
532 }
533
534 fn poll_write_vectored(
535 self: Pin<&mut Self>,
536 cx: &mut std::task::Context<'_>,
537 bufs: &[std::io::IoSlice<'_>],
538 ) -> std::task::Poll<Result<usize, std::io::Error>> {
539 Pin::new(&mut self.get_mut().io).poll_write_vectored(cx, bufs)
540 }
541}