Skip to main content

soli_proxy/
tls.rs

1use anyhow::{Context, Result};
2use rcgen::{Certificate, CertificateParams};
3use std::path::PathBuf;
4use std::sync::Arc;
5use tokio_rustls::rustls::ServerConfig;
6
7use crate::acme::{
8    build_server_config, certified_key_from_pem, load_certificate, AcmeCertResolver,
9};
10use crate::config::TlsConfig;
11
12pub struct TlsManager {
13    server_config: Option<Arc<ServerConfig>>,
14    resolver: Arc<AcmeCertResolver>,
15    cache_dir: PathBuf,
16}
17
18impl TlsManager {
19    pub fn new(tls_config: &TlsConfig) -> Result<Self> {
20        let cache_dir = PathBuf::from(&tls_config.cache_dir);
21        std::fs::create_dir_all(&cache_dir).ok();
22
23        let resolver = Arc::new(AcmeCertResolver::new());
24
25        Ok(Self {
26            server_config: None,
27            resolver,
28            cache_dir,
29        })
30    }
31
32    /// Load self-signed fallback cert. Always called to ensure TLS works.
33    pub fn load_self_signed_fallback(&self) -> Result<()> {
34        let cert_path = self.cache_dir.join("self-signed.cert.pem");
35        let key_path = self.cache_dir.join("self-signed.key.pem");
36
37        if cert_path.exists() && key_path.exists() {
38            let cert_pem = std::fs::read(&cert_path)?;
39            let key_pem = std::fs::read(&key_path)?;
40            let ck = certified_key_from_pem(&cert_pem, &key_pem)?;
41            self.resolver.set_fallback(Arc::new(ck));
42            tracing::info!("Loaded existing self-signed fallback certificate");
43            return Ok(());
44        }
45
46        tracing::info!("Generating self-signed TLS certificate...");
47        let (cert_pem, key_pem) = generate_self_signed_cert()?;
48
49        std::fs::create_dir_all(&self.cache_dir)?;
50        std::fs::write(&cert_path, &cert_pem).context("Failed to write self-signed certificate")?;
51        std::fs::write(&key_path, &key_pem).context("Failed to write self-signed key")?;
52
53        let ck = certified_key_from_pem(cert_pem.as_bytes(), key_pem.as_bytes())?;
54        self.resolver.set_fallback(Arc::new(ck));
55
56        tracing::info!(
57            "Generated self-signed certificate at {}",
58            cert_path.display()
59        );
60        Ok(())
61    }
62
63    /// Load cached ACME certs from disk into the resolver.
64    pub fn load_cached_certs(&self, domains: &[String]) -> Result<()> {
65        for domain in domains {
66            match load_certificate(&self.cache_dir, domain) {
67                Ok(Some(ck)) => {
68                    self.resolver.set_cert(domain, ck);
69                    tracing::info!("Loaded cached certificate for {}", domain);
70                }
71                Ok(None) => {
72                    tracing::debug!("No cached certificate for {}", domain);
73                }
74                Err(e) => {
75                    tracing::warn!("Failed to load cached cert for {}: {}", domain, e);
76                }
77            }
78        }
79        Ok(())
80    }
81
82    /// Build the ServerConfig using the cert resolver. Call after loading certs.
83    pub fn build(&mut self) -> Result<()> {
84        let config = build_server_config(self.resolver.clone())?;
85        self.server_config = Some(config);
86        Ok(())
87    }
88
89    pub fn server_config(&self) -> Option<&Arc<ServerConfig>> {
90        self.server_config.as_ref()
91    }
92
93    pub fn cert_resolver(&self) -> Arc<AcmeCertResolver> {
94        self.resolver.clone()
95    }
96
97    pub fn cache_dir(&self) -> &PathBuf {
98        &self.cache_dir
99    }
100}
101
102fn generate_self_signed_cert() -> Result<(String, String)> {
103    let mut params = CertificateParams::default();
104
105    params.subject_alt_names = vec![
106        rcgen::SanType::DnsName("localhost".to_string()),
107        rcgen::SanType::IpAddress([127, 0, 0, 1].into()),
108    ];
109
110    let cert = Certificate::from_params(params).context("Failed to generate certificate")?;
111    let cert_pem = cert
112        .serialize_pem()
113        .context("Failed to serialize certificate")?;
114    let key_pem = cert.serialize_private_key_pem();
115
116    Ok((cert_pem, key_pem))
117}