1use crate::error::{ProxyError, Result};
23use dashmap::DashMap;
24use rustls::pki_types::{CertificateDer, PrivateKeyDer};
25use rustls::server::{ClientHello, ResolvesServerCert};
26use rustls::sign::CertifiedKey;
27use std::io::BufReader;
28use std::sync::{Arc, RwLock};
29use tracing::{debug, trace};
30
31#[derive(Debug)]
43pub struct SniCertResolver {
44 certs: DashMap<String, Arc<CertifiedKey>>,
46 default_cert: RwLock<Option<Arc<CertifiedKey>>>,
48}
49
50impl SniCertResolver {
51 #[must_use]
61 pub fn new() -> Self {
62 Self {
63 certs: DashMap::new(),
64 default_cert: RwLock::new(None),
65 }
66 }
67
68 pub fn load_cert(&self, domain: &str, cert_pem: &str, key_pem: &str) -> Result<()> {
92 let certified_key = create_certified_key(cert_pem, key_pem)?;
93 let domain_normalized = normalize_domain(domain);
94
95 debug!(domain = %domain_normalized, "Loaded TLS certificate");
96 self.certs
97 .insert(domain_normalized, Arc::new(certified_key));
98
99 Ok(())
100 }
101
102 pub fn set_default_cert(&self, cert_pem: &str, key_pem: &str) -> Result<()> {
125 let certified_key = create_certified_key(cert_pem, key_pem)?;
126
127 debug!("Set default TLS certificate");
128 let mut default = self.default_cert.write().expect("RwLock poisoned");
129 *default = Some(Arc::new(certified_key));
130
131 Ok(())
132 }
133
134 pub fn remove_cert(&self, domain: &str) {
146 let domain_normalized = normalize_domain(domain);
147 if self.certs.remove(&domain_normalized).is_some() {
148 debug!(domain = %domain_normalized, "Removed TLS certificate");
149 }
150 }
151
152 pub fn refresh_cert(&self, domain: &str, cert_pem: &str, key_pem: &str) -> Result<()> {
173 let certified_key = create_certified_key(cert_pem, key_pem)?;
174 let domain_normalized = normalize_domain(domain);
175
176 debug!(domain = %domain_normalized, "Refreshed TLS certificate");
177 self.certs
178 .insert(domain_normalized, Arc::new(certified_key));
179
180 Ok(())
181 }
182
183 #[must_use]
193 pub fn has_cert(&self, domain: &str) -> bool {
194 let domain_normalized = normalize_domain(domain);
195 self.certs.contains_key(&domain_normalized)
196 }
197
198 #[must_use]
200 pub fn cert_count(&self) -> usize {
201 self.certs.len()
202 }
203
204 #[must_use]
206 pub fn domains(&self) -> Vec<String> {
207 self.certs.iter().map(|r| r.key().clone()).collect()
208 }
209
210 #[must_use]
212 pub fn has_default_cert(&self) -> bool {
213 self.default_cert.read().is_ok_and(|guard| guard.is_some())
214 }
215
216 fn resolve_cert(&self, server_name: Option<&str>) -> Option<Arc<CertifiedKey>> {
218 let server_name = server_name?;
219 let normalized = normalize_domain(server_name);
220
221 if let Some(cert) = self.certs.get(&normalized) {
223 trace!(domain = %normalized, "Exact certificate match");
224 return Some(Arc::clone(cert.value()));
225 }
226
227 if let Some(wildcard_domain) = get_wildcard_domain(&normalized) {
229 if let Some(cert) = self.certs.get(&wildcard_domain) {
230 trace!(
231 domain = %normalized,
232 wildcard = %wildcard_domain,
233 "Wildcard certificate match"
234 );
235 return Some(Arc::clone(cert.value()));
236 }
237 }
238
239 if let Ok(guard) = self.default_cert.read() {
242 if let Some(default) = guard.as_ref() {
243 trace!(domain = %normalized, "Using default certificate");
244 return Some(Arc::clone(default));
245 }
246 }
247
248 debug!(domain = %normalized, "No certificate found");
253 None
254 }
255}
256
257impl Default for SniCertResolver {
258 fn default() -> Self {
259 Self::new()
260 }
261}
262
263impl ResolvesServerCert for SniCertResolver {
264 fn resolve(&self, client_hello: ClientHello<'_>) -> Option<Arc<CertifiedKey>> {
265 let server_name = client_hello.server_name();
266 self.resolve_cert(server_name)
267 }
268}
269
270fn create_certified_key(cert_pem: &str, key_pem: &str) -> Result<CertifiedKey> {
286 let certs = parse_certificates(cert_pem)?;
288 if certs.is_empty() {
289 return Err(ProxyError::Tls("No certificates found in PEM".to_string()));
290 }
291
292 let key = parse_private_key(key_pem)?;
294
295 let signing_key = rustls::crypto::ring::sign::any_supported_type(&key)
297 .map_err(|e| ProxyError::Tls(format!("Failed to create signing key: {e}")))?;
298
299 Ok(CertifiedKey::new(certs, signing_key))
300}
301
302fn parse_certificates(pem: &str) -> Result<Vec<CertificateDer<'static>>> {
304 let mut reader = BufReader::new(pem.as_bytes());
305 let certs: Vec<CertificateDer<'static>> = rustls_pemfile::certs(&mut reader)
306 .collect::<std::result::Result<Vec<_>, _>>()
307 .map_err(|e| ProxyError::Tls(format!("Failed to parse certificate PEM: {e}")))?;
308
309 Ok(certs)
310}
311
312fn parse_private_key(pem: &str) -> Result<PrivateKeyDer<'static>> {
314 let mut reader = BufReader::new(pem.as_bytes());
315
316 loop {
318 match rustls_pemfile::read_one(&mut reader) {
319 Ok(Some(rustls_pemfile::Item::Pkcs1Key(key))) => {
320 return Ok(PrivateKeyDer::Pkcs1(key));
321 }
322 Ok(Some(rustls_pemfile::Item::Pkcs8Key(key))) => {
323 return Ok(PrivateKeyDer::Pkcs8(key));
324 }
325 Ok(Some(rustls_pemfile::Item::Sec1Key(key))) => {
326 return Ok(PrivateKeyDer::Sec1(key));
327 }
328 Ok(Some(_)) => {
329 }
331 Ok(None) => {
332 return Err(ProxyError::Tls("No private key found in PEM".to_string()));
333 }
334 Err(e) => {
335 return Err(ProxyError::Tls(format!(
336 "Failed to parse private key PEM: {e}"
337 )));
338 }
339 }
340 }
341}
342
343fn normalize_domain(domain: &str) -> String {
345 domain.trim().to_lowercase()
346}
347
348fn get_wildcard_domain(domain: &str) -> Option<String> {
353 let parts: Vec<&str> = domain.split('.').collect();
354 if parts.len() > 2 {
355 Some(format!("*.{}", parts[1..].join(".")))
357 } else {
358 None
360 }
361}
362
363#[cfg(test)]
364mod tests {
365 use super::*;
366
367 #[test]
368 fn test_normalize_domain() {
369 assert_eq!(normalize_domain("Example.COM"), "example.com");
370 assert_eq!(normalize_domain(" foo.bar.com "), "foo.bar.com");
371 assert_eq!(normalize_domain("API.Example.ORG"), "api.example.org");
372 }
373
374 #[test]
375 fn test_get_wildcard_domain() {
376 assert_eq!(
377 get_wildcard_domain("foo.example.com"),
378 Some("*.example.com".to_string())
379 );
380 assert_eq!(
381 get_wildcard_domain("bar.foo.example.com"),
382 Some("*.foo.example.com".to_string())
383 );
384 assert_eq!(get_wildcard_domain("example.com"), None);
385 assert_eq!(get_wildcard_domain("localhost"), None);
386 }
387
388 #[test]
389 fn test_sni_resolver_new() {
390 let resolver = SniCertResolver::new();
391 assert_eq!(resolver.cert_count(), 0);
392 assert!(resolver.domains().is_empty());
393 }
394
395 #[test]
396 fn test_sni_resolver_default() {
397 let resolver = SniCertResolver::default();
398 assert_eq!(resolver.cert_count(), 0);
399 }
400
401 fn generate_test_cert() -> (String, String) {
403 use rcgen::{generate_simple_self_signed, CertifiedKey as RcgenCertifiedKey};
404
405 let subject_alt_names = vec!["localhost".to_string(), "example.com".to_string()];
406 let RcgenCertifiedKey { cert, key_pair } =
407 generate_simple_self_signed(subject_alt_names).unwrap();
408
409 (cert.pem(), key_pair.serialize_pem())
410 }
411
412 #[tokio::test]
413 async fn test_load_cert() {
414 let resolver = SniCertResolver::new();
415 let (cert_pem, key_pem) = generate_test_cert();
416
417 let result = resolver.load_cert("example.com", &cert_pem, &key_pem);
418 assert!(result.is_ok());
419 assert!(resolver.has_cert("example.com"));
420 assert_eq!(resolver.cert_count(), 1);
421 }
422
423 #[tokio::test]
424 async fn test_load_cert_case_insensitive() {
425 let resolver = SniCertResolver::new();
426 let (cert_pem, key_pem) = generate_test_cert();
427
428 resolver
429 .load_cert("Example.COM", &cert_pem, &key_pem)
430 .unwrap();
431 assert!(resolver.has_cert("example.com"));
432 assert!(resolver.has_cert("EXAMPLE.COM"));
433 }
434
435 #[tokio::test]
436 async fn test_remove_cert() {
437 let resolver = SniCertResolver::new();
438 let (cert_pem, key_pem) = generate_test_cert();
439
440 resolver
441 .load_cert("example.com", &cert_pem, &key_pem)
442 .unwrap();
443 assert!(resolver.has_cert("example.com"));
444
445 resolver.remove_cert("example.com");
446 assert!(!resolver.has_cert("example.com"));
447 assert_eq!(resolver.cert_count(), 0);
448 }
449
450 #[tokio::test]
451 async fn test_refresh_cert() {
452 let resolver = SniCertResolver::new();
453 let (cert_pem, key_pem) = generate_test_cert();
454
455 resolver
457 .load_cert("example.com", &cert_pem, &key_pem)
458 .unwrap();
459
460 let (new_cert_pem, new_key_pem) = generate_test_cert();
462 let result = resolver.refresh_cert("example.com", &new_cert_pem, &new_key_pem);
463 assert!(result.is_ok());
464 assert_eq!(resolver.cert_count(), 1);
465 }
466
467 #[tokio::test]
468 async fn test_set_default_cert() {
469 let resolver = SniCertResolver::new();
470 let (cert_pem, key_pem) = generate_test_cert();
471
472 let result = resolver.set_default_cert(&cert_pem, &key_pem);
473 assert!(result.is_ok());
474
475 assert_eq!(resolver.cert_count(), 0);
477 }
478
479 #[tokio::test]
480 async fn test_has_default_cert() {
481 let resolver = SniCertResolver::new();
482 assert!(!resolver.has_default_cert());
483
484 let (cert_pem, key_pem) = generate_test_cert();
485 resolver.set_default_cert(&cert_pem, &key_pem).unwrap();
486
487 assert!(resolver.has_default_cert());
488 }
489
490 #[tokio::test]
491 async fn test_domains() {
492 let resolver = SniCertResolver::new();
493 let (cert_pem, key_pem) = generate_test_cert();
494
495 resolver
496 .load_cert("api.example.com", &cert_pem, &key_pem)
497 .unwrap();
498 resolver
499 .load_cert("web.example.com", &cert_pem, &key_pem)
500 .unwrap();
501
502 let domains = resolver.domains();
503 assert_eq!(domains.len(), 2);
504 assert!(domains.contains(&"api.example.com".to_string()));
505 assert!(domains.contains(&"web.example.com".to_string()));
506 }
507
508 #[tokio::test]
509 async fn test_resolve_exact_match() {
510 let resolver = SniCertResolver::new();
511 let (cert_pem, key_pem) = generate_test_cert();
512
513 resolver
514 .load_cert("example.com", &cert_pem, &key_pem)
515 .unwrap();
516
517 let result = resolver.resolve_cert(Some("example.com"));
518 assert!(result.is_some());
519 }
520
521 #[tokio::test]
522 async fn test_resolve_wildcard_match() {
523 let resolver = SniCertResolver::new();
524 let (cert_pem, key_pem) = generate_test_cert();
525
526 resolver
528 .load_cert("*.example.com", &cert_pem, &key_pem)
529 .unwrap();
530
531 let result = resolver.resolve_cert(Some("api.example.com"));
533 assert!(result.is_some());
534
535 let result = resolver.resolve_cert(Some("web.example.com"));
536 assert!(result.is_some());
537
538 let result = resolver.resolve_cert(Some("example.com"));
540 assert!(result.is_none());
541 }
542
543 #[tokio::test]
544 async fn test_resolve_default_fallback() {
545 let resolver = SniCertResolver::new();
546 let (cert_pem, key_pem) = generate_test_cert();
547
548 resolver.set_default_cert(&cert_pem, &key_pem).unwrap();
549
550 let result = resolver.resolve_cert(Some("unknown.com"));
552 assert!(result.is_some());
553 }
554
555 #[tokio::test]
556 async fn test_resolve_no_match() {
557 let resolver = SniCertResolver::new();
558 let (cert_pem, key_pem) = generate_test_cert();
559
560 resolver
561 .load_cert("example.com", &cert_pem, &key_pem)
562 .unwrap();
563
564 let result = resolver.resolve_cert(Some("other.com"));
566 assert!(result.is_none());
567 }
568
569 #[tokio::test]
570 async fn test_resolve_none_server_name() {
571 let resolver = SniCertResolver::new();
572
573 let result = resolver.resolve_cert(None);
575 assert!(result.is_none());
576 }
577
578 #[test]
579 fn test_invalid_cert_pem() {
580 let result = parse_certificates("not a valid PEM");
581 assert!(result.is_ok()); assert!(result.unwrap().is_empty());
583 }
584
585 #[test]
586 fn test_invalid_key_pem() {
587 let result = parse_private_key("not a valid PEM");
588 assert!(result.is_err());
589 }
590
591 #[test]
592 fn test_create_certified_key_empty_certs() {
593 let (_, key_pem) = generate_test_cert();
594 let result = create_certified_key("", &key_pem);
595 assert!(result.is_err());
596 }
597}