test_context_macros/
lib.rs

1mod args;
2
3use args::TestContextArgs;
4use proc_macro::TokenStream;
5use quote::{format_ident, quote};
6use syn::{Block, Ident};
7
8/// Macro to use on tests to add the setup/teardown functionality of your context.
9///
10/// Ordering of this attribute is important, and typically `test_context` should come
11/// before other test attributes. For example, the following is valid:
12///
13/// ```ignore
14/// #[test_context(MyContext)]
15/// #[test]
16/// fn my_test() {
17/// }
18/// ```
19///
20/// The following is NOT valid...
21///
22/// ```ignore
23/// #[test]
24/// #[test_context(MyContext)]
25/// fn my_test() {
26/// }
27/// ```
28#[proc_macro_attribute]
29pub fn test_context(attr: TokenStream, item: TokenStream) -> TokenStream {
30    let args = syn::parse_macro_input!(attr as TestContextArgs);
31
32    let input = syn::parse_macro_input!(item as syn::ItemFn);
33    let is_async = input.sig.asyncness.is_some();
34
35    let (new_input, context_arg_name) =
36        extract_and_remove_context_arg(input.clone(), args.context_type.clone());
37
38    let wrapper_body = if is_async {
39        async_wrapper_body(args, &context_arg_name, &input.block)
40    } else {
41        sync_wrapper_body(args, &context_arg_name, &input.block)
42    };
43
44    let mut result_input = new_input;
45    result_input.block = Box::new(syn::parse2(wrapper_body).unwrap());
46
47    quote! { #result_input }.into()
48}
49
50fn async_wrapper_body(
51    args: TestContextArgs,
52    context_arg_name: &Option<syn::Ident>,
53    body: &Block,
54) -> proc_macro2::TokenStream {
55    let context_type = args.context_type;
56    let result_name = format_ident!("wrapped_result");
57
58    let binding = format_ident!("test_ctx");
59    let context_name = context_arg_name.as_ref().unwrap_or(&binding);
60
61    let body = if args.skip_teardown {
62        quote! {
63            let #context_name = <#context_type as test_context::AsyncTestContext>::setup().await;
64            let #result_name = std::panic::AssertUnwindSafe( async { #body } ).catch_unwind().await;
65        }
66    } else {
67        quote! {
68            let mut #context_name = <#context_type as test_context::AsyncTestContext>::setup().await;
69            let #result_name = std::panic::AssertUnwindSafe( async { #body } ).catch_unwind().await;
70            <#context_type as test_context::AsyncTestContext>::teardown(#context_name).await;
71        }
72    };
73
74    let handle_wrapped_result = handle_result(result_name);
75
76    quote! {
77        {
78            use test_context::futures::FutureExt;
79            #body
80            #handle_wrapped_result
81        }
82    }
83}
84
85fn sync_wrapper_body(
86    args: TestContextArgs,
87    context_arg_name: &Option<syn::Ident>,
88    body: &Block,
89) -> proc_macro2::TokenStream {
90    let context_type = args.context_type;
91    let result_name = format_ident!("wrapped_result");
92
93    let binding = format_ident!("test_ctx");
94    let context_name = context_arg_name.as_ref().unwrap_or(&binding);
95
96    let body = if args.skip_teardown {
97        quote! {
98            let mut #context_name = <#context_type as test_context::TestContext>::setup();
99            let #result_name = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
100                let #context_name = &mut #context_name;
101                #body
102            }));
103        }
104    } else {
105        quote! {
106            let mut #context_name = <#context_type as test_context::TestContext>::setup();
107            let #result_name = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
108                #body
109            }));
110            <#context_type as test_context::TestContext>::teardown(#context_name);
111        }
112    };
113
114    let handle_wrapped_result = handle_result(result_name);
115
116    quote! {
117        {
118            #body
119            #handle_wrapped_result
120        }
121    }
122}
123
124fn handle_result(result_name: Ident) -> proc_macro2::TokenStream {
125    quote! {
126        match #result_name {
127            Ok(value) => value,
128            Err(err) => {
129                std::panic::resume_unwind(err);
130            }
131        }
132    }
133}
134
135fn extract_and_remove_context_arg(
136    mut input: syn::ItemFn,
137    expected_context_type: syn::Type,
138) -> (syn::ItemFn, Option<syn::Ident>) {
139    let mut context_arg_name = None;
140    let mut new_args = syn::punctuated::Punctuated::new();
141
142    for arg in &input.sig.inputs {
143        // Extract function arg:
144        if let syn::FnArg::Typed(pat_type) = arg {
145            // Extract arg identifier:
146            if let syn::Pat::Ident(pat_ident) = &*pat_type.pat {
147                // Check that context arg is only ref or mutable ref:
148                if let syn::Type::Reference(type_ref) = &*pat_type.ty {
149                    // Check that context has expected type:
150                    if types_equal(&type_ref.elem, &expected_context_type) {
151                        context_arg_name = Some(pat_ident.ident.clone());
152                        continue;
153                    }
154                }
155            }
156        }
157        new_args.push(arg.clone());
158    }
159
160    input.sig.inputs = new_args;
161    (input, context_arg_name)
162}
163
164fn types_equal(a: &syn::Type, b: &syn::Type) -> bool {
165    if let (syn::Type::Path(a_path), syn::Type::Path(b_path)) = (a, b) {
166        return a_path.path.segments.last().unwrap().ident
167            == b_path.path.segments.last().unwrap().ident;
168    }
169    quote!(#a).to_string() == quote!(#b).to_string()
170}