test_with_tokio_macros/
lib.rs

1use proc_macro::TokenStream;
2use quote::{quote, quote_spanned};
3use syn::spanned::Spanned;
4use syn::visit::Visit;
5use syn::Stmt;
6
7fn token_stream_with_error(mut tokens: TokenStream, error: syn::Error) -> TokenStream {
8    tokens.extend(TokenStream::from(error.into_compile_error()));
9    tokens
10}
11
12#[derive(Debug, Default)]
13struct AsyncSearcher {
14    found_async: bool,
15}
16
17impl<'ast> Visit<'ast> for AsyncSearcher {
18    fn visit_expr_async(&mut self, _i: &'ast syn::ExprAsync) {
19        self.found_async = true;
20    }
21    fn visit_expr_await(&mut self, _i: &'ast syn::ExprAwait) {
22        self.found_async = true;
23    }
24}
25
26fn has_async(stmt: &&Stmt) -> bool {
27    let mut s = AsyncSearcher::default();
28    s.visit_stmt(stmt);
29    s.found_async
30}
31
32#[proc_macro_attribute]
33pub fn please(_args: TokenStream, item: TokenStream) -> TokenStream {
34    // If any of the steps for this macro fail, we still want to expand to an item that is as close
35    // to the expected output as possible. This helps out IDEs such that completions and other
36    // related features keep working.
37    let mut input: syn::ItemFn = match syn::parse(item.clone()) {
38        Ok(it) => it,
39        Err(e) => return token_stream_with_error(item, e),
40    };
41    input.sig.asyncness = None;
42    let mut cases: Vec<(syn::Expr, syn::Expr, String)> = Vec::new();
43    for stmt in input.block.stmts.iter() {
44        if let Stmt::Local(local) = stmt {
45            if let Some((_, e)) = &local.init {
46                if let syn::Expr::Match(m) = e.as_ref() {
47                    if let syn::Expr::Path(p) = m.expr.as_ref() {
48                        if let Some(i) = p.path.get_ident() {
49                            if format!("{i}") == "CASE" {
50                                for arm in m.arms.iter() {
51                                    if let syn::Pat::Lit(p) = &arm.pat {
52                                        if let syn::Expr::Lit(e) = p.expr.as_ref() {
53                                            if let syn::Lit::Str(s) = &e.lit {
54                                                if s.value()
55                                                    .chars()
56                                                    .any(|c| !c.is_alphanumeric() && c != '_')
57                                                {
58                                                    return quote_spanned! {
59                                                        s.span() =>
60                                                        compile_error!("not a valid identifier");
61                                                    }
62                                                    .into();
63                                                }
64                                                cases.push((
65                                                    (*p.expr).clone(),
66                                                    (*arm.body).clone(),
67                                                    s.value(),
68                                                ));
69                                            } else {
70                                                return quote_spanned! {
71                                                    e.span() =>
72                                                    compile_error!("expected string literal");
73                                                }
74                                                .into();
75                                            }
76                                        } else {
77                                            return quote_spanned! {
78                                                p.expr.span() =>
79                                                compile_error!("expected string literal");
80                                            }
81                                            .into();
82                                        }
83                                    } else {
84                                        return quote_spanned! {
85                                            arm.pat.span() =>
86                                            compile_error!("expected string literal");
87                                        }
88                                        .into();
89                                    }
90                                }
91                                break;
92                            }
93                        }
94                    }
95                }
96            }
97        }
98    }
99    let first_async = input
100        .block
101        .stmts
102        .iter()
103        .enumerate()
104        .find(|(_, s)| has_async(s))
105        .map(|(i, _)| i)
106        .unwrap_or(input.block.stmts.len());
107    let async_statements = input.block.stmts.split_off(first_async);
108    let last_statement: Stmt = syn::parse2(quote! {
109        ::tokio::runtime::Builder::new_current_thread()
110            .enable_all()
111            .build()
112            .unwrap()
113            .block_on(async {
114                #(#async_statements)*
115            });
116    })
117    .expect("Constructing tokio call");
118    let last_statement = if let Stmt::Semi(e, _) = last_statement {
119        Stmt::Expr(e)
120    } else {
121        last_statement
122    };
123    input.block.stmts.push(last_statement);
124    if cases.is_empty() {
125        let result = quote! {
126            #[::core::prelude::v1::test]
127            #input
128        };
129        result.into()
130    } else {
131        let mut functions = Vec::new();
132        for (e, b, n) in cases.into_iter() {
133            let mut f = input.clone();
134            f.sig.ident = syn::Ident::new(&format!("{}_{n}", f.sig.ident), f.sig.ident.span());
135            for stmt in f.block.stmts.iter_mut() {
136                if let Stmt::Local(local) = stmt {
137                    if let Some((_, e)) = &mut local.init {
138                        let is_case_match = if let syn::Expr::Match(m) = e.as_mut() {
139                            if let syn::Expr::Path(p) = m.expr.as_ref() {
140                                if let Some(i) = p.path.get_ident() {
141                                    format!("{i}") == "CASE"
142                                } else {
143                                    false
144                                }
145                            } else {
146                                false
147                            }
148                        } else {
149                            false
150                        };
151                        if is_case_match {
152                            *e = Box::new(b);
153                            break;
154                        }
155                    }
156                }
157            }
158            f.block.stmts.insert(
159                0,
160                syn::parse2(quote! {
161                    const CASE: &str = #e;
162                })
163                .unwrap(),
164            );
165            functions.push(quote! {
166               #[::core::prelude::v1::test]
167               #f
168            });
169        }
170        let result = quote! {
171            #( #functions )*
172        };
173        result.into()
174    }
175}