Skip to main content

rustgate/
tls.rs

1use crate::cert::CertificateAuthority;
2use crate::error::Result;
3use rustls::ServerConfig;
4use std::sync::Arc;
5use tokio::net::TcpStream;
6use tokio_rustls::TlsAcceptor;
7
8/// Create a `TlsAcceptor` for the given domain using a dynamically generated certificate.
9pub async fn make_tls_acceptor(
10    ca: &CertificateAuthority,
11    domain: &str,
12) -> Result<TlsAcceptor> {
13    let ck = ca.get_or_create_cert(domain).await?;
14
15    let config = ServerConfig::builder()
16        .with_no_client_auth()
17        .with_single_cert(
18            vec![ck.cert_der.clone()],
19            rustls::pki_types::PrivateKeyDer::Pkcs8(ck.key_der.clone_key()),
20        )?;
21
22    Ok(TlsAcceptor::from(Arc::new(config)))
23}
24
25/// Connect to an upstream server over TLS and return the stream.
26pub async fn connect_tls_upstream(
27    host: &str,
28    addr: &str,
29) -> Result<tokio_rustls::client::TlsStream<TcpStream>> {
30    let tcp = TcpStream::connect(addr).await?;
31
32    let mut root_store = rustls::RootCertStore::empty();
33    root_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
34
35    let config = rustls::ClientConfig::builder()
36        .with_root_certificates(root_store)
37        .with_no_client_auth();
38
39    let connector = tokio_rustls::TlsConnector::from(Arc::new(config));
40    let server_name = rustls::pki_types::ServerName::try_from(host.to_string())
41        .map_err(|e| crate::error::ProxyError::Other(e.to_string()))?;
42
43    let tls_stream = connector.connect(server_name, tcp).await?;
44    Ok(tls_stream)
45}