rustls_tokio_postgres/
lib.rs

1#![forbid(unsafe_code)]
2#![cfg_attr(docsrs, feature(doc_cfg))]
3
4//! A [`tokio_postgres`] TLS connector backed by [`rustls`].
5//!
6//! # Example
7//!
8//! ```rust,no_run
9//! use rustls_tokio_postgres::{config_no_verify, MakeRustlsConnect};
10//! use tokio_postgres::{connect, Config};
11//!
12//! #[tokio::main]
13//! async fn main() -> Result<(), Box<dyn std::error::Error>> {
14//!     static CONFIG: &str = "host=localhost user=postgres";
15//!
16//!     // This rustls config does not verify the server certificate.
17//!     // You can construct your own rustls ClientConfig, if needed.
18//!     let tls = MakeRustlsConnect::new(config_no_verify());
19//!
20//!     // Create the client with the TLS configuration.
21//!     let (_client, _conn) = connect(CONFIG, tls).await?;
22//!
23//!     Ok(())
24//! }
25//! ```
26//!
27//! # Features
28//!
29//! - **channel-binding**: enables TLS channel binding, if supported.
30//! - **native-roots**: enables a helper function for creating a [`rustls::ClientConfig`] using the native roots of your OS.
31//! - **webpki-roots**: enables a helper function for creating a [`rustls::ClientConfig`] using the webpki roots.
32
33use std::{io, sync::Arc};
34
35use rustls::ClientConfig;
36use rustls_pki_types::{InvalidDnsNameError, ServerName};
37use tokio::io::{AsyncRead, AsyncWrite};
38use tokio_postgres::tls::MakeTlsConnect;
39
40mod config;
41mod connect;
42
43#[cfg(feature = "native-roots")]
44pub use config::config_native_roots;
45pub use config::config_no_verify;
46#[cfg(feature = "webpki-roots")]
47pub use config::config_webpki_roots;
48pub use rustls;
49pub use tokio_postgres;
50
51/// A MakeTlsConnect implementation that uses rustls.
52#[derive(Clone)]
53pub struct MakeRustlsConnect {
54    config: Arc<ClientConfig>,
55}
56
57impl MakeRustlsConnect {
58    /// Construct a new `MakeRustlsConnect` instance with the provided [`ClientConfig`].
59    pub fn new(config: ClientConfig) -> Self {
60        Self {
61            config: Arc::new(config),
62        }
63    }
64}
65
66impl<S> MakeTlsConnect<S> for MakeRustlsConnect
67where
68    S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
69{
70    type Stream = connect::TlsStream<S>;
71    type TlsConnect = connect::RustlsConnect;
72    type Error = io::Error;
73
74    fn make_tls_connect(&mut self, hostname: &str) -> Result<Self::TlsConnect, Self::Error> {
75        let server_name = ServerName::try_from(hostname)
76            .map_err(|e: InvalidDnsNameError| {
77                io::Error::new(io::ErrorKind::InvalidInput, e.to_string())
78            })?
79            .to_owned();
80
81        Ok(connect::RustlsConnect {
82            config: self.config.clone(),
83            server_name,
84        })
85    }
86}