Skip to main content

rust_serv/auto_tls/
store.rs

1//! Certificate storage
2
3use serde::{Deserialize, Serialize};
4use std::path::{Path, PathBuf};
5use tokio::fs;
6
7use super::{AutoTlsError, AutoTlsResult};
8
9/// Stored certificate data
10#[derive(Debug, Clone, Serialize, Deserialize)]
11pub struct StoredCertificate {
12    /// Domain name
13    pub domain: String,
14    /// Certificate chain (PEM format)
15    pub certificate: String,
16    /// Private key (PEM format)
17    pub private_key: String,
18    /// When the certificate was issued (Unix timestamp)
19    pub issued_at: u64,
20    /// When the certificate expires (Unix timestamp)
21    pub expires_at: u64,
22}
23
24impl StoredCertificate {
25    /// Create a new stored certificate
26    pub fn new(domain: String, certificate: String, private_key: String, expires_at: u64) -> Self {
27        let issued_at = std::time::SystemTime::now()
28            .duration_since(std::time::UNIX_EPOCH)
29            .unwrap()
30            .as_secs();
31
32        Self {
33            domain,
34            certificate,
35            private_key,
36            issued_at,
37            expires_at,
38        }
39    }
40
41    /// Check if certificate is expired
42    pub fn is_expired(&self) -> bool {
43        let now = std::time::SystemTime::now()
44            .duration_since(std::time::UNIX_EPOCH)
45            .unwrap()
46            .as_secs();
47        now >= self.expires_at
48    }
49
50    /// Get days until expiration
51    pub fn days_until_expiry(&self) -> i64 {
52        let now = std::time::SystemTime::now()
53            .duration_since(std::time::UNIX_EPOCH)
54            .unwrap()
55            .as_secs();
56        let secs_until = self.expires_at.saturating_sub(now);
57        (secs_until / 86400) as i64
58    }
59}
60
61/// Certificate store for persistence
62#[derive(Debug, Clone)]
63pub struct CertificateStore {
64    cache_dir: PathBuf,
65}
66
67impl CertificateStore {
68    /// Create a new certificate store
69    pub fn new(cache_dir: &Path) -> Self {
70        Self {
71            cache_dir: cache_dir.to_path_buf(),
72        }
73    }
74
75    /// Get path for account credentials
76    pub fn account_path(&self) -> PathBuf {
77        self.cache_dir.join("account.json")
78    }
79
80    /// Get path for certificate file
81    pub fn certificate_path(&self, domain: &str) -> PathBuf {
82        self.cache_dir.join(format!("{}.json", domain))
83    }
84
85    /// Get path for certificate PEM file
86    pub fn cert_pem_path(&self, domain: &str) -> PathBuf {
87        self.cache_dir.join(format!("{}.crt", domain))
88    }
89
90    /// Get path for private key PEM file
91    pub fn key_pem_path(&self, domain: &str) -> PathBuf {
92        self.cache_dir.join(format!("{}.key", domain))
93    }
94
95    /// Initialize storage directory
96    pub async fn init(&self) -> AutoTlsResult<()> {
97        fs::create_dir_all(&self.cache_dir)
98            .await
99            .map_err(|e| AutoTlsError::StoreError(format!("Failed to create cache dir: {}", e)))?;
100
101        // Set restrictive permissions (600)
102        #[cfg(unix)]
103        {
104            use std::os::unix::fs::PermissionsExt;
105            let perms = std::fs::Permissions::from_mode(0o700);
106            fs::set_permissions(&self.cache_dir, perms)
107                .await
108                .ok();
109        }
110
111        Ok(())
112    }
113
114    /// Save certificate to storage
115    pub async fn save_certificate(
116        &self,
117        domain: &str,
118        certificate: &str,
119        private_key: &str,
120    ) -> AutoTlsResult<()> {
121        self.save_certificate_with_expiry(domain, certificate, private_key, 0).await
122    }
123
124    /// Save certificate with custom expiry time
125    pub async fn save_certificate_with_expiry(
126        &self,
127        domain: &str,
128        certificate: &str,
129        private_key: &str,
130        expires_at: u64,
131    ) -> AutoTlsResult<()> {
132        self.init().await?;
133
134        // Use provided expiry or default to 90 days from now
135        let expires_at = if expires_at == 0 {
136            std::time::SystemTime::now()
137                .duration_since(std::time::UNIX_EPOCH)
138                .unwrap()
139                .as_secs()
140                + (90 * 24 * 60 * 60)
141        } else {
142            expires_at
143        };
144
145        let stored = StoredCertificate::new(
146            domain.to_string(),
147            certificate.to_string(),
148            private_key.to_string(),
149            expires_at,
150        );
151
152        // Save JSON metadata
153        let json_path = self.certificate_path(domain);
154        let content = serde_json::to_string_pretty(&stored)?;
155        fs::write(&json_path, content)
156            .await
157            .map_err(|e| AutoTlsError::StoreError(format!("Failed to write certificate: {}", e)))?;
158
159        // Save certificate PEM
160        let cert_path = self.cert_pem_path(domain);
161        fs::write(&cert_path, certificate)
162            .await
163            .map_err(|e| AutoTlsError::StoreError(format!("Failed to write cert PEM: {}", e)))?;
164
165        // Save private key PEM
166        let key_path = self.key_pem_path(domain);
167        fs::write(&key_path, private_key)
168            .await
169            .map_err(|e| AutoTlsError::StoreError(format!("Failed to write key PEM: {}", e)))?;
170
171        // Set restrictive permissions on key file
172        #[cfg(unix)]
173        {
174            use std::os::unix::fs::PermissionsExt;
175            let perms = std::fs::Permissions::from_mode(0o600);
176            fs::set_permissions(&key_path, perms)
177                .await
178                .ok();
179        }
180
181        tracing::info!(
182            "Saved certificate for {} to {}",
183            domain,
184            self.cache_dir.display()
185        );
186
187        Ok(())
188    }
189
190    /// Load certificate from storage
191    pub async fn load_certificate(&self, domain: &str) -> AutoTlsResult<Option<StoredCertificate>> {
192        let path = self.certificate_path(domain);
193
194        if !path.exists() {
195            return Ok(None);
196        }
197
198        let content = fs::read_to_string(&path)
199            .await
200            .map_err(|e| AutoTlsError::StoreError(format!("Failed to read certificate: {}", e)))?;
201
202        let cert: StoredCertificate = serde_json::from_str(&content)?;
203
204        Ok(Some(cert))
205    }
206
207    /// Delete certificate from storage
208    pub async fn delete_certificate(&self, domain: &str) -> AutoTlsResult<()> {
209        let json_path = self.certificate_path(domain);
210        let cert_path = self.cert_pem_path(domain);
211        let key_path = self.key_pem_path(domain);
212
213        if json_path.exists() {
214            fs::remove_file(&json_path).await.ok();
215        }
216        if cert_path.exists() {
217            fs::remove_file(&cert_path).await.ok();
218        }
219        if key_path.exists() {
220            fs::remove_file(&key_path).await.ok();
221        }
222
223        Ok(())
224    }
225
226    /// List all stored certificates
227    pub async fn list_certificates(&self) -> AutoTlsResult<Vec<String>> {
228        if !self.cache_dir.exists() {
229            return Ok(vec![]);
230        }
231
232        let mut entries = fs::read_dir(&self.cache_dir)
233            .await
234            .map_err(|e| AutoTlsError::StoreError(format!("Failed to read cache dir: {}", e)))?;
235
236        let mut domains = vec![];
237
238        while let Some(entry) = entries.next_entry().await.ok().flatten() {
239            let path = entry.path();
240            if let Some(ext) = path.extension() {
241                if ext == "json" {
242                    if let Some(stem) = path.file_stem() {
243                        if stem != "account" {
244                            domains.push(stem.to_string_lossy().to_string());
245                        }
246                    }
247                }
248            }
249        }
250
251        Ok(domains)
252    }
253
254    /// Clean up expired certificates
255    pub async fn cleanup_expired(&self) -> AutoTlsResult<Vec<String>> {
256        let domains = self.list_certificates().await?;
257        let mut cleaned = vec![];
258
259        for domain in domains {
260            if let Some(cert) = self.load_certificate(&domain).await? {
261                if cert.is_expired() {
262                    self.delete_certificate(&domain).await?;
263                    cleaned.push(domain);
264                }
265            }
266        }
267
268        Ok(cleaned)
269    }
270}
271
272#[cfg(test)]
273mod tests {
274    use super::*;
275    use tempfile::tempdir;
276
277    #[test]
278    fn test_stored_certificate_creation() {
279        let cert = StoredCertificate::new(
280            "example.com".to_string(),
281            "cert_pem".to_string(),
282            "key_pem".to_string(),
283            std::time::SystemTime::now()
284                .duration_since(std::time::UNIX_EPOCH)
285                .unwrap()
286                .as_secs()
287                + 86400 * 90,
288        );
289
290        assert_eq!(cert.domain, "example.com");
291        assert!(!cert.is_expired());
292        assert!(cert.days_until_expiry() > 80);
293    }
294
295    #[test]
296    fn test_stored_certificate_expired() {
297        let cert = StoredCertificate::new(
298            "example.com".to_string(),
299            "cert_pem".to_string(),
300            "key_pem".to_string(),
301            0, // Expired at Unix epoch
302        );
303
304        assert!(cert.is_expired());
305        assert!(cert.days_until_expiry() <= 0);
306    }
307
308    #[test]
309    fn test_stored_certificate_serialization() {
310        let cert = StoredCertificate::new(
311            "example.com".to_string(),
312            "cert_pem".to_string(),
313            "key_pem".to_string(),
314            1234567890,
315        );
316
317        let json = serde_json::to_string(&cert).unwrap();
318        let parsed: StoredCertificate = serde_json::from_str(&json).unwrap();
319
320        assert_eq!(parsed.domain, cert.domain);
321        assert_eq!(parsed.certificate, cert.certificate);
322        assert_eq!(parsed.expires_at, cert.expires_at);
323    }
324
325    #[tokio::test]
326    async fn test_certificate_store_paths() {
327        let dir = tempdir().unwrap();
328        let store = CertificateStore::new(dir.path());
329
330        assert_eq!(store.account_path(), dir.path().join("account.json"));
331        assert_eq!(
332            store.certificate_path("example.com"),
333            dir.path().join("example.com.json")
334        );
335        assert_eq!(
336            store.cert_pem_path("example.com"),
337            dir.path().join("example.com.crt")
338        );
339        assert_eq!(
340            store.key_pem_path("example.com"),
341            dir.path().join("example.com.key")
342        );
343    }
344
345    #[tokio::test]
346    async fn test_certificate_store_save_load() {
347        let dir = tempdir().unwrap();
348        let store = CertificateStore::new(dir.path());
349
350        store.save_certificate("example.com", "cert", "key").await.unwrap();
351
352        let loaded = store.load_certificate("example.com").await.unwrap();
353        assert!(loaded.is_some());
354
355        let cert = loaded.unwrap();
356        assert_eq!(cert.domain, "example.com");
357        assert_eq!(cert.certificate, "cert");
358        assert_eq!(cert.private_key, "key");
359    }
360
361    #[tokio::test]
362    async fn test_certificate_store_delete() {
363        let dir = tempdir().unwrap();
364        let store = CertificateStore::new(dir.path());
365
366        store.save_certificate("example.com", "cert", "key").await.unwrap();
367        store.delete_certificate("example.com").await.unwrap();
368
369        let loaded = store.load_certificate("example.com").await.unwrap();
370        assert!(loaded.is_none());
371    }
372
373    #[tokio::test]
374    async fn test_certificate_store_list() {
375        let dir = tempdir().unwrap();
376        let store = CertificateStore::new(dir.path());
377
378        store.save_certificate("example.com", "cert1", "key1").await.unwrap();
379        store.save_certificate("test.com", "cert2", "key2").await.unwrap();
380
381        let domains = store.list_certificates().await.unwrap();
382        assert_eq!(domains.len(), 2);
383        assert!(domains.contains(&"example.com".to_string()));
384        assert!(domains.contains(&"test.com".to_string()));
385    }
386
387    #[tokio::test]
388    async fn test_certificate_store_init() {
389        let dir = tempdir().unwrap();
390        let new_dir = dir.path().join("new_cache");
391        let store = CertificateStore::new(&new_dir);
392
393        assert!(!new_dir.exists());
394        store.init().await.unwrap();
395        assert!(new_dir.exists());
396    }
397}