1use proc_macro::TokenStream;
2use proc_macro2::Ident;
3use proc_macro2::TokenStream as TokenStream2;
4use quote::quote;
5
6use syn::parse_macro_input;
7use syn::ExprAssign;
8use syn::ItemFn;
9use syn::Path;
10use syn::ReturnType;
11
12#[proc_macro_attribute]
13pub fn test_span(attr: TokenStream, item: TokenStream) -> TokenStream {
14 let test_fn = parse_macro_input!(item as ItemFn);
15
16 let macro_attrs = if attr.is_empty() {
17 quote! { test }
18 } else {
19 attr.into()
20 };
21
22 let fn_attrs = &test_fn.attrs;
23
24 let mut level = quote!(::test_span::reexports::tracing::Level::INFO);
25
26 let mut target_directives: Vec<_> = Vec::new();
27
28 let fn_attrs = fn_attrs
30 .iter()
31 .filter(|attr| {
32 let path = &attr.path();
33 match quote!(#path).to_string().as_str() {
34 "level" => {
35 let value: Path = attr.parse_args().expect(
36 "wrong level attribute syntax. Example: #[level(tracing::Level::INFO)]",
37 );
38 level = quote!(#value);
39 false
40 }
41 "target" => {
42 let value: ExprAssign = attr.parse_args().expect("each targetFilter directive expects a single assignment expression. example: #[targetFilter(apollo_router=debug)]");
43 let name = value.left;
45 let mut target_name = quote!(#name).to_string();
46 target_name.retain(|c| !c.is_whitespace());
47
48 let target_value = value.right;
49
50 target_directives.push(quote!(.with_target(#target_name .to_string(), #target_value)));
51
52 false
53 }
54 _ => true,
55 }
56 })
57 .collect::<Vec<_>>();
58
59 let maybe_async = &test_fn.sig.asyncness;
60
61 let body = &test_fn.block;
62 let test_name = &test_fn.sig.ident;
63 let output_type = &test_fn.sig.output;
64
65 let maybe_semicolon = if let ReturnType::Default = output_type {
66 quote! {;}
67 } else {
68 quote! {}
69 };
70
71 let run_test = if maybe_async.is_some() {
72 async_test(test_name)
73 } else {
74 sync_test(test_name)
75 };
76
77 let ret = quote! {#output_type};
78
79 let subscriber_boilerplate = subscriber_boilerplate(level, target_directives);
80
81 quote! {
82 #[#macro_attrs]
83 #(#fn_attrs)*
84 #maybe_async fn #test_name() #ret {
85 use ::test_span::reexports::tracing::Instrument;
86 #maybe_async fn #test_name(get_telemetry: impl Fn() -> (::test_span::Span, ::test_span::Records), get_logs: impl Fn() -> ::test_span::Records, get_spans: impl Fn() -> ::test_span::Span) #ret
87 #body
88
89
90 #subscriber_boilerplate
91
92 #run_test #maybe_semicolon
93 }
94 }
95 .into()
96}
97
98fn async_test(test_name: &Ident) -> TokenStream2 {
99 quote! {
100 #test_name(get_telemetry, get_logs, get_spans)
101 .instrument(root_span).await
102 }
103}
104
105fn sync_test(test_name: &Ident) -> TokenStream2 {
106 quote! {
107 root_span.in_scope(|| {
108 #test_name(get_telemetry, get_logs, get_spans)
109 });
110 }
111}
112fn subscriber_boilerplate(
113 level: TokenStream2,
114 target_directives: Vec<TokenStream2>,
115) -> TokenStream2 {
116 quote! {
117 let filter = ::test_span::Filter::new(#level) #(#target_directives)*;
118
119 ::test_span::init();
120
121 let root_span = ::test_span::reexports::tracing::span!(#level, "root");
122
123 let root_id = root_span.id().clone().expect("couldn't get root span id; this cannot happen.");
124
125 #[allow(unused)]
126 let get_telemetry = || ::test_span::get_telemetry_for_root(&root_id, &filter);
127
128 #[allow(unused)]
129 let get_logs = || ::test_span::get_logs_for_root(&root_id, &filter);
130
131
132 #[allow(unused)]
133 let get_spans = || ::test_span::get_spans_for_root(&root_id, &filter);
134 }
135}