Skip to main content

wae_https/
lib.rs

1#![doc = include_str!("../readme.md")]
2#![warn(missing_docs)]
3
4pub mod error;
5pub mod extract;
6pub mod middleware;
7pub mod response;
8pub mod router;
9pub mod template;
10pub mod tls;
11
12pub use wae_session as session;
13
14use axum::{
15    Router as AxumRouter,
16    body::Body,
17    http::{StatusCode, header},
18    response::{IntoResponse, Response},
19};
20use hyper_util::service::TowerToHyperService;
21use std::{net::SocketAddr, time::Duration};
22use tokio::net::TcpListener;
23use tracing::info;
24
25pub use wae_types::{CloudError, CloudResult, WaeError, WaeResult};
26
27/// HTTPS 服务结果类型
28pub type HttpsResult<T> = WaeResult<T>;
29
30/// HTTPS 服务错误类型
31pub type HttpsError = WaeError;
32
33/// HTTP 协议版本配置
34#[derive(Debug, Clone, Copy, Default)]
35pub enum HttpVersion {
36    /// 仅 HTTP/1.1
37    Http1Only,
38    /// 仅 HTTP/2
39    Http2Only,
40    /// 自动选择(HTTP/1.1 和 HTTP/2 双协议支持)
41    #[default]
42    Both,
43}
44
45/// HTTP/2 配置选项
46#[derive(Debug, Clone)]
47pub struct Http2Config {
48    /// 是否启用 HTTP/2
49    pub enabled: bool,
50    /// 是否启用 HTTP/2 推送
51    pub enable_push: bool,
52    /// 最大并发流数量
53    pub max_concurrent_streams: u32,
54    /// 初始窗口大小
55    pub initial_stream_window_size: u32,
56    /// 最大帧大小
57    pub max_frame_size: u32,
58    /// 启用 CONNECT 协议
59    pub enable_connect_protocol: bool,
60    /// 流空闲超时时间
61    pub stream_idle_timeout: Duration,
62}
63
64impl Default for Http2Config {
65    fn default() -> Self {
66        Self {
67            enabled: true,
68            enable_push: false,
69            max_concurrent_streams: 256,
70            initial_stream_window_size: 65535,
71            max_frame_size: 16384,
72            enable_connect_protocol: false,
73            stream_idle_timeout: Duration::from_secs(60),
74        }
75    }
76}
77
78impl Http2Config {
79    /// 创建默认的 HTTP/2 配置
80    pub fn new() -> Self {
81        Self::default()
82    }
83
84    /// 禁用 HTTP/2
85    pub fn disabled() -> Self {
86        Self { enabled: false, ..Self::default() }
87    }
88
89    /// 设置是否启用 HTTP/2 推送
90    pub fn with_enable_push(mut self, enable: bool) -> Self {
91        self.enable_push = enable;
92        self
93    }
94
95    /// 设置最大并发流数量
96    pub fn with_max_concurrent_streams(mut self, max: u32) -> Self {
97        self.max_concurrent_streams = max;
98        self
99    }
100
101    /// 设置初始窗口大小
102    pub fn with_initial_stream_window_size(mut self, size: u32) -> Self {
103        self.initial_stream_window_size = size;
104        self
105    }
106
107    /// 设置最大帧大小
108    pub fn with_max_frame_size(mut self, size: u32) -> Self {
109        self.max_frame_size = size;
110        self
111    }
112
113    /// 设置是否启用 CONNECT 协议
114    pub fn with_enable_connect_protocol(mut self, enable: bool) -> Self {
115        self.enable_connect_protocol = enable;
116        self
117    }
118
119    /// 设置流空闲超时时间
120    pub fn with_stream_idle_timeout(mut self, timeout: Duration) -> Self {
121        self.stream_idle_timeout = timeout;
122        self
123    }
124}
125
126/// TLS 配置
127#[derive(Debug, Clone)]
128pub struct TlsConfig {
129    /// 证书文件路径
130    pub cert_path: String,
131    /// 私钥文件路径
132    pub key_path: String,
133}
134
135impl TlsConfig {
136    /// 创建新的 TLS 配置
137    pub fn new(cert_path: impl Into<String>, key_path: impl Into<String>) -> Self {
138        Self { cert_path: cert_path.into(), key_path: key_path.into() }
139    }
140}
141
142/// HTTPS 服务配置
143#[derive(Debug, Clone)]
144pub struct HttpsServerConfig {
145    /// 服务监听地址
146    pub addr: SocketAddr,
147    /// 服务名称
148    pub service_name: String,
149    /// HTTP 协议版本
150    pub http_version: HttpVersion,
151    /// HTTP/2 配置
152    pub http2_config: Http2Config,
153    /// TLS 配置(可选)
154    pub tls_config: Option<TlsConfig>,
155}
156
157impl Default for HttpsServerConfig {
158    fn default() -> Self {
159        Self {
160            addr: "0.0.0.0:3000".parse().unwrap(),
161            service_name: "wae-https-service".to_string(),
162            http_version: HttpVersion::Both,
163            http2_config: Http2Config::default(),
164            tls_config: None,
165        }
166    }
167}
168
169/// HTTPS 服务构建器
170pub struct HttpsServerBuilder {
171    config: HttpsServerConfig,
172    router: AxumRouter,
173}
174
175impl HttpsServerBuilder {
176    /// 创建新的服务构建器
177    pub fn new() -> Self {
178        Self { config: HttpsServerConfig::default(), router: AxumRouter::new() }
179    }
180
181    /// 设置监听地址
182    pub fn addr(mut self, addr: SocketAddr) -> Self {
183        self.config.addr = addr;
184        self
185    }
186
187    /// 设置服务名称
188    pub fn service_name(mut self, name: impl Into<String>) -> Self {
189        self.config.service_name = name.into();
190        self
191    }
192
193    /// 设置路由
194    pub fn router(mut self, router: AxumRouter) -> Self {
195        self.router = router;
196        self
197    }
198
199    /// 合并路由
200    pub fn merge_router(mut self, router: AxumRouter) -> Self {
201        self.router = self.router.merge(router);
202        self
203    }
204
205    /// 设置 HTTP 协议版本
206    pub fn http_version(mut self, version: HttpVersion) -> Self {
207        self.config.http_version = version;
208        self
209    }
210
211    /// 设置 HTTP/2 配置
212    pub fn http2_config(mut self, config: Http2Config) -> Self {
213        self.config.http2_config = config;
214        self
215    }
216
217    /// 设置 TLS 配置
218    pub fn tls(mut self, cert_path: impl Into<String>, key_path: impl Into<String>) -> Self {
219        self.config.tls_config = Some(TlsConfig::new(cert_path, key_path));
220        self
221    }
222
223    /// 设置 TLS 配置对象
224    pub fn tls_config(mut self, config: TlsConfig) -> Self {
225        self.config.tls_config = Some(config);
226        self
227    }
228
229    /// 构建 HTTPS 服务
230    pub fn build(self) -> HttpsServer {
231        HttpsServer { config: self.config, router: self.router }
232    }
233}
234
235impl Default for HttpsServerBuilder {
236    fn default() -> Self {
237        Self::new()
238    }
239}
240
241/// HTTPS 服务
242pub struct HttpsServer {
243    config: HttpsServerConfig,
244    router: AxumRouter,
245}
246
247impl HttpsServer {
248    /// 启动 HTTPS 服务
249    pub async fn serve(self) -> HttpsResult<()> {
250        let addr = self.config.addr;
251        let service_name = self.config.service_name.clone();
252        let protocol_info = self.get_protocol_info();
253        let tls_config = self.config.tls_config.clone();
254
255        let listener =
256            TcpListener::bind(addr).await.map_err(|e| WaeError::internal(format!("Failed to bind address: {}", e)))?;
257
258        info!("{} {} server starting on {}", service_name, protocol_info, addr);
259
260        match tls_config {
261            Some(tls_config) => self.serve_tls(listener, &tls_config).await,
262            None => self.serve_plain(listener).await,
263        }
264    }
265
266    /// 启动纯文本 HTTP 服务(支持 HTTP/1.1 和 h2c)
267    async fn serve_plain(self, listener: TcpListener) -> HttpsResult<()> {
268        let app = self.router;
269
270        axum::serve(listener, app).await.map_err(|e| WaeError::internal(format!("Server error: {}", e)))?;
271
272        Ok(())
273    }
274
275    /// 启动 TLS HTTPS 服务(支持 HTTP/1.1 和 HTTP/2 over TLS)
276    async fn serve_tls(self, listener: TcpListener, tls_config: &TlsConfig) -> HttpsResult<()> {
277        let tls_acceptor =
278            tls::create_tls_acceptor_with_http2(&tls_config.cert_path, &tls_config.key_path, self.config.http2_config.enabled)
279                .map_err(|e| WaeError::internal(format!("TLS config error: {}", e)))?;
280
281        let app = self.router;
282
283        loop {
284            let (stream, _remote_addr) =
285                listener.accept().await.map_err(|e| WaeError::internal(format!("Failed to accept connection: {}", e)))?;
286
287            let acceptor = tls_acceptor.clone();
288            let app = app.clone();
289
290            tokio::spawn(async move {
291                let tls_stream = match acceptor.accept(stream).await {
292                    Ok(s) => s,
293                    Err(e) => {
294                        tracing::debug!("TLS handshake error: {}", e);
295                        return;
296                    }
297                };
298
299                let service = TowerToHyperService::new(app);
300                let io = hyper_util::rt::TokioIo::new(tls_stream);
301
302                let builder = hyper::server::conn::http2::Builder::new(hyper_util::rt::TokioExecutor::new());
303                let conn = builder.serve_connection(io, service);
304
305                if let Err(e) = conn.await {
306                    tracing::debug!("HTTP/2 connection error: {}", e);
307                }
308            });
309        }
310    }
311
312    /// 获取协议信息字符串
313    fn get_protocol_info(&self) -> String {
314        let tls_info = if self.config.tls_config.is_some() { "S" } else { "" };
315        let version_info = match self.config.http_version {
316            HttpVersion::Http1Only => "HTTP/1.1",
317            HttpVersion::Http2Only => "HTTP/2",
318            HttpVersion::Both => "HTTP/1.1+HTTP/2",
319        };
320        format!("{}{}", version_info, tls_info)
321    }
322}
323
324/// 统一 JSON 响应结构
325#[derive(Debug, serde::Serialize)]
326pub struct ApiResponse<T> {
327    /// 是否成功
328    pub success: bool,
329    /// 响应数据
330    pub data: Option<T>,
331    /// 错误信息
332    pub error: Option<ApiErrorBody>,
333    /// 请求追踪 ID
334    pub trace_id: Option<String>,
335}
336
337/// API 错误体
338#[derive(Debug, serde::Serialize)]
339pub struct ApiErrorBody {
340    /// 错误码
341    pub code: String,
342    /// 错误消息
343    pub message: String,
344}
345
346impl<T: serde::Serialize> IntoResponse for ApiResponse<T> {
347    fn into_response(self) -> Response {
348        let status = if self.success { StatusCode::OK } else { StatusCode::BAD_REQUEST };
349
350        let body = serde_json::to_string(&self).unwrap_or_default();
351        Response::builder().status(status).header(header::CONTENT_TYPE, "application/json").body(Body::from(body)).unwrap()
352    }
353}
354
355impl<T> ApiResponse<T>
356where
357    T: serde::Serialize,
358{
359    /// 创建成功响应
360    pub fn success(data: T) -> Self {
361        Self { success: true, data: Some(data), error: None, trace_id: None }
362    }
363
364    /// 创建成功响应(带追踪 ID)
365    pub fn success_with_trace(data: T, trace_id: impl Into<String>) -> Self {
366        Self { success: true, data: Some(data), error: None, trace_id: Some(trace_id.into()) }
367    }
368
369    /// 创建错误响应
370    pub fn error(code: impl Into<String>, message: impl Into<String>) -> Self {
371        Self {
372            success: false,
373            data: None,
374            error: Some(ApiErrorBody { code: code.into(), message: message.into() }),
375            trace_id: None,
376        }
377    }
378
379    /// 创建错误响应(带追踪 ID)
380    pub fn error_with_trace(code: impl Into<String>, message: impl Into<String>, trace_id: impl Into<String>) -> Self {
381        Self {
382            success: false,
383            data: None,
384            error: Some(ApiErrorBody { code: code.into(), message: message.into() }),
385            trace_id: Some(trace_id.into()),
386        }
387    }
388}