Skip to main content

silent_openapi/
doc.rs

1use std::collections::HashMap;
2use std::sync::{Arc, Mutex};
3
4use http::Method;
5use once_cell::sync::Lazy;
6
7use silent::prelude::{HandlerGetter, Route};
8use silent::{
9    Handler, HandlerWrapper, Request as SilentRequest, Response as SilentResponse,
10    Result as SilentResult,
11};
12use utoipa::openapi::{Components, ComponentsBuilder, OpenApi};
13
14/// 用于标注接口文档的元信息
15#[derive(Clone, Debug)]
16pub struct DocMeta {
17    pub summary: Option<String>,
18    pub description: Option<String>,
19}
20
21static DOC_REGISTRY: Lazy<Mutex<HashMap<usize, DocMeta>>> =
22    Lazy::new(|| Mutex::new(HashMap::new()));
23
24pub fn register_doc_by_ptr(ptr: usize, summary: Option<&str>, description: Option<&str>) {
25    let mut map = DOC_REGISTRY.lock().expect("doc registry poisoned");
26    map.insert(
27        ptr,
28        DocMeta {
29            summary: summary.map(|s| s.to_string()),
30            description: description.map(|s| s.to_string()),
31        },
32    );
33}
34
35pub(crate) fn lookup_doc_by_handler_ptr(ptr: usize) -> Option<DocMeta> {
36    DOC_REGISTRY.lock().ok().and_then(|m| m.get(&ptr).cloned())
37}
38
39/// 响应类型元信息
40#[derive(Clone, Debug)]
41pub enum ResponseMeta {
42    TextPlain,
43    Json { type_name: &'static str },
44}
45
46static RESPONSE_REGISTRY: Lazy<Mutex<HashMap<usize, ResponseMeta>>> =
47    Lazy::new(|| Mutex::new(HashMap::new()));
48
49pub fn register_response_by_ptr(ptr: usize, meta: ResponseMeta) {
50    let mut map = RESPONSE_REGISTRY
51        .lock()
52        .expect("response registry poisoned");
53    map.insert(ptr, meta);
54}
55
56pub(crate) fn lookup_response_by_handler_ptr(ptr: usize) -> Option<ResponseMeta> {
57    RESPONSE_REGISTRY
58        .lock()
59        .ok()
60        .and_then(|m| m.get(&ptr).cloned())
61}
62
63pub fn list_registered_json_types() -> Vec<&'static str> {
64    let map = RESPONSE_REGISTRY.lock().ok();
65    let mut out = Vec::new();
66    if let Some(map) = map {
67        for meta in map.values() {
68            if let ResponseMeta::Json { type_name } = meta
69                && !out.contains(type_name)
70            {
71                out.push(*type_name);
72            }
73        }
74    }
75    out
76}
77
78// ====== 请求元信息注册 ======
79
80/// 请求参数/请求体元信息
81#[derive(Clone, Debug)]
82pub enum RequestMeta {
83    /// JSON 请求体(对应 Json<T> 提取器)
84    JsonBody { type_name: &'static str },
85    /// 表单请求体(对应 Form<T> 提取器)
86    FormBody { type_name: &'static str },
87    /// 查询参数(对应 Query<T> 提取器)
88    QueryParams { type_name: &'static str },
89}
90
91static REQUEST_REGISTRY: Lazy<Mutex<HashMap<usize, Vec<RequestMeta>>>> =
92    Lazy::new(|| Mutex::new(HashMap::new()));
93
94pub fn register_request_by_ptr(ptr: usize, meta: RequestMeta) {
95    let mut map = REQUEST_REGISTRY.lock().expect("request registry poisoned");
96    map.entry(ptr).or_default().push(meta);
97}
98
99pub(crate) fn lookup_request_by_handler_ptr(ptr: usize) -> Option<Vec<RequestMeta>> {
100    REQUEST_REGISTRY
101        .lock()
102        .ok()
103        .and_then(|m| m.get(&ptr).cloned())
104}
105
106// ====== ToSchema 完整 schema 注册 ======
107type SchemaRegFn = fn(&mut Components);
108static SCHEMA_REGISTRY: Lazy<Mutex<Vec<SchemaRegFn>>> = Lazy::new(|| Mutex::new(Vec::new()));
109
110pub fn register_schema_for<T>()
111where
112    T: crate::ToSchema + ::utoipa::PartialSchema + 'static,
113{
114    fn add_impl<U: crate::ToSchema + ::utoipa::PartialSchema>(components: &mut Components) {
115        let mut refs: Vec<(
116            String,
117            ::utoipa::openapi::RefOr<::utoipa::openapi::schema::Schema>,
118        )> = Vec::new();
119        <U as crate::ToSchema>::schemas(&mut refs);
120        for (name, schema) in refs {
121            components.schemas.entry(name).or_insert(schema);
122        }
123        let name = <U as crate::ToSchema>::name().into_owned();
124        let schema = <U as ::utoipa::PartialSchema>::schema();
125        components.schemas.entry(name).or_insert(schema);
126    }
127    let mut reg = SCHEMA_REGISTRY.lock().expect("schema registry poisoned");
128    reg.push(add_impl::<T> as SchemaRegFn);
129}
130
131pub fn apply_registered_schemas(openapi: &mut OpenApi) {
132    let mut components = openapi
133        .components
134        .clone()
135        .unwrap_or_else(|| ComponentsBuilder::new().build());
136    if let Ok(reg) = SCHEMA_REGISTRY.lock() {
137        for f in reg.iter() {
138            f(&mut components);
139        }
140    }
141    openapi.components = Some(components);
142}
143
144/// 路由文档标注扩展:在完成 handler 挂载后,追加文档说明
145pub trait RouteDocMarkExt {
146    fn doc(self, method: Method, summary: &str, description: &str) -> Self;
147}
148
149/// 便捷构造:将基于 Request 的处理函数包装为 `Arc<dyn Handler>` 并注册文档
150pub fn handler_with_doc<F, Fut, T>(f: F, summary: &str, description: &str) -> Arc<dyn Handler>
151where
152    F: Fn(SilentRequest) -> Fut + Send + Sync + 'static,
153    Fut: core::future::Future<Output = SilentResult<T>> + Send + 'static,
154    T: Into<SilentResponse> + Send + 'static,
155{
156    let handler = Arc::new(HandlerWrapper::new(f));
157    let ptr = Arc::as_ptr(&handler) as *const () as usize;
158    register_doc_by_ptr(ptr, Some(summary), Some(description));
159    handler
160}
161
162impl RouteDocMarkExt for Route {
163    fn doc(self, method: Method, summary: &str, description: &str) -> Self {
164        if let Some(handler) = self.handler.get(&method).cloned() {
165            let ptr = Arc::as_ptr(&handler) as *const () as usize;
166            register_doc_by_ptr(ptr, Some(summary), Some(description));
167        }
168        self
169    }
170}
171
172/// 便捷追加:同时挂载处理器并标注文档
173pub trait RouteDocAppendExt {
174    fn get_with_doc(self, handler: Arc<dyn Handler>, summary: &str, description: &str) -> Self;
175    fn post_with_doc(self, handler: Arc<dyn Handler>, summary: &str, description: &str) -> Self;
176    fn put_with_doc(self, handler: Arc<dyn Handler>, summary: &str, description: &str) -> Self;
177    fn delete_with_doc(self, handler: Arc<dyn Handler>, summary: &str, description: &str) -> Self;
178    fn patch_with_doc(self, handler: Arc<dyn Handler>, summary: &str, description: &str) -> Self;
179    fn options_with_doc(self, handler: Arc<dyn Handler>, summary: &str, description: &str) -> Self;
180}
181
182impl RouteDocAppendExt for Route {
183    fn get_with_doc(self, handler: Arc<dyn Handler>, summary: &str, description: &str) -> Self {
184        let ptr = Arc::as_ptr(&handler) as *const () as usize;
185        register_doc_by_ptr(ptr, Some(summary), Some(description));
186        <Route as HandlerGetter>::handler(self, Method::GET, handler)
187    }
188
189    fn post_with_doc(self, handler: Arc<dyn Handler>, summary: &str, description: &str) -> Self {
190        let ptr = Arc::as_ptr(&handler) as *const () as usize;
191        register_doc_by_ptr(ptr, Some(summary), Some(description));
192        <Route as HandlerGetter>::handler(self, Method::POST, handler)
193    }
194
195    fn put_with_doc(self, handler: Arc<dyn Handler>, summary: &str, description: &str) -> Self {
196        let ptr = Arc::as_ptr(&handler) as *const () as usize;
197        register_doc_by_ptr(ptr, Some(summary), Some(description));
198        <Route as HandlerGetter>::handler(self, Method::PUT, handler)
199    }
200
201    fn delete_with_doc(self, handler: Arc<dyn Handler>, summary: &str, description: &str) -> Self {
202        let ptr = Arc::as_ptr(&handler) as *const () as usize;
203        register_doc_by_ptr(ptr, Some(summary), Some(description));
204        <Route as HandlerGetter>::handler(self, Method::DELETE, handler)
205    }
206
207    fn patch_with_doc(self, handler: Arc<dyn Handler>, summary: &str, description: &str) -> Self {
208        let ptr = Arc::as_ptr(&handler) as *const () as usize;
209        register_doc_by_ptr(ptr, Some(summary), Some(description));
210        <Route as HandlerGetter>::handler(self, Method::PATCH, handler)
211    }
212
213    fn options_with_doc(self, handler: Arc<dyn Handler>, summary: &str, description: &str) -> Self {
214        let ptr = Arc::as_ptr(&handler) as *const () as usize;
215        register_doc_by_ptr(ptr, Some(summary), Some(description));
216        <Route as HandlerGetter>::handler(self, Method::OPTIONS, handler)
217    }
218}
219
220#[cfg(test)]
221mod tests {
222    use super::*;
223    use serde::Serialize;
224    use utoipa::ToSchema;
225
226    async fn ok_handler(_req: SilentRequest) -> SilentResult<SilentResponse> {
227        Ok(SilentResponse::text("ok"))
228    }
229
230    #[test]
231    fn test_register_and_lookup_doc() {
232        let handler = Arc::new(HandlerWrapper::new(|_req: SilentRequest| async move {
233            Ok::<_, silent::SilentError>(SilentResponse::text("doc"))
234        }));
235        let ptr = Arc::as_ptr(&handler) as *const () as usize;
236        register_doc_by_ptr(ptr, Some("summary"), Some("desc"));
237        let got = lookup_doc_by_handler_ptr(ptr).expect("doc meta");
238        assert_eq!(got.summary.as_deref(), Some("summary"));
239        assert_eq!(got.description.as_deref(), Some("desc"));
240    }
241
242    #[test]
243    fn test_register_and_lookup_response() {
244        let handler = Arc::new(HandlerWrapper::new(ok_handler));
245        let ptr = Arc::as_ptr(&handler) as *const () as usize;
246        register_response_by_ptr(ptr, ResponseMeta::TextPlain);
247        let got = lookup_response_by_handler_ptr(ptr).expect("resp meta");
248        matches!(got, ResponseMeta::TextPlain);
249    }
250
251    #[test]
252    fn test_list_registered_json_types() {
253        let h1 = Arc::new(HandlerWrapper::new(ok_handler));
254        let h2 = Arc::new(HandlerWrapper::new(ok_handler));
255        let p1 = Arc::as_ptr(&h1) as *const () as usize;
256        let p2 = Arc::as_ptr(&h2) as *const () as usize;
257        register_response_by_ptr(p1, ResponseMeta::Json { type_name: "User" });
258        register_response_by_ptr(p2, ResponseMeta::Json { type_name: "User" });
259        let list = list_registered_json_types();
260        assert!(list.contains(&"User"));
261        assert_eq!(list.len(), 1);
262    }
263
264    #[derive(Serialize, ToSchema)]
265    struct FooSchema {
266        id: i32,
267        name: String,
268    }
269
270    #[test]
271    fn test_register_schema_and_apply() {
272        register_schema_for::<FooSchema>();
273        let mut openapi = crate::OpenApiDoc::new("T", "1").into_openapi();
274        apply_registered_schemas(&mut openapi);
275        let components = openapi.components.expect("components");
276        assert!(components.schemas.contains_key("FooSchema"));
277    }
278
279    // ====== 枚举变体文档测试 ======
280
281    #[derive(Serialize, ToSchema)]
282    #[allow(dead_code)]
283    enum ApiResponse {
284        Success { data: String },
285        Error { code: i32, message: String },
286    }
287
288    #[test]
289    fn test_register_enum_schema() {
290        register_schema_for::<ApiResponse>();
291        let mut openapi = crate::OpenApiDoc::new("T", "1").into_openapi();
292        apply_registered_schemas(&mut openapi);
293        let components = openapi.components.expect("components");
294        assert!(components.schemas.contains_key("ApiResponse"));
295    }
296
297    #[derive(Serialize, ToSchema)]
298    #[allow(dead_code)]
299    enum Status {
300        Active,
301        Inactive,
302        Pending,
303    }
304
305    #[test]
306    fn test_register_unit_enum_schema() {
307        register_schema_for::<Status>();
308        let mut openapi = crate::OpenApiDoc::new("T", "1").into_openapi();
309        apply_registered_schemas(&mut openapi);
310        let components = openapi.components.expect("components");
311        assert!(components.schemas.contains_key("Status"));
312    }
313
314    #[derive(Serialize, ToSchema)]
315    struct NestedData {
316        value: i32,
317    }
318
319    #[derive(Serialize, ToSchema)]
320    #[allow(dead_code)]
321    enum ComplexEnum {
322        WithStruct(NestedData),
323        WithString(String),
324        Empty,
325    }
326
327    #[test]
328    fn test_register_enum_with_nested_schemas() {
329        register_schema_for::<ComplexEnum>();
330        let mut openapi = crate::OpenApiDoc::new("T", "1").into_openapi();
331        apply_registered_schemas(&mut openapi);
332        let components = openapi.components.expect("components");
333        assert!(components.schemas.contains_key("ComplexEnum"));
334        // 嵌套的 NestedData 也应被注册
335        assert!(components.schemas.contains_key("NestedData"));
336    }
337
338    #[test]
339    fn test_register_request_and_lookup() {
340        let handler = Arc::new(HandlerWrapper::new(ok_handler));
341        let ptr = Arc::as_ptr(&handler) as *const () as usize;
342        register_request_by_ptr(ptr, RequestMeta::JsonBody { type_name: "User" });
343        register_request_by_ptr(
344            ptr,
345            RequestMeta::QueryParams {
346                type_name: "Filter",
347            },
348        );
349        let got = lookup_request_by_handler_ptr(ptr).expect("request meta");
350        assert_eq!(got.len(), 2);
351        assert!(matches!(
352            &got[0],
353            RequestMeta::JsonBody { type_name: "User" }
354        ));
355        assert!(matches!(
356            &got[1],
357            RequestMeta::QueryParams {
358                type_name: "Filter"
359            }
360        ));
361    }
362}