test_trace_macros/
lib.rs

1// Copyright (C) 2019-2023 Daniel Mueller <deso@posteo.net>
2// SPDX-License-Identifier: (Apache-2.0 OR MIT)
3
4extern crate proc_macro;
5
6use proc_macro::TokenStream;
7use proc_macro2::TokenStream as Tokens;
8
9use quote::quote;
10
11use syn::parse::Parse;
12use syn::parse_macro_input;
13use syn::Attribute;
14use syn::Expr;
15use syn::ItemFn;
16use syn::Lit;
17use syn::Meta;
18
19
20#[proc_macro_attribute]
21pub fn test(attr: TokenStream, item: TokenStream) -> TokenStream {
22  let item = parse_macro_input!(item as ItemFn);
23  try_test(attr, item)
24    .unwrap_or_else(syn::Error::into_compile_error)
25    .into()
26}
27
28fn parse_attrs(attrs: Vec<Attribute>) -> syn::Result<(AttributeArgs, Vec<Attribute>)> {
29  let mut attribute_args = AttributeArgs::default();
30  if cfg!(feature = "unstable") {
31    let mut ignored_attrs = vec![];
32    for attr in attrs {
33      let matched = attribute_args.try_parse_attr_single(&attr)?;
34      // Keep only attrs that didn't match the #[test_trace(_)] syntax.
35      if !matched {
36        ignored_attrs.push(attr);
37      }
38    }
39
40    Ok((attribute_args, ignored_attrs))
41  } else {
42    Ok((attribute_args, attrs))
43  }
44}
45
46fn try_test(attr: TokenStream, input: ItemFn) -> syn::Result<Tokens> {
47  let inner_test = if attr.is_empty() {
48    quote! { ::core::prelude::v1::test }
49  } else {
50    attr.into()
51  };
52
53  let ItemFn {
54    attrs,
55    vis,
56    sig,
57    block,
58  } = input;
59
60  let (attribute_args, ignored_attrs) = parse_attrs(attrs)?;
61  let logging_init = expand_logging_init(&attribute_args);
62  let tracing_init = expand_tracing_init(&attribute_args);
63
64  let result = quote! {
65    #[#inner_test]
66    #(#ignored_attrs)*
67    #vis #sig {
68      // We put all initialization code into a separate module here in
69      // order to prevent potential ambiguities that could result in
70      // compilation errors. E.g., client code could use traits that
71      // could have methods that interfere with ones we use as part of
72      // initialization; with a `Foo` trait that is implemented for T
73      // and that contains a `map` (or similarly common named) method
74      // that could cause an ambiguity with `Iterator::map`, for
75      // example.
76      // The alternative would be to use fully qualified call syntax in
77      // all initialization code, but that's much harder to control.
78      mod init {
79        pub fn init() {
80          #logging_init
81          #tracing_init
82        }
83      }
84
85      init::init();
86
87      #block
88    }
89  };
90  Ok(result)
91}
92
93
94#[derive(Debug, Default)]
95struct AttributeArgs {
96  default_log_filter: Option<String>,
97}
98
99impl AttributeArgs {
100  fn try_parse_attr_single(&mut self, attr: &Attribute) -> syn::Result<bool> {
101    if !attr.path().is_ident("test_trace") {
102      return Ok(false)
103    }
104
105    let nested_meta = attr.parse_args_with(Meta::parse)?;
106    let name_value = if let Meta::NameValue(name_value) = nested_meta {
107      name_value
108    } else {
109      return Err(syn::Error::new_spanned(
110        &nested_meta,
111        "Expected NameValue syntax, e.g. 'default_log_filter = \"debug\"'.",
112      ))
113    };
114
115    let ident = if let Some(ident) = name_value.path.get_ident() {
116      ident
117    } else {
118      return Err(syn::Error::new_spanned(
119        &name_value.path,
120        "Expected NameValue syntax, e.g. 'default_log_filter = \"debug\"'.",
121      ))
122    };
123
124    let arg_ref = if ident == "default_log_filter" {
125      &mut self.default_log_filter
126    } else {
127      return Err(syn::Error::new_spanned(
128        &name_value.path,
129        "Unrecognized attribute, see documentation for details.",
130      ))
131    };
132
133    if let Expr::Lit(lit) = &name_value.value {
134      if let Lit::Str(lit_str) = &lit.lit {
135        *arg_ref = Some(lit_str.value());
136      }
137    }
138
139    // If we couldn't parse the value on the right-hand side because it was some
140    // unexpected type, e.g. #[test_trace::log(default_log_filter=10)], return an error.
141    if arg_ref.is_none() {
142      return Err(syn::Error::new_spanned(
143        &name_value.value,
144        "Failed to parse value, expected a string",
145      ))
146    }
147
148    Ok(true)
149  }
150}
151
152
153/// Expand the initialization code for the `log` crate.
154#[cfg(all(feature = "log", not(feature = "trace")))]
155fn expand_logging_init(attribute_args: &AttributeArgs) -> Tokens {
156  let add_default_log_filter = if let Some(default_log_filter) = &attribute_args.default_log_filter
157  {
158    quote! {
159      let env_logger_builder = env_logger_builder
160        .parse_env(::test_trace::env_logger::Env::default().default_filter_or(#default_log_filter));
161    }
162  } else {
163    quote! {}
164  };
165
166  quote! {
167    {
168      let mut env_logger_builder = ::test_trace::env_logger::builder();
169      #add_default_log_filter
170      let _ = env_logger_builder.is_test(true).try_init();
171    }
172  }
173}
174
175#[cfg(not(all(feature = "log", not(feature = "trace"))))]
176fn expand_logging_init(_attribute_args: &AttributeArgs) -> Tokens {
177  quote! {}
178}
179
180/// Expand the initialization code for the `tracing` crate.
181#[cfg(feature = "trace")]
182fn expand_tracing_init(attribute_args: &AttributeArgs) -> Tokens {
183  let env_filter = if let Some(default_log_filter) = &attribute_args.default_log_filter {
184    quote! {
185      ::test_trace::tracing_subscriber::EnvFilter::builder()
186        .with_default_directive(
187          #default_log_filter
188            .parse()
189            .expect("test-trace: default_log_filter must be valid")
190        )
191        .from_env_lossy()
192    }
193  } else {
194    quote! {
195    ::test_trace::tracing_subscriber::EnvFilter::builder()
196      .with_default_directive(
197        ::test_trace::tracing_subscriber::filter::LevelFilter::TRACE.into()
198      ).from_env_lossy()
199    }
200  };
201
202  quote! {
203    {
204      let __internal_event_filter = {
205        use ::test_trace::tracing_subscriber::fmt::format::FmtSpan;
206
207        match ::std::env::var_os("RUST_LOG_SPAN_EVENTS") {
208          Some(mut value) => {
209            value.make_ascii_lowercase();
210            let value = value.to_str().expect("test-trace: RUST_LOG_SPAN_EVENTS must be valid UTF-8");
211            value
212              .split(",")
213              .map(|filter| match filter.trim() {
214                "new" => FmtSpan::NEW,
215                "enter" => FmtSpan::ENTER,
216                "exit" => FmtSpan::EXIT,
217                "close" => FmtSpan::CLOSE,
218                "active" => FmtSpan::ACTIVE,
219                "full" => FmtSpan::FULL,
220                _ => panic!("test-trace: RUST_LOG_SPAN_EVENTS must contain filters separated by `,`.\n\t\
221                  For example: `active` or `new,close`\n\t\
222                  Supported filters: new, enter, exit, close, active, full\n\t\
223                  Got: {}", value),
224              })
225              .fold(FmtSpan::NONE, |acc, filter| filter | acc)
226          },
227          None => FmtSpan::NONE,
228        }
229      };
230
231      let _ = ::test_trace::tracing_subscriber::FmtSubscriber::builder()
232        .with_env_filter(#env_filter)
233        .with_span_events(__internal_event_filter)
234        .with_test_writer()
235        .try_init();
236    }
237  }
238}
239
240#[cfg(not(feature = "trace"))]
241fn expand_tracing_init(_attribute_args: &AttributeArgs) -> Tokens {
242  quote! {}
243}