test_with_tokio_macros/
lib.rs1use proc_macro::TokenStream;
2use quote::{quote, quote_spanned};
3use syn::spanned::Spanned;
4use syn::visit::Visit;
5use syn::Stmt;
6
7fn token_stream_with_error(mut tokens: TokenStream, error: syn::Error) -> TokenStream {
8 tokens.extend(TokenStream::from(error.into_compile_error()));
9 tokens
10}
11
12#[derive(Debug, Default)]
13struct AsyncSearcher {
14 found_async: bool,
15}
16
17impl<'ast> Visit<'ast> for AsyncSearcher {
18 fn visit_expr_async(&mut self, _i: &'ast syn::ExprAsync) {
19 self.found_async = true;
20 }
21 fn visit_expr_await(&mut self, _i: &'ast syn::ExprAwait) {
22 self.found_async = true;
23 }
24}
25
26fn has_async(stmt: &&Stmt) -> bool {
27 let mut s = AsyncSearcher::default();
28 s.visit_stmt(stmt);
29 s.found_async
30}
31
32#[proc_macro_attribute]
33pub fn please(_args: TokenStream, item: TokenStream) -> TokenStream {
34 let mut input: syn::ItemFn = match syn::parse(item.clone()) {
38 Ok(it) => it,
39 Err(e) => return token_stream_with_error(item, e),
40 };
41 input.sig.asyncness = None;
42 let mut cases: Vec<(syn::Expr, syn::Expr, String)> = Vec::new();
43 for stmt in input.block.stmts.iter() {
44 if let Stmt::Local(local) = stmt {
45 if let Some((_, e)) = &local.init {
46 if let syn::Expr::Match(m) = e.as_ref() {
47 if let syn::Expr::Path(p) = m.expr.as_ref() {
48 if let Some(i) = p.path.get_ident() {
49 if format!("{i}") == "CASE" {
50 for arm in m.arms.iter() {
51 if let syn::Pat::Lit(p) = &arm.pat {
52 if let syn::Expr::Lit(e) = p.expr.as_ref() {
53 if let syn::Lit::Str(s) = &e.lit {
54 if s.value()
55 .chars()
56 .any(|c| !c.is_alphanumeric() && c != '_')
57 {
58 return quote_spanned! {
59 s.span() =>
60 compile_error!("not a valid identifier");
61 }
62 .into();
63 }
64 cases.push((
65 (*p.expr).clone(),
66 (*arm.body).clone(),
67 s.value(),
68 ));
69 } else {
70 return quote_spanned! {
71 e.span() =>
72 compile_error!("expected string literal");
73 }
74 .into();
75 }
76 } else {
77 return quote_spanned! {
78 p.expr.span() =>
79 compile_error!("expected string literal");
80 }
81 .into();
82 }
83 } else {
84 return quote_spanned! {
85 arm.pat.span() =>
86 compile_error!("expected string literal");
87 }
88 .into();
89 }
90 }
91 break;
92 }
93 }
94 }
95 }
96 }
97 }
98 }
99 let first_async = input
100 .block
101 .stmts
102 .iter()
103 .enumerate()
104 .find(|(_, s)| has_async(s))
105 .map(|(i, _)| i)
106 .unwrap_or(input.block.stmts.len());
107 let async_statements = input.block.stmts.split_off(first_async);
108 let last_statement: Stmt = syn::parse2(quote! {
109 ::tokio::runtime::Builder::new_current_thread()
110 .enable_all()
111 .build()
112 .unwrap()
113 .block_on(async {
114 #(#async_statements)*
115 });
116 })
117 .expect("Constructing tokio call");
118 let last_statement = if let Stmt::Semi(e, _) = last_statement {
119 Stmt::Expr(e)
120 } else {
121 last_statement
122 };
123 input.block.stmts.push(last_statement);
124 if cases.is_empty() {
125 let result = quote! {
126 #[::core::prelude::v1::test]
127 #input
128 };
129 result.into()
130 } else {
131 let mut functions = Vec::new();
132 for (e, b, n) in cases.into_iter() {
133 let mut f = input.clone();
134 f.sig.ident = syn::Ident::new(&format!("{}_{n}", f.sig.ident), f.sig.ident.span());
135 for stmt in f.block.stmts.iter_mut() {
136 if let Stmt::Local(local) = stmt {
137 if let Some((_, e)) = &mut local.init {
138 let is_case_match = if let syn::Expr::Match(m) = e.as_mut() {
139 if let syn::Expr::Path(p) = m.expr.as_ref() {
140 if let Some(i) = p.path.get_ident() {
141 format!("{i}") == "CASE"
142 } else {
143 false
144 }
145 } else {
146 false
147 }
148 } else {
149 false
150 };
151 if is_case_match {
152 *e = Box::new(b);
153 break;
154 }
155 }
156 }
157 }
158 f.block.stmts.insert(
159 0,
160 syn::parse2(quote! {
161 const CASE: &str = #e;
162 })
163 .unwrap(),
164 );
165 functions.push(quote! {
166 #[::core::prelude::v1::test]
167 #f
168 });
169 }
170 let result = quote! {
171 #( #functions )*
172 };
173 result.into()
174 }
175}