silent_openapi/
middleware.rs

1//! Swagger UI 中间件
2//!
3//! 提供中间件形式的Swagger UI支持,可以更灵活地集成到现有路由中。
4
5use crate::{OpenApiError, Result, SwaggerUiOptions};
6use async_trait::async_trait;
7use silent::{Handler, MiddleWareHandler, Next, Request, Response, StatusCode};
8use utoipa::openapi::OpenApi;
9
10/// Swagger UI 中间件
11///
12/// 实现了Silent的MiddleWareHandler trait,可以作为中间件添加到路由中。
13/// 当请求匹配Swagger UI相关路径时,直接返回响应;否则继续执行后续处理器。
14#[derive(Clone)]
15pub struct SwaggerUiMiddleware {
16    /// Swagger UI的基础路径
17    ui_path: String,
18    /// OpenAPI JSON的路径
19    api_doc_path: String,
20    /// OpenAPI 规范的JSON字符串
21    openapi_json: String,
22    /// UI 配置
23    options: SwaggerUiOptions,
24}
25
26impl SwaggerUiMiddleware {
27    /// 创建新的Swagger UI中间件
28    ///
29    /// # 参数
30    ///
31    /// - `ui_path`: Swagger UI的访问路径,如 "/swagger-ui"
32    /// - `openapi`: OpenAPI规范对象
33    ///
34    /// # 示例
35    ///
36    /// ```ignore
37    /// use silent::prelude::*;
38    /// use silent_openapi::SwaggerUiMiddleware;
39    /// use utoipa::OpenApi;
40    ///
41    /// #[derive(OpenApi)]
42    /// #[openapi(paths(), components(schemas()))]
43    /// struct ApiDoc;
44    ///
45    /// let middleware = SwaggerUiMiddleware::new("/swagger-ui", ApiDoc::openapi());
46    ///
47    /// let route = Route::new("")
48    ///     .hook(middleware)
49    ///     .get(your_handler);
50    /// ```
51    pub fn new(ui_path: &str, openapi: OpenApi) -> Result<Self> {
52        let api_doc_path = format!("{}/openapi.json", ui_path.trim_end_matches('/'));
53        let openapi_json = serde_json::to_string_pretty(&openapi).map_err(OpenApiError::Json)?;
54
55        Ok(Self {
56            ui_path: ui_path.to_string(),
57            api_doc_path,
58            openapi_json,
59            options: SwaggerUiOptions::default(),
60        })
61    }
62
63    /// 使用自定义的API文档路径
64    pub fn with_custom_api_doc_path(
65        ui_path: &str,
66        api_doc_path: &str,
67        openapi: OpenApi,
68    ) -> Result<Self> {
69        let openapi_json = serde_json::to_string_pretty(&openapi).map_err(OpenApiError::Json)?;
70
71        Ok(Self {
72            ui_path: ui_path.to_string(),
73            api_doc_path: api_doc_path.to_string(),
74            openapi_json,
75            options: SwaggerUiOptions::default(),
76        })
77    }
78
79    /// 使用自定义选项创建中间件
80    pub fn with_options(
81        ui_path: &str,
82        openapi: OpenApi,
83        options: SwaggerUiOptions,
84    ) -> Result<Self> {
85        let api_doc_path = format!("{}/openapi.json", ui_path.trim_end_matches('/'));
86        let openapi_json = serde_json::to_string_pretty(&openapi).map_err(OpenApiError::Json)?;
87
88        Ok(Self {
89            ui_path: ui_path.to_string(),
90            api_doc_path,
91            openapi_json,
92            options,
93        })
94    }
95
96    /// 检查请求路径是否匹配Swagger UI相关路径
97    fn matches_swagger_path(&self, path: &str) -> bool {
98        path == self.ui_path
99            || path.starts_with(&format!("{}/", self.ui_path))
100            || path == self.api_doc_path
101    }
102
103    /// 处理Swagger UI相关请求
104    async fn handle_swagger_request(&self, path: &str) -> Result<Response> {
105        if path == self.api_doc_path {
106            self.handle_openapi_json().await
107        } else if path == self.ui_path {
108            self.handle_ui_redirect().await
109        } else {
110            self.handle_ui_resource(path).await
111        }
112    }
113
114    /// 处理OpenAPI JSON请求
115    async fn handle_openapi_json(&self) -> Result<Response> {
116        let mut response = Response::empty();
117        response.set_status(StatusCode::OK);
118        response.set_header(
119            http::header::CONTENT_TYPE,
120            http::HeaderValue::from_static("application/json; charset=utf-8"),
121        );
122        response.set_header(
123            http::header::ACCESS_CONTROL_ALLOW_ORIGIN,
124            http::HeaderValue::from_static("*"),
125        );
126        response.set_body(self.openapi_json.clone().into());
127        Ok(response)
128    }
129
130    /// 处理UI主页重定向
131    async fn handle_ui_redirect(&self) -> Result<Response> {
132        let redirect_url = format!("{}/", self.ui_path);
133        let mut response = Response::empty();
134        response.set_status(StatusCode::MOVED_PERMANENTLY);
135        response.set_header(
136            http::header::LOCATION,
137            http::HeaderValue::from_str(&redirect_url)
138                .unwrap_or_else(|_| http::HeaderValue::from_static("/")),
139        );
140        Ok(response)
141    }
142
143    /// 处理UI资源请求
144    async fn handle_ui_resource(&self, path: &str) -> Result<Response> {
145        let relative_path = path
146            .strip_prefix(&format!("{}/", self.ui_path))
147            .unwrap_or("");
148
149        if relative_path.is_empty() || relative_path == "index.html" {
150            self.serve_swagger_ui_index().await
151        } else {
152            // 对于其他资源,返回404(基础版本使用CDN)
153            let mut response = Response::empty();
154            response.set_status(StatusCode::NOT_FOUND);
155            response.set_body("Resource not found".into());
156            Ok(response)
157        }
158    }
159
160    /// 生成Swagger UI主页HTML
161    async fn serve_swagger_ui_index(&self) -> Result<Response> {
162        let html = format!(
163            r#"<!DOCTYPE html>
164<html lang="zh-CN">
165<head>
166    <meta charset="UTF-8">
167    <meta name="viewport" content="width=device-width, initial-scale=1.0">
168    <title>API Documentation - Swagger UI</title>
169    <link rel="stylesheet" type="text/css" href="https://unpkg.com/swagger-ui-dist@5.17.14/swagger-ui.css" />
170    <link rel="icon" type="image/png" href="https://unpkg.com/swagger-ui-dist@5.17.14/favicon-32x32.png" sizes="32x32" />
171    <style>
172        html {{
173            box-sizing: border-box;
174            overflow: -moz-scrollbars-vertical;
175            overflow-y: scroll;
176        }}
177        *, *:before, *:after {{
178            box-sizing: inherit;
179        }}
180        body {{
181            margin: 0;
182            background: #fafafa;
183        }}
184        .swagger-ui .topbar {{
185            display: none;
186        }}
187        .swagger-ui .info {{
188            margin: 50px 0;
189        }}
190        .custom-header {{
191            background: #89CFF0;
192            padding: 20px;
193            text-align: center;
194            color: #1976d2;
195            font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, sans-serif;
196        }}
197        .custom-header h1 {{
198            margin: 0;
199            font-size: 24px;
200            font-weight: 600;
201        }}
202        .custom-header p {{
203            margin: 8px 0 0 0;
204            opacity: 0.8;
205        }}
206    </style>
207</head>
208<body>
209    <div class="custom-header">
210        <h1>🚀 Silent Framework API Documentation</h1>
211        <p>基于 OpenAPI 3.0 规范的交互式 API 文档</p>
212    </div>
213    <div id="swagger-ui"></div>
214
215    <script src="https://unpkg.com/swagger-ui-dist@5.17.14/swagger-ui-bundle.js"></script>
216    <script src="https://unpkg.com/swagger-ui-dist@5.17.14/swagger-ui-standalone-preset.js"></script>
217    <script>
218        window.onload = function() {{
219            // 配置Swagger UI
220            const ui = SwaggerUIBundle({{
221                url: '{}',
222                dom_id: '#swagger-ui',
223                deepLinking: true,
224                presets: [
225                    SwaggerUIBundle.presets.apis,
226                    SwaggerUIStandalonePreset
227                ],
228                plugins: [
229                    SwaggerUIBundle.plugins.DownloadUrl
230                ],
231                layout: "StandaloneLayout",
232                validatorUrl: null,
233                docExpansion: "list",
234                defaultModelsExpandDepth: 1,
235                defaultModelExpandDepth: 1,
236                displayRequestDuration: true,
237                filter: true,
238                showExtensions: true,
239                showCommonExtensions: true,
240                tryItOutEnabled: {}
241            }});
242
243            // 添加自定义样式
244            window.ui = ui;
245        }}
246    </script>
247</body>
248</html>"#,
249            self.api_doc_path,
250            if self.options.try_it_out_enabled {
251                "true"
252            } else {
253                "false"
254            }
255        );
256
257        let mut response = Response::empty();
258        response.set_status(StatusCode::OK);
259        response.set_header(
260            http::header::CONTENT_TYPE,
261            http::HeaderValue::from_static("text/html; charset=utf-8"),
262        );
263        response.set_header(
264            http::header::CACHE_CONTROL,
265            http::HeaderValue::from_static("no-cache, no-store, must-revalidate"),
266        );
267        response.set_body(html.into());
268        Ok(response)
269    }
270}
271
272#[async_trait]
273impl MiddleWareHandler for SwaggerUiMiddleware {
274    /// 处理请求:命中 Swagger 相关路径则拦截返回,否则交由下一个处理器
275    async fn handle(&self, req: Request, next: &Next) -> silent::Result<Response> {
276        let path = req.uri().path();
277        if self.matches_swagger_path(path) {
278            match self.handle_swagger_request(path).await {
279                Ok(response) => Ok(response),
280                Err(e) => {
281                    eprintln!("Swagger UI中间件处理错误: {}", e);
282                    // 返回适当的错误响应
283                    let mut response = Response::empty();
284                    response.set_status(StatusCode::INTERNAL_SERVER_ERROR);
285                    response.set_body(format!("Swagger UI Error: {}", e).into());
286                    Ok(response)
287                }
288            }
289        } else {
290            next.call(req).await
291        }
292    }
293}
294
295/// 便捷函数:创建Swagger UI中间件并添加到路由
296///
297/// # 参数
298///
299/// - `route`: 要添加中间件的路由
300/// - `ui_path`: Swagger UI的访问路径
301/// - `openapi`: OpenAPI规范对象
302///
303/// # 示例
304///
305/// ```ignore
306/// use silent::prelude::*;
307/// use silent_openapi::add_swagger_ui;
308/// use utoipa::OpenApi;
309///
310/// #[derive(OpenApi)]
311/// #[openapi(paths(), components(schemas()))]
312/// struct ApiDoc;
313///
314/// let route = Route::new("api")
315///     .get(some_handler);
316///
317/// let route_with_swagger = add_swagger_ui(route, "/docs", ApiDoc::openapi());
318/// ```
319pub fn add_swagger_ui(
320    route: silent::prelude::Route,
321    ui_path: &str,
322    openapi: OpenApi,
323) -> silent::prelude::Route {
324    match SwaggerUiMiddleware::new(ui_path, openapi) {
325        Ok(middleware) => route.hook(middleware),
326        Err(e) => {
327            eprintln!("创建Swagger UI中间件失败: {}", e);
328            route
329        }
330    }
331}
332
333#[cfg(test)]
334mod tests {
335    use super::*;
336    use utoipa::OpenApi;
337
338    #[derive(OpenApi)]
339    #[openapi(
340        info(title = "Test API", version = "1.0.0"),
341        paths(),
342        components(schemas())
343    )]
344    struct TestApiDoc;
345
346    #[test]
347    fn test_middleware_creation() {
348        let middleware = SwaggerUiMiddleware::new("/docs", TestApiDoc::openapi());
349        assert!(middleware.is_ok());
350
351        let middleware = middleware.unwrap();
352        assert_eq!(middleware.ui_path, "/docs");
353        assert_eq!(middleware.api_doc_path, "/docs/openapi.json");
354    }
355
356    #[test]
357    fn test_path_matching() {
358        let middleware = SwaggerUiMiddleware::new("/docs", TestApiDoc::openapi()).unwrap();
359
360        assert!(middleware.matches_swagger_path("/docs"));
361        assert!(middleware.matches_swagger_path("/docs/"));
362        assert!(middleware.matches_swagger_path("/docs/index.html"));
363        assert!(middleware.matches_swagger_path("/docs/openapi.json"));
364        assert!(!middleware.matches_swagger_path("/api/users"));
365        assert!(!middleware.matches_swagger_path("/doc"));
366    }
367
368    #[tokio::test]
369    async fn test_openapi_json_handling() {
370        let middleware = SwaggerUiMiddleware::new("/docs", TestApiDoc::openapi()).unwrap();
371        let response = middleware.handle_openapi_json().await.unwrap();
372
373        // 验证Content-Type头(Silent Response没有public的status方法)
374        let content_type = response.headers().get(http::header::CONTENT_TYPE);
375        assert!(content_type.is_some());
376        // 验证CORS头
377        assert!(
378            response
379                .headers()
380                .get(http::header::ACCESS_CONTROL_ALLOW_ORIGIN)
381                .is_some()
382        );
383    }
384
385    #[tokio::test]
386    async fn test_redirect_on_base_path() {
387        let middleware = SwaggerUiMiddleware::new("/docs", TestApiDoc::openapi()).unwrap();
388        let resp = middleware.handle_swagger_request("/docs").await.unwrap();
389        // 无法读取状态码,验证是否存在 LOCATION 头以确认重定向
390        assert!(resp.headers().get(http::header::LOCATION).is_some());
391    }
392
393    #[tokio::test]
394    async fn test_custom_api_doc_path() {
395        let mw = SwaggerUiMiddleware::with_custom_api_doc_path(
396            "/docs",
397            "/openapi-docs.json",
398            TestApiDoc::openapi(),
399        )
400        .unwrap();
401        // 自定义路径匹配
402        assert!(mw.matches_swagger_path("/openapi-docs.json"));
403        let resp = mw
404            .handle_swagger_request("/openapi-docs.json")
405            .await
406            .unwrap();
407        assert!(
408            resp.headers()
409                .get(http::header::CONTENT_TYPE)
410                .map(|v| v.to_str().unwrap_or("").contains("application/json"))
411                .unwrap_or(false)
412        );
413    }
414
415    #[tokio::test]
416    async fn test_non_match_request_path() {
417        let mw = SwaggerUiMiddleware::new("/docs", TestApiDoc::openapi()).unwrap();
418        assert!(!mw.matches_swagger_path("/other"));
419    }
420
421    #[tokio::test]
422    async fn test_asset_404_branch() {
423        let mw = SwaggerUiMiddleware::new("/docs", TestApiDoc::openapi()).unwrap();
424        let resp = mw.handle_swagger_request("/docs/app.css").await.unwrap();
425        // 不应是重定向
426        assert!(resp.headers().get(http::header::LOCATION).is_none());
427    }
428
429    #[tokio::test]
430    async fn test_index_html_headers() {
431        let mw = SwaggerUiMiddleware::new("/docs", TestApiDoc::openapi()).unwrap();
432        let resp = mw.handle_swagger_request("/docs/index.html").await.unwrap();
433        let ct = resp.headers().get(http::header::CONTENT_TYPE).unwrap();
434        assert!(ct.to_str().unwrap_or("").contains("text/html"));
435        assert!(resp.headers().get(http::header::CACHE_CONTROL).is_some());
436    }
437}
438
439// 选项类型在 crate 根导出