Skip to main content

zlayer_proxy/
tls.rs

1//! TLS server configuration
2//!
3//! This module provides TLS termination capabilities for the proxy server.
4
5use crate::config::{TlsConfig, TlsVersion};
6use crate::error::{ProxyError, Result};
7use rustls::pki_types::{CertificateDer, PrivateKeyDer};
8use rustls::server::ServerConfig;
9use std::fs::File;
10use std::io::BufReader;
11use std::sync::Arc;
12use tokio_rustls::TlsAcceptor;
13use tracing::{debug, info};
14
15/// TLS server configuration builder
16#[derive(Debug, Clone)]
17pub struct TlsServerConfig {
18    /// Path to certificate file (PEM format)
19    pub cert_path: String,
20    /// Path to private key file (PEM format)
21    pub key_path: String,
22    /// ALPN protocols to advertise
23    pub alpn_protocols: Vec<Vec<u8>>,
24    /// Minimum TLS version
25    pub min_version: TlsVersion,
26}
27
28impl TlsServerConfig {
29    /// Create a new TLS server configuration
30    pub fn new(cert_path: impl Into<String>, key_path: impl Into<String>) -> Self {
31        Self {
32            cert_path: cert_path.into(),
33            key_path: key_path.into(),
34            alpn_protocols: vec![b"h2".to_vec(), b"http/1.1".to_vec()],
35            min_version: TlsVersion::Tls12,
36        }
37    }
38
39    /// Set ALPN protocols
40    #[must_use]
41    pub fn with_alpn_protocols(mut self, protocols: Vec<Vec<u8>>) -> Self {
42        self.alpn_protocols = protocols;
43        self
44    }
45
46    /// Set minimum TLS version
47    #[must_use]
48    pub fn with_min_version(mut self, version: TlsVersion) -> Self {
49        self.min_version = version;
50        self
51    }
52
53    /// Create from a `TlsConfig`
54    #[must_use]
55    pub fn from_config(config: &TlsConfig) -> Self {
56        let mut alpn = vec![];
57        if config.alpn_h2 {
58            alpn.push(b"h2".to_vec());
59        }
60        alpn.push(b"http/1.1".to_vec());
61
62        Self {
63            cert_path: config.cert_path.to_string_lossy().to_string(),
64            key_path: config.key_path.to_string_lossy().to_string(),
65            alpn_protocols: alpn,
66            min_version: config.min_version,
67        }
68    }
69}
70
71/// Load certificates from a PEM file
72///
73/// # Errors
74///
75/// Returns an error if the file cannot be opened, the PEM cannot be parsed,
76/// or no certificates are found.
77pub fn load_certs(path: &str) -> Result<Vec<CertificateDer<'static>>> {
78    let file = File::open(path)
79        .map_err(|e| ProxyError::Tls(format!("Failed to open certificate file '{path}': {e}")))?;
80
81    let mut reader = BufReader::new(file);
82    let certs: Vec<_> = rustls_pemfile::certs(&mut reader)
83        .collect::<std::result::Result<Vec<_>, _>>()
84        .map_err(|e| ProxyError::Tls(format!("Failed to parse certificates from '{path}': {e}")))?;
85
86    if certs.is_empty() {
87        return Err(ProxyError::Tls(format!(
88            "No certificates found in '{path}'"
89        )));
90    }
91
92    debug!(count = certs.len(), path = %path, "Loaded certificates");
93    Ok(certs)
94}
95
96/// Load a private key from a PEM file
97///
98/// # Errors
99///
100/// Returns an error if the file cannot be opened, the PEM cannot be parsed,
101/// or no private key is found.
102pub fn load_private_key(path: &str) -> Result<PrivateKeyDer<'static>> {
103    let file = File::open(path)
104        .map_err(|e| ProxyError::Tls(format!("Failed to open private key file '{path}': {e}")))?;
105
106    let mut reader = BufReader::new(file);
107
108    // Try to read different key formats
109    loop {
110        match rustls_pemfile::read_one(&mut reader) {
111            Ok(Some(rustls_pemfile::Item::Pkcs1Key(key))) => {
112                debug!(path = %path, "Loaded PKCS#1 RSA private key");
113                return Ok(PrivateKeyDer::Pkcs1(key));
114            }
115            Ok(Some(rustls_pemfile::Item::Pkcs8Key(key))) => {
116                debug!(path = %path, "Loaded PKCS#8 private key");
117                return Ok(PrivateKeyDer::Pkcs8(key));
118            }
119            Ok(Some(rustls_pemfile::Item::Sec1Key(key))) => {
120                debug!(path = %path, "Loaded SEC1 EC private key");
121                return Ok(PrivateKeyDer::Sec1(key));
122            }
123            Ok(Some(_)) => {
124                // Skip non-key items (like certificates)
125            }
126            Ok(None) => {
127                return Err(ProxyError::Tls(format!("No private key found in '{path}'")));
128            }
129            Err(e) => {
130                return Err(ProxyError::Tls(format!(
131                    "Failed to parse private key from '{path}': {e}"
132                )));
133            }
134        }
135    }
136}
137
138/// Create a TLS acceptor from configuration
139///
140/// # Errors
141///
142/// Returns an error if loading the certificate or private key fails,
143/// or if building the TLS server configuration fails.
144pub fn create_tls_acceptor(config: &TlsServerConfig) -> Result<TlsAcceptor> {
145    let certs = load_certs(&config.cert_path)?;
146    let key = load_private_key(&config.key_path)?;
147
148    let server_config = create_server_config(certs, key, config)?;
149
150    info!(
151        cert_path = %config.cert_path,
152        min_version = ?config.min_version,
153        alpn = ?config.alpn_protocols.iter()
154            .map(|p| String::from_utf8_lossy(p).to_string())
155            .collect::<Vec<_>>(),
156        "Created TLS acceptor"
157    );
158
159    Ok(TlsAcceptor::from(Arc::new(server_config)))
160}
161
162/// Create a rustls `ServerConfig`
163fn create_server_config(
164    certs: Vec<CertificateDer<'static>>,
165    key: PrivateKeyDer<'static>,
166    config: &TlsServerConfig,
167) -> Result<ServerConfig> {
168    // Select TLS protocol versions based on minimum version
169    let versions: Vec<&'static rustls::SupportedProtocolVersion> = match config.min_version {
170        TlsVersion::Tls12 => vec![&rustls::version::TLS13, &rustls::version::TLS12],
171        TlsVersion::Tls13 => vec![&rustls::version::TLS13],
172    };
173
174    let mut server_config = ServerConfig::builder_with_protocol_versions(&versions)
175        .with_no_client_auth()
176        .with_single_cert(certs, key)
177        .map_err(|e| ProxyError::Tls(format!("Failed to create TLS config: {e}")))?;
178
179    // Configure ALPN
180    if !config.alpn_protocols.is_empty() {
181        server_config
182            .alpn_protocols
183            .clone_from(&config.alpn_protocols);
184    }
185
186    Ok(server_config)
187}
188
189/// Create a TLS acceptor directly from file paths
190///
191/// # Errors
192///
193/// Returns an error if loading the certificate or key files fails,
194/// or if creating the TLS acceptor fails.
195pub fn create_acceptor_from_files(cert_path: &str, key_path: &str) -> Result<TlsAcceptor> {
196    let config = TlsServerConfig::new(cert_path, key_path);
197    create_tls_acceptor(&config)
198}
199
200#[cfg(test)]
201mod tests {
202    use super::*;
203
204    #[test]
205    fn test_tls_server_config_creation() {
206        let config = TlsServerConfig::new("/path/to/cert.pem", "/path/to/key.pem");
207        assert_eq!(config.cert_path, "/path/to/cert.pem");
208        assert_eq!(config.key_path, "/path/to/key.pem");
209        assert_eq!(
210            config.alpn_protocols,
211            vec![b"h2".to_vec(), b"http/1.1".to_vec()]
212        );
213        assert_eq!(config.min_version, TlsVersion::Tls12);
214    }
215
216    #[test]
217    fn test_tls_server_config_builder() {
218        let config = TlsServerConfig::new("/path/to/cert.pem", "/path/to/key.pem")
219            .with_alpn_protocols(vec![b"http/1.1".to_vec()])
220            .with_min_version(TlsVersion::Tls13);
221
222        assert_eq!(config.alpn_protocols, vec![b"http/1.1".to_vec()]);
223        assert_eq!(config.min_version, TlsVersion::Tls13);
224    }
225
226    #[test]
227    fn test_load_certs_file_not_found() {
228        let result = load_certs("/nonexistent/path/cert.pem");
229        assert!(result.is_err());
230        let err = result.unwrap_err();
231        assert!(matches!(err, ProxyError::Tls(_)));
232    }
233
234    #[test]
235    fn test_load_private_key_file_not_found() {
236        let result = load_private_key("/nonexistent/path/key.pem");
237        assert!(result.is_err());
238        let err = result.unwrap_err();
239        assert!(matches!(err, ProxyError::Tls(_)));
240    }
241
242    #[test]
243    fn test_from_tls_config() {
244        let tls_config = TlsConfig {
245            cert_path: "/path/to/cert.pem".into(),
246            key_path: "/path/to/key.pem".into(),
247            min_version: TlsVersion::Tls13,
248            alpn_h2: true,
249        };
250
251        let server_config = TlsServerConfig::from_config(&tls_config);
252        assert_eq!(server_config.cert_path, "/path/to/cert.pem");
253        assert_eq!(server_config.key_path, "/path/to/key.pem");
254        assert_eq!(server_config.min_version, TlsVersion::Tls13);
255        assert!(server_config.alpn_protocols.contains(&b"h2".to_vec()));
256    }
257
258    #[test]
259    fn test_from_tls_config_no_h2() {
260        let tls_config = TlsConfig {
261            cert_path: "/path/to/cert.pem".into(),
262            key_path: "/path/to/key.pem".into(),
263            min_version: TlsVersion::Tls12,
264            alpn_h2: false,
265        };
266
267        let server_config = TlsServerConfig::from_config(&tls_config);
268        assert!(!server_config.alpn_protocols.contains(&b"h2".to_vec()));
269        assert!(server_config.alpn_protocols.contains(&b"http/1.1".to_vec()));
270    }
271}