1use crate::error::{ProxyError, Result};
23use dashmap::DashMap;
24use rustls::pki_types::{CertificateDer, PrivateKeyDer};
25use rustls::server::{ClientHello, ResolvesServerCert};
26use rustls::sign::CertifiedKey;
27use std::io::BufReader;
28use std::sync::{Arc, RwLock};
29use tracing::{debug, trace, warn};
30
31#[derive(Debug)]
43pub struct SniCertResolver {
44 certs: DashMap<String, Arc<CertifiedKey>>,
46 default_cert: RwLock<Option<Arc<CertifiedKey>>>,
48}
49
50impl SniCertResolver {
51 #[must_use]
61 pub fn new() -> Self {
62 Self {
63 certs: DashMap::new(),
64 default_cert: RwLock::new(None),
65 }
66 }
67
68 pub fn load_cert(&self, domain: &str, cert_pem: &str, key_pem: &str) -> Result<()> {
92 let certified_key = create_certified_key(cert_pem, key_pem)?;
93 let domain_normalized = normalize_domain(domain);
94
95 debug!(domain = %domain_normalized, "Loaded TLS certificate");
96 self.certs
97 .insert(domain_normalized, Arc::new(certified_key));
98
99 Ok(())
100 }
101
102 pub fn set_default_cert(&self, cert_pem: &str, key_pem: &str) -> Result<()> {
125 let certified_key = create_certified_key(cert_pem, key_pem)?;
126
127 debug!("Set default TLS certificate");
128 let mut default = self.default_cert.write().expect("RwLock poisoned");
129 *default = Some(Arc::new(certified_key));
130
131 Ok(())
132 }
133
134 pub fn remove_cert(&self, domain: &str) {
146 let domain_normalized = normalize_domain(domain);
147 if self.certs.remove(&domain_normalized).is_some() {
148 debug!(domain = %domain_normalized, "Removed TLS certificate");
149 }
150 }
151
152 pub fn refresh_cert(&self, domain: &str, cert_pem: &str, key_pem: &str) -> Result<()> {
173 let certified_key = create_certified_key(cert_pem, key_pem)?;
174 let domain_normalized = normalize_domain(domain);
175
176 debug!(domain = %domain_normalized, "Refreshed TLS certificate");
177 self.certs
178 .insert(domain_normalized, Arc::new(certified_key));
179
180 Ok(())
181 }
182
183 #[must_use]
193 pub fn has_cert(&self, domain: &str) -> bool {
194 let domain_normalized = normalize_domain(domain);
195 self.certs.contains_key(&domain_normalized)
196 }
197
198 #[must_use]
200 pub fn cert_count(&self) -> usize {
201 self.certs.len()
202 }
203
204 #[must_use]
206 pub fn domains(&self) -> Vec<String> {
207 self.certs.iter().map(|r| r.key().clone()).collect()
208 }
209
210 #[must_use]
212 pub fn has_default_cert(&self) -> bool {
213 self.default_cert
214 .read()
215 .map(|guard| guard.is_some())
216 .unwrap_or(false)
217 }
218
219 fn resolve_cert(&self, server_name: Option<&str>) -> Option<Arc<CertifiedKey>> {
221 let server_name = server_name?;
222 let normalized = normalize_domain(server_name);
223
224 if let Some(cert) = self.certs.get(&normalized) {
226 trace!(domain = %normalized, "Exact certificate match");
227 return Some(Arc::clone(cert.value()));
228 }
229
230 if let Some(wildcard_domain) = get_wildcard_domain(&normalized) {
232 if let Some(cert) = self.certs.get(&wildcard_domain) {
233 trace!(
234 domain = %normalized,
235 wildcard = %wildcard_domain,
236 "Wildcard certificate match"
237 );
238 return Some(Arc::clone(cert.value()));
239 }
240 }
241
242 if let Ok(guard) = self.default_cert.read() {
245 if let Some(default) = guard.as_ref() {
246 trace!(domain = %normalized, "Using default certificate");
247 return Some(Arc::clone(default));
248 }
249 }
250
251 warn!(domain = %normalized, "No certificate found");
252 None
253 }
254}
255
256impl Default for SniCertResolver {
257 fn default() -> Self {
258 Self::new()
259 }
260}
261
262impl ResolvesServerCert for SniCertResolver {
263 fn resolve(&self, client_hello: ClientHello<'_>) -> Option<Arc<CertifiedKey>> {
264 let server_name = client_hello.server_name();
265 self.resolve_cert(server_name)
266 }
267}
268
269fn create_certified_key(cert_pem: &str, key_pem: &str) -> Result<CertifiedKey> {
285 let certs = parse_certificates(cert_pem)?;
287 if certs.is_empty() {
288 return Err(ProxyError::Tls("No certificates found in PEM".to_string()));
289 }
290
291 let key = parse_private_key(key_pem)?;
293
294 let signing_key = rustls::crypto::ring::sign::any_supported_type(&key)
296 .map_err(|e| ProxyError::Tls(format!("Failed to create signing key: {e}")))?;
297
298 Ok(CertifiedKey::new(certs, signing_key))
299}
300
301fn parse_certificates(pem: &str) -> Result<Vec<CertificateDer<'static>>> {
303 let mut reader = BufReader::new(pem.as_bytes());
304 let certs: Vec<CertificateDer<'static>> = rustls_pemfile::certs(&mut reader)
305 .collect::<std::result::Result<Vec<_>, _>>()
306 .map_err(|e| ProxyError::Tls(format!("Failed to parse certificate PEM: {e}")))?;
307
308 Ok(certs)
309}
310
311fn parse_private_key(pem: &str) -> Result<PrivateKeyDer<'static>> {
313 let mut reader = BufReader::new(pem.as_bytes());
314
315 loop {
317 match rustls_pemfile::read_one(&mut reader) {
318 Ok(Some(rustls_pemfile::Item::Pkcs1Key(key))) => {
319 return Ok(PrivateKeyDer::Pkcs1(key));
320 }
321 Ok(Some(rustls_pemfile::Item::Pkcs8Key(key))) => {
322 return Ok(PrivateKeyDer::Pkcs8(key));
323 }
324 Ok(Some(rustls_pemfile::Item::Sec1Key(key))) => {
325 return Ok(PrivateKeyDer::Sec1(key));
326 }
327 Ok(Some(_)) => {
328 }
330 Ok(None) => {
331 return Err(ProxyError::Tls("No private key found in PEM".to_string()));
332 }
333 Err(e) => {
334 return Err(ProxyError::Tls(format!(
335 "Failed to parse private key PEM: {e}"
336 )));
337 }
338 }
339 }
340}
341
342fn normalize_domain(domain: &str) -> String {
344 domain.trim().to_lowercase()
345}
346
347fn get_wildcard_domain(domain: &str) -> Option<String> {
352 let parts: Vec<&str> = domain.split('.').collect();
353 if parts.len() > 2 {
354 Some(format!("*.{}", parts[1..].join(".")))
356 } else {
357 None
359 }
360}
361
362#[cfg(test)]
363mod tests {
364 use super::*;
365
366 #[test]
367 fn test_normalize_domain() {
368 assert_eq!(normalize_domain("Example.COM"), "example.com");
369 assert_eq!(normalize_domain(" foo.bar.com "), "foo.bar.com");
370 assert_eq!(normalize_domain("API.Example.ORG"), "api.example.org");
371 }
372
373 #[test]
374 fn test_get_wildcard_domain() {
375 assert_eq!(
376 get_wildcard_domain("foo.example.com"),
377 Some("*.example.com".to_string())
378 );
379 assert_eq!(
380 get_wildcard_domain("bar.foo.example.com"),
381 Some("*.foo.example.com".to_string())
382 );
383 assert_eq!(get_wildcard_domain("example.com"), None);
384 assert_eq!(get_wildcard_domain("localhost"), None);
385 }
386
387 #[test]
388 fn test_sni_resolver_new() {
389 let resolver = SniCertResolver::new();
390 assert_eq!(resolver.cert_count(), 0);
391 assert!(resolver.domains().is_empty());
392 }
393
394 #[test]
395 fn test_sni_resolver_default() {
396 let resolver = SniCertResolver::default();
397 assert_eq!(resolver.cert_count(), 0);
398 }
399
400 fn generate_test_cert() -> (String, String) {
402 use rcgen::{generate_simple_self_signed, CertifiedKey as RcgenCertifiedKey};
403
404 let subject_alt_names = vec!["localhost".to_string(), "example.com".to_string()];
405 let RcgenCertifiedKey { cert, key_pair } =
406 generate_simple_self_signed(subject_alt_names).unwrap();
407
408 (cert.pem(), key_pair.serialize_pem())
409 }
410
411 #[tokio::test]
412 async fn test_load_cert() {
413 let resolver = SniCertResolver::new();
414 let (cert_pem, key_pem) = generate_test_cert();
415
416 let result = resolver.load_cert("example.com", &cert_pem, &key_pem);
417 assert!(result.is_ok());
418 assert!(resolver.has_cert("example.com"));
419 assert_eq!(resolver.cert_count(), 1);
420 }
421
422 #[tokio::test]
423 async fn test_load_cert_case_insensitive() {
424 let resolver = SniCertResolver::new();
425 let (cert_pem, key_pem) = generate_test_cert();
426
427 resolver
428 .load_cert("Example.COM", &cert_pem, &key_pem)
429 .unwrap();
430 assert!(resolver.has_cert("example.com"));
431 assert!(resolver.has_cert("EXAMPLE.COM"));
432 }
433
434 #[tokio::test]
435 async fn test_remove_cert() {
436 let resolver = SniCertResolver::new();
437 let (cert_pem, key_pem) = generate_test_cert();
438
439 resolver
440 .load_cert("example.com", &cert_pem, &key_pem)
441 .unwrap();
442 assert!(resolver.has_cert("example.com"));
443
444 resolver.remove_cert("example.com");
445 assert!(!resolver.has_cert("example.com"));
446 assert_eq!(resolver.cert_count(), 0);
447 }
448
449 #[tokio::test]
450 async fn test_refresh_cert() {
451 let resolver = SniCertResolver::new();
452 let (cert_pem, key_pem) = generate_test_cert();
453
454 resolver
456 .load_cert("example.com", &cert_pem, &key_pem)
457 .unwrap();
458
459 let (new_cert_pem, new_key_pem) = generate_test_cert();
461 let result = resolver.refresh_cert("example.com", &new_cert_pem, &new_key_pem);
462 assert!(result.is_ok());
463 assert_eq!(resolver.cert_count(), 1);
464 }
465
466 #[tokio::test]
467 async fn test_set_default_cert() {
468 let resolver = SniCertResolver::new();
469 let (cert_pem, key_pem) = generate_test_cert();
470
471 let result = resolver.set_default_cert(&cert_pem, &key_pem);
472 assert!(result.is_ok());
473
474 assert_eq!(resolver.cert_count(), 0);
476 }
477
478 #[tokio::test]
479 async fn test_has_default_cert() {
480 let resolver = SniCertResolver::new();
481 assert!(!resolver.has_default_cert());
482
483 let (cert_pem, key_pem) = generate_test_cert();
484 resolver.set_default_cert(&cert_pem, &key_pem).unwrap();
485
486 assert!(resolver.has_default_cert());
487 }
488
489 #[tokio::test]
490 async fn test_domains() {
491 let resolver = SniCertResolver::new();
492 let (cert_pem, key_pem) = generate_test_cert();
493
494 resolver
495 .load_cert("api.example.com", &cert_pem, &key_pem)
496 .unwrap();
497 resolver
498 .load_cert("web.example.com", &cert_pem, &key_pem)
499 .unwrap();
500
501 let domains = resolver.domains();
502 assert_eq!(domains.len(), 2);
503 assert!(domains.contains(&"api.example.com".to_string()));
504 assert!(domains.contains(&"web.example.com".to_string()));
505 }
506
507 #[tokio::test]
508 async fn test_resolve_exact_match() {
509 let resolver = SniCertResolver::new();
510 let (cert_pem, key_pem) = generate_test_cert();
511
512 resolver
513 .load_cert("example.com", &cert_pem, &key_pem)
514 .unwrap();
515
516 let result = resolver.resolve_cert(Some("example.com"));
517 assert!(result.is_some());
518 }
519
520 #[tokio::test]
521 async fn test_resolve_wildcard_match() {
522 let resolver = SniCertResolver::new();
523 let (cert_pem, key_pem) = generate_test_cert();
524
525 resolver
527 .load_cert("*.example.com", &cert_pem, &key_pem)
528 .unwrap();
529
530 let result = resolver.resolve_cert(Some("api.example.com"));
532 assert!(result.is_some());
533
534 let result = resolver.resolve_cert(Some("web.example.com"));
535 assert!(result.is_some());
536
537 let result = resolver.resolve_cert(Some("example.com"));
539 assert!(result.is_none());
540 }
541
542 #[tokio::test]
543 async fn test_resolve_default_fallback() {
544 let resolver = SniCertResolver::new();
545 let (cert_pem, key_pem) = generate_test_cert();
546
547 resolver.set_default_cert(&cert_pem, &key_pem).unwrap();
548
549 let result = resolver.resolve_cert(Some("unknown.com"));
551 assert!(result.is_some());
552 }
553
554 #[tokio::test]
555 async fn test_resolve_no_match() {
556 let resolver = SniCertResolver::new();
557 let (cert_pem, key_pem) = generate_test_cert();
558
559 resolver
560 .load_cert("example.com", &cert_pem, &key_pem)
561 .unwrap();
562
563 let result = resolver.resolve_cert(Some("other.com"));
565 assert!(result.is_none());
566 }
567
568 #[tokio::test]
569 async fn test_resolve_none_server_name() {
570 let resolver = SniCertResolver::new();
571
572 let result = resolver.resolve_cert(None);
574 assert!(result.is_none());
575 }
576
577 #[test]
578 fn test_invalid_cert_pem() {
579 let result = parse_certificates("not a valid PEM");
580 assert!(result.is_ok()); assert!(result.unwrap().is_empty());
582 }
583
584 #[test]
585 fn test_invalid_key_pem() {
586 let result = parse_private_key("not a valid PEM");
587 assert!(result.is_err());
588 }
589
590 #[test]
591 fn test_create_certified_key_empty_certs() {
592 let (_, key_pem) = generate_test_cert();
593 let result = create_certified_key("", &key_pem);
594 assert!(result.is_err());
595 }
596}