tokio_wrap/
lib.rs

1use proc_macro::TokenStream;
2use quote::quote;
3use syn::{parse::Parse, parse::ParseStream, parse_macro_input, punctuated::Punctuated, Expr, ItemFn, PatType, Result, ReturnType, Signature, Stmt, Token, Type};
4
5struct ClosureArg<P: Parse> {
6    pat: P,
7    colon_token: Option<Token![:]>,
8    ty: Option<Type>,
9}
10
11struct ClosureInput<P: Parse> {
12    args: Punctuated<ClosureArg<P>, Token![,]>,
13    body: Expr,
14}
15
16struct BlockInput {
17    stmts: Vec<Stmt>,
18    expr: Option<Expr>,
19}
20
21impl<P: Parse> Parse for ClosureArg<P> {
22    fn parse(input: ParseStream) -> Result<Self> {
23        let pat: P = input.parse()?;
24        let (colon_token, ty) = if input.peek(Token![:]) {
25            let colon_token = input.parse()?;
26            let ty = input.parse()?;
27            (Some(colon_token), Some(ty))
28        } else {
29            (None, None)
30        };
31        Ok(ClosureArg { pat, colon_token, ty })
32    }
33}
34
35impl<P: Parse> Parse for ClosureInput<P> {
36    fn parse(input: ParseStream) -> Result<Self> {
37        let args = if input.peek(Token![|]) {
38            let _: Token![|] = input.parse()?;
39            let mut args = Punctuated::new();
40            while !input.peek(Token![|]) {
41                let arg: ClosureArg<P> = input.parse()?;
42                args.push_value(arg);
43                if input.peek(Token![|]) {
44                    break;
45                }
46                let punct: Token![,] = input.parse()?;
47                args.push_punct(punct);
48            }
49            let _: Token![|] = input.parse()?;
50            args
51        } else if input.peek(syn::token::Paren) {
52            let content;
53            syn::parenthesized!(content in input);
54            content.parse_terminated(ClosureArg::parse, Token![,])?
55        } else {
56            return Err(input.error("expected closure arguments"));
57        };
58
59        input.parse::<Token![=>]>()?;
60        let body = input.parse()?;
61
62        Ok(ClosureInput { args, body })
63    }
64}
65
66impl Parse for BlockInput {
67    fn parse(input: ParseStream) -> Result<Self> {
68        let mut stmts = Vec::new();
69        let mut expr = None;
70
71        while !input.is_empty() {
72            if input.fork().parse::<Stmt>().is_ok() {
73                stmts.push(input.parse()?);
74            } else {
75                expr = Some(input.parse()?);
76                break;
77            }
78        }
79
80        Ok(BlockInput { stmts, expr })
81    }
82}
83
84#[proc_macro]
85pub fn closure(input: TokenStream) -> TokenStream {
86    let input: ClosureInput<PatType> = parse_macro_input!(input);
87    let ClosureInput { args, body } = input;
88
89    let args = args.iter().map(|arg| {
90        let ClosureArg { pat, colon_token, ty } = arg;
91        quote! { #pat #colon_token #ty }
92    });
93
94    let gen = quote! {{
95        |#(#args),*| {
96            let fut = async move { #body };
97            let rt = tokio::runtime::Runtime::new().unwrap();
98            rt.block_on(fut)
99        }
100    }};
101
102    gen.into()
103}
104
105#[proc_macro]
106pub fn block(input: TokenStream) -> TokenStream {
107    let input: BlockInput = parse_macro_input!(input);
108    let BlockInput { stmts, expr } = input;
109
110    let gen = if let Some(final_expr) = expr {
111        quote! {{
112            let rt = tokio::runtime::Runtime::new().unwrap();
113            rt.block_on(async move {
114                #(#stmts)*
115                #final_expr
116            })
117        }}
118    } else {
119        quote! {{
120            let rt = tokio::runtime::Runtime::new().unwrap();
121            rt.block_on(async move {
122                #(#stmts)*
123            })
124        }}
125    };
126
127    gen.into()
128}
129
130#[proc_macro_attribute]
131pub fn sync(_attr: TokenStream, item: TokenStream) -> TokenStream {
132    let ItemFn { attrs, vis, sig, block } = parse_macro_input!(item);
133    let Signature { ident, generics, inputs, output, .. } = sig;
134
135    let return_type = match output {
136        ReturnType::Default => quote! { () },
137        ReturnType::Type(_, ty) => quote! { #ty },
138    };
139
140    let gen = quote! {
141        #(#attrs)*
142        #vis fn #ident #generics(#inputs) -> #return_type {
143            let rt = tokio::runtime::Runtime::new().unwrap();
144            rt.block_on(async {
145                #block
146            })
147        }
148    };
149
150    gen.into()
151}
152
153#[cfg(test)]
154mod tests {
155    use trybuild::TestCases;
156    #[test]
157    fn test_tokio_sync_wrapper() {
158        let t = TestCases::new();
159        t.pass("tests/*.rs");
160    }
161}