test_context_macros/
lib.rs

1mod args;
2
3use args::TestContextArgs;
4use proc_macro::TokenStream;
5use quote::{format_ident, quote};
6use syn::Ident;
7
8/// Macro to use on tests to add the setup/teardown functionality of your context.
9///
10/// Ordering of this attribute is important, and typically `test_context` should come
11/// before other test attributes. For example, the following is valid:
12///
13/// ```ignore
14/// #[test_context(MyContext)]
15/// #[test]
16/// fn my_test() {
17/// }
18/// ```
19///
20/// The following is NOT valid...
21///
22/// ```ignore
23/// #[test]
24/// #[test_context(MyContext)]
25/// fn my_test() {
26/// }
27/// ```
28#[proc_macro_attribute]
29pub fn test_context(attr: TokenStream, item: TokenStream) -> TokenStream {
30    let args = syn::parse_macro_input!(attr as TestContextArgs);
31
32    let input = syn::parse_macro_input!(item as syn::ItemFn);
33    let ret = &input.sig.output;
34    let name = &input.sig.ident;
35    let arguments = &input.sig.inputs;
36    let body = &input.block;
37    let attrs = &input.attrs;
38    let is_async = input.sig.asyncness.is_some();
39
40    let wrapped_name = format_ident!("__test_context_wrapped_{}", name);
41
42    let wrapper_body = if is_async {
43        async_wrapper_body(args, &wrapped_name)
44    } else {
45        sync_wrapper_body(args, &wrapped_name)
46    };
47
48    let async_tag = if is_async {
49        quote! { async }
50    } else {
51        quote! {}
52    };
53
54    quote! {
55        #(#attrs)*
56        #async_tag fn #name() #ret #wrapper_body
57
58        #async_tag fn #wrapped_name(#arguments) #ret #body
59    }
60    .into()
61}
62
63fn async_wrapper_body(args: TestContextArgs, wrapped_name: &Ident) -> proc_macro2::TokenStream {
64    let context_type = args.context_type;
65    let result_name = format_ident!("wrapped_result");
66
67    let body = if args.skip_teardown {
68        quote! {
69            let ctx = <#context_type as test_context::AsyncTestContext>::setup().await;
70            let #result_name = std::panic::AssertUnwindSafe(
71                #wrapped_name(ctx)
72            ).catch_unwind().await;
73        }
74    } else {
75        quote! {
76            let mut ctx = <#context_type as test_context::AsyncTestContext>::setup().await;
77            let ctx_reference = &mut ctx;
78            let #result_name = std::panic::AssertUnwindSafe(
79                #wrapped_name(ctx_reference)
80            ).catch_unwind().await;
81            <#context_type as test_context::AsyncTestContext>::teardown(ctx).await;
82        }
83    };
84
85    let handle_wrapped_result = handle_result(result_name);
86
87    quote! {
88        {
89            use test_context::futures::FutureExt;
90            #body
91            #handle_wrapped_result
92        }
93    }
94}
95
96fn sync_wrapper_body(args: TestContextArgs, wrapped_name: &Ident) -> proc_macro2::TokenStream {
97    let context_type = args.context_type;
98    let result_name = format_ident!("wrapped_result");
99
100    let body = if args.skip_teardown {
101        quote! {
102            let ctx = <#context_type as test_context::TestContext>::setup();
103            let #result_name = std::panic::catch_unwind(move || {
104                #wrapped_name(ctx)
105            });
106        }
107    } else {
108        quote! {
109            let mut ctx = <#context_type as test_context::TestContext>::setup();
110            let mut pointer = std::panic::AssertUnwindSafe(&mut ctx);
111            let #result_name = std::panic::catch_unwind(move || {
112                #wrapped_name(*pointer)
113            });
114            <#context_type as test_context::TestContext>::teardown(ctx);
115        }
116    };
117
118    let handle_wrapped_result = handle_result(result_name);
119
120    quote! {
121        {
122            #body
123            #handle_wrapped_result
124        }
125    }
126}
127
128fn handle_result(result_name: Ident) -> proc_macro2::TokenStream {
129    quote! {
130        match #result_name {
131            Ok(value) => value,
132            Err(err) => {
133                std::panic::resume_unwind(err);
134            }
135        }
136    }
137}