Skip to main content

test_log_core/
lib.rs

1// Copyright (C) 2019-2026 Daniel Mueller <deso@posteo.net>
2// SPDX-License-Identifier: (Apache-2.0 OR MIT)
3
4//! Core logic for the `test-log` procedural macro.
5
6use std::borrow::Cow;
7
8use proc_macro2::TokenStream as Tokens;
9
10use quote::quote;
11
12use syn::parse::Parse;
13use syn::Attribute;
14use syn::Expr;
15use syn::ItemFn;
16use syn::Lit;
17use syn::Meta;
18
19
20/// Parse `#[test_log(...)]` attributes from a function's attribute
21/// list, separating them from other attributes.
22fn parse_attrs(attrs: Vec<Attribute>) -> syn::Result<(AttributeArgs, Vec<Attribute>)> {
23  let mut attribute_args = AttributeArgs::default();
24  if cfg!(feature = "unstable") {
25    let mut ignored_attrs = vec![];
26    for attr in attrs {
27      let matched = attribute_args.try_parse_attr_single(&attr)?;
28      // Keep only attrs that didn't match the #[test_log(_)] syntax.
29      if !matched {
30        ignored_attrs.push(attr);
31      }
32    }
33
34    Ok((attribute_args, ignored_attrs))
35  } else {
36    Ok((attribute_args, attrs))
37  }
38}
39
40/// Check whether given attribute is a test attribute of forms:
41/// * `#[test]`
42/// * `#[core::prelude::*::test]` or `#[::core::prelude::*::test]`
43/// * `#[std::prelude::*::test]` or `#[::std::prelude::*::test]`
44fn is_test_attribute(attr: &Attribute) -> bool {
45  let path = match &attr.meta {
46    syn::Meta::Path(path) => path,
47    _ => return false,
48  };
49  let candidates = [
50    ["core", "prelude", "*", "test"],
51    ["std", "prelude", "*", "test"],
52  ];
53  if path.leading_colon.is_none()
54    && path.segments.len() == 1
55    && path.segments[0].arguments.is_none()
56    && path.segments[0].ident == "test"
57  {
58    return true;
59  } else if path.segments.len() != candidates[0].len() {
60    return false;
61  }
62  candidates.into_iter().any(|segments| {
63    path
64      .segments
65      .iter()
66      .zip(segments)
67      .all(|(segment, path)| segment.arguments.is_none() && (path == "*" || segment.ident == path))
68  })
69}
70
71
72/// Main expansion logic for `#[test_log::test]`.
73pub fn try_test(attr: Tokens, input: ItemFn) -> syn::Result<Tokens> {
74  let ItemFn {
75    attrs,
76    vis,
77    sig,
78    mut block,
79  } = input;
80
81  let (attribute_args, ignored_attrs) = parse_attrs(attrs)?;
82  let logging_init = expand_logging_init(&attribute_args);
83  let tracing_init = expand_tracing_init(&attribute_args);
84
85  let (inner_test, generated_test) = if attr.is_empty() {
86    let has_test = ignored_attrs.iter().any(is_test_attribute);
87    let generated_test = if has_test {
88      quote! {}
89    } else {
90      quote! { #[::core::prelude::v1::test]}
91    };
92    (quote! {}, generated_test)
93  } else {
94    (quote! { #[#attr] }, quote! {})
95  };
96
97  // Insert the initialization prologue into the existing block, so that
98  // the fn body keeps its original source spans. Building a fresh outer
99  // block with `quote!` here would give it a span covering the whole
100  // annotated fn. That makes rust-analyzer generate spurious
101  // jump-to-definition results for code inside the test; see
102  // <https://github.com/rust-lang/rust-analyzer/issues/20441>.
103  let prologue = quote! {
104    {
105      // We put all initialization code into a separate module here in
106      // order to prevent potential ambiguities that could result in
107      // compilation errors. E.g., client code could use traits that
108      // could have methods that interfere with ones we use as part of
109      // initialization; with a `Foo` trait that is implemented for T
110      // and that contains a `map` (or similarly common named) method
111      // that could cause an ambiguity with `Iterator::map`, for
112      // example.
113      // The alternative would be to use fully qualified call syntax in
114      // all initialization code, but that's much harder to control.
115      mod init {
116        pub fn init() {
117          #logging_init
118          #tracing_init
119        }
120      }
121
122      init::init();
123    }
124  };
125  let prologue_block: syn::Stmt = syn::parse2(prologue)?;
126  block.stmts.insert(0, prologue_block);
127
128  let result = quote! {
129    #inner_test
130    #(#ignored_attrs)*
131    #generated_test
132    #vis #sig #block
133  };
134  Ok(result)
135}
136
137
138/// Parsed `#[test_log(...)]` attributes.
139#[derive(Debug, Default)]
140struct AttributeArgs {
141  /// The default log filter directive (e.g., `"debug"`).
142  default_log_filter: Option<Cow<'static, str>>,
143}
144
145impl AttributeArgs {
146  /// Try to parse a single `#[test_log(...)]` attribute.
147  fn try_parse_attr_single(&mut self, attr: &Attribute) -> syn::Result<bool> {
148    if !attr.path().is_ident("test_log") {
149      return Ok(false)
150    }
151
152    let nested_meta = attr.parse_args_with(Meta::parse)?;
153    let name_value = if let Meta::NameValue(name_value) = nested_meta {
154      name_value
155    } else {
156      return Err(syn::Error::new_spanned(
157        &nested_meta,
158        "Expected NameValue syntax, e.g. 'default_log_filter = \"debug\"'.",
159      ))
160    };
161
162    let ident = if let Some(ident) = name_value.path.get_ident() {
163      ident
164    } else {
165      return Err(syn::Error::new_spanned(
166        &name_value.path,
167        "Expected NameValue syntax, e.g. 'default_log_filter = \"debug\"'.",
168      ))
169    };
170
171    let arg_ref = if ident == "default_log_filter" {
172      &mut self.default_log_filter
173    } else {
174      return Err(syn::Error::new_spanned(
175        &name_value.path,
176        "Unrecognized attribute, see documentation for details.",
177      ))
178    };
179
180    if let Expr::Lit(lit) = &name_value.value {
181      if let Lit::Str(lit_str) = &lit.lit {
182        *arg_ref = Some(Cow::from(lit_str.value()));
183      }
184    }
185
186    // If we couldn't parse the value on the right-hand side because it was some
187    // unexpected type, e.g. #[test_log::log(default_log_filter=10)], return an error.
188    if arg_ref.is_none() {
189      return Err(syn::Error::new_spanned(
190        &name_value.value,
191        "Failed to parse value, expected a string",
192      ))
193    }
194
195    Ok(true)
196  }
197}
198
199
200/// Expand the initialization code for the `log` crate.
201#[cfg(all(feature = "log", not(feature = "trace")))]
202fn expand_logging_init(attribute_args: &AttributeArgs) -> Tokens {
203  let default_filter = attribute_args
204    .default_log_filter
205    .as_ref()
206    .unwrap_or(&Cow::Borrowed("info"));
207
208  quote! {
209    {
210      let _result = ::test_log::env_logger::builder()
211        .parse_env(
212          ::test_log::env_logger::Env::default()
213            .default_filter_or(#default_filter)
214        )
215        .target(::test_log::env_logger::Target::Stderr)
216        .is_test(true)
217        .try_init();
218    }
219  }
220}
221
222#[cfg(not(all(feature = "log", not(feature = "trace"))))]
223fn expand_logging_init(_attribute_args: &AttributeArgs) -> Tokens {
224  quote! {}
225}
226
227/// Expand the initialization code for the `tracing` crate.
228#[cfg(feature = "trace")]
229fn expand_tracing_init(attribute_args: &AttributeArgs) -> Tokens {
230  let env_filter = if let Some(default_log_filter) = &attribute_args.default_log_filter {
231    quote! {
232      ::test_log::tracing_subscriber::EnvFilter::builder()
233        .with_default_directive(
234          #default_log_filter
235            .parse()
236            .expect("test-log: default_log_filter must be valid")
237        )
238        .from_env_lossy()
239    }
240  } else {
241    quote! {
242      ::test_log::tracing_subscriber::EnvFilter::builder()
243        .with_default_directive(
244          ::test_log::tracing_subscriber::filter::LevelFilter::INFO.into()
245        )
246        .from_env_lossy()
247    }
248  };
249
250  quote! {
251    {
252      let __internal_event_filter = {
253        use ::test_log::tracing_subscriber::fmt::format::FmtSpan;
254
255        match ::std::env::var_os("RUST_LOG_SPAN_EVENTS") {
256          Some(mut value) => {
257            value.make_ascii_lowercase();
258            let value = value.to_str().expect("test-log: RUST_LOG_SPAN_EVENTS must be valid UTF-8");
259            value
260              .split(",")
261              .map(|filter| match filter.trim() {
262                "new" => FmtSpan::NEW,
263                "enter" => FmtSpan::ENTER,
264                "exit" => FmtSpan::EXIT,
265                "close" => FmtSpan::CLOSE,
266                "active" => FmtSpan::ACTIVE,
267                "full" => FmtSpan::FULL,
268                _ => panic!("test-log: RUST_LOG_SPAN_EVENTS must contain filters separated by `,`.\n\t\
269                  For example: `active` or `new,close`\n\t\
270                  Supported filters: new, enter, exit, close, active, full\n\t\
271                  Got: {}", value),
272              })
273              .fold(FmtSpan::NONE, |acc, filter| filter | acc)
274          },
275          None => FmtSpan::NONE,
276        }
277      };
278
279      let _ = ::test_log::tracing_subscriber::FmtSubscriber::builder()
280        .with_env_filter(#env_filter)
281        .with_span_events(__internal_event_filter)
282        .with_writer(::test_log::tracing_subscriber::fmt::TestWriter::with_stderr)
283        .try_init();
284    }
285  }
286}
287
288#[cfg(not(feature = "trace"))]
289fn expand_tracing_init(_attribute_args: &AttributeArgs) -> Tokens {
290  quote! {}
291}