Skip to main content

zlayer_proxy/
sni_resolver.rs

1//! SNI-based TLS Certificate Resolver
2//!
3//! This module provides dynamic TLS certificate selection based on Server Name
4//! Indication (SNI). It allows serving different certificates for different domains
5//! at runtime, with support for wildcard certificates.
6//!
7//! # Example
8//!
9//! ```rust,ignore
10//! use zlayer_proxy::SniCertResolver;
11//!
12//! let resolver = SniCertResolver::new();
13//!
14//! // Load certificates for specific domains
15//! resolver.load_cert("example.com", cert_pem, key_pem)?;
16//! resolver.load_cert("*.example.com", wildcard_cert_pem, wildcard_key_pem)?;
17//!
18//! // Set a fallback certificate
19//! resolver.set_default_cert(default_cert_pem, default_key_pem)?;
20//! ```
21
22use crate::error::{ProxyError, Result};
23use dashmap::DashMap;
24use rustls::pki_types::{CertificateDer, PrivateKeyDer};
25use rustls::server::{ClientHello, ResolvesServerCert};
26use rustls::sign::CertifiedKey;
27use std::io::BufReader;
28use std::sync::{Arc, RwLock};
29use tracing::{debug, trace, warn};
30
31/// SNI-based certificate resolver for dynamic TLS certificate selection
32///
33/// This resolver maintains a mapping of domain names to TLS certificates,
34/// allowing the proxy to serve different certificates for different domains.
35/// It supports:
36///
37/// - Exact domain matching (e.g., `api.example.com`)
38/// - Wildcard certificates (e.g., `*.example.com`)
39/// - A default/fallback certificate for unmatched domains
40///
41/// The resolver is thread-safe and supports concurrent certificate updates.
42#[derive(Debug)]
43pub struct SniCertResolver {
44    /// Domain -> `CertifiedKey` mapping
45    certs: DashMap<String, Arc<CertifiedKey>>,
46    /// Default/fallback certificate (optional)
47    default_cert: RwLock<Option<Arc<CertifiedKey>>>,
48}
49
50impl SniCertResolver {
51    /// Create a new empty SNI certificate resolver
52    ///
53    /// # Example
54    ///
55    /// ```rust
56    /// use zlayer_proxy::SniCertResolver;
57    ///
58    /// let resolver = SniCertResolver::new();
59    /// ```
60    #[must_use]
61    pub fn new() -> Self {
62        Self {
63            certs: DashMap::new(),
64            default_cert: RwLock::new(None),
65        }
66    }
67
68    /// Load a certificate for a specific domain
69    ///
70    /// Parses the PEM-encoded certificate chain and private key, then stores
71    /// the resulting `CertifiedKey` for the given domain.
72    ///
73    /// # Arguments
74    ///
75    /// * `domain` - The domain name (e.g., `example.com` or `*.example.com`)
76    /// * `cert_pem` - PEM-encoded certificate chain
77    /// * `key_pem` - PEM-encoded private key
78    ///
79    /// # Errors
80    ///
81    /// Returns an error if:
82    /// - The certificate PEM cannot be parsed
83    /// - The private key PEM cannot be parsed
84    /// - The key is not compatible with the certificate
85    ///
86    /// # Example
87    ///
88    /// ```rust,ignore
89    /// resolver.load_cert("example.com", cert_pem, key_pem)?;
90    /// ```
91    pub fn load_cert(&self, domain: &str, cert_pem: &str, key_pem: &str) -> Result<()> {
92        let certified_key = create_certified_key(cert_pem, key_pem)?;
93        let domain_normalized = normalize_domain(domain);
94
95        debug!(domain = %domain_normalized, "Loaded TLS certificate");
96        self.certs
97            .insert(domain_normalized, Arc::new(certified_key));
98
99        Ok(())
100    }
101
102    /// Set the default/fallback certificate
103    ///
104    /// This certificate is used when no domain-specific certificate matches
105    /// the client's SNI request.
106    ///
107    /// # Arguments
108    ///
109    /// * `cert_pem` - PEM-encoded certificate chain
110    /// * `key_pem` - PEM-encoded private key
111    ///
112    /// # Errors
113    ///
114    /// Returns an error if the certificate or key cannot be parsed.
115    ///
116    /// # Example
117    ///
118    /// ```rust,ignore
119    /// resolver.set_default_cert(default_cert_pem, default_key_pem)?;
120    /// ```
121    /// # Panics
122    ///
123    /// Panics if the internal `RwLock` is poisoned.
124    pub fn set_default_cert(&self, cert_pem: &str, key_pem: &str) -> Result<()> {
125        let certified_key = create_certified_key(cert_pem, key_pem)?;
126
127        debug!("Set default TLS certificate");
128        let mut default = self.default_cert.write().expect("RwLock poisoned");
129        *default = Some(Arc::new(certified_key));
130
131        Ok(())
132    }
133
134    /// Remove a certificate for a specific domain
135    ///
136    /// # Arguments
137    ///
138    /// * `domain` - The domain name to remove the certificate for
139    ///
140    /// # Example
141    ///
142    /// ```rust,ignore
143    /// resolver.remove_cert("example.com");
144    /// ```
145    pub fn remove_cert(&self, domain: &str) {
146        let domain_normalized = normalize_domain(domain);
147        if self.certs.remove(&domain_normalized).is_some() {
148            debug!(domain = %domain_normalized, "Removed TLS certificate");
149        }
150    }
151
152    /// Refresh/update a certificate for an existing domain
153    ///
154    /// This is equivalent to calling `load_cert` but semantically indicates
155    /// an update to an existing certificate (e.g., for certificate renewal).
156    ///
157    /// # Arguments
158    ///
159    /// * `domain` - The domain name
160    /// * `cert_pem` - New PEM-encoded certificate chain
161    /// * `key_pem` - New PEM-encoded private key
162    ///
163    /// # Errors
164    ///
165    /// Returns an error if the certificate or key cannot be parsed.
166    ///
167    /// # Example
168    ///
169    /// ```rust,ignore
170    /// resolver.refresh_cert("example.com", new_cert_pem, new_key_pem)?;
171    /// ```
172    pub fn refresh_cert(&self, domain: &str, cert_pem: &str, key_pem: &str) -> Result<()> {
173        let certified_key = create_certified_key(cert_pem, key_pem)?;
174        let domain_normalized = normalize_domain(domain);
175
176        debug!(domain = %domain_normalized, "Refreshed TLS certificate");
177        self.certs
178            .insert(domain_normalized, Arc::new(certified_key));
179
180        Ok(())
181    }
182
183    /// Check if a certificate exists for a domain
184    ///
185    /// # Arguments
186    ///
187    /// * `domain` - The domain name to check
188    ///
189    /// # Returns
190    ///
191    /// `true` if a certificate is loaded for the exact domain name
192    #[must_use]
193    pub fn has_cert(&self, domain: &str) -> bool {
194        let domain_normalized = normalize_domain(domain);
195        self.certs.contains_key(&domain_normalized)
196    }
197
198    /// Get the number of loaded certificates
199    #[must_use]
200    pub fn cert_count(&self) -> usize {
201        self.certs.len()
202    }
203
204    /// List all domains with loaded certificates
205    #[must_use]
206    pub fn domains(&self) -> Vec<String> {
207        self.certs.iter().map(|r| r.key().clone()).collect()
208    }
209
210    /// Check if a default/fallback certificate is configured
211    #[must_use]
212    pub fn has_default_cert(&self) -> bool {
213        self.default_cert
214            .read()
215            .map(|guard| guard.is_some())
216            .unwrap_or(false)
217    }
218
219    /// Internal method to resolve a certificate for a given server name
220    fn resolve_cert(&self, server_name: Option<&str>) -> Option<Arc<CertifiedKey>> {
221        let server_name = server_name?;
222        let normalized = normalize_domain(server_name);
223
224        // Try exact match first
225        if let Some(cert) = self.certs.get(&normalized) {
226            trace!(domain = %normalized, "Exact certificate match");
227            return Some(Arc::clone(cert.value()));
228        }
229
230        // Try wildcard match (e.g., for "foo.example.com", try "*.example.com")
231        if let Some(wildcard_domain) = get_wildcard_domain(&normalized) {
232            if let Some(cert) = self.certs.get(&wildcard_domain) {
233                trace!(
234                    domain = %normalized,
235                    wildcard = %wildcard_domain,
236                    "Wildcard certificate match"
237                );
238                return Some(Arc::clone(cert.value()));
239            }
240        }
241
242        // Fall back to default certificate
243        // Note: We use std::sync::RwLock since ResolvesServerCert::resolve is sync
244        if let Ok(guard) = self.default_cert.read() {
245            if let Some(default) = guard.as_ref() {
246                trace!(domain = %normalized, "Using default certificate");
247                return Some(Arc::clone(default));
248            }
249        }
250
251        warn!(domain = %normalized, "No certificate found");
252        None
253    }
254}
255
256impl Default for SniCertResolver {
257    fn default() -> Self {
258        Self::new()
259    }
260}
261
262impl ResolvesServerCert for SniCertResolver {
263    fn resolve(&self, client_hello: ClientHello<'_>) -> Option<Arc<CertifiedKey>> {
264        let server_name = client_hello.server_name();
265        self.resolve_cert(server_name)
266    }
267}
268
269/// Create a `CertifiedKey` from PEM-encoded certificate and private key
270///
271/// # Arguments
272///
273/// * `cert_pem` - PEM-encoded certificate chain (may contain multiple certificates)
274/// * `key_pem` - PEM-encoded private key (PKCS#1, PKCS#8, or SEC1 format)
275///
276/// # Errors
277///
278/// Returns an error if:
279/// - The certificate PEM cannot be parsed
280/// - No certificates are found in the PEM
281/// - The private key PEM cannot be parsed
282/// - No private key is found in the PEM
283/// - The key cannot be converted to a signing key
284fn create_certified_key(cert_pem: &str, key_pem: &str) -> Result<CertifiedKey> {
285    // Parse certificates
286    let certs = parse_certificates(cert_pem)?;
287    if certs.is_empty() {
288        return Err(ProxyError::Tls("No certificates found in PEM".to_string()));
289    }
290
291    // Parse private key
292    let key = parse_private_key(key_pem)?;
293
294    // Create signing key using rustls crypto provider
295    let signing_key = rustls::crypto::ring::sign::any_supported_type(&key)
296        .map_err(|e| ProxyError::Tls(format!("Failed to create signing key: {e}")))?;
297
298    Ok(CertifiedKey::new(certs, signing_key))
299}
300
301/// Parse PEM-encoded certificates
302fn parse_certificates(pem: &str) -> Result<Vec<CertificateDer<'static>>> {
303    let mut reader = BufReader::new(pem.as_bytes());
304    let certs: Vec<CertificateDer<'static>> = rustls_pemfile::certs(&mut reader)
305        .collect::<std::result::Result<Vec<_>, _>>()
306        .map_err(|e| ProxyError::Tls(format!("Failed to parse certificate PEM: {e}")))?;
307
308    Ok(certs)
309}
310
311/// Parse a PEM-encoded private key
312fn parse_private_key(pem: &str) -> Result<PrivateKeyDer<'static>> {
313    let mut reader = BufReader::new(pem.as_bytes());
314
315    // Try to read different key formats
316    loop {
317        match rustls_pemfile::read_one(&mut reader) {
318            Ok(Some(rustls_pemfile::Item::Pkcs1Key(key))) => {
319                return Ok(PrivateKeyDer::Pkcs1(key));
320            }
321            Ok(Some(rustls_pemfile::Item::Pkcs8Key(key))) => {
322                return Ok(PrivateKeyDer::Pkcs8(key));
323            }
324            Ok(Some(rustls_pemfile::Item::Sec1Key(key))) => {
325                return Ok(PrivateKeyDer::Sec1(key));
326            }
327            Ok(Some(_)) => {
328                // Skip non-key items (like certificates)
329            }
330            Ok(None) => {
331                return Err(ProxyError::Tls("No private key found in PEM".to_string()));
332            }
333            Err(e) => {
334                return Err(ProxyError::Tls(format!(
335                    "Failed to parse private key PEM: {e}"
336                )));
337            }
338        }
339    }
340}
341
342/// Normalize a domain name (lowercase, trim whitespace)
343fn normalize_domain(domain: &str) -> String {
344    domain.trim().to_lowercase()
345}
346
347/// Get the wildcard domain for a given domain
348///
349/// For `foo.example.com`, returns `*.example.com`.
350/// For `example.com` (no subdomain), returns `None`.
351fn get_wildcard_domain(domain: &str) -> Option<String> {
352    let parts: Vec<&str> = domain.split('.').collect();
353    if parts.len() > 2 {
354        // Has subdomain, create wildcard
355        Some(format!("*.{}", parts[1..].join(".")))
356    } else {
357        // No subdomain (e.g., "example.com")
358        None
359    }
360}
361
362#[cfg(test)]
363mod tests {
364    use super::*;
365
366    #[test]
367    fn test_normalize_domain() {
368        assert_eq!(normalize_domain("Example.COM"), "example.com");
369        assert_eq!(normalize_domain("  foo.bar.com  "), "foo.bar.com");
370        assert_eq!(normalize_domain("API.Example.ORG"), "api.example.org");
371    }
372
373    #[test]
374    fn test_get_wildcard_domain() {
375        assert_eq!(
376            get_wildcard_domain("foo.example.com"),
377            Some("*.example.com".to_string())
378        );
379        assert_eq!(
380            get_wildcard_domain("bar.foo.example.com"),
381            Some("*.foo.example.com".to_string())
382        );
383        assert_eq!(get_wildcard_domain("example.com"), None);
384        assert_eq!(get_wildcard_domain("localhost"), None);
385    }
386
387    #[test]
388    fn test_sni_resolver_new() {
389        let resolver = SniCertResolver::new();
390        assert_eq!(resolver.cert_count(), 0);
391        assert!(resolver.domains().is_empty());
392    }
393
394    #[test]
395    fn test_sni_resolver_default() {
396        let resolver = SniCertResolver::default();
397        assert_eq!(resolver.cert_count(), 0);
398    }
399
400    // Generate a self-signed certificate for testing
401    fn generate_test_cert() -> (String, String) {
402        use rcgen::{generate_simple_self_signed, CertifiedKey as RcgenCertifiedKey};
403
404        let subject_alt_names = vec!["localhost".to_string(), "example.com".to_string()];
405        let RcgenCertifiedKey { cert, key_pair } =
406            generate_simple_self_signed(subject_alt_names).unwrap();
407
408        (cert.pem(), key_pair.serialize_pem())
409    }
410
411    #[tokio::test]
412    async fn test_load_cert() {
413        let resolver = SniCertResolver::new();
414        let (cert_pem, key_pem) = generate_test_cert();
415
416        let result = resolver.load_cert("example.com", &cert_pem, &key_pem);
417        assert!(result.is_ok());
418        assert!(resolver.has_cert("example.com"));
419        assert_eq!(resolver.cert_count(), 1);
420    }
421
422    #[tokio::test]
423    async fn test_load_cert_case_insensitive() {
424        let resolver = SniCertResolver::new();
425        let (cert_pem, key_pem) = generate_test_cert();
426
427        resolver
428            .load_cert("Example.COM", &cert_pem, &key_pem)
429            .unwrap();
430        assert!(resolver.has_cert("example.com"));
431        assert!(resolver.has_cert("EXAMPLE.COM"));
432    }
433
434    #[tokio::test]
435    async fn test_remove_cert() {
436        let resolver = SniCertResolver::new();
437        let (cert_pem, key_pem) = generate_test_cert();
438
439        resolver
440            .load_cert("example.com", &cert_pem, &key_pem)
441            .unwrap();
442        assert!(resolver.has_cert("example.com"));
443
444        resolver.remove_cert("example.com");
445        assert!(!resolver.has_cert("example.com"));
446        assert_eq!(resolver.cert_count(), 0);
447    }
448
449    #[tokio::test]
450    async fn test_refresh_cert() {
451        let resolver = SniCertResolver::new();
452        let (cert_pem, key_pem) = generate_test_cert();
453
454        // Load initial cert
455        resolver
456            .load_cert("example.com", &cert_pem, &key_pem)
457            .unwrap();
458
459        // Refresh with new cert
460        let (new_cert_pem, new_key_pem) = generate_test_cert();
461        let result = resolver.refresh_cert("example.com", &new_cert_pem, &new_key_pem);
462        assert!(result.is_ok());
463        assert_eq!(resolver.cert_count(), 1);
464    }
465
466    #[tokio::test]
467    async fn test_set_default_cert() {
468        let resolver = SniCertResolver::new();
469        let (cert_pem, key_pem) = generate_test_cert();
470
471        let result = resolver.set_default_cert(&cert_pem, &key_pem);
472        assert!(result.is_ok());
473
474        // Default cert should not show up in cert_count or domains
475        assert_eq!(resolver.cert_count(), 0);
476    }
477
478    #[tokio::test]
479    async fn test_has_default_cert() {
480        let resolver = SniCertResolver::new();
481        assert!(!resolver.has_default_cert());
482
483        let (cert_pem, key_pem) = generate_test_cert();
484        resolver.set_default_cert(&cert_pem, &key_pem).unwrap();
485
486        assert!(resolver.has_default_cert());
487    }
488
489    #[tokio::test]
490    async fn test_domains() {
491        let resolver = SniCertResolver::new();
492        let (cert_pem, key_pem) = generate_test_cert();
493
494        resolver
495            .load_cert("api.example.com", &cert_pem, &key_pem)
496            .unwrap();
497        resolver
498            .load_cert("web.example.com", &cert_pem, &key_pem)
499            .unwrap();
500
501        let domains = resolver.domains();
502        assert_eq!(domains.len(), 2);
503        assert!(domains.contains(&"api.example.com".to_string()));
504        assert!(domains.contains(&"web.example.com".to_string()));
505    }
506
507    #[tokio::test]
508    async fn test_resolve_exact_match() {
509        let resolver = SniCertResolver::new();
510        let (cert_pem, key_pem) = generate_test_cert();
511
512        resolver
513            .load_cert("example.com", &cert_pem, &key_pem)
514            .unwrap();
515
516        let result = resolver.resolve_cert(Some("example.com"));
517        assert!(result.is_some());
518    }
519
520    #[tokio::test]
521    async fn test_resolve_wildcard_match() {
522        let resolver = SniCertResolver::new();
523        let (cert_pem, key_pem) = generate_test_cert();
524
525        // Load wildcard cert
526        resolver
527            .load_cert("*.example.com", &cert_pem, &key_pem)
528            .unwrap();
529
530        // Should match subdomains
531        let result = resolver.resolve_cert(Some("api.example.com"));
532        assert!(result.is_some());
533
534        let result = resolver.resolve_cert(Some("web.example.com"));
535        assert!(result.is_some());
536
537        // Should not match base domain
538        let result = resolver.resolve_cert(Some("example.com"));
539        assert!(result.is_none());
540    }
541
542    #[tokio::test]
543    async fn test_resolve_default_fallback() {
544        let resolver = SniCertResolver::new();
545        let (cert_pem, key_pem) = generate_test_cert();
546
547        resolver.set_default_cert(&cert_pem, &key_pem).unwrap();
548
549        // Unknown domain should fall back to default
550        let result = resolver.resolve_cert(Some("unknown.com"));
551        assert!(result.is_some());
552    }
553
554    #[tokio::test]
555    async fn test_resolve_no_match() {
556        let resolver = SniCertResolver::new();
557        let (cert_pem, key_pem) = generate_test_cert();
558
559        resolver
560            .load_cert("example.com", &cert_pem, &key_pem)
561            .unwrap();
562
563        // No default, different domain
564        let result = resolver.resolve_cert(Some("other.com"));
565        assert!(result.is_none());
566    }
567
568    #[tokio::test]
569    async fn test_resolve_none_server_name() {
570        let resolver = SniCertResolver::new();
571
572        // No server name provided
573        let result = resolver.resolve_cert(None);
574        assert!(result.is_none());
575    }
576
577    #[test]
578    fn test_invalid_cert_pem() {
579        let result = parse_certificates("not a valid PEM");
580        assert!(result.is_ok()); // Will succeed but return empty vec
581        assert!(result.unwrap().is_empty());
582    }
583
584    #[test]
585    fn test_invalid_key_pem() {
586        let result = parse_private_key("not a valid PEM");
587        assert!(result.is_err());
588    }
589
590    #[test]
591    fn test_create_certified_key_empty_certs() {
592        let (_, key_pem) = generate_test_cert();
593        let result = create_certified_key("", &key_pem);
594        assert!(result.is_err());
595    }
596}