test_context_macros/
lib.rs1mod args;
2
3use args::TestContextArgs;
4use proc_macro::TokenStream;
5use quote::{format_ident, quote};
6use syn::{Block, Ident};
7
8#[proc_macro_attribute]
29pub fn test_context(attr: TokenStream, item: TokenStream) -> TokenStream {
30 let args = syn::parse_macro_input!(attr as TestContextArgs);
31
32 let input = syn::parse_macro_input!(item as syn::ItemFn);
33 let is_async = input.sig.asyncness.is_some();
34
35 let (new_input, context_arg_name) =
36 extract_and_remove_context_arg(input.clone(), args.context_type.clone());
37
38 let wrapper_body = if is_async {
39 async_wrapper_body(args, &context_arg_name, &input.block)
40 } else {
41 sync_wrapper_body(args, &context_arg_name, &input.block)
42 };
43
44 let mut result_input = new_input;
45 result_input.block = Box::new(syn::parse2(wrapper_body).unwrap());
46
47 quote! { #result_input }.into()
48}
49
50fn async_wrapper_body(
51 args: TestContextArgs,
52 context_arg_name: &Option<syn::Ident>,
53 body: &Block,
54) -> proc_macro2::TokenStream {
55 let context_type = args.context_type;
56 let result_name = format_ident!("wrapped_result");
57
58 let binding = format_ident!("test_ctx");
59 let context_name = context_arg_name.as_ref().unwrap_or(&binding);
60
61 let body = if args.skip_teardown {
62 quote! {
63 let #context_name = <#context_type as test_context::AsyncTestContext>::setup().await;
64 let #result_name = std::panic::AssertUnwindSafe( async { #body } ).catch_unwind().await;
65 }
66 } else {
67 quote! {
68 let mut #context_name = <#context_type as test_context::AsyncTestContext>::setup().await;
69 let #result_name = std::panic::AssertUnwindSafe( async { #body } ).catch_unwind().await;
70 <#context_type as test_context::AsyncTestContext>::teardown(#context_name).await;
71 }
72 };
73
74 let handle_wrapped_result = handle_result(result_name);
75
76 quote! {
77 {
78 use test_context::futures::FutureExt;
79 #body
80 #handle_wrapped_result
81 }
82 }
83}
84
85fn sync_wrapper_body(
86 args: TestContextArgs,
87 context_arg_name: &Option<syn::Ident>,
88 body: &Block,
89) -> proc_macro2::TokenStream {
90 let context_type = args.context_type;
91 let result_name = format_ident!("wrapped_result");
92
93 let binding = format_ident!("test_ctx");
94 let context_name = context_arg_name.as_ref().unwrap_or(&binding);
95
96 let body = if args.skip_teardown {
97 quote! {
98 let mut #context_name = <#context_type as test_context::TestContext>::setup();
99 let #result_name = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
100 let #context_name = &mut #context_name;
101 #body
102 }));
103 }
104 } else {
105 quote! {
106 let mut #context_name = <#context_type as test_context::TestContext>::setup();
107 let #result_name = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
108 #body
109 }));
110 <#context_type as test_context::TestContext>::teardown(#context_name);
111 }
112 };
113
114 let handle_wrapped_result = handle_result(result_name);
115
116 quote! {
117 {
118 #body
119 #handle_wrapped_result
120 }
121 }
122}
123
124fn handle_result(result_name: Ident) -> proc_macro2::TokenStream {
125 quote! {
126 match #result_name {
127 Ok(value) => value,
128 Err(err) => {
129 std::panic::resume_unwind(err);
130 }
131 }
132 }
133}
134
135fn extract_and_remove_context_arg(
136 mut input: syn::ItemFn,
137 expected_context_type: syn::Type,
138) -> (syn::ItemFn, Option<syn::Ident>) {
139 let mut context_arg_name = None;
140 let mut new_args = syn::punctuated::Punctuated::new();
141
142 for arg in &input.sig.inputs {
143 if let syn::FnArg::Typed(pat_type) = arg {
145 if let syn::Pat::Ident(pat_ident) = &*pat_type.pat {
147 if let syn::Type::Reference(type_ref) = &*pat_type.ty {
149 if types_equal(&type_ref.elem, &expected_context_type) {
151 context_arg_name = Some(pat_ident.ident.clone());
152 continue;
153 }
154 }
155 }
156 }
157 new_args.push(arg.clone());
158 }
159
160 input.sig.inputs = new_args;
161 (input, context_arg_name)
162}
163
164fn types_equal(a: &syn::Type, b: &syn::Type) -> bool {
165 if let (syn::Type::Path(a_path), syn::Type::Path(b_path)) = (a, b) {
166 return a_path.path.segments.last().unwrap().ident
167 == b_path.path.segments.last().unwrap().ident;
168 }
169 quote!(#a).to_string() == quote!(#b).to_string()
170}