1use crate::{OpenApiError, Result, SwaggerUiOptions};
6use async_trait::async_trait;
7use silent::{Handler, MiddleWareHandler, Next, Request, Response, StatusCode};
8use utoipa::openapi::OpenApi;
9
10#[derive(Clone)]
15pub struct SwaggerUiMiddleware {
16 ui_path: String,
18 api_doc_path: String,
20 openapi_json: String,
22 options: SwaggerUiOptions,
24}
25
26impl SwaggerUiMiddleware {
27 pub fn new(ui_path: &str, openapi: OpenApi) -> Result<Self> {
52 let api_doc_path = format!("{}/openapi.json", ui_path.trim_end_matches('/'));
53 let openapi_json = serde_json::to_string_pretty(&openapi).map_err(OpenApiError::Json)?;
54
55 Ok(Self {
56 ui_path: ui_path.to_string(),
57 api_doc_path,
58 openapi_json,
59 options: SwaggerUiOptions::default(),
60 })
61 }
62
63 pub fn with_custom_api_doc_path(
65 ui_path: &str,
66 api_doc_path: &str,
67 openapi: OpenApi,
68 ) -> Result<Self> {
69 let openapi_json = serde_json::to_string_pretty(&openapi).map_err(OpenApiError::Json)?;
70
71 Ok(Self {
72 ui_path: ui_path.to_string(),
73 api_doc_path: api_doc_path.to_string(),
74 openapi_json,
75 options: SwaggerUiOptions::default(),
76 })
77 }
78
79 pub fn with_options(
81 ui_path: &str,
82 openapi: OpenApi,
83 options: SwaggerUiOptions,
84 ) -> Result<Self> {
85 let api_doc_path = format!("{}/openapi.json", ui_path.trim_end_matches('/'));
86 let openapi_json = serde_json::to_string_pretty(&openapi).map_err(OpenApiError::Json)?;
87
88 Ok(Self {
89 ui_path: ui_path.to_string(),
90 api_doc_path,
91 openapi_json,
92 options,
93 })
94 }
95
96 fn matches_swagger_path(&self, path: &str) -> bool {
98 path == self.ui_path
99 || path.starts_with(&format!("{}/", self.ui_path))
100 || path == self.api_doc_path
101 }
102
103 async fn handle_swagger_request(&self, path: &str) -> Result<Response> {
105 if path == self.api_doc_path {
106 self.handle_openapi_json().await
107 } else if path == self.ui_path {
108 self.handle_ui_redirect().await
109 } else {
110 self.handle_ui_resource(path).await
111 }
112 }
113
114 async fn handle_openapi_json(&self) -> Result<Response> {
116 let mut response = Response::empty();
117 response.set_status(StatusCode::OK);
118 response.set_header(
119 http::header::CONTENT_TYPE,
120 http::HeaderValue::from_static("application/json; charset=utf-8"),
121 );
122 response.set_header(
123 http::header::ACCESS_CONTROL_ALLOW_ORIGIN,
124 http::HeaderValue::from_static("*"),
125 );
126 response.set_body(self.openapi_json.clone().into());
127 Ok(response)
128 }
129
130 async fn handle_ui_redirect(&self) -> Result<Response> {
132 let redirect_url = format!("{}/", self.ui_path);
133 let mut response = Response::empty();
134 response.set_status(StatusCode::MOVED_PERMANENTLY);
135 response.set_header(
136 http::header::LOCATION,
137 http::HeaderValue::from_str(&redirect_url)
138 .unwrap_or_else(|_| http::HeaderValue::from_static("/")),
139 );
140 Ok(response)
141 }
142
143 async fn handle_ui_resource(&self, path: &str) -> Result<Response> {
145 let relative_path = path
146 .strip_prefix(&format!("{}/", self.ui_path))
147 .unwrap_or("");
148
149 if relative_path.is_empty() || relative_path == "index.html" {
150 self.serve_swagger_ui_index().await
151 } else {
152 let mut response = Response::empty();
154 response.set_status(StatusCode::NOT_FOUND);
155 response.set_body("Resource not found".into());
156 Ok(response)
157 }
158 }
159
160 async fn serve_swagger_ui_index(&self) -> Result<Response> {
162 let html = format!(
163 r#"<!DOCTYPE html>
164<html lang="zh-CN">
165<head>
166 <meta charset="UTF-8">
167 <meta name="viewport" content="width=device-width, initial-scale=1.0">
168 <title>API Documentation - Swagger UI</title>
169 <link rel="stylesheet" type="text/css" href="https://unpkg.com/swagger-ui-dist@5.17.14/swagger-ui.css" />
170 <link rel="icon" type="image/png" href="https://unpkg.com/swagger-ui-dist@5.17.14/favicon-32x32.png" sizes="32x32" />
171 <style>
172 html {{
173 box-sizing: border-box;
174 overflow: -moz-scrollbars-vertical;
175 overflow-y: scroll;
176 }}
177 *, *:before, *:after {{
178 box-sizing: inherit;
179 }}
180 body {{
181 margin: 0;
182 background: #fafafa;
183 }}
184 .swagger-ui .topbar {{
185 display: none;
186 }}
187 .swagger-ui .info {{
188 margin: 50px 0;
189 }}
190 .custom-header {{
191 background: #89CFF0;
192 padding: 20px;
193 text-align: center;
194 color: #1976d2;
195 font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, sans-serif;
196 }}
197 .custom-header h1 {{
198 margin: 0;
199 font-size: 24px;
200 font-weight: 600;
201 }}
202 .custom-header p {{
203 margin: 8px 0 0 0;
204 opacity: 0.8;
205 }}
206 </style>
207</head>
208<body>
209 <div class="custom-header">
210 <h1>🚀 Silent Framework API Documentation</h1>
211 <p>基于 OpenAPI 3.0 规范的交互式 API 文档</p>
212 </div>
213 <div id="swagger-ui"></div>
214
215 <script src="https://unpkg.com/swagger-ui-dist@5.17.14/swagger-ui-bundle.js"></script>
216 <script src="https://unpkg.com/swagger-ui-dist@5.17.14/swagger-ui-standalone-preset.js"></script>
217 <script>
218 window.onload = function() {{
219 // 配置Swagger UI
220 const ui = SwaggerUIBundle({{
221 url: '{}',
222 dom_id: '#swagger-ui',
223 deepLinking: true,
224 presets: [
225 SwaggerUIBundle.presets.apis,
226 SwaggerUIStandalonePreset
227 ],
228 plugins: [
229 SwaggerUIBundle.plugins.DownloadUrl
230 ],
231 layout: "StandaloneLayout",
232 validatorUrl: null,
233 docExpansion: "list",
234 defaultModelsExpandDepth: 1,
235 defaultModelExpandDepth: 1,
236 displayRequestDuration: true,
237 filter: true,
238 showExtensions: true,
239 showCommonExtensions: true,
240 tryItOutEnabled: {}
241 }});
242
243 // 添加自定义样式
244 window.ui = ui;
245 }}
246 </script>
247</body>
248</html>"#,
249 self.api_doc_path,
250 if self.options.try_it_out_enabled {
251 "true"
252 } else {
253 "false"
254 }
255 );
256
257 let mut response = Response::empty();
258 response.set_status(StatusCode::OK);
259 response.set_header(
260 http::header::CONTENT_TYPE,
261 http::HeaderValue::from_static("text/html; charset=utf-8"),
262 );
263 response.set_header(
264 http::header::CACHE_CONTROL,
265 http::HeaderValue::from_static("no-cache, no-store, must-revalidate"),
266 );
267 response.set_body(html.into());
268 Ok(response)
269 }
270}
271
272#[async_trait]
273impl MiddleWareHandler for SwaggerUiMiddleware {
274 async fn handle(&self, req: Request, next: &Next) -> silent::Result<Response> {
276 let path = req.uri().path();
277 if self.matches_swagger_path(path) {
278 match self.handle_swagger_request(path).await {
279 Ok(response) => Ok(response),
280 Err(e) => {
281 eprintln!("Swagger UI中间件处理错误: {}", e);
282 let mut response = Response::empty();
284 response.set_status(StatusCode::INTERNAL_SERVER_ERROR);
285 response.set_body(format!("Swagger UI Error: {}", e).into());
286 Ok(response)
287 }
288 }
289 } else {
290 next.call(req).await
291 }
292 }
293}
294
295pub fn add_swagger_ui(
320 route: silent::prelude::Route,
321 ui_path: &str,
322 openapi: OpenApi,
323) -> silent::prelude::Route {
324 match SwaggerUiMiddleware::new(ui_path, openapi) {
325 Ok(middleware) => route.hook(middleware),
326 Err(e) => {
327 eprintln!("创建Swagger UI中间件失败: {}", e);
328 route
329 }
330 }
331}
332
333#[cfg(test)]
334mod tests {
335 use super::*;
336 use utoipa::OpenApi;
337
338 #[derive(OpenApi)]
339 #[openapi(
340 info(title = "Test API", version = "1.0.0"),
341 paths(),
342 components(schemas())
343 )]
344 struct TestApiDoc;
345
346 #[test]
347 fn test_middleware_creation() {
348 let middleware = SwaggerUiMiddleware::new("/docs", TestApiDoc::openapi());
349 assert!(middleware.is_ok());
350
351 let middleware = middleware.unwrap();
352 assert_eq!(middleware.ui_path, "/docs");
353 assert_eq!(middleware.api_doc_path, "/docs/openapi.json");
354 }
355
356 #[test]
357 fn test_path_matching() {
358 let middleware = SwaggerUiMiddleware::new("/docs", TestApiDoc::openapi()).unwrap();
359
360 assert!(middleware.matches_swagger_path("/docs"));
361 assert!(middleware.matches_swagger_path("/docs/"));
362 assert!(middleware.matches_swagger_path("/docs/index.html"));
363 assert!(middleware.matches_swagger_path("/docs/openapi.json"));
364 assert!(!middleware.matches_swagger_path("/api/users"));
365 assert!(!middleware.matches_swagger_path("/doc"));
366 }
367
368 #[tokio::test]
369 async fn test_openapi_json_handling() {
370 let middleware = SwaggerUiMiddleware::new("/docs", TestApiDoc::openapi()).unwrap();
371 let response = middleware.handle_openapi_json().await.unwrap();
372
373 let content_type = response.headers().get(http::header::CONTENT_TYPE);
375 assert!(content_type.is_some());
376 assert!(
378 response
379 .headers()
380 .get(http::header::ACCESS_CONTROL_ALLOW_ORIGIN)
381 .is_some()
382 );
383 }
384
385 #[tokio::test]
386 async fn test_redirect_on_base_path() {
387 let middleware = SwaggerUiMiddleware::new("/docs", TestApiDoc::openapi()).unwrap();
388 let resp = middleware.handle_swagger_request("/docs").await.unwrap();
389 assert!(resp.headers().get(http::header::LOCATION).is_some());
391 }
392
393 #[tokio::test]
394 async fn test_custom_api_doc_path() {
395 let mw = SwaggerUiMiddleware::with_custom_api_doc_path(
396 "/docs",
397 "/openapi-docs.json",
398 TestApiDoc::openapi(),
399 )
400 .unwrap();
401 assert!(mw.matches_swagger_path("/openapi-docs.json"));
403 let resp = mw
404 .handle_swagger_request("/openapi-docs.json")
405 .await
406 .unwrap();
407 assert!(
408 resp.headers()
409 .get(http::header::CONTENT_TYPE)
410 .map(|v| v.to_str().unwrap_or("").contains("application/json"))
411 .unwrap_or(false)
412 );
413 }
414
415 #[tokio::test]
416 async fn test_non_match_request_path() {
417 let mw = SwaggerUiMiddleware::new("/docs", TestApiDoc::openapi()).unwrap();
418 assert!(!mw.matches_swagger_path("/other"));
419 }
420
421 #[tokio::test]
422 async fn test_asset_404_branch() {
423 let mw = SwaggerUiMiddleware::new("/docs", TestApiDoc::openapi()).unwrap();
424 let resp = mw.handle_swagger_request("/docs/app.css").await.unwrap();
425 assert!(resp.headers().get(http::header::LOCATION).is_none());
427 }
428
429 #[tokio::test]
430 async fn test_index_html_headers() {
431 let mw = SwaggerUiMiddleware::new("/docs", TestApiDoc::openapi()).unwrap();
432 let resp = mw.handle_swagger_request("/docs/index.html").await.unwrap();
433 let ct = resp.headers().get(http::header::CONTENT_TYPE).unwrap();
434 assert!(ct.to_str().unwrap_or("").contains("text/html"));
435 assert!(resp.headers().get(http::header::CACHE_CONTROL).is_some());
436 }
437}
438
439