test_context_macros/
lib.rs

1mod args;
2
3use args::TestContextArgs;
4use proc_macro::TokenStream;
5use quote::{format_ident, quote};
6use syn::Ident;
7
8/// Macro to use on tests to add the setup/teardown functionality of your context.
9///
10/// Ordering of this attribute is important, and typically `test_context` should come
11/// before other test attributes. For example, the following is valid:
12///
13/// ```ignore
14/// #[test_context(MyContext)]
15/// #[test]
16/// fn my_test() {
17/// }
18/// ```
19///
20/// The following is NOT valid...
21///
22/// ```ignore
23/// #[test]
24/// #[test_context(MyContext)]
25/// fn my_test() {
26/// }
27/// ```
28#[proc_macro_attribute]
29pub fn test_context(attr: TokenStream, item: TokenStream) -> TokenStream {
30    let args = syn::parse_macro_input!(attr as TestContextArgs);
31    let input = syn::parse_macro_input!(item as syn::ItemFn);
32
33    let (input, context_arg_name) = remove_context_arg(input, args.context_type.clone());
34    let input = refactor_input_body(input, &args, context_arg_name);
35
36    quote! { #input }.into()
37}
38
39fn refactor_input_body(
40    mut input: syn::ItemFn,
41    args: &TestContextArgs,
42    context_arg_name: Option<Ident>,
43) -> syn::ItemFn {
44    let context_type = &args.context_type;
45    let context_arg_name = context_arg_name.unwrap_or_else(|| format_ident!("test_ctx"));
46    let result_name = format_ident!("wrapped_result");
47    let body = &input.block;
48    let is_async = input.sig.asyncness.is_some();
49
50    let body = match (is_async, args.skip_teardown) {
51        (true, true) => {
52            quote! {
53                use test_context::futures::FutureExt;
54                let mut __context = <#context_type as test_context::AsyncTestContext>::setup().await;
55                let #context_arg_name = &mut __context;
56                let #result_name = std::panic::AssertUnwindSafe( async { #body } ).catch_unwind().await;
57            }
58        }
59        (true, false) => {
60            quote! {
61                use test_context::futures::FutureExt;
62                let mut __context = <#context_type as test_context::AsyncTestContext>::setup().await;
63                let #context_arg_name = &mut __context;
64                let #result_name = std::panic::AssertUnwindSafe( async { #body } ).catch_unwind().await;
65                <#context_type as test_context::AsyncTestContext>::teardown(__context).await;
66            }
67        }
68        (false, true) => {
69            quote! {
70                let mut __context = <#context_type as test_context::TestContext>::setup();
71                let #result_name = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
72                    let #context_arg_name = &mut __context;
73                    #body
74                }));
75            }
76        }
77        (false, false) => {
78            quote! {
79                let mut __context = <#context_type as test_context::TestContext>::setup();
80                let #result_name = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
81                    let #context_arg_name = &mut __context;
82                    #body
83                }));
84                <#context_type as test_context::TestContext>::teardown(__context);
85            }
86        }
87    };
88
89    let body = quote! {
90        {
91            #body
92            match #result_name {
93                Ok(value) => value,
94                Err(err) => {
95                    std::panic::resume_unwind(err);
96                }
97            }
98        }
99    };
100
101    input.block = Box::new(syn::parse2(body).unwrap());
102
103    input
104}
105
106fn remove_context_arg(
107    mut input: syn::ItemFn,
108    expected_context_type: syn::Type,
109) -> (syn::ItemFn, Option<syn::Ident>) {
110    let mut context_arg_name = None;
111    let mut new_args = syn::punctuated::Punctuated::new();
112
113    for arg in &input.sig.inputs {
114        // Extract function arg:
115        if let syn::FnArg::Typed(pat_type) = arg {
116            // Extract arg identifier:
117            if let syn::Pat::Ident(pat_ident) = &*pat_type.pat {
118                // Check that context arg is only ref or mutable ref:
119                if let syn::Type::Reference(type_ref) = &*pat_type.ty {
120                    // Check that context has expected type:
121                    if types_equal(&type_ref.elem, &expected_context_type) {
122                        context_arg_name = Some(pat_ident.ident.clone());
123                        continue;
124                    }
125                }
126            }
127        }
128
129        new_args.push(arg.clone());
130    }
131
132    input.sig.inputs = new_args;
133
134    (input, context_arg_name)
135}
136
137fn types_equal(a: &syn::Type, b: &syn::Type) -> bool {
138    if let (syn::Type::Path(a_path), syn::Type::Path(b_path)) = (a, b) {
139        return a_path.path.segments.last().unwrap().ident
140            == b_path.path.segments.last().unwrap().ident;
141    }
142    quote!(#a).to_string() == quote!(#b).to_string()
143}