simple_invoke_client_macro/
lib.rs

1use std::{collections::HashMap, str::FromStr};
2
3use proc_macro::TokenStream;
4use quote::quote;
5use syn::{parse_macro_input, FnArg, GenericArgument, Ident, ItemImpl, LitStr, Pat, ReturnType, Token, Type, TypeTuple};
6
7struct MatchGenericT {
8    pub ident: syn::Ident,
9    pub tp: syn::Type,
10    pub is_string: bool,
11    pub is_optional: bool,
12}
13fn match_generic_t(gtype: &str, arg: &FnArg) -> Option<MatchGenericT> {
14    let FnArg::Typed(arg) = arg.clone() else {
15        return None;
16    };
17    let Pat::Ident(ident) = *arg.pat else {
18        return None;
19    };
20    let ident = ident.ident.clone();
21    // get path last segment
22    let p_last_seg = if let syn::Type::Path(p) = *arg.ty {
23        p.path.segments.last().cloned()?
24    } else {
25        return None;
26    };
27    // is last segment Query ?
28    if p_last_seg.ident != gtype {
29        return None;
30    }
31    // get Query<T> or Query<Option<T>> T
32    let syn::PathArguments::AngleBracketed(a) = p_last_seg.arguments else {
33        return None;
34    };
35    let Some(GenericArgument::Type(t)) = a.args.into_iter().next() else {
36        return None;
37    };
38    let syn::Type::Path(p) = &t else {
39        return None;
40    };
41
42    let p_last_seg = p.path.segments.last()?;
43    let is_string;
44    let is_optional;
45    let mut tp = t.clone();
46    // is last segment Option ?
47    if p_last_seg.ident == "Option" {
48        is_optional = true;
49        let syn::PathArguments::AngleBracketed(a) = &p_last_seg.arguments else {
50            return None;
51        };
52        let Some(GenericArgument::Type(t)) = a.args.iter().next() else {
53            return None;
54        };
55        let syn::Type::Path(p) = t else {
56            return None;
57        };
58        let p_last_seg = p.path.segments.last()?;
59        is_string = p_last_seg.ident == "String";
60        tp = t.clone();
61    } else if p_last_seg.ident == "String" {
62        is_string = true;
63        is_optional = false;
64    } else {
65        is_string = false;
66        is_optional = false;
67    }
68    Some(MatchGenericT {
69        ident,
70        tp,
71        is_string,
72        is_optional,
73    })
74}
75
76fn match_result_t(gtype: &str, ty: &syn::Type) -> Option<syn::Type> {
77    // get path last segment
78    let p_last_seg = if let syn::Type::Path(p) = ty {
79        p.path.segments.last().cloned()?
80    } else {
81        return None;
82    };
83    // is last segment Query ?
84    if p_last_seg.ident != gtype {
85        return None;
86    }
87    // get Query<T> or Query<Option<T>> T
88    let syn::PathArguments::AngleBracketed(ref a) = p_last_seg.arguments else {
89        return None;
90    };
91    let Some(GenericArgument::Type(t)) = a.args.iter().next() else {
92        return None;
93    };
94    let tp = t.clone();
95    Some(tp)
96}
97enum PathItem {
98    Literal(String),
99    Variant { ident: syn::Ident, tp: syn::Type, is_string: bool },
100}
101enum Method {
102    Get,
103    Post,
104    Put,
105    Delete,
106}
107
108impl FromStr for Method {
109    type Err = syn::Error;
110
111    fn from_str(s: &str) -> Result<Self, Self::Err> {
112        match s.to_lowercase().as_str() {
113            "get" => Ok(Self::Get),
114            "post" => Ok(Self::Post),
115            "put" => Ok(Self::Put),
116            "delete" => Ok(Self::Delete),
117            _ => Err(syn::Error::new_spanned(s, "expect `get`, `post`, `put` or `delete`")),
118        }
119    }
120}
121
122struct ApiInfo {
123    pub name: Ident,
124    pub path: Vec<PathItem>,
125    pub query: Vec<(String, MatchGenericT)>,
126    pub body: syn::Type,
127    pub resp: syn::Type,
128    pub method: Method,
129}
130
131struct ApiInfoBuilder {
132    pub name: Ident,
133    pub path: Vec<PathItem>,
134    pub body: Option<syn::Type>,
135    pub resp: syn::Type,
136    pub query: Vec<(String, MatchGenericT)>,
137    pub method: Option<Method>,
138}
139
140impl ApiInfoBuilder {
141    pub fn new(name: Ident) -> Self {
142        Self {
143            name,
144            path: Vec::new(),
145            body: None,
146            resp: syn::Type::Tuple(TypeTuple {
147                paren_token: Default::default(),
148                elems: Default::default(),
149            }),
150            method: None,
151            query: Vec::new(),
152        }
153    }
154    pub fn build(self) -> Result<ApiInfo, syn::Error> {
155        let body = self.body.unwrap_or(syn::Type::Tuple(TypeTuple {
156            paren_token: Default::default(),
157            elems: Default::default(),
158        }));
159        let method = self.method.ok_or_else(|| syn::Error::new_spanned(&self.name, "missing method"))?;
160
161        Ok(ApiInfo {
162            name: self.name,
163            path: self.path,
164            body,
165            resp: self.resp,
166            method,
167            query: self.query,
168        })
169    }
170}
171
172/// # Usage
173/// This Attribute Macro is used to generate corresponding client methods for you api.
174/// Simplely add it **upon** `OpenApi` attribute.
175///
176/// The `Client` is your custom client struct witch implemented `SimpleInvokeClient` trait.
177/// ```no_run, ignore
178/// #[simple_invoke_client(Client)]
179/// #[poem_openapi::OpenApi(prefix_path = "/ct/msg")]
180/// impl Api {
181///     #[oai(method = "get", path = "/page")]
182///     pub async fn get_page(
183///         &self,
184///         page_number: Path<u32>,
185///         page_size: Query<Option<u32>>,
186///         TardisContextExtractor(ctx): TardisContextExtractor,
187///     ) -> TardisApiResult<TardisPage<String>> {
188///         // do something
189///         TardisResp::ok(TardisPage {
190///             page_number: 1,
191///             page_size: 10,
192///             total_size: 1,
193///             records: vec!["hello".to_string()],
194///         })
195///     }
196/// }
197/// ```
198#[proc_macro_attribute]
199pub fn simple_invoke_client(attr: TokenStream, item: TokenStream) -> TokenStream {
200    let input = parse_macro_input!(item as ItemImpl);
201    let mut metadata = parse_macro_input!(attr as Metadata);
202    // extract openapi metadata
203    input.attrs.iter().for_each(|attr| {
204        if attr.path().segments.iter().last().is_some_and(|last| last.ident == "OpenApi") {
205            let _ = attr.parse_nested_meta(|meta| {
206                if metadata.prefix_path.is_none() && meta.path.is_ident("prefix_path") {
207                    let path = meta.value()?.parse::<LitStr>()?;
208                    metadata.prefix_path.replace(path);
209                }
210                Ok(())
211            });
212        }
213    });
214    let method_info_list = input
215        .items
216        .iter()
217        .filter_map(|item| {
218            if let syn::ImplItem::Fn(func) = item {
219                let name = &func.sig.ident;
220
221                let mut builder = ApiInfoBuilder::new(name.clone());
222                let mut path_map = HashMap::new();
223                // 1. find out body: arg with type: Json<T>,
224                // 2. find out resp: ReturnType wrapped in TardisApiResult<T>,
225                // 3. find out path args: arg with type: Path<T>,
226                // 4. find out query args: arg with type: Query<T>,
227                for arg in &func.sig.inputs {
228                    if let Some(q) = match_generic_t("Query", arg) {
229                        builder.query.push((q.ident.to_string(), q));
230                    }
231                    if let Some(p) = match_generic_t("Path", arg) {
232                        path_map.insert(p.ident.to_string(), p);
233                    }
234                    if let Some(j) = match_generic_t("Json", arg) {
235                        builder.body = Some(j.tp);
236                    }
237                }
238                let _oai_metadata = func.attrs.iter().find(|attr| attr.path().is_ident("oai")).map(|attr| {
239                    attr.parse_nested_meta(|nested| {
240                        if nested.path.is_ident("method") {
241                            let method = nested.value()?;
242                            let method: LitStr = method.parse()?;
243                            let method = Method::from_str(&method.value())?;
244                            builder.method = Some(method);
245                        }
246                        if nested.path.is_ident("path") {
247                            let path = nested.value()?.parse::<LitStr>()?.value();
248                            builder.path = path
249                                .split('/')
250                                .filter(|x| !x.is_empty())
251                                .map(|x| {
252                                    if let Some(ident) = x.strip_prefix(':') {
253                                        path_map
254                                            .remove(ident)
255                                            .map(|arg| PathItem::Variant {
256                                                ident: arg.ident,
257                                                tp: arg.tp,
258                                                is_string: arg.is_string,
259                                            })
260                                            .unwrap_or(PathItem::Literal(x.to_string()))
261                                    } else {
262                                        PathItem::Literal(x.to_string())
263                                    }
264                                })
265                                .collect::<Vec<_>>();
266                        }
267                        Ok(())
268                    })
269                });
270                builder.resp = match &func.sig.output {
271                    ReturnType::Type(_, tp) => {
272                        match_result_t("TardisApiResult", tp).ok_or_else(|| syn::Error::new_spanned(&func.sig.output, "expect `TardisApiResult<T>`")).unwrap()
273                    }
274                    _ => syn::Type::Tuple(TypeTuple {
275                        paren_token: Default::default(),
276                        elems: Default::default(),
277                    }),
278                };
279                Some(builder.build().unwrap())
280            } else {
281                None
282            }
283        })
284        .collect::<Vec<_>>();
285
286    let client = metadata.client;
287    let impl_apis = generate_impl_tardis_api_client(&method_info_list, client, metadata.prefix_path);
288
289    let output = quote! {
290        #input
291        #impl_apis
292    };
293
294    output.into()
295}
296
297fn generate_impl_tardis_api_client(apis: &[ApiInfo], client: Type, prefix: Option<LitStr>) -> proc_macro2::TokenStream {
298    let mut impl_items = Vec::new();
299
300    for api_info in apis {
301        let name = &api_info.name;
302        let path = generate_path_tokens(&api_info.path);
303        let query = generate_query_tokens(&api_info.query);
304        let body = generate_type_tokens(&api_info.body);
305        let resp = generate_type_tokens(&api_info.resp);
306        let method = generate_method_token(&api_info.method);
307        let body_resp = match &api_info.method {
308            Method::Get | Method::Delete => quote!( #resp ),
309            Method::Post | Method::Put => quote!( #body => #resp ),
310        };
311        let path = match &prefix {
312            Some(prefix) => quote! { #prefix, #path },
313            None => quote! { #path },
314        };
315        let item = quote! {
316            { #name, #method [#path] {#query} #body_resp }
317        };
318
319        impl_items.push(item);
320    }
321
322    quote! {
323        bios_sdk_invoke::impl_tardis_api_client! {
324            #client:
325            #(#impl_items)*
326        }
327    }
328}
329
330fn generate_path_tokens(path: &[PathItem]) -> proc_macro2::TokenStream {
331    let tokens = path.iter().map(|item| match item {
332        PathItem::Literal(s) => quote! { #s },
333        PathItem::Variant { ident, tp, is_string } => {
334            if *is_string {
335                quote! { #ident }
336            } else {
337                quote! { #ident: #tp }
338            }
339        }
340    });
341
342    quote! { #(#tokens),* }
343}
344
345fn generate_query_tokens(query: &[(String, MatchGenericT)]) -> proc_macro2::TokenStream {
346    let tokens = query.iter().map(|(_name, ty)| {
347        let ty_ts = generate_type_tokens(&ty.tp);
348        let ident = &ty.ident;
349        match (ty.is_optional, ty.is_string) {
350            (true, true) => quote! { #ident? },
351            (true, false) => quote! { #ident?: #ty_ts },
352            (false, true) => quote! { #ident },
353            (false, false) => quote! { #ident: #ty_ts },
354        }
355    });
356
357    quote! { #(#tokens),* }
358}
359
360fn generate_type_tokens(ty: &Type) -> proc_macro2::TokenStream {
361    quote! { #ty }
362}
363
364fn generate_method_token(method: &Method) -> proc_macro2::TokenStream {
365    match method {
366        Method::Get => quote! { get },
367        Method::Post => quote! { post },
368        Method::Put => quote! { put },
369        Method::Delete => quote! { delete },
370    }
371}
372
373struct Metadata {
374    client: syn::Type,
375    prefix_path: Option<syn::LitStr>,
376}
377
378impl syn::parse::Parse for Metadata {
379    fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
380        let client = input.parse::<Type>()?;
381        let mut meta_data = Self { client, prefix_path: None };
382        if let Ok(_comma) = input.parse::<Token![,]>() {
383            let prefix_path = Some(input.parse::<LitStr>()?);
384            meta_data.prefix_path = prefix_path;
385        }
386        Ok(meta_data)
387    }
388}