rustls_tokio_postgres/lib.rs
1#![forbid(unsafe_code)]
2#![cfg_attr(docsrs, feature(doc_cfg))]
3
4//! A [`tokio_postgres`] TLS connector backed by [`rustls`].
5//!
6//! # Example
7//!
8//! ```rust,no_run
9//! use rustls_tokio_postgres::{config_no_verify, MakeRustlsConnect};
10//! use tokio_postgres::{connect, Config};
11//!
12//! #[tokio::main]
13//! async fn main() -> Result<(), Box<dyn std::error::Error>> {
14//! static CONFIG: &str = "host=localhost user=postgres";
15//!
16//! // This rustls config does not verify the server certificate.
17//! // You can construct your own rustls ClientConfig, if needed.
18//! let tls = MakeRustlsConnect::new(config_no_verify());
19//!
20//! // Create the client with the TLS configuration.
21//! let (_client, _conn) = connect(CONFIG, tls).await?;
22//!
23//! Ok(())
24//! }
25//! ```
26//!
27//! # Features
28//!
29//! - **channel-binding**: enables TLS channel binding, if supported.
30//! - **native-roots**: enables a helper function for creating a [`rustls::ClientConfig`] using the native roots of your OS.
31//! - **webpki-roots**: enables a helper function for creating a [`rustls::ClientConfig`] using the webpki roots.
32
33use std::{io, sync::Arc};
34
35use rustls::ClientConfig;
36use rustls_pki_types::{InvalidDnsNameError, ServerName};
37use tokio::io::{AsyncRead, AsyncWrite};
38use tokio_postgres::tls::MakeTlsConnect;
39
40mod config;
41mod connect;
42
43#[cfg(feature = "native-roots")]
44pub use config::config_native_roots;
45pub use config::config_no_verify;
46#[cfg(feature = "webpki-roots")]
47pub use config::config_webpki_roots;
48pub use rustls;
49pub use tokio_postgres;
50
51/// A MakeTlsConnect implementation that uses rustls.
52#[derive(Clone)]
53pub struct MakeRustlsConnect {
54 config: Arc<ClientConfig>,
55}
56
57impl MakeRustlsConnect {
58 /// Construct a new `MakeRustlsConnect` instance with the provided [`ClientConfig`].
59 pub fn new(config: ClientConfig) -> Self {
60 Self {
61 config: Arc::new(config),
62 }
63 }
64}
65
66impl<S> MakeTlsConnect<S> for MakeRustlsConnect
67where
68 S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
69{
70 type Stream = connect::TlsStream<S>;
71 type TlsConnect = connect::RustlsConnect;
72 type Error = io::Error;
73
74 fn make_tls_connect(&mut self, hostname: &str) -> Result<Self::TlsConnect, Self::Error> {
75 let server_name = ServerName::try_from(hostname)
76 .map_err(|e: InvalidDnsNameError| {
77 io::Error::new(io::ErrorKind::InvalidInput, e.to_string())
78 })?
79 .to_owned();
80
81 Ok(connect::RustlsConnect {
82 config: self.config.clone(),
83 server_name,
84 })
85 }
86}