test_context_macros/
lib.rs1mod args;
2
3use args::TestContextArgs;
4use proc_macro::TokenStream;
5use quote::{format_ident, quote};
6use syn::Ident;
7
8#[proc_macro_attribute]
29pub fn test_context(attr: TokenStream, item: TokenStream) -> TokenStream {
30 let args = syn::parse_macro_input!(attr as TestContextArgs);
31 let input = syn::parse_macro_input!(item as syn::ItemFn);
32
33 let (input, context_arg_name) = remove_context_arg(input, args.context_type.clone());
34 let input = refactor_input_body(input, &args, context_arg_name);
35
36 quote! { #input }.into()
37}
38
39fn refactor_input_body(
40 mut input: syn::ItemFn,
41 args: &TestContextArgs,
42 context_arg_name: Option<Ident>,
43) -> syn::ItemFn {
44 let context_type = &args.context_type;
45 let context_arg_name = context_arg_name.unwrap_or_else(|| format_ident!("test_ctx"));
46 let result_name = format_ident!("wrapped_result");
47 let body = &input.block;
48 let is_async = input.sig.asyncness.is_some();
49
50 let body = match (is_async, args.skip_teardown) {
51 (true, true) => {
52 quote! {
53 use test_context::futures::FutureExt;
54 let mut __context = <#context_type as test_context::AsyncTestContext>::setup().await;
55 let #context_arg_name = &mut __context;
56 let #result_name = std::panic::AssertUnwindSafe( async { #body } ).catch_unwind().await;
57 }
58 }
59 (true, false) => {
60 quote! {
61 use test_context::futures::FutureExt;
62 let mut __context = <#context_type as test_context::AsyncTestContext>::setup().await;
63 let #context_arg_name = &mut __context;
64 let #result_name = std::panic::AssertUnwindSafe( async { #body } ).catch_unwind().await;
65 <#context_type as test_context::AsyncTestContext>::teardown(__context).await;
66 }
67 }
68 (false, true) => {
69 quote! {
70 let mut __context = <#context_type as test_context::TestContext>::setup();
71 let #result_name = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
72 let #context_arg_name = &mut __context;
73 #body
74 }));
75 }
76 }
77 (false, false) => {
78 quote! {
79 let mut __context = <#context_type as test_context::TestContext>::setup();
80 let #result_name = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
81 let #context_arg_name = &mut __context;
82 #body
83 }));
84 <#context_type as test_context::TestContext>::teardown(__context);
85 }
86 }
87 };
88
89 let body = quote! {
90 {
91 #body
92 match #result_name {
93 Ok(value) => value,
94 Err(err) => {
95 std::panic::resume_unwind(err);
96 }
97 }
98 }
99 };
100
101 input.block = Box::new(syn::parse2(body).unwrap());
102
103 input
104}
105
106fn remove_context_arg(
107 mut input: syn::ItemFn,
108 expected_context_type: syn::Type,
109) -> (syn::ItemFn, Option<syn::Ident>) {
110 let mut context_arg_name = None;
111 let mut new_args = syn::punctuated::Punctuated::new();
112
113 for arg in &input.sig.inputs {
114 if let syn::FnArg::Typed(pat_type) = arg {
116 if let syn::Pat::Ident(pat_ident) = &*pat_type.pat {
118 if let syn::Type::Reference(type_ref) = &*pat_type.ty {
120 if types_equal(&type_ref.elem, &expected_context_type) {
122 context_arg_name = Some(pat_ident.ident.clone());
123 continue;
124 }
125 }
126 }
127 }
128
129 new_args.push(arg.clone());
130 }
131
132 input.sig.inputs = new_args;
133
134 (input, context_arg_name)
135}
136
137fn types_equal(a: &syn::Type, b: &syn::Type) -> bool {
138 if let (syn::Type::Path(a_path), syn::Type::Path(b_path)) = (a, b) {
139 return a_path.path.segments.last().unwrap().ident
140 == b_path.path.segments.last().unwrap().ident;
141 }
142 quote!(#a).to_string() == quote!(#b).to_string()
143}