1use crate::error::{Error, Result};
6use rustls::ServerConfig;
7use rustls_pemfile::{certs, private_key};
8use std::fs::File;
9use std::io::BufReader;
10use std::path::Path;
11use std::sync::Arc;
12
13pub fn load_tls_config(cert_path: &Path, key_path: &Path) -> Result<Arc<ServerConfig>> {
15 let cert_file = File::open(cert_path)
17 .map_err(|e| Error::Internal(format!("Failed to open certificate file: {}", e)))?;
18 let mut cert_reader = BufReader::new(cert_file);
19 let cert_chain: Vec<_> = certs(&mut cert_reader)
20 .filter_map(|c| c.ok())
21 .collect();
22
23 if cert_chain.is_empty() {
24 return Err(Error::Internal("No certificates found in certificate file".to_string()));
25 }
26
27 let key_file = File::open(key_path)
29 .map_err(|e| Error::Internal(format!("Failed to open private key file: {}", e)))?;
30 let mut key_reader = BufReader::new(key_file);
31 let private_key_result = private_key(&mut key_reader);
32
33 let private_key = private_key_result
34 .map_err(|e| Error::Internal(format!("Failed to read private key: {}", e)))?
35 .ok_or_else(|| Error::Internal("No private key found in key file".to_string()))?;
36
37 let config = ServerConfig::builder()
39 .with_no_client_auth()
40 .with_single_cert(cert_chain, private_key)
41 .map_err(|e| Error::Internal(format!("Failed to build TLS config: {}", e)))?;
42
43 Ok(Arc::new(config))
44}
45
46pub fn validate_tls_config(cert_path: Option<&str>, key_path: Option<&str>) -> Result<()> {
48 match (cert_path, key_path) {
49 (Some(cert), Some(key)) => {
50 let cert_path = Path::new(cert);
51 let key_path = Path::new(key);
52
53 if !cert_path.exists() {
55 return Err(Error::Internal(format!(
56 "Certificate file not found: {}",
57 cert_path.display()
58 )));
59 }
60
61 if !key_path.exists() {
63 return Err(Error::Internal(format!(
64 "Private key file not found: {}",
65 key_path.display()
66 )));
67 }
68
69 if let Err(e) = File::open(cert_path) {
71 return Err(Error::Internal(format!(
72 "Cannot read certificate file: {}",
73 e
74 )));
75 }
76
77 if let Err(e) = File::open(key_path) {
78 return Err(Error::Internal(format!(
79 "Cannot read private key file: {}",
80 e
81 )));
82 }
83
84 Ok(())
85 }
86 (None, None) => Ok(()),
87 _ => Err(Error::Internal(
88 "Both tls_cert and tls_key must be specified together or both omitted".to_string(),
89 )),
90 }
91}
92
93#[cfg(test)]
94mod tests {
95 use super::*;
96 use std::fs;
97 use tempfile::TempDir;
98
99 use std::sync::Once;
101 static INIT: Once = Once::new();
102
103 fn init_crypto_provider() {
104 INIT.call_once(|| {
105 rustls::crypto::ring::default_provider()
106 .install_default()
107 .expect("Failed to install crypto provider");
108 });
109 }
110
111 #[test]
112 fn test_validate_tls_config_missing_both() {
113 let result = validate_tls_config(None, None);
114 assert!(result.is_ok());
115 }
116
117 #[test]
118 fn test_validate_tls_config_missing_cert() {
119 let result = validate_tls_config(Some("nonexistent.pem"), Some("key.pem"));
120 assert!(result.is_err());
121 }
122
123 #[test]
124 fn test_validate_tls_config_missing_key() {
125 let temp_dir = TempDir::new().unwrap();
126 let cert_path = temp_dir.path().join("cert.pem");
127 fs::write(&cert_path, "dummy cert").unwrap();
128
129 let result = validate_tls_config(Some(cert_path.to_str().unwrap()), None);
130 assert!(result.is_err());
131 }
132
133 #[test]
134 fn test_validate_tls_config_both_provided() {
135 let temp_dir = TempDir::new().unwrap();
136 let cert_path = temp_dir.path().join("cert.pem");
137 let key_path = temp_dir.path().join("key.pem");
138
139 fs::write(&cert_path, "dummy cert").unwrap();
140 fs::write(&key_path, "dummy key").unwrap();
141
142 let result = validate_tls_config(
143 Some(cert_path.to_str().unwrap()),
144 Some(key_path.to_str().unwrap()),
145 );
146 assert!(result.is_ok());
149 }
150
151 #[test]
152 fn test_load_tls_config_missing_cert_file() {
153 let temp_dir = TempDir::new().unwrap();
154 let cert_path = temp_dir.path().join("nonexistent.pem");
155 let key_path = temp_dir.path().join("key.pem");
156 fs::write(&key_path, "dummy key").unwrap();
157
158 let result = load_tls_config(&cert_path, &key_path);
159 assert!(result.is_err());
160 }
161
162 #[test]
163 fn test_load_tls_config_missing_key_file() {
164 let temp_dir = TempDir::new().unwrap();
165 let cert_path = temp_dir.path().join("cert.pem");
166 let key_path = temp_dir.path().join("nonexistent.pem");
167 fs::write(&cert_path, "dummy cert").unwrap();
168
169 let result = load_tls_config(&cert_path, &key_path);
170 assert!(result.is_err());
171 }
172
173 #[test]
174 fn test_load_tls_config_empty_cert_file() {
175 let temp_dir = TempDir::new().unwrap();
176 let cert_path = temp_dir.path().join("cert.pem");
177 let key_path = temp_dir.path().join("key.pem");
178
179 fs::write(&cert_path, "").unwrap();
180 fs::write(&key_path, "dummy key").unwrap();
181
182 let result = load_tls_config(&cert_path, &key_path);
183 assert!(result.is_err());
184 }
185
186 #[test]
187 fn test_load_tls_config_invalid_cert_format() {
188 let temp_dir = TempDir::new().unwrap();
189 let cert_path = temp_dir.path().join("cert.pem");
190 let key_path = temp_dir.path().join("key.pem");
191
192 fs::write(&cert_path, "Not a valid certificate").unwrap();
193 fs::write(&key_path, "dummy key").unwrap();
194
195 let result = load_tls_config(&cert_path, &key_path);
196 assert!(result.is_err());
197 }
198
199 #[test]
200 fn test_load_tls_config_invalid_key_format() {
201 let temp_dir = TempDir::new().unwrap();
202 let cert_path = temp_dir.path().join("cert.pem");
203 let key_path = temp_dir.path().join("key.pem");
204
205 fs::write(&cert_path, "-----BEGIN CERTIFICATE-----\nMIIBkTCB+wIJAKHHCgVZU65BMA0GCSqGSIb3DQEBCwUAMBExDzANBgNVBAMMBnNl\n-----END CERTIFICATE-----\n").unwrap();
207 fs::write(&key_path, "Not a valid private key").unwrap();
208
209 let result = load_tls_config(&cert_path, &key_path);
210 assert!(result.is_err());
211 }
212
213 #[test]
214 fn test_validate_tls_config_only_cert_provided() {
215 let temp_dir = TempDir::new().unwrap();
216 let cert_path = temp_dir.path().join("cert.pem");
217 fs::write(&cert_path, "dummy cert").unwrap();
218
219 let result = validate_tls_config(Some(cert_path.to_str().unwrap()), None);
220 assert!(result.is_err());
221 assert!(result.unwrap_err().to_string().contains("both omitted"));
222 }
223
224 #[test]
225 fn test_validate_tls_config_only_key_provided() {
226 let temp_dir = TempDir::new().unwrap();
227 let key_path = temp_dir.path().join("key.pem");
228 fs::write(&key_path, "dummy key").unwrap();
229
230 let result = validate_tls_config(None, Some(key_path.to_str().unwrap()));
231 assert!(result.is_err());
232 assert!(result.unwrap_err().to_string().contains("both omitted"));
233 }
234
235 #[test]
236 fn test_validate_tls_config_unreadable_cert() {
237 let temp_dir = TempDir::new().unwrap();
238 let cert_path = temp_dir.path().join("cert.pem");
239 let key_path = temp_dir.path().join("key.pem");
240
241 fs::write(&cert_path, "dummy cert").unwrap();
243 fs::write(&key_path, "dummy key").unwrap();
244
245 #[cfg(unix)]
247 {
248 use std::os::unix::fs::PermissionsExt;
249 fs::set_permissions(&cert_path, fs::Permissions::from_mode(0o000)).ok();
250 }
251
252 let result = validate_tls_config(
254 Some(cert_path.to_str().unwrap()),
255 Some(key_path.to_str().unwrap()),
256 );
257
258 #[cfg(unix)]
259 {
260 assert!(result.is_err() || result.is_ok());
262 }
263
264 #[cfg(not(unix))]
265 {
266 assert!(result.is_ok());
268 }
269 }
270
271 #[test]
272 fn test_load_tls_config_with_valid_files() {
273 init_crypto_provider();
274
275 let subject_alt_names = vec!["localhost".to_string(), "127.0.0.1".to_string()];
277 let cert = rcgen::generate_simple_self_signed(subject_alt_names).unwrap();
278 let cert_pem = cert.cert.pem();
279 let key_pem = cert.key_pair.serialize_pem();
280
281 let temp_dir = TempDir::new().unwrap();
282 let cert_path = temp_dir.path().join("cert.pem");
283 let key_path = temp_dir.path().join("key.pem");
284
285 fs::write(&cert_path, cert_pem).unwrap();
286 fs::write(&key_path, key_pem).unwrap();
287
288 let result = load_tls_config(&cert_path, &key_path);
289 assert!(result.is_ok(), "Should successfully load valid TLS config: {:?}", result.err());
290
291 let config = result.unwrap();
293 assert!(Arc::strong_count(&config) >= 1);
295 }
296
297 #[test]
298 fn test_load_tls_config_with_certificate_chain() {
299 init_crypto_provider();
300
301 let ca_key_pair = rcgen::KeyPair::generate().unwrap();
303 let ca_params = {
304 let mut params = rcgen::CertificateParams::new(vec!["Test CA".to_string()]).unwrap();
305 params.is_ca = rcgen::IsCa::Ca(rcgen::BasicConstraints::Unconstrained);
306 params.key_usages = vec![rcgen::KeyUsagePurpose::KeyCertSign, rcgen::KeyUsagePurpose::CrlSign];
307 params
308 };
309 let ca_cert = ca_params.self_signed(&ca_key_pair).unwrap();
310
311 let ee_key_pair = rcgen::KeyPair::generate().unwrap();
313 let mut ee_params = rcgen::CertificateParams::new(vec!["localhost".to_string()]).unwrap();
314 ee_params.key_usages = vec![rcgen::KeyUsagePurpose::DigitalSignature];
315 let ee_cert = ee_params.signed_by(&ee_key_pair, &ca_cert, &ca_key_pair).unwrap();
316
317 let ca_pem = ca_cert.pem();
319 let ee_pem = ee_cert.pem();
320 let key_pem = ee_key_pair.serialize_pem();
321
322 let temp_dir = TempDir::new().unwrap();
323 let cert_path = temp_dir.path().join("cert_chain.pem");
324 let key_path = temp_dir.path().join("key.pem");
325
326 let cert_chain = format!("{}{}", ee_pem, ca_pem);
328 fs::write(&cert_path, cert_chain).unwrap();
329 fs::write(&key_path, key_pem).unwrap();
330
331 let result = load_tls_config(&cert_path, &key_path);
332 assert!(result.is_ok(), "Should successfully load TLS config with certificate chain: {:?}", result.err());
333 }
334
335 #[test]
336 fn test_validate_tls_config_unreadable_key() {
337 let temp_dir = TempDir::new().unwrap();
338 let cert_path = temp_dir.path().join("cert.pem");
339 let key_path = temp_dir.path().join("key.pem");
340
341 fs::write(&cert_path, "dummy cert").unwrap();
343 fs::write(&key_path, "dummy key").unwrap();
344
345 #[cfg(unix)]
347 {
348 use std::os::unix::fs::PermissionsExt;
349 fs::set_permissions(&key_path, fs::Permissions::from_mode(0o000)).ok();
350 }
351
352 let result = validate_tls_config(
353 Some(cert_path.to_str().unwrap()),
354 Some(key_path.to_str().unwrap()),
355 );
356
357 #[cfg(unix)]
358 {
359 assert!(result.is_err() || result.is_ok());
361 }
362
363 #[cfg(not(unix))]
364 {
365 assert!(result.is_ok());
367 }
368 }
369
370 #[test]
371 fn test_validate_tls_config_key_file_not_found() {
372 let temp_dir = TempDir::new().unwrap();
373 let cert_path = temp_dir.path().join("cert.pem");
374 let key_path = temp_dir.path().join("nonexistent_key.pem");
375
376 fs::write(&cert_path, "dummy cert").unwrap();
378 assert!(!key_path.exists());
380
381 let result = validate_tls_config(
382 Some(cert_path.to_str().unwrap()),
383 Some(key_path.to_str().unwrap()),
384 );
385
386 assert!(result.is_err());
387 let err_msg = result.unwrap_err().to_string();
388 assert!(err_msg.contains("Private key file not found"), "Expected 'Private key file not found' error, got: {}", err_msg);
389 }
390}