rustls_tokio_postgres/
lib.rs

1#![forbid(unsafe_code)]
2
3//! A [`tokio_postgres`] TLS connector backed by [`rustls`].
4//!
5//! # Example
6//!
7//! ```rust,no_run
8//! use rustls_tokio_postgres::{config_no_verify, MakeRustlsConnect};
9//! use tokio_postgres::{connect, Config};
10//!
11//! #[tokio::main]
12//! async fn main() -> Result<(), Box<dyn std::error::Error>> {
13//!     static CONFIG: &str = "host=localhost user=postgres";
14//!
15//!     // This rustls config does not verify the server certificate.
16//!     // You can construct your own rustls ClientConfig, if needed.
17//!     let tls = MakeRustlsConnect::new(config_no_verify());
18//!
19//!     // Create the client with the TLS configuration.
20//!     let (_client, _conn) = connect(CONFIG, tls).await?;
21//!
22//!     Ok(())
23//! }
24//! ```
25//!
26//! # Features
27//!
28//! - **channel-binding**: enables TLS channel binding, if supported.
29//! - **native-roots**: enables a helper function for creating a [`rustls::ClientConfig`] using the native roots of your OS.
30//! - **webpki-roots**: enables a helper function for creating a [`rustls::ClientConfig`] using the webpki roots.
31
32use std::{io, sync::Arc};
33
34use rustls::ClientConfig;
35use rustls_pki_types::{InvalidDnsNameError, ServerName};
36use tokio::io::{AsyncRead, AsyncWrite};
37use tokio_postgres::tls::MakeTlsConnect;
38
39mod config;
40mod connect;
41
42#[cfg(feature = "native-roots")]
43pub use config::config_native_roots;
44pub use config::config_no_verify;
45#[cfg(feature = "webpki-roots")]
46pub use config::config_webpki_roots;
47pub use rustls;
48pub use tokio_postgres;
49
50/// A MakeTlsConnect implementation that uses rustls.
51#[derive(Clone)]
52pub struct MakeRustlsConnect {
53    config: Arc<ClientConfig>,
54}
55
56impl MakeRustlsConnect {
57    /// Construct a new `MakeRustlsConnect` instance with the provided [`ClientConfig`].
58    pub fn new(config: ClientConfig) -> Self {
59        Self {
60            config: Arc::new(config),
61        }
62    }
63}
64
65impl<S> MakeTlsConnect<S> for MakeRustlsConnect
66where
67    S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
68{
69    type Stream = connect::TlsStream<S>;
70    type TlsConnect = connect::RustlsConnect;
71    type Error = io::Error;
72
73    fn make_tls_connect(&mut self, hostname: &str) -> Result<Self::TlsConnect, Self::Error> {
74        let server_name = ServerName::try_from(hostname)
75            .map_err(|e: InvalidDnsNameError| {
76                io::Error::new(io::ErrorKind::InvalidInput, e.to_string())
77            })?
78            .to_owned();
79
80        Ok(connect::RustlsConnect {
81            config: self.config.clone(),
82            server_name,
83        })
84    }
85}