Skip to main content

rush_sync_server/server/
tls.rs

1use crate::core::prelude::*;
2use rcgen::{Certificate, CertificateParams, DistinguishedName};
3use rustls::{Certificate as RustlsCertificate, PrivateKey, ServerConfig};
4use rustls_pemfile::{certs, pkcs8_private_keys};
5use std::fs;
6use std::io::BufReader;
7use std::path::{Path, PathBuf};
8use std::sync::Arc;
9
10#[derive(Debug)]
11pub struct TlsManager {
12    cert_dir: PathBuf,
13    validity_days: u32,
14}
15
16impl TlsManager {
17    pub fn new(cert_dir: &str, validity_days: u32) -> Result<Self> {
18        let base_dir = crate::core::helpers::get_base_dir()?;
19
20        let cert_path = base_dir.join(cert_dir);
21        fs::create_dir_all(&cert_path).map_err(AppError::Io)?;
22
23        Ok(Self {
24            cert_dir: cert_path,
25            validity_days,
26        })
27    }
28
29    pub fn get_rustls_config(&self, server_name: &str, port: u16) -> Result<Arc<ServerConfig>> {
30        self.get_rustls_config_for_domain(server_name, port, "localhost")
31    }
32
33    pub fn get_rustls_config_for_domain(
34        &self,
35        server_name: &str,
36        port: u16,
37        production_domain: &str,
38    ) -> Result<Arc<ServerConfig>> {
39        let cert_file = self.get_cert_path(server_name, port);
40        let key_file = self.get_key_path(server_name, port);
41
42        // Generate certificate if it doesn't exist
43        if !cert_file.exists() || !key_file.exists() {
44            self.generate_certificate_with_domain(server_name, port, production_domain)?;
45        }
46
47        // Load certificate and key
48        let cert_chain = self.load_certificates(&cert_file)?;
49        let private_key = self.load_private_key(&key_file)?;
50
51        // Build rustls configuration
52        let config = ServerConfig::builder()
53            .with_safe_defaults()
54            .with_no_client_auth()
55            .with_single_cert(cert_chain, private_key)
56            .map_err(|e| AppError::Validation(format!("TLS config error: {}", e)))?;
57
58        Ok(Arc::new(config))
59    }
60
61    fn generate_certificate_with_domain(
62        &self,
63        server_name: &str,
64        port: u16,
65        production_domain: &str,
66    ) -> Result<()> {
67        log::info!("Generating TLS certificate for {}:{}", server_name, port);
68
69        // SANs: wildcard for proxy, specific subdomain for individual servers
70        let mut subject_alt_names = if server_name == "proxy" {
71            vec![
72                "localhost".to_string(),
73                "127.0.0.1".to_string(),
74                "*.localhost".to_string(),
75                "proxy.localhost".to_string(),
76            ]
77        } else {
78            vec![
79                "localhost".to_string(),
80                "127.0.0.1".to_string(),
81                format!("{}.localhost", server_name),
82                format!("{}:{}", server_name, port),
83            ]
84        };
85
86        // Add production domain SANs if configured
87        if production_domain != "localhost" {
88            subject_alt_names.push(production_domain.to_string());
89            subject_alt_names.push(format!("*.{}", production_domain));
90            if server_name != "proxy" {
91                subject_alt_names.push(format!("{}.{}", server_name, production_domain));
92            }
93        }
94
95        let mut params = CertificateParams::new(subject_alt_names);
96
97        // Distinguished Name
98        let mut dn = DistinguishedName::new();
99        dn.push(rcgen::DnType::OrganizationName, "Rush Sync Server");
100
101        let common_name = if production_domain != "localhost" {
102            if server_name == "proxy" {
103                format!("*.{}", production_domain)
104            } else {
105                format!("{}.{}", server_name, production_domain)
106            }
107        } else if server_name == "proxy" {
108            "*.localhost".to_string()
109        } else {
110            format!("{}.localhost", server_name)
111        };
112        let common_name = &common_name;
113
114        dn.push(rcgen::DnType::CommonName, common_name);
115        params.distinguished_name = dn;
116
117        // Validity period and key usage
118        params.not_before = time::OffsetDateTime::now_utc() - time::Duration::days(1);
119        params.not_after =
120            time::OffsetDateTime::now_utc() + time::Duration::days(self.validity_days as i64);
121
122        params.key_usages = vec![
123            rcgen::KeyUsagePurpose::DigitalSignature,
124            rcgen::KeyUsagePurpose::KeyEncipherment,
125        ];
126
127        params.extended_key_usages = vec![rcgen::ExtendedKeyUsagePurpose::ServerAuth];
128
129        // Generate and serialize certificate
130        let cert = Certificate::from_params(params)
131            .map_err(|e| AppError::Validation(format!("Certificate generation failed: {}", e)))?;
132
133        let cert_pem = cert.serialize_pem().map_err(|e| {
134            AppError::Validation(format!("Certificate serialization failed: {}", e))
135        })?;
136        let key_pem = cert.serialize_private_key_pem();
137
138        let cert_file = self.get_cert_path(server_name, port);
139        let key_file = self.get_key_path(server_name, port);
140
141        fs::write(&cert_file, cert_pem).map_err(AppError::Io)?;
142        fs::write(&key_file, key_pem).map_err(AppError::Io)?;
143
144        #[cfg(unix)]
145        {
146            use std::os::unix::fs::PermissionsExt;
147            let mut perms = fs::metadata(&key_file).map_err(AppError::Io)?.permissions();
148            perms.set_mode(0o600);
149            fs::set_permissions(&key_file, perms).map_err(AppError::Io)?;
150        }
151
152        log::info!("TLS certificate generated with CN: {}", common_name);
153        log::info!("Certificate: {:?}", cert_file);
154        log::info!("Private Key: {:?}", key_file);
155
156        Ok(())
157    }
158
159    fn load_certificates(&self, path: &Path) -> Result<Vec<RustlsCertificate>> {
160        let cert_file = fs::File::open(path).map_err(AppError::Io)?;
161        let mut reader = BufReader::new(cert_file);
162
163        let cert_chain = certs(&mut reader)
164            .map_err(|e| AppError::Validation(format!("Certificate parsing error: {}", e)))?
165            .into_iter()
166            .map(RustlsCertificate)
167            .collect();
168
169        Ok(cert_chain)
170    }
171
172    fn load_private_key(&self, path: &Path) -> Result<PrivateKey> {
173        let key_file = fs::File::open(path).map_err(AppError::Io)?;
174        let mut reader = BufReader::new(key_file);
175
176        let keys = pkcs8_private_keys(&mut reader)
177            .map_err(|e| AppError::Validation(format!("Private key parsing error: {}", e)))?;
178
179        if keys.is_empty() {
180            return Err(AppError::Validation("No private key found".to_string()));
181        }
182
183        Ok(PrivateKey(keys[0].clone()))
184    }
185
186    fn get_cert_path(&self, server_name: &str, port: u16) -> PathBuf {
187        self.cert_dir.join(format!("{}-{}.cert", server_name, port))
188    }
189
190    fn get_key_path(&self, server_name: &str, port: u16) -> PathBuf {
191        self.cert_dir.join(format!("{}-{}.key", server_name, port))
192    }
193
194    pub fn certificate_exists(&self, server_name: &str, port: u16) -> bool {
195        let cert_file = self.get_cert_path(server_name, port);
196        let key_file = self.get_key_path(server_name, port);
197        cert_file.exists() && key_file.exists()
198    }
199
200    pub fn remove_certificate(&self, server_name: &str, port: u16) -> Result<()> {
201        let cert_file = self.get_cert_path(server_name, port);
202        let key_file = self.get_key_path(server_name, port);
203
204        if cert_file.exists() {
205            fs::remove_file(&cert_file).map_err(AppError::Io)?;
206            log::info!("Removed certificate: {:?}", cert_file);
207        }
208
209        if key_file.exists() {
210            fs::remove_file(&key_file).map_err(AppError::Io)?;
211            log::info!("Removed private key: {:?}", key_file);
212        }
213
214        Ok(())
215    }
216
217    pub fn get_certificate_info(&self, server_name: &str, port: u16) -> Option<CertificateInfo> {
218        let cert_file = self.get_cert_path(server_name, port);
219
220        if !cert_file.exists() {
221            return None;
222        }
223
224        let metadata = fs::metadata(&cert_file).ok()?;
225        let size = metadata.len();
226        let modified = metadata.modified().ok()?;
227
228        Some(CertificateInfo {
229            cert_path: cert_file,
230            key_path: self.get_key_path(server_name, port),
231            file_size: size,
232            created: modified,
233            valid_days: self.validity_days,
234        })
235    }
236
237    pub fn list_certificates(&self) -> Result<Vec<CertificateInfo>> {
238        let mut certificates = Vec::new();
239
240        let entries = fs::read_dir(&self.cert_dir).map_err(AppError::Io)?;
241
242        for entry in entries {
243            let entry = entry.map_err(AppError::Io)?;
244            let path = entry.path();
245
246            if path.extension().and_then(|s| s.to_str()) == Some("cert") {
247                if let Some(stem) = path.file_stem().and_then(|s| s.to_str()) {
248                    // Parse server-port from filename
249                    if let Some((server, port_str)) = stem.rsplit_once('-') {
250                        if let Ok(port) = port_str.parse::<u16>() {
251                            if let Some(info) = self.get_certificate_info(server, port) {
252                                certificates.push(info);
253                            }
254                        }
255                    }
256                }
257            }
258        }
259
260        certificates.sort_by(|a, b| b.created.cmp(&a.created));
261        Ok(certificates)
262    }
263
264    pub fn get_production_config(&self, domain: &str) -> Result<Arc<ServerConfig>> {
265        // Check for existing Let's Encrypt certificate
266        let cert_file = self.cert_dir.join(format!("{}.fullchain.pem", domain));
267        let key_file = self.cert_dir.join(format!("{}.privkey.pem", domain));
268
269        if cert_file.exists() && key_file.exists() {
270            log::info!("Loading Let's Encrypt certificate for {}", domain);
271            let cert_chain = match self.load_certificates(&cert_file) {
272                Ok(c) => c,
273                Err(e) => {
274                    log::error!("LE cert corrupt for {}: {} — deleting for re-provision", domain, e);
275                    let _ = fs::remove_file(&cert_file);
276                    let _ = fs::remove_file(&key_file);
277                    return Err(e);
278                }
279            };
280            let private_key = match self.load_private_key(&key_file) {
281                Ok(k) => k,
282                Err(e) => {
283                    log::error!("LE key corrupt for {}: {} — deleting for re-provision", domain, e);
284                    let _ = fs::remove_file(&cert_file);
285                    let _ = fs::remove_file(&key_file);
286                    return Err(e);
287                }
288            };
289
290            match ServerConfig::builder()
291                .with_safe_defaults()
292                .with_no_client_auth()
293                .with_single_cert(cert_chain, private_key)
294            {
295                Ok(config) => return Ok(Arc::new(config)),
296                Err(e) => {
297                    // Cert/key mismatch (e.g. key was overwritten by a failed ACME attempt).
298                    // Delete both files so ACME will re-provision on the next cycle.
299                    log::error!(
300                        "LE cert/key mismatch for {}: {} — deleting for re-provision",
301                        domain, e
302                    );
303                    let _ = fs::remove_file(&cert_file);
304                    let _ = fs::remove_file(&key_file);
305                    return Err(AppError::Validation(format!(
306                        "Cert/key mismatch for {}: {}",
307                        domain, e
308                    )));
309                }
310            }
311        }
312
313        log::warn!("No Let's Encrypt certificate found for {}", domain);
314
315        // Return error so the caller can generate a proper self-signed cert
316        // with the correct production domain SANs (not *.localhost)
317        Err(AppError::Validation(format!(
318            "No Let's Encrypt certificate found for {}",
319            domain
320        )))
321    }
322}
323
324#[derive(Debug)]
325pub struct CertificateInfo {
326    pub cert_path: PathBuf,
327    pub key_path: PathBuf,
328    pub file_size: u64,
329    pub created: std::time::SystemTime,
330    pub valid_days: u32,
331}
332
333impl CertificateInfo {
334    pub fn is_expired(&self) -> bool {
335        if let Ok(elapsed) = self.created.elapsed() {
336            elapsed.as_secs() > (self.valid_days as u64 * 24 * 60 * 60)
337        } else {
338            true
339        }
340    }
341
342    pub fn days_until_expiry(&self) -> i64 {
343        if let Ok(elapsed) = self.created.elapsed() {
344            let elapsed_days = elapsed.as_secs() / (24 * 60 * 60);
345            (self.valid_days as i64) - (elapsed_days as i64)
346        } else {
347            0
348        }
349    }
350}