Skip to main content

wae_https/
tls.rs

1//! TLS 配置模块
2//!
3//! 提供 TLS/HTTPS 支持的工具函数,包括 HTTP/2 ALPN 协商支持。
4
5use std::{fs::File, io::BufReader, sync::Arc};
6use tokio_rustls::{
7    TlsAcceptor,
8    rustls::{
9        RootCertStore, ServerConfig,
10        pki_types::{CertificateDer, PrivateKeyDer},
11        server::WebPkiClientVerifier,
12    },
13};
14
15use crate::{WaeError, WaeResult};
16
17/// ALPN 协议标识符
18pub mod alpn {
19    /// HTTP/1.1 ALPN 协议标识
20    pub const HTTP_1_1: &[u8] = b"http/1.1";
21    /// HTTP/2 ALPN 协议标识
22    pub const HTTP_2: &[u8] = b"h2";
23}
24
25/// 创建 TLS 接受器
26///
27/// 从 PEM 格式的证书和私钥文件创建 TLS 接受器。
28/// 默认支持 HTTP/1.1 协议。
29pub fn create_tls_acceptor(cert_path: &str, key_path: &str) -> WaeResult<TlsAcceptor> {
30    create_tls_acceptor_with_http2(cert_path, key_path, false)
31}
32
33/// 创建支持 HTTP/2 的 TLS 接受器
34///
35/// 从 PEM 格式的证书和私钥文件创建 TLS 接受器,
36/// 支持 HTTP/1.1 和 HTTP/2 的 ALPN 协商。
37///
38/// 参数:
39/// - `cert_path`: 证书文件路径
40/// - `key_path`: 私钥文件路径
41/// - `enable_http2`: 是否启用 HTTP/2 支持
42pub fn create_tls_acceptor_with_http2(cert_path: &str, key_path: &str, enable_http2: bool) -> WaeResult<TlsAcceptor> {
43    let certs = load_certs(cert_path)?;
44    let key = load_private_key(key_path)?;
45
46    let alpn_protocols =
47        if enable_http2 { vec![alpn::HTTP_2.to_vec(), alpn::HTTP_1_1.to_vec()] } else { vec![alpn::HTTP_1_1.to_vec()] };
48
49    let config = ServerConfig::builder()
50        .with_no_client_auth()
51        .with_single_cert(certs, key)
52        .map_err(|e| WaeError::internal(format!("Failed to create TLS config: {}", e)))?;
53
54    let mut config = Arc::new(config);
55    Arc::get_mut(&mut config).expect("Config should be unique").alpn_protocols = alpn_protocols;
56
57    Ok(TlsAcceptor::from(config))
58}
59
60/// 创建支持客户端证书验证的 TLS 接受器
61///
62/// 从 PEM 格式的证书和私钥文件创建 TLS 接受器,
63/// 同时验证客户端证书。
64///
65/// 参数:
66/// - `cert_path`: 服务端证书文件路径
67/// - `key_path`: 服务端私钥文件路径
68/// - `ca_path`: CA 证书文件路径(用于验证客户端证书)
69/// - `enable_http2`: 是否启用 HTTP/2 支持
70pub fn create_tls_acceptor_with_client_auth(
71    cert_path: &str,
72    key_path: &str,
73    ca_path: &str,
74    enable_http2: bool,
75) -> WaeResult<TlsAcceptor> {
76    let certs = load_certs(cert_path)?;
77    let key = load_private_key(key_path)?;
78    let ca_certs = load_certs(ca_path)?;
79
80    let mut root_cert_store = RootCertStore::empty();
81    for cert in ca_certs {
82        root_cert_store.add(cert).map_err(|e| WaeError::internal(format!("Failed to add CA cert: {}", e)))?;
83    }
84
85    let client_verifier = WebPkiClientVerifier::builder(Arc::new(root_cert_store))
86        .build()
87        .map_err(|e| WaeError::internal(format!("Failed to create client verifier: {}", e)))?;
88
89    let alpn_protocols =
90        if enable_http2 { vec![alpn::HTTP_2.to_vec(), alpn::HTTP_1_1.to_vec()] } else { vec![alpn::HTTP_1_1.to_vec()] };
91
92    let config = ServerConfig::builder()
93        .with_client_cert_verifier(client_verifier)
94        .with_single_cert(certs, key)
95        .map_err(|e| WaeError::internal(format!("Failed to create TLS config: {}", e)))?;
96
97    let mut config = Arc::new(config);
98    Arc::get_mut(&mut config).expect("Config should be unique").alpn_protocols = alpn_protocols;
99
100    Ok(TlsAcceptor::from(config))
101}
102
103/// 从 PEM 文件加载证书
104fn load_certs(path: &str) -> WaeResult<Vec<CertificateDer<'static>>> {
105    let file = File::open(path).map_err(|e| WaeError::internal(format!("Failed to open cert file {}: {}", path, e)))?;
106    let mut reader = BufReader::new(file);
107
108    rustls_pemfile::certs(&mut reader)
109        .collect::<Result<Vec<_>, _>>()
110        .map_err(|e| WaeError::internal(format!("Failed to parse cert file {}: {}", path, e)))
111}
112
113/// 从 PEM 文件加载私钥
114fn load_private_key(path: &str) -> WaeResult<PrivateKeyDer<'static>> {
115    let file = File::open(path).map_err(|e| WaeError::internal(format!("Failed to open key file {}: {}", path, e)))?;
116    let mut reader = BufReader::new(file);
117
118    let keys: Vec<PrivateKeyDer<'static>> = rustls_pemfile::private_key(&mut reader)
119        .map_err(|e| WaeError::internal(format!("Failed to parse key file {}: {}", path, e)))?
120        .into_iter()
121        .collect();
122
123    keys.into_iter().next().ok_or_else(|| WaeError::internal(format!("No private key found in {}", path)))
124}
125
126/// TLS 配置构建器
127pub struct TlsConfigBuilder {
128    cert_path: Option<String>,
129    key_path: Option<String>,
130    ca_path: Option<String>,
131    enable_http2: bool,
132}
133
134impl TlsConfigBuilder {
135    /// 创建新的 TLS 配置构建器
136    pub fn new() -> Self {
137        Self { cert_path: None, key_path: None, ca_path: None, enable_http2: true }
138    }
139
140    /// 设置证书文件路径
141    pub fn cert_path(mut self, path: impl Into<String>) -> Self {
142        self.cert_path = Some(path.into());
143        self
144    }
145
146    /// 设置私钥文件路径
147    pub fn key_path(mut self, path: impl Into<String>) -> Self {
148        self.key_path = Some(path.into());
149        self
150    }
151
152    /// 设置 CA 证书文件路径(用于客户端证书验证)
153    pub fn ca_path(mut self, path: impl Into<String>) -> Self {
154        self.ca_path = Some(path.into());
155        self
156    }
157
158    /// 设置是否启用 HTTP/2
159    pub fn enable_http2(mut self, enable: bool) -> Self {
160        self.enable_http2 = enable;
161        self
162    }
163
164    /// 构建 TLS 接受器
165    pub fn build(self) -> WaeResult<TlsAcceptor> {
166        let cert_path = self.cert_path.ok_or_else(|| WaeError::internal("Certificate path is required"))?;
167        let key_path = self.key_path.ok_or_else(|| WaeError::internal("Key path is required"))?;
168
169        match self.ca_path {
170            Some(ca_path) => create_tls_acceptor_with_client_auth(&cert_path, &key_path, &ca_path, self.enable_http2),
171            None => create_tls_acceptor_with_http2(&cert_path, &key_path, self.enable_http2),
172        }
173    }
174}
175
176impl Default for TlsConfigBuilder {
177    fn default() -> Self {
178        Self::new()
179    }
180}