silent_openapi_macros/
lib.rs

1use convert_case::Casing;
2use proc_macro::TokenStream;
3use quote::{format_ident, quote};
4use syn::Token;
5use syn::punctuated::Punctuated;
6use syn::{
7    Expr, ExprLit, FnArg, ItemFn, Lit, Meta, Result as SynResult, parse::Parse, parse::ParseStream,
8};
9
10fn endpoint_impl(
11    attr: proc_macro2::TokenStream,
12    item: proc_macro2::TokenStream,
13) -> proc_macro2::TokenStream {
14    struct MetaArgs(Punctuated<Meta, Token![,]>);
15    impl Parse for MetaArgs {
16        fn parse(input: ParseStream) -> SynResult<Self> {
17            Ok(MetaArgs(Punctuated::parse_terminated(input)?))
18        }
19    }
20    let MetaArgs(args) = syn::parse2::<MetaArgs>(attr).expect("parse attr");
21    let mut summary_arg: Option<String> = None;
22    let mut description_arg: Option<String> = None;
23    for meta in args {
24        if let Meta::NameValue(nv) = meta {
25            if nv.path.is_ident("summary")
26                && let Expr::Lit(ExprLit {
27                    lit: Lit::Str(s), ..
28                }) = &nv.value
29            {
30                summary_arg = Some(s.value());
31            } else if nv.path.is_ident("description")
32                && let Expr::Lit(ExprLit {
33                    lit: Lit::Str(s), ..
34                }) = &nv.value
35            {
36                description_arg = Some(s.value());
37            }
38        }
39    }
40
41    let input: ItemFn = syn::parse2(item).expect("parse item fn");
42    let vis = &input.vis;
43    let sig = input.sig.clone();
44    let attrs = &input.attrs;
45    let block = &input.block;
46    let name = &sig.ident;
47
48    // 收集文档注释作为默认 summary/description
49    let mut doc_lines: Vec<String> = Vec::new();
50    for a in attrs.iter() {
51        if a.path().is_ident("doc") {
52            let _ = a.parse_nested_meta(|meta| {
53                let lit: syn::LitStr = meta.value()?.parse()?;
54                let v = lit.value();
55                doc_lines.push(v.trim().to_string());
56                Ok(())
57            });
58        }
59    }
60    let (def_summary, def_description) = if !doc_lines.is_empty() {
61        let mut it = doc_lines.into_iter().filter(|s| !s.is_empty());
62        if let Some(first) = it.next() {
63            let rest = it.collect::<Vec<_>>().join("\n");
64            (Some(first), if rest.is_empty() { None } else { Some(rest) })
65        } else {
66            (None, None)
67        }
68    } else {
69        (None, None)
70    };
71
72    let summary = summary_arg.or(def_summary);
73    let description = description_arg.or(def_description);
74
75    // 真实处理函数改名
76    let impl_name = format_ident!("{}_impl", name);
77    // 生成实现函数签名(重命名)
78    let mut impl_sig = sig.clone();
79    impl_sig.ident = impl_name.clone();
80
81    // 端点类型 + 常量(实现与原 `.get(get_xxx)` 风格兼容)
82    let ep_ty = format_ident!(
83        "{}Endpoint",
84        name.to_string().to_case(convert_case::Case::UpperCamel)
85    );
86    let sum_tokens = if let Some(s) = &summary {
87        let lit = syn::LitStr::new(s, proc_macro2::Span::call_site());
88        quote!(Some(#lit))
89    } else {
90        quote!(None)
91    };
92    let desc_tokens = if let Some(s) = &description {
93        let lit = syn::LitStr::new(s, proc_macro2::Span::call_site());
94        quote!(Some(#lit))
95    } else {
96        quote!(None)
97    };
98
99    // 解析返回类型 Ok(T) -> ResponseMeta
100    let ret_meta = {
101        match &sig.output {
102            syn::ReturnType::Type(_, ty) => {
103                if let syn::Type::Path(tp) = ty.as_ref() {
104                    if let Some(seg) = tp.path.segments.last() {
105                        if seg.ident == "Result" || seg.ident == "SilentResult" {
106                            if let syn::PathArguments::AngleBracketed(args) = &seg.arguments {
107                                if let Some(syn::GenericArgument::Type(ok_ty)) = args.args.first() {
108                                    match ok_ty {
109                                        syn::Type::Path(tpath) => {
110                                            if let Some(id) = tpath.path.segments.last() {
111                                                if id.ident == "Response" {
112                                                    quote!(None)
113                                                } else if id.ident == "String" {
114                                                    quote!(Some(::silent_openapi::doc::ResponseMeta::TextPlain))
115                                                } else {
116                                                    let tn = id.ident.to_string();
117                                                    quote!(Some(::silent_openapi::doc::ResponseMeta::Json { type_name: #tn }))
118                                                }
119                                            } else {
120                                                quote!(None)
121                                            }
122                                        }
123                                        syn::Type::Reference(r) => {
124                                            if let syn::Type::Path(tp2) = r.elem.as_ref() {
125                                                if let Some(id) = tp2.path.segments.last() {
126                                                    if id.ident == "str" {
127                                                        quote!(Some(::silent_openapi::doc::ResponseMeta::TextPlain))
128                                                    } else {
129                                                        let tn = id.ident.to_string();
130                                                        quote!(Some(::silent_openapi::doc::ResponseMeta::Json { type_name: #tn }))
131                                                    }
132                                                } else {
133                                                    quote!(None)
134                                                }
135                                            } else {
136                                                quote!(None)
137                                            }
138                                        }
139                                        _ => quote!(None),
140                                    }
141                                } else {
142                                    quote!(None)
143                                }
144                            } else {
145                                quote!(None)
146                            }
147                        } else {
148                            quote!(None)
149                        }
150                    } else {
151                        quote!(None)
152                    }
153                } else {
154                    quote!(None)
155                }
156            }
157            _ => quote!(None),
158        }
159    };
160
161    // 为自定义 Ok(T) 注册 ToSchema 完整 schema
162    let ret_schema_register = {
163        match &sig.output {
164            syn::ReturnType::Type(_, ty) => {
165                if let syn::Type::Path(tp) = ty.as_ref() {
166                    if let Some(seg) = tp.path.segments.last() {
167                        if seg.ident == "Result" || seg.ident == "SilentResult" {
168                            if let syn::PathArguments::AngleBracketed(args) = &seg.arguments {
169                                if let Some(syn::GenericArgument::Type(ok_ty)) = args.args.first() {
170                                    match ok_ty {
171                                        syn::Type::Path(tpath) => {
172                                            if let Some(id) = tpath.path.segments.last() {
173                                                if id.ident == "Response" || id.ident == "String" {
174                                                    quote!()
175                                                } else {
176                                                    let ty = ok_ty.clone();
177                                                    quote!(::silent_openapi::doc::register_schema_for::<#ty>();)
178                                                }
179                                            } else {
180                                                quote!()
181                                            }
182                                        }
183                                        syn::Type::Reference(r) => {
184                                            if let syn::Type::Path(tp2) = r.elem.as_ref() {
185                                                if let Some(id) = tp2.path.segments.last() {
186                                                    if id.ident == "str" {
187                                                        quote!()
188                                                    } else {
189                                                        let inner = tp2.clone();
190                                                        quote!(::silent_openapi::doc::register_schema_for::<#inner>();)
191                                                    }
192                                                } else {
193                                                    quote!()
194                                                }
195                                            } else {
196                                                quote!()
197                                            }
198                                        }
199                                        _ => quote!(),
200                                    }
201                                } else {
202                                    quote!()
203                                }
204                            } else {
205                                quote!()
206                            }
207                        } else {
208                            quote!()
209                        }
210                    } else {
211                        quote!()
212                    }
213                } else {
214                    quote!()
215                }
216            }
217            _ => quote!(),
218        }
219    };
220
221    // 根据函数参数形态生成 IntoRouteHandler 实现
222    let inputs = sig.inputs.clone().into_iter().collect::<Vec<_>>();
223    let impls = if inputs.len() == 1 {
224        match &inputs[0] {
225            FnArg::Typed(pat_ty) => {
226                let ty = &pat_ty.ty;
227                // 简单规则:类型标识名为 Request 则认为是 Request 形态
228                let is_request = matches!(
229                    &**ty,
230                    syn::Type::Path(tp) if tp.path.segments.last().map(|s| s.ident == "Request").unwrap_or(false)
231                );
232                if is_request {
233                    quote! {
234                        impl ::silent::prelude::IntoRouteHandler<::silent::Request> for #ep_ty {
235                            fn into_handler(self) -> std::sync::Arc<dyn ::silent::Handler> {
236                                let handler = std::sync::Arc::new(::silent::HandlerWrapper::new(#impl_name));
237                                let ptr = std::sync::Arc::as_ptr(&handler) as *const () as usize;
238                                ::silent_openapi::doc::register_doc_by_ptr(
239                                    ptr,
240                                    #sum_tokens,
241                                    #desc_tokens,
242                                );
243                                #ret_schema_register
244                                if let Some(meta) = #ret_meta { ::silent_openapi::doc::register_response_by_ptr(ptr, meta); }
245                                handler
246                            }
247                        }
248                    }
249                } else {
250                    // 单萃取器参数
251                    quote! {
252                        impl ::silent::prelude::IntoRouteHandler<#ty> for #ep_ty {
253                            fn into_handler(self) -> std::sync::Arc<dyn ::silent::Handler> {
254                                let adapted = ::silent::extractor::handler_from_extractor::<#ty, _, _, _>(#impl_name);
255                                let handler = std::sync::Arc::new(::silent::HandlerWrapper::new(adapted));
256                                let ptr = std::sync::Arc::as_ptr(&handler) as *const () as usize;
257                                ::silent_openapi::doc::register_doc_by_ptr(
258                                    ptr,
259                                    #sum_tokens,
260                                    #desc_tokens,
261                                );
262                                #ret_schema_register
263                                if let Some(meta) = #ret_meta { ::silent_openapi::doc::register_response_by_ptr(ptr, meta); }
264                                handler
265                            }
266                        }
267                    }
268                }
269            }
270            _ => quote! {},
271        }
272    } else if inputs.len() == 2 {
273        match (&inputs[0], &inputs[1]) {
274            (FnArg::Typed(first), FnArg::Typed(second)) => {
275                let ty1 = &first.ty;
276                let ty2 = &second.ty;
277                // 期望形态: (Request, Args)
278                let is_request_first = matches!(
279                    &**ty1,
280                    syn::Type::Path(tp) if tp.path.segments.last().map(|s| s.ident == "Request").unwrap_or(false)
281                );
282                if is_request_first {
283                    quote! {
284                        impl ::silent::prelude::IntoRouteHandler<(::silent::Request, #ty2)> for #ep_ty {
285                            fn into_handler(self) -> std::sync::Arc<dyn ::silent::Handler> {
286                                let adapted = ::silent::extractor::handler_from_extractor_with_request::<#ty2, _, _, _>(#impl_name);
287                                let handler = std::sync::Arc::new(::silent::HandlerWrapper::new(adapted));
288                                let ptr = std::sync::Arc::as_ptr(&handler) as *const () as usize;
289                                ::silent_openapi::doc::register_doc_by_ptr(
290                                    ptr,
291                                    #sum_tokens,
292                                    #desc_tokens,
293                                );
294                                #ret_schema_register
295                                if let Some(meta) = #ret_meta { ::silent_openapi::doc::register_response_by_ptr(ptr, meta); }
296                                handler
297                            }
298                        }
299                    }
300                } else {
301                    quote! {}
302                }
303            }
304            _ => quote! {},
305        }
306    } else {
307        quote! {}
308    };
309
310    let code = quote! {
311        // 原函数体改名为实现函数
312        #(#attrs)*
313        #impl_sig #block
314
315        // 端点类型(零尺寸) + 常量,同名以保留 `.get(get_xxx)` 调用方式
316        pub struct #ep_ty;
317        #[allow(non_upper_case_globals)]
318        #vis const #name: #ep_ty = #ep_ty;
319
320        #impls
321    };
322
323    code
324}
325
326#[proc_macro_attribute]
327pub fn endpoint(attr: TokenStream, item: TokenStream) -> TokenStream {
328    endpoint_impl(attr.into(), item.into()).into()
329}
330
331#[cfg(test)]
332mod tests {
333    use quote::quote;
334
335    fn render(ts: proc_macro2::TokenStream) -> String {
336        ts.to_string()
337    }
338
339    #[test]
340    fn generates_endpoint_type_and_const_for_request_sig() {
341        let attr = quote!(summary = "hello", description = "world");
342        let item = quote!(
343            async fn get_hello(_req: ::silent::Request) -> ::silent::Result<::silent::Response> {
344                unimplemented!()
345            }
346        );
347        let out = super::endpoint_impl(attr, item);
348        let s = render(out);
349        assert!(s.contains("struct GetHelloEndpoint"));
350        assert!(s.contains("const get_hello"));
351    }
352
353    #[test]
354    fn generates_into_route_handler_for_extractor_sig() {
355        let attr = quote!();
356        let item = quote!(
357            async fn get_user(_id: Path<u64>) -> ::silent::Result<::silent::Response> {
358                unimplemented!()
359            }
360        );
361        let out = super::endpoint_impl(attr, item);
362        let s = render(out);
363        // 生成的端点常量与 IntoRouteHandler 实现
364        assert!(s.contains("struct GetUserEndpoint"));
365        assert!(s.contains("const get_user"));
366        assert!(s.contains("IntoRouteHandler"));
367        assert!(s.contains("GetUserEndpoint"));
368    }
369
370    #[test]
371    fn registers_response_meta_for_string() {
372        let attr = quote!();
373        let item = quote!(
374            async fn ping(_req: ::silent::Request) -> ::silent::Result<String> {
375                unimplemented!()
376            }
377        );
378        let out = super::endpoint_impl(attr, item);
379        let s = render(out);
380        // 生成文本响应的注册调用
381        assert!(s.contains("ResponseMeta :: TextPlain"));
382    }
383}