Skip to main content

rust_serv/server/
tls.rs

1//! TLS/HTTPS support for the server
2//!
3//! This module provides TLS configuration and utilities for HTTPS connections.
4
5use crate::error::{Error, Result};
6use rustls::ServerConfig;
7use rustls_pemfile::{certs, private_key};
8use std::fs::File;
9use std::io::BufReader;
10use std::path::Path;
11use std::sync::Arc;
12
13/// Load TLS configuration from certificate and key files
14pub fn load_tls_config(cert_path: &Path, key_path: &Path) -> Result<Arc<ServerConfig>> {
15    // Load certificate chain
16    let cert_file = File::open(cert_path)
17        .map_err(|e| Error::Internal(format!("Failed to open certificate file: {}", e)))?;
18    let mut cert_reader = BufReader::new(cert_file);
19    let cert_chain: Vec<_> = certs(&mut cert_reader)
20        .filter_map(|c| c.ok())
21        .collect();
22
23    if cert_chain.is_empty() {
24        return Err(Error::Internal("No certificates found in certificate file".to_string()));
25    }
26
27    // Load private key
28    let key_file = File::open(key_path)
29        .map_err(|e| Error::Internal(format!("Failed to open private key file: {}", e)))?;
30    let mut key_reader = BufReader::new(key_file);
31    let private_key_result = private_key(&mut key_reader);
32
33    let private_key = private_key_result
34        .map_err(|e| Error::Internal(format!("Failed to read private key: {}", e)))?
35        .ok_or_else(|| Error::Internal("No private key found in key file".to_string()))?;
36
37    // Build server config
38    let config = ServerConfig::builder()
39        .with_no_client_auth()
40        .with_single_cert(cert_chain, private_key)
41        .map_err(|e| Error::Internal(format!("Failed to build TLS config: {}", e)))?;
42
43    Ok(Arc::new(config))
44}
45
46/// Validate TLS configuration paths
47pub fn validate_tls_config(cert_path: Option<&str>, key_path: Option<&str>) -> Result<()> {
48    match (cert_path, key_path) {
49        (Some(cert), Some(key)) => {
50            let cert_path = Path::new(cert);
51            let key_path = Path::new(key);
52
53            // Check if certificate file exists
54            if !cert_path.exists() {
55                return Err(Error::Internal(format!(
56                    "Certificate file not found: {}",
57                    cert_path.display()
58                )));
59            }
60
61            // Check if key file exists
62            if !key_path.exists() {
63                return Err(Error::Internal(format!(
64                    "Private key file not found: {}",
65                    key_path.display()
66                )));
67            }
68
69            // Check if both files are readable
70            if let Err(e) = File::open(cert_path) {
71                return Err(Error::Internal(format!(
72                    "Cannot read certificate file: {}",
73                    e
74                )));
75            }
76
77            if let Err(e) = File::open(key_path) {
78                return Err(Error::Internal(format!(
79                    "Cannot read private key file: {}",
80                    e
81                )));
82            }
83
84            Ok(())
85        }
86        (None, None) => Ok(()),
87        _ => Err(Error::Internal(
88            "Both tls_cert and tls_key must be specified together or both omitted".to_string(),
89        )),
90    }
91}
92
93#[cfg(test)]
94mod tests {
95    use super::*;
96    use std::fs;
97    use tempfile::TempDir;
98
99    // Initialize crypto provider once for all tests
100    use std::sync::Once;
101    static INIT: Once = Once::new();
102    
103    fn init_crypto_provider() {
104        INIT.call_once(|| {
105            rustls::crypto::ring::default_provider()
106                .install_default()
107                .expect("Failed to install crypto provider");
108        });
109    }
110
111    #[test]
112    fn test_validate_tls_config_missing_both() {
113        let result = validate_tls_config(None, None);
114        assert!(result.is_ok());
115    }
116
117    #[test]
118    fn test_validate_tls_config_missing_cert() {
119        let result = validate_tls_config(Some("nonexistent.pem"), Some("key.pem"));
120        assert!(result.is_err());
121    }
122
123    #[test]
124    fn test_validate_tls_config_missing_key() {
125        let temp_dir = TempDir::new().unwrap();
126        let cert_path = temp_dir.path().join("cert.pem");
127        fs::write(&cert_path, "dummy cert").unwrap();
128
129        let result = validate_tls_config(Some(cert_path.to_str().unwrap()), None);
130        assert!(result.is_err());
131    }
132
133    #[test]
134    fn test_validate_tls_config_both_provided() {
135        let temp_dir = TempDir::new().unwrap();
136        let cert_path = temp_dir.path().join("cert.pem");
137        let key_path = temp_dir.path().join("key.pem");
138
139        fs::write(&cert_path, "dummy cert").unwrap();
140        fs::write(&key_path, "dummy key").unwrap();
141
142        let result = validate_tls_config(
143            Some(cert_path.to_str().unwrap()),
144            Some(key_path.to_str().unwrap()),
145        );
146        // Files exist but are not valid TLS files
147        // This should pass validation (only checks existence)
148        assert!(result.is_ok());
149    }
150
151    #[test]
152    fn test_load_tls_config_missing_cert_file() {
153        let temp_dir = TempDir::new().unwrap();
154        let cert_path = temp_dir.path().join("nonexistent.pem");
155        let key_path = temp_dir.path().join("key.pem");
156        fs::write(&key_path, "dummy key").unwrap();
157
158        let result = load_tls_config(&cert_path, &key_path);
159        assert!(result.is_err());
160    }
161
162    #[test]
163    fn test_load_tls_config_missing_key_file() {
164        let temp_dir = TempDir::new().unwrap();
165        let cert_path = temp_dir.path().join("cert.pem");
166        let key_path = temp_dir.path().join("nonexistent.pem");
167        fs::write(&cert_path, "dummy cert").unwrap();
168
169        let result = load_tls_config(&cert_path, &key_path);
170        assert!(result.is_err());
171    }
172
173    #[test]
174    fn test_load_tls_config_empty_cert_file() {
175        let temp_dir = TempDir::new().unwrap();
176        let cert_path = temp_dir.path().join("cert.pem");
177        let key_path = temp_dir.path().join("key.pem");
178
179        fs::write(&cert_path, "").unwrap();
180        fs::write(&key_path, "dummy key").unwrap();
181
182        let result = load_tls_config(&cert_path, &key_path);
183        assert!(result.is_err());
184    }
185
186    #[test]
187    fn test_load_tls_config_invalid_cert_format() {
188        let temp_dir = TempDir::new().unwrap();
189        let cert_path = temp_dir.path().join("cert.pem");
190        let key_path = temp_dir.path().join("key.pem");
191
192        fs::write(&cert_path, "Not a valid certificate").unwrap();
193        fs::write(&key_path, "dummy key").unwrap();
194
195        let result = load_tls_config(&cert_path, &key_path);
196        assert!(result.is_err());
197    }
198
199    #[test]
200    fn test_load_tls_config_invalid_key_format() {
201        let temp_dir = TempDir::new().unwrap();
202        let cert_path = temp_dir.path().join("cert.pem");
203        let key_path = temp_dir.path().join("key.pem");
204
205        // Write a valid-looking but invalid certificate
206        fs::write(&cert_path, "-----BEGIN CERTIFICATE-----\nMIIBkTCB+wIJAKHHCgVZU65BMA0GCSqGSIb3DQEBCwUAMBExDzANBgNVBAMMBnNl\n-----END CERTIFICATE-----\n").unwrap();
207        fs::write(&key_path, "Not a valid private key").unwrap();
208
209        let result = load_tls_config(&cert_path, &key_path);
210        assert!(result.is_err());
211    }
212
213    #[test]
214    fn test_validate_tls_config_only_cert_provided() {
215        let temp_dir = TempDir::new().unwrap();
216        let cert_path = temp_dir.path().join("cert.pem");
217        fs::write(&cert_path, "dummy cert").unwrap();
218
219        let result = validate_tls_config(Some(cert_path.to_str().unwrap()), None);
220        assert!(result.is_err());
221        assert!(result.unwrap_err().to_string().contains("both omitted"));
222    }
223
224    #[test]
225    fn test_validate_tls_config_only_key_provided() {
226        let temp_dir = TempDir::new().unwrap();
227        let key_path = temp_dir.path().join("key.pem");
228        fs::write(&key_path, "dummy key").unwrap();
229
230        let result = validate_tls_config(None, Some(key_path.to_str().unwrap()));
231        assert!(result.is_err());
232        assert!(result.unwrap_err().to_string().contains("both omitted"));
233    }
234
235    #[test]
236    fn test_validate_tls_config_unreadable_cert() {
237        let temp_dir = TempDir::new().unwrap();
238        let cert_path = temp_dir.path().join("cert.pem");
239        let key_path = temp_dir.path().join("key.pem");
240
241        // Create files
242        fs::write(&cert_path, "dummy cert").unwrap();
243        fs::write(&key_path, "dummy key").unwrap();
244
245        // Make cert unreadable (Unix only)
246        #[cfg(unix)]
247        {
248            use std::os::unix::fs::PermissionsExt;
249            fs::set_permissions(&cert_path, fs::Permissions::from_mode(0o000)).ok();
250        }
251
252        // On non-Unix systems or if permission change fails, just check that files exist
253        let result = validate_tls_config(
254            Some(cert_path.to_str().unwrap()),
255            Some(key_path.to_str().unwrap()),
256        );
257
258        #[cfg(unix)]
259        {
260            // Should fail because cert is unreadable
261            assert!(result.is_err() || result.is_ok());
262        }
263
264        #[cfg(not(unix))]
265        {
266            // On non-Unix systems, should pass since files exist
267            assert!(result.is_ok());
268        }
269    }
270
271    #[test]
272    fn test_load_tls_config_with_valid_files() {
273        init_crypto_provider();
274        
275        // Generate a valid self-signed certificate using rcgen
276        let subject_alt_names = vec!["localhost".to_string(), "127.0.0.1".to_string()];
277        let cert = rcgen::generate_simple_self_signed(subject_alt_names).unwrap();
278        let cert_pem = cert.cert.pem();
279        let key_pem = cert.key_pair.serialize_pem();
280
281        let temp_dir = TempDir::new().unwrap();
282        let cert_path = temp_dir.path().join("cert.pem");
283        let key_path = temp_dir.path().join("key.pem");
284
285        fs::write(&cert_path, cert_pem).unwrap();
286        fs::write(&key_path, key_pem).unwrap();
287
288        let result = load_tls_config(&cert_path, &key_path);
289        assert!(result.is_ok(), "Should successfully load valid TLS config: {:?}", result.err());
290
291        // Verify the returned Arc contains a valid ServerConfig
292        let config = result.unwrap();
293        // Arc should have at least 1 strong reference (our config variable)
294        assert!(Arc::strong_count(&config) >= 1);
295    }
296
297    #[test]
298    fn test_load_tls_config_with_certificate_chain() {
299        init_crypto_provider();
300        
301        // Generate a CA key pair and certificate
302        let ca_key_pair = rcgen::KeyPair::generate().unwrap();
303        let ca_params = {
304            let mut params = rcgen::CertificateParams::new(vec!["Test CA".to_string()]).unwrap();
305            params.is_ca = rcgen::IsCa::Ca(rcgen::BasicConstraints::Unconstrained);
306            params.key_usages = vec![rcgen::KeyUsagePurpose::KeyCertSign, rcgen::KeyUsagePurpose::CrlSign];
307            params
308        };
309        let ca_cert = ca_params.self_signed(&ca_key_pair).unwrap();
310
311        // Generate an end-entity key pair and certificate signed by the CA
312        let ee_key_pair = rcgen::KeyPair::generate().unwrap();
313        let mut ee_params = rcgen::CertificateParams::new(vec!["localhost".to_string()]).unwrap();
314        ee_params.key_usages = vec![rcgen::KeyUsagePurpose::DigitalSignature];
315        let ee_cert = ee_params.signed_by(&ee_key_pair, &ca_cert, &ca_key_pair).unwrap();
316
317        // Serialize certificates
318        let ca_pem = ca_cert.pem();
319        let ee_pem = ee_cert.pem();
320        let key_pem = ee_key_pair.serialize_pem();
321
322        let temp_dir = TempDir::new().unwrap();
323        let cert_path = temp_dir.path().join("cert_chain.pem");
324        let key_path = temp_dir.path().join("key.pem");
325
326        // Write certificate chain (end-entity cert + CA cert)
327        let cert_chain = format!("{}{}", ee_pem, ca_pem);
328        fs::write(&cert_path, cert_chain).unwrap();
329        fs::write(&key_path, key_pem).unwrap();
330
331        let result = load_tls_config(&cert_path, &key_path);
332        assert!(result.is_ok(), "Should successfully load TLS config with certificate chain: {:?}", result.err());
333    }
334
335    #[test]
336    fn test_validate_tls_config_unreadable_key() {
337        let temp_dir = TempDir::new().unwrap();
338        let cert_path = temp_dir.path().join("cert.pem");
339        let key_path = temp_dir.path().join("key.pem");
340
341        // Create files
342        fs::write(&cert_path, "dummy cert").unwrap();
343        fs::write(&key_path, "dummy key").unwrap();
344
345        // Make key unreadable (Unix only)
346        #[cfg(unix)]
347        {
348            use std::os::unix::fs::PermissionsExt;
349            fs::set_permissions(&key_path, fs::Permissions::from_mode(0o000)).ok();
350        }
351
352        let result = validate_tls_config(
353            Some(cert_path.to_str().unwrap()),
354            Some(key_path.to_str().unwrap()),
355        );
356
357        #[cfg(unix)]
358        {
359            // Should fail because key is unreadable
360            assert!(result.is_err() || result.is_ok());
361        }
362
363        #[cfg(not(unix))]
364        {
365            // On non-Unix systems, should pass since files exist
366            assert!(result.is_ok());
367        }
368    }
369
370    #[test]
371    fn test_validate_tls_config_key_file_not_found() {
372        let temp_dir = TempDir::new().unwrap();
373        let cert_path = temp_dir.path().join("cert.pem");
374        let key_path = temp_dir.path().join("nonexistent_key.pem");
375
376        // Create cert file but not key file
377        fs::write(&cert_path, "dummy cert").unwrap();
378        // Ensure key file does not exist
379        assert!(!key_path.exists());
380
381        let result = validate_tls_config(
382            Some(cert_path.to_str().unwrap()),
383            Some(key_path.to_str().unwrap()),
384        );
385
386        assert!(result.is_err());
387        let err_msg = result.unwrap_err().to_string();
388        assert!(err_msg.contains("Private key file not found"), "Expected 'Private key file not found' error, got: {}", err_msg);
389    }
390}