Skip to main content

rustbasic_core_macro/
lib.rs

1use proc_macro::TokenStream;
2use quote::quote;
3use syn::{parse_macro_input, Item, TraitItem, ImplItem, Signature, ReturnType, FnArg};
4
5#[proc_macro_attribute]
6pub fn async_trait(_attr: TokenStream, item: TokenStream) -> TokenStream {
7    let input = parse_macro_input!(item as Item);
8    match input {
9        Item::Trait(mut trait_item) => {
10            for item in &mut trait_item.items {
11                if let TraitItem::Fn(method) = item {
12                    if method.sig.asyncness.is_some() {
13                        let original_body = method.default.clone();
14                        transform_signature(&mut method.sig);
15                        if let Some(body) = original_body {
16                            method.default = Some(syn::parse2(quote! {
17                                {
18                                    ::std::boxed::Box::pin(async move {
19                                        #body
20                                    })
21                                }
22                            }).unwrap());
23                        }
24                    }
25                }
26            }
27            TokenStream::from(quote!(#trait_item))
28        }
29        Item::Impl(mut impl_item) => {
30            for item in &mut impl_item.items {
31                if let ImplItem::Fn(method) = item {
32                    if method.sig.asyncness.is_some() {
33                        let original_body = method.block.clone();
34                        transform_signature(&mut method.sig);
35                        method.block = syn::parse2(quote! {
36                            {
37                                ::std::boxed::Box::pin(async move {
38                                    #original_body
39                                })
40                            }
41                        }).unwrap();
42                    }
43                }
44            }
45            TokenStream::from(quote!(#impl_item))
46        }
47        _ => TokenStream::from(quote!(#input)),
48    }
49}
50
51fn transform_signature(sig: &mut Signature) {
52    sig.asyncness = None;
53    let ret_type = match &sig.output {
54        ReturnType::Default => quote!(()),
55        ReturnType::Type(_, ty) => quote!(#ty),
56    };
57
58    // Find any lifetime in the generics. If none, default to '_
59    let mut lifetime_str = quote!('_);
60    for param in &sig.generics.params {
61        if let syn::GenericParam::Lifetime(lt) = param {
62            let lt_ident = &lt.lifetime;
63            lifetime_str = quote!(#lt_ident);
64            break;
65        }
66    }
67
68    // If there is self, and we have a specific lifetime like 'a, bind self to 'a
69    if lifetime_str.to_string() != "'_" {
70        for arg in &mut sig.inputs {
71            if let FnArg::Receiver(receiver) = arg {
72                if let Some((_, ref mut opt_lifetime)) = receiver.reference {
73                    if opt_lifetime.is_none() {
74                        let lt: syn::Lifetime = syn::parse2(quote!(#lifetime_str)).unwrap();
75                        *opt_lifetime = Some(lt);
76                    }
77                }
78            }
79        }
80    }
81
82    sig.output = syn::parse2(quote! {
83        -> ::std::pin::Pin<::std::boxed::Box<dyn ::std::future::Future<Output = #ret_type> + ::std::marker::Send + #lifetime_str>>
84    }).unwrap();
85}