Skip to main content

soli_proxy/
acme.rs

1use std::collections::HashMap;
2use std::fmt;
3use std::path::{Path, PathBuf};
4use std::sync::Arc;
5
6use anyhow::{Context, Result};
7use instant_acme::{
8    Account, AccountCredentials, ChallengeType, Identifier, LetsEncrypt, NewAccount, NewOrder,
9    OrderStatus,
10};
11use rustls_pki_types::{CertificateDer, PrivateKeyDer};
12use std::sync::RwLock;
13use tokio_rustls::rustls::server::ResolvesServerCert;
14use tokio_rustls::rustls::sign::CertifiedKey;
15use tokio_rustls::rustls::ServerConfig;
16
17use crate::config::{ConfigManagerTrait, LetsEncryptConfig};
18
19/// Shared store for ACME HTTP-01 challenge tokens.
20/// Maps token -> key_authorization.
21pub type ChallengeStore = Arc<RwLock<HashMap<String, String>>>;
22
23pub fn new_challenge_store() -> ChallengeStore {
24    Arc::new(RwLock::new(HashMap::new()))
25}
26
27/// Dynamic certificate resolver that supports per-domain ACME certs
28/// with a self-signed fallback.
29pub struct AcmeCertResolver {
30    certs: RwLock<HashMap<String, Arc<CertifiedKey>>>,
31    fallback: RwLock<Option<Arc<CertifiedKey>>>,
32}
33
34impl fmt::Debug for AcmeCertResolver {
35    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
36        let domain_count = self.certs.read().map(|c| c.len()).unwrap_or(0);
37        let has_fallback = self.fallback.read().map(|f| f.is_some()).unwrap_or(false);
38        f.debug_struct("AcmeCertResolver")
39            .field("domains", &domain_count)
40            .field("has_fallback", &has_fallback)
41            .finish()
42    }
43}
44
45impl Default for AcmeCertResolver {
46    fn default() -> Self {
47        Self {
48            certs: RwLock::new(HashMap::new()),
49            fallback: RwLock::new(None),
50        }
51    }
52}
53
54impl AcmeCertResolver {
55    pub fn new() -> Self {
56        Self::default()
57    }
58
59    pub fn set_fallback(&self, key: Arc<CertifiedKey>) {
60        if let Ok(mut fallback) = self.fallback.write() {
61            *fallback = Some(key);
62        }
63    }
64
65    pub fn set_cert(&self, domain: &str, key: Arc<CertifiedKey>) {
66        if let Ok(mut certs) = self.certs.write() {
67            certs.insert(domain.to_string(), key);
68        }
69    }
70}
71
72impl ResolvesServerCert for AcmeCertResolver {
73    fn resolve(
74        &self,
75        client_hello: tokio_rustls::rustls::server::ClientHello<'_>,
76    ) -> Option<Arc<CertifiedKey>> {
77        if let Some(sni) = client_hello.server_name() {
78            if let Ok(certs) = self.certs.read() {
79                if let Some(key) = certs.get(sni) {
80                    return Some(key.clone());
81                }
82            }
83        }
84
85        if let Ok(fallback) = self.fallback.read() {
86            return fallback.clone();
87        }
88
89        None
90    }
91}
92
93/// Service that encapsulates ACME cert issuance logic, shareable across components.
94pub struct AcmeService {
95    account: Arc<Account>,
96    challenge_store: ChallengeStore,
97    resolver: Arc<AcmeCertResolver>,
98    cache_dir: PathBuf,
99}
100
101impl AcmeService {
102    pub fn new(
103        account: Arc<Account>,
104        challenge_store: ChallengeStore,
105        resolver: Arc<AcmeCertResolver>,
106        cache_dir: PathBuf,
107    ) -> Self {
108        Self {
109            account,
110            challenge_store,
111            resolver,
112            cache_dir,
113        }
114    }
115
116    /// Issue a cert for `domain` if one doesn't exist or is expiring.
117    pub async fn ensure_certificate(&self, domain: &str) -> Result<()> {
118        if !cert_expires_soon(&self.cache_dir, domain) {
119            return Ok(());
120        }
121        tracing::info!("Issuing certificate for new domain: {}", domain);
122        let (cert_pem, key_pem) =
123            issue_certificate(&self.account, &[domain.to_string()], &self.challenge_store).await?;
124        save_certificate(&self.cache_dir, domain, &cert_pem, &key_pem)?;
125        let ck = certified_key_from_pem(cert_pem.as_bytes(), key_pem.as_bytes())?;
126        self.resolver.set_cert(domain, Arc::new(ck));
127        Ok(())
128    }
129}
130
131/// Build a ServerConfig using the AcmeCertResolver.
132pub fn build_server_config(resolver: Arc<AcmeCertResolver>) -> Result<Arc<ServerConfig>> {
133    let mut config = ServerConfig::builder()
134        .with_no_client_auth()
135        .with_cert_resolver(resolver);
136    config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
137    Ok(Arc::new(config))
138}
139
140/// Get or create an ACME account, persisting credentials to disk.
141pub async fn get_or_create_account(
142    le_config: &LetsEncryptConfig,
143    cache_dir: &Path,
144) -> Result<Account> {
145    let creds_path = cache_dir.join("account_credentials.json");
146
147    if creds_path.exists() {
148        let json = std::fs::read_to_string(&creds_path)
149            .context("Failed to read ACME account credentials")?;
150        let credentials: AccountCredentials =
151            serde_json::from_str(&json).context("Failed to parse ACME account credentials")?;
152        let account = Account::builder()
153            .map_err(|e| anyhow::anyhow!("Failed to create account builder: {:?}", e))?
154            .from_credentials(credentials)
155            .await
156            .map_err(|e| anyhow::anyhow!("Failed to restore ACME account: {:?}", e))?;
157        tracing::info!("Restored ACME account from {}", creds_path.display());
158        return Ok(account);
159    }
160
161    let directory_url = if le_config.staging {
162        LetsEncrypt::Staging.url().to_string()
163    } else {
164        LetsEncrypt::Production.url().to_string()
165    };
166
167    let contact = format!("mailto:{}", le_config.email);
168    let (account, credentials) = Account::builder()
169        .map_err(|e| anyhow::anyhow!("Failed to create account builder: {:?}", e))?
170        .create(
171            &NewAccount {
172                contact: &[&contact],
173                terms_of_service_agreed: le_config.terms_agreed,
174                only_return_existing: false,
175            },
176            directory_url,
177            None,
178        )
179        .await
180        .map_err(|e| anyhow::anyhow!("Failed to create ACME account: {:?}", e))?;
181
182    std::fs::create_dir_all(cache_dir)?;
183    let json = serde_json::to_string_pretty(&credentials)?;
184    std::fs::write(&creds_path, &json)?;
185
186    #[cfg(unix)]
187    {
188        use std::os::unix::fs::PermissionsExt;
189        std::fs::set_permissions(&creds_path, std::fs::Permissions::from_mode(0o600))?;
190    }
191
192    tracing::info!(
193        "Created new ACME account, credentials saved to {}",
194        creds_path.display()
195    );
196    Ok(account)
197}
198
199/// Issue a certificate for the given domains using HTTP-01 challenges.
200pub async fn issue_certificate(
201    account: &Account,
202    domains: &[String],
203    challenge_store: &ChallengeStore,
204) -> Result<(String, String)> {
205    let identifiers: Vec<Identifier> = domains.iter().map(|d| Identifier::Dns(d.clone())).collect();
206
207    let mut order = account
208        .new_order(&NewOrder::new(&identifiers))
209        .await
210        .map_err(|e| anyhow::anyhow!("Failed to create ACME order: {:?}", e))?;
211
212    // Collect challenge tokens to clean up later
213    let mut challenge_tokens: Vec<String> = Vec::new();
214
215    // Process authorizations using the stream-like API
216    {
217        let mut authorizations = order.authorizations();
218        while let Some(result) = authorizations.next().await {
219            let mut authz =
220                result.map_err(|e| anyhow::anyhow!("Failed to get authorization: {:?}", e))?;
221
222            if authz.status == instant_acme::AuthorizationStatus::Valid {
223                continue;
224            }
225
226            let mut challenge = authz
227                .challenge(ChallengeType::Http01)
228                .ok_or_else(|| anyhow::anyhow!("No HTTP-01 challenge found"))?;
229
230            let key_auth = challenge.key_authorization();
231            let key_auth_str = key_auth.as_str().to_string();
232            let token = challenge.token.clone();
233
234            // Store challenge for the HTTP server to serve
235            {
236                let mut store = challenge_store
237                    .write()
238                    .map_err(|_| anyhow::anyhow!("Challenge store poisoned"))?;
239                store.insert(token.clone(), key_auth_str);
240            }
241
242            challenge_tokens.push(token.clone());
243            tracing::info!("ACME challenge set for token: {}", token);
244
245            // Signal that we're ready for validation
246            challenge
247                .set_ready()
248                .await
249                .map_err(|e| anyhow::anyhow!("Failed to set challenge ready: {:?}", e))?;
250        }
251    }
252
253    // Poll until order is ready
254    let status = order
255        .poll_ready(&instant_acme::RetryPolicy::default())
256        .await
257        .map_err(|e| anyhow::anyhow!("ACME order failed to become ready: {:?}", e))?;
258
259    if status != OrderStatus::Ready {
260        anyhow::bail!("ACME order not ready, status: {:?}", status);
261    }
262
263    // Finalize order — generates private key and CSR automatically
264    let private_key_pem = order
265        .finalize()
266        .await
267        .map_err(|e| anyhow::anyhow!("Failed to finalize ACME order: {:?}", e))?;
268
269    // Get cert chain
270    let cert_chain_pem = order
271        .certificate()
272        .await
273        .map_err(|e| anyhow::anyhow!("Failed to retrieve certificate: {:?}", e))?
274        .ok_or_else(|| anyhow::anyhow!("No certificate returned"))?;
275
276    // Clean up challenge tokens
277    {
278        let mut store = challenge_store
279            .write()
280            .map_err(|_| anyhow::anyhow!("Challenge store poisoned"))?;
281        for token in &challenge_tokens {
282            store.remove(token);
283        }
284    }
285
286    Ok((cert_chain_pem, private_key_pem))
287}
288
289/// Save certificate and key PEM files to disk.
290pub fn save_certificate(
291    cache_dir: &Path,
292    domain: &str,
293    cert_pem: &str,
294    key_pem: &str,
295) -> Result<()> {
296    std::fs::create_dir_all(cache_dir)?;
297
298    let cert_path = cache_dir.join(format!("{}.cert.pem", domain));
299    let key_path = cache_dir.join(format!("{}.key.pem", domain));
300
301    std::fs::write(&cert_path, cert_pem)?;
302    std::fs::write(&key_path, key_pem)?;
303
304    #[cfg(unix)]
305    {
306        use std::os::unix::fs::PermissionsExt;
307        std::fs::set_permissions(&key_path, std::fs::Permissions::from_mode(0o600))?;
308    }
309
310    tracing::info!(
311        "Saved certificate for {} to {}",
312        domain,
313        cert_path.display()
314    );
315    Ok(())
316}
317
318/// Load certificate and key from PEM files on disk.
319pub fn load_certificate(cache_dir: &Path, domain: &str) -> Result<Option<Arc<CertifiedKey>>> {
320    let cert_path = cache_dir.join(format!("{}.cert.pem", domain));
321    let key_path = cache_dir.join(format!("{}.key.pem", domain));
322
323    if !cert_path.exists() || !key_path.exists() {
324        return Ok(None);
325    }
326
327    let ck = certified_key_from_pem(&std::fs::read(&cert_path)?, &std::fs::read(&key_path)?)?;
328
329    Ok(Some(Arc::new(ck)))
330}
331
332/// Parse PEM-encoded cert chain and private key into a CertifiedKey.
333pub fn certified_key_from_pem(cert_pem: &[u8], key_pem: &[u8]) -> Result<CertifiedKey> {
334    let certs: Vec<CertificateDer<'static>> = rustls_pemfile::certs(&mut &*cert_pem)
335        .collect::<std::result::Result<Vec<_>, _>>()
336        .context("Failed to parse certificate PEM")?;
337
338    let key: PrivateKeyDer<'static> = rustls_pemfile::private_key(&mut &*key_pem)
339        .context("Failed to parse private key PEM")?
340        .ok_or_else(|| anyhow::anyhow!("No private key found in PEM"))?;
341
342    let provider = tokio_rustls::rustls::crypto::CryptoProvider::get_default()
343        .ok_or_else(|| anyhow::anyhow!("No default CryptoProvider installed"))?;
344
345    let certified_key = CertifiedKey::from_der(certs, key, provider)
346        .map_err(|e| anyhow::anyhow!("Failed to build CertifiedKey: {:?}", e))?;
347
348    Ok(certified_key)
349}
350
351/// Check if a certificate for the given domain expires within 30 days.
352pub fn cert_expires_soon(cache_dir: &Path, domain: &str) -> bool {
353    let cert_path = cache_dir.join(format!("{}.cert.pem", domain));
354
355    if !cert_path.exists() {
356        return true;
357    }
358
359    let cert_pem = match std::fs::read(&cert_path) {
360        Ok(data) => data,
361        Err(_) => return true,
362    };
363
364    let certs: Vec<CertificateDer<'static>> =
365        match rustls_pemfile::certs(&mut &*cert_pem).collect::<std::result::Result<Vec<_>, _>>() {
366            Ok(c) if !c.is_empty() => c,
367            _ => return true,
368        };
369
370    match x509_parser::parse_x509_certificate(&certs[0]) {
371        Ok((_, cert)) => {
372            let not_after = cert.validity().not_after.timestamp();
373            let now = std::time::SystemTime::now()
374                .duration_since(std::time::UNIX_EPOCH)
375                .map(|d| d.as_secs() as i64)
376                .unwrap_or(0);
377            let thirty_days = 30 * 24 * 3600;
378            not_after - now < thirty_days
379        }
380        Err(_) => true,
381    }
382}
383
384/// Spawn a background task that checks and renews certificates every 12 hours.
385/// Reads the current domain list from config each cycle to pick up dynamically added domains.
386pub fn spawn_renewal_task(
387    account: Arc<Account>,
388    config_manager: Arc<dyn ConfigManagerTrait + Send + Sync>,
389    cache_dir: PathBuf,
390    challenge_store: ChallengeStore,
391    resolver: Arc<AcmeCertResolver>,
392) {
393    tokio::spawn(async move {
394        let interval = tokio::time::Duration::from_secs(12 * 3600);
395
396        loop {
397            tokio::time::sleep(interval).await;
398
399            // Re-read domains from config each cycle to include dynamically added ones
400            let domains = config_manager.get_config().acme_domains();
401            tracing::info!("ACME renewal check: examining {} domain(s)", domains.len());
402
403            for domain in &domains {
404                if !cert_expires_soon(&cache_dir, domain) {
405                    tracing::debug!("Certificate for {} is still valid", domain);
406                    continue;
407                }
408
409                tracing::info!("Certificate for {} needs renewal, issuing...", domain);
410
411                match issue_certificate(&account, std::slice::from_ref(domain), &challenge_store)
412                    .await
413                {
414                    Ok((cert_pem, key_pem)) => {
415                        if let Err(e) = save_certificate(&cache_dir, domain, &cert_pem, &key_pem) {
416                            tracing::error!("Failed to save renewed cert for {}: {}", domain, e);
417                            continue;
418                        }
419
420                        match certified_key_from_pem(cert_pem.as_bytes(), key_pem.as_bytes()) {
421                            Ok(ck) => {
422                                resolver.set_cert(domain, Arc::new(ck));
423                                tracing::info!("Successfully renewed certificate for {}", domain);
424                            }
425                            Err(e) => {
426                                tracing::error!(
427                                    "Failed to parse renewed cert for {}: {}",
428                                    domain,
429                                    e
430                                );
431                            }
432                        }
433                    }
434                    Err(e) => {
435                        tracing::error!("Failed to renew certificate for {}: {}", domain, e);
436                    }
437                }
438            }
439        }
440    });
441}