test_context_macros/
lib.rs1mod args;
2
3use args::TestContextArgs;
4use proc_macro::TokenStream;
5use quote::{format_ident, quote};
6use syn::Ident;
7
8#[proc_macro_attribute]
29pub fn test_context(attr: TokenStream, item: TokenStream) -> TokenStream {
30 let args = syn::parse_macro_input!(attr as TestContextArgs);
31
32 let input = syn::parse_macro_input!(item as syn::ItemFn);
33 let ret = &input.sig.output;
34 let name = &input.sig.ident;
35 let arguments = &input.sig.inputs;
36 let body = &input.block;
37 let attrs = &input.attrs;
38 let is_async = input.sig.asyncness.is_some();
39
40 let wrapped_name = format_ident!("__test_context_wrapped_{}", name);
41
42 let wrapper_body = if is_async {
43 async_wrapper_body(args, &wrapped_name)
44 } else {
45 sync_wrapper_body(args, &wrapped_name)
46 };
47
48 let async_tag = if is_async {
49 quote! { async }
50 } else {
51 quote! {}
52 };
53
54 quote! {
55 #(#attrs)*
56 #async_tag fn #name() #ret #wrapper_body
57
58 #async_tag fn #wrapped_name(#arguments) #ret #body
59 }
60 .into()
61}
62
63fn async_wrapper_body(args: TestContextArgs, wrapped_name: &Ident) -> proc_macro2::TokenStream {
64 let context_type = args.context_type;
65 let result_name = format_ident!("wrapped_result");
66
67 let body = if args.skip_teardown {
68 quote! {
69 let ctx = <#context_type as test_context::AsyncTestContext>::setup().await;
70 let #result_name = std::panic::AssertUnwindSafe(
71 #wrapped_name(ctx)
72 ).catch_unwind().await;
73 }
74 } else {
75 quote! {
76 let mut ctx = <#context_type as test_context::AsyncTestContext>::setup().await;
77 let ctx_reference = &mut ctx;
78 let #result_name = std::panic::AssertUnwindSafe(
79 #wrapped_name(ctx_reference)
80 ).catch_unwind().await;
81 <#context_type as test_context::AsyncTestContext>::teardown(ctx).await;
82 }
83 };
84
85 let handle_wrapped_result = handle_result(result_name);
86
87 quote! {
88 {
89 use test_context::futures::FutureExt;
90 #body
91 #handle_wrapped_result
92 }
93 }
94}
95
96fn sync_wrapper_body(args: TestContextArgs, wrapped_name: &Ident) -> proc_macro2::TokenStream {
97 let context_type = args.context_type;
98 let result_name = format_ident!("wrapped_result");
99
100 let body = if args.skip_teardown {
101 quote! {
102 let ctx = <#context_type as test_context::TestContext>::setup();
103 let #result_name = std::panic::catch_unwind(move || {
104 #wrapped_name(ctx)
105 });
106 }
107 } else {
108 quote! {
109 let mut ctx = <#context_type as test_context::TestContext>::setup();
110 let mut pointer = std::panic::AssertUnwindSafe(&mut ctx);
111 let #result_name = std::panic::catch_unwind(move || {
112 #wrapped_name(*pointer)
113 });
114 <#context_type as test_context::TestContext>::teardown(ctx);
115 }
116 };
117
118 let handle_wrapped_result = handle_result(result_name);
119
120 quote! {
121 {
122 #body
123 #handle_wrapped_result
124 }
125 }
126}
127
128fn handle_result(result_name: Ident) -> proc_macro2::TokenStream {
129 quote! {
130 match #result_name {
131 Ok(value) => value,
132 Err(err) => {
133 std::panic::resume_unwind(err);
134 }
135 }
136 }
137}