test_span_macro/
lib.rs

1use proc_macro::TokenStream;
2use proc_macro2::Ident;
3use proc_macro2::TokenStream as TokenStream2;
4use quote::quote;
5
6use syn::parse_macro_input;
7use syn::ExprAssign;
8use syn::ItemFn;
9use syn::Path;
10use syn::ReturnType;
11
12#[proc_macro_attribute]
13pub fn test_span(attr: TokenStream, item: TokenStream) -> TokenStream {
14    let test_fn = parse_macro_input!(item as ItemFn);
15
16    let macro_attrs = if attr.is_empty() {
17        quote! { test }
18    } else {
19        attr.into()
20    };
21
22    let fn_attrs = &test_fn.attrs;
23
24    let mut level = quote!(::test_span::reexports::tracing::Level::INFO);
25
26    let mut target_directives: Vec<_> = Vec::new();
27
28    // Get tracing level from #[level(tracing::Level::INFO)]
29    let fn_attrs = fn_attrs
30        .iter()
31        .filter(|attr| {
32            let path = &attr.path();
33            match quote!(#path).to_string().as_str() {
34                "level" => {
35                    let value: Path = attr.parse_args().expect(
36                        "wrong level attribute syntax. Example: #[level(tracing::Level::INFO)]",
37                    );
38                    level = quote!(#value);
39                    false
40                }
41                "target" => {
42                    let value: ExprAssign = attr.parse_args().expect("each targetFilter directive expects a single assignment expression. example: #[targetFilter(apollo_router=debug)]");
43                    // foo = Level::INFO => .with_target("foo".to_string(), Level::INFO)
44                    let name = value.left;
45                    let mut target_name = quote!(#name).to_string();
46                    target_name.retain(|c| !c.is_whitespace());
47
48                    let target_value = value.right;
49
50                    target_directives.push(quote!(.with_target(#target_name .to_string(), #target_value)));
51
52                    false
53                }
54                _ => true,
55            }
56        })
57        .collect::<Vec<_>>();
58
59    let maybe_async = &test_fn.sig.asyncness;
60
61    let body = &test_fn.block;
62    let test_name = &test_fn.sig.ident;
63    let output_type = &test_fn.sig.output;
64
65    let maybe_semicolon = if let ReturnType::Default = output_type {
66        quote! {;}
67    } else {
68        quote! {}
69    };
70
71    let run_test = if maybe_async.is_some() {
72        async_test(test_name)
73    } else {
74        sync_test(test_name)
75    };
76
77    let ret = quote! {#output_type};
78
79    let subscriber_boilerplate = subscriber_boilerplate(level, target_directives);
80
81    quote! {
82      #[#macro_attrs]
83      #(#fn_attrs)*
84      #maybe_async fn #test_name() #ret {
85        use ::test_span::reexports::tracing::Instrument;
86        #maybe_async fn #test_name(get_telemetry: impl Fn() -> (::test_span::Span, ::test_span::Records), get_logs: impl Fn() -> ::test_span::Records, get_spans: impl Fn() -> ::test_span::Span) #ret
87          #body
88
89
90        #subscriber_boilerplate
91
92        #run_test #maybe_semicolon
93      }
94    }
95    .into()
96}
97
98fn async_test(test_name: &Ident) -> TokenStream2 {
99    quote! {
100        #test_name(get_telemetry, get_logs, get_spans)
101            .instrument(root_span).await
102    }
103}
104
105fn sync_test(test_name: &Ident) -> TokenStream2 {
106    quote! {
107        root_span.in_scope(|| {
108            #test_name(get_telemetry, get_logs, get_spans)
109        });
110    }
111}
112fn subscriber_boilerplate(
113    level: TokenStream2,
114    target_directives: Vec<TokenStream2>,
115) -> TokenStream2 {
116    quote! {
117        let filter = ::test_span::Filter::new(#level) #(#target_directives)*;
118
119        ::test_span::init();
120
121        let root_span = ::test_span::reexports::tracing::span!(#level, "root");
122
123        let root_id = root_span.id().clone().expect("couldn't get root span id; this cannot happen.");
124
125        #[allow(unused)]
126        let get_telemetry = || ::test_span::get_telemetry_for_root(&root_id, &filter);
127
128        #[allow(unused)]
129        let get_logs = || ::test_span::get_logs_for_root(&root_id, &filter);
130
131
132        #[allow(unused)]
133        let get_spans = || ::test_span::get_spans_for_root(&root_id, &filter);
134    }
135}