1#![doc = include_str!("../readme.md")]
2#![warn(missing_docs)]
3
4pub mod error;
5pub mod extract;
6pub mod middleware;
7pub mod response;
8pub mod router;
9pub mod template;
10pub mod tls;
11
12pub use wae_session as session;
13
14use axum::{
15 Router as AxumRouter,
16 body::Body,
17 http::{StatusCode, header},
18 response::{IntoResponse, Response},
19};
20use hyper_util::service::TowerToHyperService;
21use std::{net::SocketAddr, time::Duration};
22use tokio::net::TcpListener;
23use tracing::info;
24
25pub use wae_types::{CloudError, CloudResult, WaeError, WaeResult};
26
27pub type HttpsResult<T> = WaeResult<T>;
29
30pub type HttpsError = WaeError;
32
33#[derive(Debug, Clone, Copy, Default)]
35pub enum HttpVersion {
36 Http1Only,
38 Http2Only,
40 #[default]
42 Both,
43}
44
45#[derive(Debug, Clone)]
47pub struct Http2Config {
48 pub enabled: bool,
50 pub enable_push: bool,
52 pub max_concurrent_streams: u32,
54 pub initial_stream_window_size: u32,
56 pub max_frame_size: u32,
58 pub enable_connect_protocol: bool,
60 pub stream_idle_timeout: Duration,
62}
63
64impl Default for Http2Config {
65 fn default() -> Self {
66 Self {
67 enabled: true,
68 enable_push: false,
69 max_concurrent_streams: 256,
70 initial_stream_window_size: 65535,
71 max_frame_size: 16384,
72 enable_connect_protocol: false,
73 stream_idle_timeout: Duration::from_secs(60),
74 }
75 }
76}
77
78impl Http2Config {
79 pub fn new() -> Self {
81 Self::default()
82 }
83
84 pub fn disabled() -> Self {
86 Self { enabled: false, ..Self::default() }
87 }
88
89 pub fn with_enable_push(mut self, enable: bool) -> Self {
91 self.enable_push = enable;
92 self
93 }
94
95 pub fn with_max_concurrent_streams(mut self, max: u32) -> Self {
97 self.max_concurrent_streams = max;
98 self
99 }
100
101 pub fn with_initial_stream_window_size(mut self, size: u32) -> Self {
103 self.initial_stream_window_size = size;
104 self
105 }
106
107 pub fn with_max_frame_size(mut self, size: u32) -> Self {
109 self.max_frame_size = size;
110 self
111 }
112
113 pub fn with_enable_connect_protocol(mut self, enable: bool) -> Self {
115 self.enable_connect_protocol = enable;
116 self
117 }
118
119 pub fn with_stream_idle_timeout(mut self, timeout: Duration) -> Self {
121 self.stream_idle_timeout = timeout;
122 self
123 }
124}
125
126#[derive(Debug, Clone)]
128pub struct TlsConfig {
129 pub cert_path: String,
131 pub key_path: String,
133}
134
135impl TlsConfig {
136 pub fn new(cert_path: impl Into<String>, key_path: impl Into<String>) -> Self {
138 Self { cert_path: cert_path.into(), key_path: key_path.into() }
139 }
140}
141
142#[derive(Debug, Clone)]
144pub struct HttpsServerConfig {
145 pub addr: SocketAddr,
147 pub service_name: String,
149 pub http_version: HttpVersion,
151 pub http2_config: Http2Config,
153 pub tls_config: Option<TlsConfig>,
155}
156
157impl Default for HttpsServerConfig {
158 fn default() -> Self {
159 Self {
160 addr: "0.0.0.0:3000".parse().unwrap(),
161 service_name: "wae-https-service".to_string(),
162 http_version: HttpVersion::Both,
163 http2_config: Http2Config::default(),
164 tls_config: None,
165 }
166 }
167}
168
169pub struct HttpsServerBuilder {
171 config: HttpsServerConfig,
172 router: AxumRouter,
173}
174
175impl HttpsServerBuilder {
176 pub fn new() -> Self {
178 Self { config: HttpsServerConfig::default(), router: AxumRouter::new() }
179 }
180
181 pub fn addr(mut self, addr: SocketAddr) -> Self {
183 self.config.addr = addr;
184 self
185 }
186
187 pub fn service_name(mut self, name: impl Into<String>) -> Self {
189 self.config.service_name = name.into();
190 self
191 }
192
193 pub fn router(mut self, router: AxumRouter) -> Self {
195 self.router = router;
196 self
197 }
198
199 pub fn merge_router(mut self, router: AxumRouter) -> Self {
201 self.router = self.router.merge(router);
202 self
203 }
204
205 pub fn http_version(mut self, version: HttpVersion) -> Self {
207 self.config.http_version = version;
208 self
209 }
210
211 pub fn http2_config(mut self, config: Http2Config) -> Self {
213 self.config.http2_config = config;
214 self
215 }
216
217 pub fn tls(mut self, cert_path: impl Into<String>, key_path: impl Into<String>) -> Self {
219 self.config.tls_config = Some(TlsConfig::new(cert_path, key_path));
220 self
221 }
222
223 pub fn tls_config(mut self, config: TlsConfig) -> Self {
225 self.config.tls_config = Some(config);
226 self
227 }
228
229 pub fn build(self) -> HttpsServer {
231 HttpsServer { config: self.config, router: self.router }
232 }
233}
234
235impl Default for HttpsServerBuilder {
236 fn default() -> Self {
237 Self::new()
238 }
239}
240
241pub struct HttpsServer {
243 config: HttpsServerConfig,
244 router: AxumRouter,
245}
246
247impl HttpsServer {
248 pub async fn serve(self) -> HttpsResult<()> {
250 let addr = self.config.addr;
251 let service_name = self.config.service_name.clone();
252 let protocol_info = self.get_protocol_info();
253 let tls_config = self.config.tls_config.clone();
254
255 let listener =
256 TcpListener::bind(addr).await.map_err(|e| WaeError::internal(format!("Failed to bind address: {}", e)))?;
257
258 info!("{} {} server starting on {}", service_name, protocol_info, addr);
259
260 match tls_config {
261 Some(tls_config) => self.serve_tls(listener, &tls_config).await,
262 None => self.serve_plain(listener).await,
263 }
264 }
265
266 async fn serve_plain(self, listener: TcpListener) -> HttpsResult<()> {
268 let app = self.router;
269
270 axum::serve(listener, app).await.map_err(|e| WaeError::internal(format!("Server error: {}", e)))?;
271
272 Ok(())
273 }
274
275 async fn serve_tls(self, listener: TcpListener, tls_config: &TlsConfig) -> HttpsResult<()> {
277 let tls_acceptor =
278 tls::create_tls_acceptor_with_http2(&tls_config.cert_path, &tls_config.key_path, self.config.http2_config.enabled)
279 .map_err(|e| WaeError::internal(format!("TLS config error: {}", e)))?;
280
281 let app = self.router;
282
283 loop {
284 let (stream, _remote_addr) =
285 listener.accept().await.map_err(|e| WaeError::internal(format!("Failed to accept connection: {}", e)))?;
286
287 let acceptor = tls_acceptor.clone();
288 let app = app.clone();
289
290 tokio::spawn(async move {
291 let tls_stream = match acceptor.accept(stream).await {
292 Ok(s) => s,
293 Err(e) => {
294 tracing::debug!("TLS handshake error: {}", e);
295 return;
296 }
297 };
298
299 let service = TowerToHyperService::new(app);
300 let io = hyper_util::rt::TokioIo::new(tls_stream);
301
302 let builder = hyper::server::conn::http2::Builder::new(hyper_util::rt::TokioExecutor::new());
303 let conn = builder.serve_connection(io, service);
304
305 if let Err(e) = conn.await {
306 tracing::debug!("HTTP/2 connection error: {}", e);
307 }
308 });
309 }
310 }
311
312 fn get_protocol_info(&self) -> String {
314 let tls_info = if self.config.tls_config.is_some() { "S" } else { "" };
315 let version_info = match self.config.http_version {
316 HttpVersion::Http1Only => "HTTP/1.1",
317 HttpVersion::Http2Only => "HTTP/2",
318 HttpVersion::Both => "HTTP/1.1+HTTP/2",
319 };
320 format!("{}{}", version_info, tls_info)
321 }
322}
323
324#[derive(Debug, serde::Serialize)]
326pub struct ApiResponse<T> {
327 pub success: bool,
329 pub data: Option<T>,
331 pub error: Option<ApiErrorBody>,
333 pub trace_id: Option<String>,
335}
336
337#[derive(Debug, serde::Serialize)]
339pub struct ApiErrorBody {
340 pub code: String,
342 pub message: String,
344}
345
346impl<T: serde::Serialize> IntoResponse for ApiResponse<T> {
347 fn into_response(self) -> Response {
348 let status = if self.success { StatusCode::OK } else { StatusCode::BAD_REQUEST };
349
350 let body = serde_json::to_string(&self).unwrap_or_default();
351 Response::builder().status(status).header(header::CONTENT_TYPE, "application/json").body(Body::from(body)).unwrap()
352 }
353}
354
355impl<T> ApiResponse<T>
356where
357 T: serde::Serialize,
358{
359 pub fn success(data: T) -> Self {
361 Self { success: true, data: Some(data), error: None, trace_id: None }
362 }
363
364 pub fn success_with_trace(data: T, trace_id: impl Into<String>) -> Self {
366 Self { success: true, data: Some(data), error: None, trace_id: Some(trace_id.into()) }
367 }
368
369 pub fn error(code: impl Into<String>, message: impl Into<String>) -> Self {
371 Self {
372 success: false,
373 data: None,
374 error: Some(ApiErrorBody { code: code.into(), message: message.into() }),
375 trace_id: None,
376 }
377 }
378
379 pub fn error_with_trace(code: impl Into<String>, message: impl Into<String>, trace_id: impl Into<String>) -> Self {
381 Self {
382 success: false,
383 data: None,
384 error: Some(ApiErrorBody { code: code.into(), message: message.into() }),
385 trace_id: Some(trace_id.into()),
386 }
387 }
388}