Skip to main content

zlayer_proxy/
sni_resolver.rs

1//! SNI-based TLS Certificate Resolver
2//!
3//! This module provides dynamic TLS certificate selection based on Server Name
4//! Indication (SNI). It allows serving different certificates for different domains
5//! at runtime, with support for wildcard certificates.
6//!
7//! # Example
8//!
9//! ```rust,ignore
10//! use zlayer_proxy::SniCertResolver;
11//!
12//! let resolver = SniCertResolver::new();
13//!
14//! // Load certificates for specific domains
15//! resolver.load_cert("example.com", cert_pem, key_pem)?;
16//! resolver.load_cert("*.example.com", wildcard_cert_pem, wildcard_key_pem)?;
17//!
18//! // Set a fallback certificate
19//! resolver.set_default_cert(default_cert_pem, default_key_pem)?;
20//! ```
21
22use crate::error::{ProxyError, Result};
23use dashmap::DashMap;
24use rustls::pki_types::{CertificateDer, PrivateKeyDer};
25use rustls::server::{ClientHello, ResolvesServerCert};
26use rustls::sign::CertifiedKey;
27use std::io::BufReader;
28use std::sync::{Arc, RwLock};
29use tracing::{debug, trace};
30
31/// SNI-based certificate resolver for dynamic TLS certificate selection
32///
33/// This resolver maintains a mapping of domain names to TLS certificates,
34/// allowing the proxy to serve different certificates for different domains.
35/// It supports:
36///
37/// - Exact domain matching (e.g., `api.example.com`)
38/// - Wildcard certificates (e.g., `*.example.com`)
39/// - A default/fallback certificate for unmatched domains
40///
41/// The resolver is thread-safe and supports concurrent certificate updates.
42#[derive(Debug)]
43pub struct SniCertResolver {
44    /// Domain -> `CertifiedKey` mapping
45    certs: DashMap<String, Arc<CertifiedKey>>,
46    /// Default/fallback certificate (optional)
47    default_cert: RwLock<Option<Arc<CertifiedKey>>>,
48}
49
50impl SniCertResolver {
51    /// Create a new empty SNI certificate resolver
52    ///
53    /// # Example
54    ///
55    /// ```rust
56    /// use zlayer_proxy::SniCertResolver;
57    ///
58    /// let resolver = SniCertResolver::new();
59    /// ```
60    #[must_use]
61    pub fn new() -> Self {
62        Self {
63            certs: DashMap::new(),
64            default_cert: RwLock::new(None),
65        }
66    }
67
68    /// Load a certificate for a specific domain
69    ///
70    /// Parses the PEM-encoded certificate chain and private key, then stores
71    /// the resulting `CertifiedKey` for the given domain.
72    ///
73    /// # Arguments
74    ///
75    /// * `domain` - The domain name (e.g., `example.com` or `*.example.com`)
76    /// * `cert_pem` - PEM-encoded certificate chain
77    /// * `key_pem` - PEM-encoded private key
78    ///
79    /// # Errors
80    ///
81    /// Returns an error if:
82    /// - The certificate PEM cannot be parsed
83    /// - The private key PEM cannot be parsed
84    /// - The key is not compatible with the certificate
85    ///
86    /// # Example
87    ///
88    /// ```rust,ignore
89    /// resolver.load_cert("example.com", cert_pem, key_pem)?;
90    /// ```
91    pub fn load_cert(&self, domain: &str, cert_pem: &str, key_pem: &str) -> Result<()> {
92        let certified_key = create_certified_key(cert_pem, key_pem)?;
93        let domain_normalized = normalize_domain(domain);
94
95        debug!(domain = %domain_normalized, "Loaded TLS certificate");
96        self.certs
97            .insert(domain_normalized, Arc::new(certified_key));
98
99        Ok(())
100    }
101
102    /// Set the default/fallback certificate
103    ///
104    /// This certificate is used when no domain-specific certificate matches
105    /// the client's SNI request.
106    ///
107    /// # Arguments
108    ///
109    /// * `cert_pem` - PEM-encoded certificate chain
110    /// * `key_pem` - PEM-encoded private key
111    ///
112    /// # Errors
113    ///
114    /// Returns an error if the certificate or key cannot be parsed.
115    ///
116    /// # Example
117    ///
118    /// ```rust,ignore
119    /// resolver.set_default_cert(default_cert_pem, default_key_pem)?;
120    /// ```
121    /// # Panics
122    ///
123    /// Panics if the internal `RwLock` is poisoned.
124    pub fn set_default_cert(&self, cert_pem: &str, key_pem: &str) -> Result<()> {
125        let certified_key = create_certified_key(cert_pem, key_pem)?;
126
127        debug!("Set default TLS certificate");
128        let mut default = self.default_cert.write().expect("RwLock poisoned");
129        *default = Some(Arc::new(certified_key));
130
131        Ok(())
132    }
133
134    /// Remove a certificate for a specific domain
135    ///
136    /// # Arguments
137    ///
138    /// * `domain` - The domain name to remove the certificate for
139    ///
140    /// # Example
141    ///
142    /// ```rust,ignore
143    /// resolver.remove_cert("example.com");
144    /// ```
145    pub fn remove_cert(&self, domain: &str) {
146        let domain_normalized = normalize_domain(domain);
147        if self.certs.remove(&domain_normalized).is_some() {
148            debug!(domain = %domain_normalized, "Removed TLS certificate");
149        }
150    }
151
152    /// Refresh/update a certificate for an existing domain
153    ///
154    /// This is equivalent to calling `load_cert` but semantically indicates
155    /// an update to an existing certificate (e.g., for certificate renewal).
156    ///
157    /// # Arguments
158    ///
159    /// * `domain` - The domain name
160    /// * `cert_pem` - New PEM-encoded certificate chain
161    /// * `key_pem` - New PEM-encoded private key
162    ///
163    /// # Errors
164    ///
165    /// Returns an error if the certificate or key cannot be parsed.
166    ///
167    /// # Example
168    ///
169    /// ```rust,ignore
170    /// resolver.refresh_cert("example.com", new_cert_pem, new_key_pem)?;
171    /// ```
172    pub fn refresh_cert(&self, domain: &str, cert_pem: &str, key_pem: &str) -> Result<()> {
173        let certified_key = create_certified_key(cert_pem, key_pem)?;
174        let domain_normalized = normalize_domain(domain);
175
176        debug!(domain = %domain_normalized, "Refreshed TLS certificate");
177        self.certs
178            .insert(domain_normalized, Arc::new(certified_key));
179
180        Ok(())
181    }
182
183    /// Check if a certificate exists for a domain
184    ///
185    /// # Arguments
186    ///
187    /// * `domain` - The domain name to check
188    ///
189    /// # Returns
190    ///
191    /// `true` if a certificate is loaded for the exact domain name
192    #[must_use]
193    pub fn has_cert(&self, domain: &str) -> bool {
194        let domain_normalized = normalize_domain(domain);
195        self.certs.contains_key(&domain_normalized)
196    }
197
198    /// Get the number of loaded certificates
199    #[must_use]
200    pub fn cert_count(&self) -> usize {
201        self.certs.len()
202    }
203
204    /// List all domains with loaded certificates
205    #[must_use]
206    pub fn domains(&self) -> Vec<String> {
207        self.certs.iter().map(|r| r.key().clone()).collect()
208    }
209
210    /// Check if a default/fallback certificate is configured
211    #[must_use]
212    pub fn has_default_cert(&self) -> bool {
213        self.default_cert.read().is_ok_and(|guard| guard.is_some())
214    }
215
216    /// Internal method to resolve a certificate for a given server name
217    fn resolve_cert(&self, server_name: Option<&str>) -> Option<Arc<CertifiedKey>> {
218        let server_name = server_name?;
219        let normalized = normalize_domain(server_name);
220
221        // Try exact match first
222        if let Some(cert) = self.certs.get(&normalized) {
223            trace!(domain = %normalized, "Exact certificate match");
224            return Some(Arc::clone(cert.value()));
225        }
226
227        // Try wildcard match (e.g., for "foo.example.com", try "*.example.com")
228        if let Some(wildcard_domain) = get_wildcard_domain(&normalized) {
229            if let Some(cert) = self.certs.get(&wildcard_domain) {
230                trace!(
231                    domain = %normalized,
232                    wildcard = %wildcard_domain,
233                    "Wildcard certificate match"
234                );
235                return Some(Arc::clone(cert.value()));
236            }
237        }
238
239        // Fall back to default certificate
240        // Note: We use std::sync::RwLock since ResolvesServerCert::resolve is sync
241        if let Ok(guard) = self.default_cert.read() {
242            if let Some(default) = guard.as_ref() {
243                trace!(domain = %normalized, "Using default certificate");
244                return Some(Arc::clone(default));
245            }
246        }
247
248        // Demoted from warn! to debug!: a cert-less SNI (e.g. an ACME HTTP-01
249        // probe, a health check, or a client hitting a not-yet-provisioned host)
250        // is routine and was flooding daemon.log. Real misconfig surfaces when
251        // the handshake fails, not here.
252        debug!(domain = %normalized, "No certificate found");
253        None
254    }
255}
256
257impl Default for SniCertResolver {
258    fn default() -> Self {
259        Self::new()
260    }
261}
262
263impl ResolvesServerCert for SniCertResolver {
264    fn resolve(&self, client_hello: ClientHello<'_>) -> Option<Arc<CertifiedKey>> {
265        let server_name = client_hello.server_name();
266        self.resolve_cert(server_name)
267    }
268}
269
270/// Create a `CertifiedKey` from PEM-encoded certificate and private key
271///
272/// # Arguments
273///
274/// * `cert_pem` - PEM-encoded certificate chain (may contain multiple certificates)
275/// * `key_pem` - PEM-encoded private key (PKCS#1, PKCS#8, or SEC1 format)
276///
277/// # Errors
278///
279/// Returns an error if:
280/// - The certificate PEM cannot be parsed
281/// - No certificates are found in the PEM
282/// - The private key PEM cannot be parsed
283/// - No private key is found in the PEM
284/// - The key cannot be converted to a signing key
285fn create_certified_key(cert_pem: &str, key_pem: &str) -> Result<CertifiedKey> {
286    // Parse certificates
287    let certs = parse_certificates(cert_pem)?;
288    if certs.is_empty() {
289        return Err(ProxyError::Tls("No certificates found in PEM".to_string()));
290    }
291
292    // Parse private key
293    let key = parse_private_key(key_pem)?;
294
295    // Create signing key using rustls crypto provider
296    let signing_key = rustls::crypto::ring::sign::any_supported_type(&key)
297        .map_err(|e| ProxyError::Tls(format!("Failed to create signing key: {e}")))?;
298
299    Ok(CertifiedKey::new(certs, signing_key))
300}
301
302/// Parse PEM-encoded certificates
303fn parse_certificates(pem: &str) -> Result<Vec<CertificateDer<'static>>> {
304    let mut reader = BufReader::new(pem.as_bytes());
305    let certs: Vec<CertificateDer<'static>> = rustls_pemfile::certs(&mut reader)
306        .collect::<std::result::Result<Vec<_>, _>>()
307        .map_err(|e| ProxyError::Tls(format!("Failed to parse certificate PEM: {e}")))?;
308
309    Ok(certs)
310}
311
312/// Parse a PEM-encoded private key
313fn parse_private_key(pem: &str) -> Result<PrivateKeyDer<'static>> {
314    let mut reader = BufReader::new(pem.as_bytes());
315
316    // Try to read different key formats
317    loop {
318        match rustls_pemfile::read_one(&mut reader) {
319            Ok(Some(rustls_pemfile::Item::Pkcs1Key(key))) => {
320                return Ok(PrivateKeyDer::Pkcs1(key));
321            }
322            Ok(Some(rustls_pemfile::Item::Pkcs8Key(key))) => {
323                return Ok(PrivateKeyDer::Pkcs8(key));
324            }
325            Ok(Some(rustls_pemfile::Item::Sec1Key(key))) => {
326                return Ok(PrivateKeyDer::Sec1(key));
327            }
328            Ok(Some(_)) => {
329                // Skip non-key items (like certificates)
330            }
331            Ok(None) => {
332                return Err(ProxyError::Tls("No private key found in PEM".to_string()));
333            }
334            Err(e) => {
335                return Err(ProxyError::Tls(format!(
336                    "Failed to parse private key PEM: {e}"
337                )));
338            }
339        }
340    }
341}
342
343/// Normalize a domain name (lowercase, trim whitespace)
344fn normalize_domain(domain: &str) -> String {
345    domain.trim().to_lowercase()
346}
347
348/// Get the wildcard domain for a given domain
349///
350/// For `foo.example.com`, returns `*.example.com`.
351/// For `example.com` (no subdomain), returns `None`.
352fn get_wildcard_domain(domain: &str) -> Option<String> {
353    let parts: Vec<&str> = domain.split('.').collect();
354    if parts.len() > 2 {
355        // Has subdomain, create wildcard
356        Some(format!("*.{}", parts[1..].join(".")))
357    } else {
358        // No subdomain (e.g., "example.com")
359        None
360    }
361}
362
363#[cfg(test)]
364mod tests {
365    use super::*;
366
367    #[test]
368    fn test_normalize_domain() {
369        assert_eq!(normalize_domain("Example.COM"), "example.com");
370        assert_eq!(normalize_domain("  foo.bar.com  "), "foo.bar.com");
371        assert_eq!(normalize_domain("API.Example.ORG"), "api.example.org");
372    }
373
374    #[test]
375    fn test_get_wildcard_domain() {
376        assert_eq!(
377            get_wildcard_domain("foo.example.com"),
378            Some("*.example.com".to_string())
379        );
380        assert_eq!(
381            get_wildcard_domain("bar.foo.example.com"),
382            Some("*.foo.example.com".to_string())
383        );
384        assert_eq!(get_wildcard_domain("example.com"), None);
385        assert_eq!(get_wildcard_domain("localhost"), None);
386    }
387
388    #[test]
389    fn test_sni_resolver_new() {
390        let resolver = SniCertResolver::new();
391        assert_eq!(resolver.cert_count(), 0);
392        assert!(resolver.domains().is_empty());
393    }
394
395    #[test]
396    fn test_sni_resolver_default() {
397        let resolver = SniCertResolver::default();
398        assert_eq!(resolver.cert_count(), 0);
399    }
400
401    // Generate a self-signed certificate for testing
402    fn generate_test_cert() -> (String, String) {
403        use rcgen::{generate_simple_self_signed, CertifiedKey as RcgenCertifiedKey};
404
405        let subject_alt_names = vec!["localhost".to_string(), "example.com".to_string()];
406        let RcgenCertifiedKey { cert, key_pair } =
407            generate_simple_self_signed(subject_alt_names).unwrap();
408
409        (cert.pem(), key_pair.serialize_pem())
410    }
411
412    #[tokio::test]
413    async fn test_load_cert() {
414        let resolver = SniCertResolver::new();
415        let (cert_pem, key_pem) = generate_test_cert();
416
417        let result = resolver.load_cert("example.com", &cert_pem, &key_pem);
418        assert!(result.is_ok());
419        assert!(resolver.has_cert("example.com"));
420        assert_eq!(resolver.cert_count(), 1);
421    }
422
423    #[tokio::test]
424    async fn test_load_cert_case_insensitive() {
425        let resolver = SniCertResolver::new();
426        let (cert_pem, key_pem) = generate_test_cert();
427
428        resolver
429            .load_cert("Example.COM", &cert_pem, &key_pem)
430            .unwrap();
431        assert!(resolver.has_cert("example.com"));
432        assert!(resolver.has_cert("EXAMPLE.COM"));
433    }
434
435    #[tokio::test]
436    async fn test_remove_cert() {
437        let resolver = SniCertResolver::new();
438        let (cert_pem, key_pem) = generate_test_cert();
439
440        resolver
441            .load_cert("example.com", &cert_pem, &key_pem)
442            .unwrap();
443        assert!(resolver.has_cert("example.com"));
444
445        resolver.remove_cert("example.com");
446        assert!(!resolver.has_cert("example.com"));
447        assert_eq!(resolver.cert_count(), 0);
448    }
449
450    #[tokio::test]
451    async fn test_refresh_cert() {
452        let resolver = SniCertResolver::new();
453        let (cert_pem, key_pem) = generate_test_cert();
454
455        // Load initial cert
456        resolver
457            .load_cert("example.com", &cert_pem, &key_pem)
458            .unwrap();
459
460        // Refresh with new cert
461        let (new_cert_pem, new_key_pem) = generate_test_cert();
462        let result = resolver.refresh_cert("example.com", &new_cert_pem, &new_key_pem);
463        assert!(result.is_ok());
464        assert_eq!(resolver.cert_count(), 1);
465    }
466
467    #[tokio::test]
468    async fn test_set_default_cert() {
469        let resolver = SniCertResolver::new();
470        let (cert_pem, key_pem) = generate_test_cert();
471
472        let result = resolver.set_default_cert(&cert_pem, &key_pem);
473        assert!(result.is_ok());
474
475        // Default cert should not show up in cert_count or domains
476        assert_eq!(resolver.cert_count(), 0);
477    }
478
479    #[tokio::test]
480    async fn test_has_default_cert() {
481        let resolver = SniCertResolver::new();
482        assert!(!resolver.has_default_cert());
483
484        let (cert_pem, key_pem) = generate_test_cert();
485        resolver.set_default_cert(&cert_pem, &key_pem).unwrap();
486
487        assert!(resolver.has_default_cert());
488    }
489
490    #[tokio::test]
491    async fn test_domains() {
492        let resolver = SniCertResolver::new();
493        let (cert_pem, key_pem) = generate_test_cert();
494
495        resolver
496            .load_cert("api.example.com", &cert_pem, &key_pem)
497            .unwrap();
498        resolver
499            .load_cert("web.example.com", &cert_pem, &key_pem)
500            .unwrap();
501
502        let domains = resolver.domains();
503        assert_eq!(domains.len(), 2);
504        assert!(domains.contains(&"api.example.com".to_string()));
505        assert!(domains.contains(&"web.example.com".to_string()));
506    }
507
508    #[tokio::test]
509    async fn test_resolve_exact_match() {
510        let resolver = SniCertResolver::new();
511        let (cert_pem, key_pem) = generate_test_cert();
512
513        resolver
514            .load_cert("example.com", &cert_pem, &key_pem)
515            .unwrap();
516
517        let result = resolver.resolve_cert(Some("example.com"));
518        assert!(result.is_some());
519    }
520
521    #[tokio::test]
522    async fn test_resolve_wildcard_match() {
523        let resolver = SniCertResolver::new();
524        let (cert_pem, key_pem) = generate_test_cert();
525
526        // Load wildcard cert
527        resolver
528            .load_cert("*.example.com", &cert_pem, &key_pem)
529            .unwrap();
530
531        // Should match subdomains
532        let result = resolver.resolve_cert(Some("api.example.com"));
533        assert!(result.is_some());
534
535        let result = resolver.resolve_cert(Some("web.example.com"));
536        assert!(result.is_some());
537
538        // Should not match base domain
539        let result = resolver.resolve_cert(Some("example.com"));
540        assert!(result.is_none());
541    }
542
543    #[tokio::test]
544    async fn test_resolve_default_fallback() {
545        let resolver = SniCertResolver::new();
546        let (cert_pem, key_pem) = generate_test_cert();
547
548        resolver.set_default_cert(&cert_pem, &key_pem).unwrap();
549
550        // Unknown domain should fall back to default
551        let result = resolver.resolve_cert(Some("unknown.com"));
552        assert!(result.is_some());
553    }
554
555    #[tokio::test]
556    async fn test_resolve_no_match() {
557        let resolver = SniCertResolver::new();
558        let (cert_pem, key_pem) = generate_test_cert();
559
560        resolver
561            .load_cert("example.com", &cert_pem, &key_pem)
562            .unwrap();
563
564        // No default, different domain
565        let result = resolver.resolve_cert(Some("other.com"));
566        assert!(result.is_none());
567    }
568
569    #[tokio::test]
570    async fn test_resolve_none_server_name() {
571        let resolver = SniCertResolver::new();
572
573        // No server name provided
574        let result = resolver.resolve_cert(None);
575        assert!(result.is_none());
576    }
577
578    #[test]
579    fn test_invalid_cert_pem() {
580        let result = parse_certificates("not a valid PEM");
581        assert!(result.is_ok()); // Will succeed but return empty vec
582        assert!(result.unwrap().is_empty());
583    }
584
585    #[test]
586    fn test_invalid_key_pem() {
587        let result = parse_private_key("not a valid PEM");
588        assert!(result.is_err());
589    }
590
591    #[test]
592    fn test_create_certified_key_empty_certs() {
593        let (_, key_pem) = generate_test_cert();
594        let result = create_certified_key("", &key_pem);
595        assert!(result.is_err());
596    }
597}