1use std::collections::HashMap;
2use std::fmt;
3use std::path::{Path, PathBuf};
4use std::sync::Arc;
5
6use anyhow::{Context, Result};
7use instant_acme::{
8 Account, AccountCredentials, ChallengeType, Identifier, LetsEncrypt, NewAccount, NewOrder,
9 OrderStatus,
10};
11use rustls_pki_types::{CertificateDer, PrivateKeyDer};
12use std::sync::RwLock;
13use tokio_rustls::rustls::server::ResolvesServerCert;
14use tokio_rustls::rustls::sign::CertifiedKey;
15use tokio_rustls::rustls::ServerConfig;
16
17use crate::config::{ConfigManagerTrait, LetsEncryptConfig};
18
19pub type ChallengeStore = Arc<RwLock<HashMap<String, String>>>;
22
23pub fn new_challenge_store() -> ChallengeStore {
24 Arc::new(RwLock::new(HashMap::new()))
25}
26
27pub struct AcmeCertResolver {
30 certs: RwLock<HashMap<String, Arc<CertifiedKey>>>,
31 fallback: RwLock<Option<Arc<CertifiedKey>>>,
32}
33
34impl fmt::Debug for AcmeCertResolver {
35 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
36 let domain_count = self.certs.read().map(|c| c.len()).unwrap_or(0);
37 let has_fallback = self.fallback.read().map(|f| f.is_some()).unwrap_or(false);
38 f.debug_struct("AcmeCertResolver")
39 .field("domains", &domain_count)
40 .field("has_fallback", &has_fallback)
41 .finish()
42 }
43}
44
45impl Default for AcmeCertResolver {
46 fn default() -> Self {
47 Self {
48 certs: RwLock::new(HashMap::new()),
49 fallback: RwLock::new(None),
50 }
51 }
52}
53
54impl AcmeCertResolver {
55 pub fn new() -> Self {
56 Self::default()
57 }
58
59 pub fn set_fallback(&self, key: Arc<CertifiedKey>) {
60 if let Ok(mut fallback) = self.fallback.write() {
61 *fallback = Some(key);
62 }
63 }
64
65 pub fn set_cert(&self, domain: &str, key: Arc<CertifiedKey>) {
66 if let Ok(mut certs) = self.certs.write() {
67 certs.insert(domain.to_string(), key);
68 }
69 }
70}
71
72impl ResolvesServerCert for AcmeCertResolver {
73 fn resolve(
74 &self,
75 client_hello: tokio_rustls::rustls::server::ClientHello<'_>,
76 ) -> Option<Arc<CertifiedKey>> {
77 if let Some(sni) = client_hello.server_name() {
78 if let Ok(certs) = self.certs.read() {
79 if let Some(key) = certs.get(sni) {
80 return Some(key.clone());
81 }
82 }
83 }
84
85 if let Ok(fallback) = self.fallback.read() {
86 return fallback.clone();
87 }
88
89 None
90 }
91}
92
93pub struct AcmeService {
95 account: Arc<Account>,
96 challenge_store: ChallengeStore,
97 resolver: Arc<AcmeCertResolver>,
98 cache_dir: PathBuf,
99}
100
101impl AcmeService {
102 pub fn new(
103 account: Arc<Account>,
104 challenge_store: ChallengeStore,
105 resolver: Arc<AcmeCertResolver>,
106 cache_dir: PathBuf,
107 ) -> Self {
108 Self {
109 account,
110 challenge_store,
111 resolver,
112 cache_dir,
113 }
114 }
115
116 pub async fn ensure_certificate(&self, domain: &str) -> Result<()> {
118 if !cert_expires_soon(&self.cache_dir, domain) {
119 return Ok(());
120 }
121 tracing::info!("Issuing certificate for new domain: {}", domain);
122 let (cert_pem, key_pem) =
123 issue_certificate(&self.account, &[domain.to_string()], &self.challenge_store).await?;
124 save_certificate(&self.cache_dir, domain, &cert_pem, &key_pem)?;
125 let ck = certified_key_from_pem(cert_pem.as_bytes(), key_pem.as_bytes())?;
126 self.resolver.set_cert(domain, Arc::new(ck));
127 Ok(())
128 }
129}
130
131pub fn build_server_config(resolver: Arc<AcmeCertResolver>) -> Result<Arc<ServerConfig>> {
133 let mut config = ServerConfig::builder()
134 .with_no_client_auth()
135 .with_cert_resolver(resolver);
136 config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
137 Ok(Arc::new(config))
138}
139
140pub async fn get_or_create_account(
142 le_config: &LetsEncryptConfig,
143 cache_dir: &Path,
144) -> Result<Account> {
145 let creds_path = cache_dir.join("account_credentials.json");
146
147 if creds_path.exists() {
148 let json = std::fs::read_to_string(&creds_path)
149 .context("Failed to read ACME account credentials")?;
150 let credentials: AccountCredentials =
151 serde_json::from_str(&json).context("Failed to parse ACME account credentials")?;
152 let account = Account::builder()
153 .map_err(|e| anyhow::anyhow!("Failed to create account builder: {:?}", e))?
154 .from_credentials(credentials)
155 .await
156 .map_err(|e| anyhow::anyhow!("Failed to restore ACME account: {:?}", e))?;
157 tracing::info!("Restored ACME account from {}", creds_path.display());
158 return Ok(account);
159 }
160
161 let directory_url = if le_config.staging {
162 LetsEncrypt::Staging.url().to_string()
163 } else {
164 LetsEncrypt::Production.url().to_string()
165 };
166
167 let contact = format!("mailto:{}", le_config.email);
168 let (account, credentials) = Account::builder()
169 .map_err(|e| anyhow::anyhow!("Failed to create account builder: {:?}", e))?
170 .create(
171 &NewAccount {
172 contact: &[&contact],
173 terms_of_service_agreed: le_config.terms_agreed,
174 only_return_existing: false,
175 },
176 directory_url,
177 None,
178 )
179 .await
180 .map_err(|e| anyhow::anyhow!("Failed to create ACME account: {:?}", e))?;
181
182 std::fs::create_dir_all(cache_dir)?;
183 let json = serde_json::to_string_pretty(&credentials)?;
184 std::fs::write(&creds_path, &json)?;
185
186 #[cfg(unix)]
187 {
188 use std::os::unix::fs::PermissionsExt;
189 std::fs::set_permissions(&creds_path, std::fs::Permissions::from_mode(0o600))?;
190 }
191
192 tracing::info!(
193 "Created new ACME account, credentials saved to {}",
194 creds_path.display()
195 );
196 Ok(account)
197}
198
199pub async fn issue_certificate(
201 account: &Account,
202 domains: &[String],
203 challenge_store: &ChallengeStore,
204) -> Result<(String, String)> {
205 let identifiers: Vec<Identifier> = domains.iter().map(|d| Identifier::Dns(d.clone())).collect();
206
207 let mut order = account
208 .new_order(&NewOrder::new(&identifiers))
209 .await
210 .map_err(|e| anyhow::anyhow!("Failed to create ACME order: {:?}", e))?;
211
212 let mut challenge_tokens: Vec<String> = Vec::new();
214
215 {
217 let mut authorizations = order.authorizations();
218 while let Some(result) = authorizations.next().await {
219 let mut authz =
220 result.map_err(|e| anyhow::anyhow!("Failed to get authorization: {:?}", e))?;
221
222 if authz.status == instant_acme::AuthorizationStatus::Valid {
223 continue;
224 }
225
226 let mut challenge = authz
227 .challenge(ChallengeType::Http01)
228 .ok_or_else(|| anyhow::anyhow!("No HTTP-01 challenge found"))?;
229
230 let key_auth = challenge.key_authorization();
231 let key_auth_str = key_auth.as_str().to_string();
232 let token = challenge.token.clone();
233
234 {
236 let mut store = challenge_store
237 .write()
238 .map_err(|_| anyhow::anyhow!("Challenge store poisoned"))?;
239 store.insert(token.clone(), key_auth_str);
240 }
241
242 challenge_tokens.push(token.clone());
243 tracing::info!("ACME challenge set for token: {}", token);
244
245 challenge
247 .set_ready()
248 .await
249 .map_err(|e| anyhow::anyhow!("Failed to set challenge ready: {:?}", e))?;
250 }
251 }
252
253 let status = order
255 .poll_ready(&instant_acme::RetryPolicy::default())
256 .await
257 .map_err(|e| anyhow::anyhow!("ACME order failed to become ready: {:?}", e))?;
258
259 if status != OrderStatus::Ready {
260 anyhow::bail!("ACME order not ready, status: {:?}", status);
261 }
262
263 let private_key_pem = order
265 .finalize()
266 .await
267 .map_err(|e| anyhow::anyhow!("Failed to finalize ACME order: {:?}", e))?;
268
269 let cert_chain_pem = order
271 .certificate()
272 .await
273 .map_err(|e| anyhow::anyhow!("Failed to retrieve certificate: {:?}", e))?
274 .ok_or_else(|| anyhow::anyhow!("No certificate returned"))?;
275
276 {
278 let mut store = challenge_store
279 .write()
280 .map_err(|_| anyhow::anyhow!("Challenge store poisoned"))?;
281 for token in &challenge_tokens {
282 store.remove(token);
283 }
284 }
285
286 Ok((cert_chain_pem, private_key_pem))
287}
288
289pub fn save_certificate(
291 cache_dir: &Path,
292 domain: &str,
293 cert_pem: &str,
294 key_pem: &str,
295) -> Result<()> {
296 std::fs::create_dir_all(cache_dir)?;
297
298 let cert_path = cache_dir.join(format!("{}.cert.pem", domain));
299 let key_path = cache_dir.join(format!("{}.key.pem", domain));
300
301 std::fs::write(&cert_path, cert_pem)?;
302 std::fs::write(&key_path, key_pem)?;
303
304 #[cfg(unix)]
305 {
306 use std::os::unix::fs::PermissionsExt;
307 std::fs::set_permissions(&key_path, std::fs::Permissions::from_mode(0o600))?;
308 }
309
310 tracing::info!(
311 "Saved certificate for {} to {}",
312 domain,
313 cert_path.display()
314 );
315 Ok(())
316}
317
318pub fn load_certificate(cache_dir: &Path, domain: &str) -> Result<Option<Arc<CertifiedKey>>> {
320 let cert_path = cache_dir.join(format!("{}.cert.pem", domain));
321 let key_path = cache_dir.join(format!("{}.key.pem", domain));
322
323 if !cert_path.exists() || !key_path.exists() {
324 return Ok(None);
325 }
326
327 let ck = certified_key_from_pem(&std::fs::read(&cert_path)?, &std::fs::read(&key_path)?)?;
328
329 Ok(Some(Arc::new(ck)))
330}
331
332pub fn certified_key_from_pem(cert_pem: &[u8], key_pem: &[u8]) -> Result<CertifiedKey> {
334 let certs: Vec<CertificateDer<'static>> = rustls_pemfile::certs(&mut &*cert_pem)
335 .collect::<std::result::Result<Vec<_>, _>>()
336 .context("Failed to parse certificate PEM")?;
337
338 let key: PrivateKeyDer<'static> = rustls_pemfile::private_key(&mut &*key_pem)
339 .context("Failed to parse private key PEM")?
340 .ok_or_else(|| anyhow::anyhow!("No private key found in PEM"))?;
341
342 let provider = tokio_rustls::rustls::crypto::CryptoProvider::get_default()
343 .ok_or_else(|| anyhow::anyhow!("No default CryptoProvider installed"))?;
344
345 let certified_key = CertifiedKey::from_der(certs, key, provider)
346 .map_err(|e| anyhow::anyhow!("Failed to build CertifiedKey: {:?}", e))?;
347
348 Ok(certified_key)
349}
350
351pub fn cert_expires_soon(cache_dir: &Path, domain: &str) -> bool {
353 let cert_path = cache_dir.join(format!("{}.cert.pem", domain));
354
355 if !cert_path.exists() {
356 return true;
357 }
358
359 let cert_pem = match std::fs::read(&cert_path) {
360 Ok(data) => data,
361 Err(_) => return true,
362 };
363
364 let certs: Vec<CertificateDer<'static>> =
365 match rustls_pemfile::certs(&mut &*cert_pem).collect::<std::result::Result<Vec<_>, _>>() {
366 Ok(c) if !c.is_empty() => c,
367 _ => return true,
368 };
369
370 match x509_parser::parse_x509_certificate(&certs[0]) {
371 Ok((_, cert)) => {
372 let not_after = cert.validity().not_after.timestamp();
373 let now = std::time::SystemTime::now()
374 .duration_since(std::time::UNIX_EPOCH)
375 .map(|d| d.as_secs() as i64)
376 .unwrap_or(0);
377 let thirty_days = 30 * 24 * 3600;
378 not_after - now < thirty_days
379 }
380 Err(_) => true,
381 }
382}
383
384pub fn spawn_renewal_task(
387 account: Arc<Account>,
388 config_manager: Arc<dyn ConfigManagerTrait + Send + Sync>,
389 cache_dir: PathBuf,
390 challenge_store: ChallengeStore,
391 resolver: Arc<AcmeCertResolver>,
392) {
393 tokio::spawn(async move {
394 let interval = tokio::time::Duration::from_secs(12 * 3600);
395
396 loop {
397 tokio::time::sleep(interval).await;
398
399 let domains = config_manager.get_config().acme_domains();
401 tracing::info!("ACME renewal check: examining {} domain(s)", domains.len());
402
403 for domain in &domains {
404 if !cert_expires_soon(&cache_dir, domain) {
405 tracing::debug!("Certificate for {} is still valid", domain);
406 continue;
407 }
408
409 tracing::info!("Certificate for {} needs renewal, issuing...", domain);
410
411 match issue_certificate(&account, std::slice::from_ref(domain), &challenge_store)
412 .await
413 {
414 Ok((cert_pem, key_pem)) => {
415 if let Err(e) = save_certificate(&cache_dir, domain, &cert_pem, &key_pem) {
416 tracing::error!("Failed to save renewed cert for {}: {}", domain, e);
417 continue;
418 }
419
420 match certified_key_from_pem(cert_pem.as_bytes(), key_pem.as_bytes()) {
421 Ok(ck) => {
422 resolver.set_cert(domain, Arc::new(ck));
423 tracing::info!("Successfully renewed certificate for {}", domain);
424 }
425 Err(e) => {
426 tracing::error!(
427 "Failed to parse renewed cert for {}: {}",
428 domain,
429 e
430 );
431 }
432 }
433 }
434 Err(e) => {
435 tracing::error!("Failed to renew certificate for {}: {}", domain, e);
436 }
437 }
438 }
439 }
440 });
441}