1use crate::config::{TlsConfig, TlsVersion};
6use crate::error::{ProxyError, Result};
7use rustls::pki_types::{CertificateDer, PrivateKeyDer};
8use rustls::server::ServerConfig;
9use std::fs::File;
10use std::io::BufReader;
11use std::sync::Arc;
12use tokio_rustls::TlsAcceptor;
13use tracing::{debug, info};
14
15#[derive(Debug, Clone)]
17pub struct TlsServerConfig {
18 pub cert_path: String,
20 pub key_path: String,
22 pub alpn_protocols: Vec<Vec<u8>>,
24 pub min_version: TlsVersion,
26}
27
28impl TlsServerConfig {
29 pub fn new(cert_path: impl Into<String>, key_path: impl Into<String>) -> Self {
31 Self {
32 cert_path: cert_path.into(),
33 key_path: key_path.into(),
34 alpn_protocols: vec![b"h2".to_vec(), b"http/1.1".to_vec()],
35 min_version: TlsVersion::Tls12,
36 }
37 }
38
39 #[must_use]
41 pub fn with_alpn_protocols(mut self, protocols: Vec<Vec<u8>>) -> Self {
42 self.alpn_protocols = protocols;
43 self
44 }
45
46 #[must_use]
48 pub fn with_min_version(mut self, version: TlsVersion) -> Self {
49 self.min_version = version;
50 self
51 }
52
53 #[must_use]
55 pub fn from_config(config: &TlsConfig) -> Self {
56 let mut alpn = vec![];
57 if config.alpn_h2 {
58 alpn.push(b"h2".to_vec());
59 }
60 alpn.push(b"http/1.1".to_vec());
61
62 Self {
63 cert_path: config.cert_path.to_string_lossy().to_string(),
64 key_path: config.key_path.to_string_lossy().to_string(),
65 alpn_protocols: alpn,
66 min_version: config.min_version,
67 }
68 }
69}
70
71pub fn load_certs(path: &str) -> Result<Vec<CertificateDer<'static>>> {
78 let file = File::open(path)
79 .map_err(|e| ProxyError::Tls(format!("Failed to open certificate file '{path}': {e}")))?;
80
81 let mut reader = BufReader::new(file);
82 let certs: Vec<_> = rustls_pemfile::certs(&mut reader)
83 .collect::<std::result::Result<Vec<_>, _>>()
84 .map_err(|e| ProxyError::Tls(format!("Failed to parse certificates from '{path}': {e}")))?;
85
86 if certs.is_empty() {
87 return Err(ProxyError::Tls(format!(
88 "No certificates found in '{path}'"
89 )));
90 }
91
92 debug!(count = certs.len(), path = %path, "Loaded certificates");
93 Ok(certs)
94}
95
96pub fn load_private_key(path: &str) -> Result<PrivateKeyDer<'static>> {
103 let file = File::open(path)
104 .map_err(|e| ProxyError::Tls(format!("Failed to open private key file '{path}': {e}")))?;
105
106 let mut reader = BufReader::new(file);
107
108 loop {
110 match rustls_pemfile::read_one(&mut reader) {
111 Ok(Some(rustls_pemfile::Item::Pkcs1Key(key))) => {
112 debug!(path = %path, "Loaded PKCS#1 RSA private key");
113 return Ok(PrivateKeyDer::Pkcs1(key));
114 }
115 Ok(Some(rustls_pemfile::Item::Pkcs8Key(key))) => {
116 debug!(path = %path, "Loaded PKCS#8 private key");
117 return Ok(PrivateKeyDer::Pkcs8(key));
118 }
119 Ok(Some(rustls_pemfile::Item::Sec1Key(key))) => {
120 debug!(path = %path, "Loaded SEC1 EC private key");
121 return Ok(PrivateKeyDer::Sec1(key));
122 }
123 Ok(Some(_)) => {
124 }
126 Ok(None) => {
127 return Err(ProxyError::Tls(format!("No private key found in '{path}'")));
128 }
129 Err(e) => {
130 return Err(ProxyError::Tls(format!(
131 "Failed to parse private key from '{path}': {e}"
132 )));
133 }
134 }
135 }
136}
137
138pub fn create_tls_acceptor(config: &TlsServerConfig) -> Result<TlsAcceptor> {
145 let certs = load_certs(&config.cert_path)?;
146 let key = load_private_key(&config.key_path)?;
147
148 let server_config = create_server_config(certs, key, config)?;
149
150 info!(
151 cert_path = %config.cert_path,
152 min_version = ?config.min_version,
153 alpn = ?config.alpn_protocols.iter()
154 .map(|p| String::from_utf8_lossy(p).to_string())
155 .collect::<Vec<_>>(),
156 "Created TLS acceptor"
157 );
158
159 Ok(TlsAcceptor::from(Arc::new(server_config)))
160}
161
162fn create_server_config(
164 certs: Vec<CertificateDer<'static>>,
165 key: PrivateKeyDer<'static>,
166 config: &TlsServerConfig,
167) -> Result<ServerConfig> {
168 let versions: Vec<&'static rustls::SupportedProtocolVersion> = match config.min_version {
170 TlsVersion::Tls12 => vec![&rustls::version::TLS13, &rustls::version::TLS12],
171 TlsVersion::Tls13 => vec![&rustls::version::TLS13],
172 };
173
174 let mut server_config = ServerConfig::builder_with_protocol_versions(&versions)
175 .with_no_client_auth()
176 .with_single_cert(certs, key)
177 .map_err(|e| ProxyError::Tls(format!("Failed to create TLS config: {e}")))?;
178
179 if !config.alpn_protocols.is_empty() {
181 server_config
182 .alpn_protocols
183 .clone_from(&config.alpn_protocols);
184 }
185
186 Ok(server_config)
187}
188
189pub fn create_acceptor_from_files(cert_path: &str, key_path: &str) -> Result<TlsAcceptor> {
196 let config = TlsServerConfig::new(cert_path, key_path);
197 create_tls_acceptor(&config)
198}
199
200#[cfg(test)]
201mod tests {
202 use super::*;
203
204 #[test]
205 fn test_tls_server_config_creation() {
206 let config = TlsServerConfig::new("/path/to/cert.pem", "/path/to/key.pem");
207 assert_eq!(config.cert_path, "/path/to/cert.pem");
208 assert_eq!(config.key_path, "/path/to/key.pem");
209 assert_eq!(
210 config.alpn_protocols,
211 vec![b"h2".to_vec(), b"http/1.1".to_vec()]
212 );
213 assert_eq!(config.min_version, TlsVersion::Tls12);
214 }
215
216 #[test]
217 fn test_tls_server_config_builder() {
218 let config = TlsServerConfig::new("/path/to/cert.pem", "/path/to/key.pem")
219 .with_alpn_protocols(vec![b"http/1.1".to_vec()])
220 .with_min_version(TlsVersion::Tls13);
221
222 assert_eq!(config.alpn_protocols, vec![b"http/1.1".to_vec()]);
223 assert_eq!(config.min_version, TlsVersion::Tls13);
224 }
225
226 #[test]
227 fn test_load_certs_file_not_found() {
228 let result = load_certs("/nonexistent/path/cert.pem");
229 assert!(result.is_err());
230 let err = result.unwrap_err();
231 assert!(matches!(err, ProxyError::Tls(_)));
232 }
233
234 #[test]
235 fn test_load_private_key_file_not_found() {
236 let result = load_private_key("/nonexistent/path/key.pem");
237 assert!(result.is_err());
238 let err = result.unwrap_err();
239 assert!(matches!(err, ProxyError::Tls(_)));
240 }
241
242 #[test]
243 fn test_from_tls_config() {
244 let tls_config = TlsConfig {
245 cert_path: "/path/to/cert.pem".into(),
246 key_path: "/path/to/key.pem".into(),
247 min_version: TlsVersion::Tls13,
248 alpn_h2: true,
249 };
250
251 let server_config = TlsServerConfig::from_config(&tls_config);
252 assert_eq!(server_config.cert_path, "/path/to/cert.pem");
253 assert_eq!(server_config.key_path, "/path/to/key.pem");
254 assert_eq!(server_config.min_version, TlsVersion::Tls13);
255 assert!(server_config.alpn_protocols.contains(&b"h2".to_vec()));
256 }
257
258 #[test]
259 fn test_from_tls_config_no_h2() {
260 let tls_config = TlsConfig {
261 cert_path: "/path/to/cert.pem".into(),
262 key_path: "/path/to/key.pem".into(),
263 min_version: TlsVersion::Tls12,
264 alpn_h2: false,
265 };
266
267 let server_config = TlsServerConfig::from_config(&tls_config);
268 assert!(!server_config.alpn_protocols.contains(&b"h2".to_vec()));
269 assert!(server_config.alpn_protocols.contains(&b"http/1.1".to_vec()));
270 }
271}