1use std::borrow::Cow;
7
8use proc_macro::TokenStream;
9use proc_macro2::TokenStream as Tokens;
10
11use quote::quote;
12
13use syn::parse::Parse;
14use syn::parse_macro_input;
15use syn::Attribute;
16use syn::Expr;
17use syn::ItemFn;
18use syn::Lit;
19use syn::Meta;
20
21
22#[allow(missing_docs)]
24#[proc_macro_attribute]
25pub fn test(attr: TokenStream, item: TokenStream) -> TokenStream {
26 let item = parse_macro_input!(item as ItemFn);
27 try_test(attr, item)
28 .unwrap_or_else(syn::Error::into_compile_error)
29 .into()
30}
31
32fn parse_attrs(attrs: Vec<Attribute>) -> syn::Result<(AttributeArgs, Vec<Attribute>)> {
33 let mut attribute_args = AttributeArgs::default();
34 if cfg!(feature = "unstable") {
35 let mut ignored_attrs = vec![];
36 for attr in attrs {
37 let matched = attribute_args.try_parse_attr_single(&attr)?;
38 if !matched {
40 ignored_attrs.push(attr);
41 }
42 }
43
44 Ok((attribute_args, ignored_attrs))
45 } else {
46 Ok((attribute_args, attrs))
47 }
48}
49
50fn is_test_attribute(attr: &Attribute) -> bool {
55 let path = match &attr.meta {
56 syn::Meta::Path(path) => path,
57 _ => return false,
58 };
59 let candidates = [
60 ["core", "prelude", "*", "test"],
61 ["std", "prelude", "*", "test"],
62 ];
63 if path.leading_colon.is_none()
64 && path.segments.len() == 1
65 && path.segments[0].arguments.is_none()
66 && path.segments[0].ident == "test"
67 {
68 return true;
69 } else if path.segments.len() != candidates[0].len() {
70 return false;
71 }
72 candidates.into_iter().any(|segments| {
73 path
74 .segments
75 .iter()
76 .zip(segments)
77 .all(|(segment, path)| segment.arguments.is_none() && (path == "*" || segment.ident == path))
78 })
79}
80
81fn try_test(attr: TokenStream, input: ItemFn) -> syn::Result<Tokens> {
82 let ItemFn {
83 attrs,
84 vis,
85 sig,
86 block,
87 } = input;
88
89 let (attribute_args, ignored_attrs) = parse_attrs(attrs)?;
90 let logging_init = expand_logging_init(&attribute_args);
91 let tracing_init = expand_tracing_init(&attribute_args);
92
93 let (inner_test, generated_test) = if attr.is_empty() {
94 let has_test = ignored_attrs.iter().any(is_test_attribute);
95 let generated_test = if has_test {
96 quote! {}
97 } else {
98 quote! { #[::core::prelude::v1::test]}
99 };
100 (quote! {}, generated_test)
101 } else {
102 let attr = Tokens::from(attr);
103 (quote! { #[#attr] }, quote! {})
104 };
105
106 let result = quote! {
107 #inner_test
108 #(#ignored_attrs)*
109 #generated_test
110 #vis #sig {
111 mod init {
122 pub fn init() {
123 #logging_init
124 #tracing_init
125 }
126 }
127
128 init::init();
129
130 #block
131 }
132 };
133 Ok(result)
134}
135
136
137#[derive(Debug, Default)]
138struct AttributeArgs {
139 default_log_filter: Option<Cow<'static, str>>,
140}
141
142impl AttributeArgs {
143 fn try_parse_attr_single(&mut self, attr: &Attribute) -> syn::Result<bool> {
144 if !attr.path().is_ident("test_log") {
145 return Ok(false)
146 }
147
148 let nested_meta = attr.parse_args_with(Meta::parse)?;
149 let name_value = if let Meta::NameValue(name_value) = nested_meta {
150 name_value
151 } else {
152 return Err(syn::Error::new_spanned(
153 &nested_meta,
154 "Expected NameValue syntax, e.g. 'default_log_filter = \"debug\"'.",
155 ))
156 };
157
158 let ident = if let Some(ident) = name_value.path.get_ident() {
159 ident
160 } else {
161 return Err(syn::Error::new_spanned(
162 &name_value.path,
163 "Expected NameValue syntax, e.g. 'default_log_filter = \"debug\"'.",
164 ))
165 };
166
167 let arg_ref = if ident == "default_log_filter" {
168 &mut self.default_log_filter
169 } else {
170 return Err(syn::Error::new_spanned(
171 &name_value.path,
172 "Unrecognized attribute, see documentation for details.",
173 ))
174 };
175
176 if let Expr::Lit(lit) = &name_value.value {
177 if let Lit::Str(lit_str) = &lit.lit {
178 *arg_ref = Some(Cow::from(lit_str.value()));
179 }
180 }
181
182 if arg_ref.is_none() {
185 return Err(syn::Error::new_spanned(
186 &name_value.value,
187 "Failed to parse value, expected a string",
188 ))
189 }
190
191 Ok(true)
192 }
193}
194
195
196#[cfg(all(feature = "log", not(feature = "trace")))]
198fn expand_logging_init(attribute_args: &AttributeArgs) -> Tokens {
199 let default_filter = attribute_args
200 .default_log_filter
201 .as_ref()
202 .unwrap_or(&Cow::Borrowed("info"));
203
204 quote! {
205 {
206 let _result = ::test_log::env_logger::builder()
207 .parse_env(
208 ::test_log::env_logger::Env::default()
209 .default_filter_or(#default_filter)
210 )
211 .target(::test_log::env_logger::Target::Stderr)
212 .is_test(true)
213 .try_init();
214 }
215 }
216}
217
218#[cfg(not(all(feature = "log", not(feature = "trace"))))]
219fn expand_logging_init(_attribute_args: &AttributeArgs) -> Tokens {
220 quote! {}
221}
222
223#[cfg(feature = "trace")]
225fn expand_tracing_init(attribute_args: &AttributeArgs) -> Tokens {
226 let env_filter = if let Some(default_log_filter) = &attribute_args.default_log_filter {
227 quote! {
228 ::test_log::tracing_subscriber::EnvFilter::builder()
229 .with_default_directive(
230 #default_log_filter
231 .parse()
232 .expect("test-log: default_log_filter must be valid")
233 )
234 .from_env_lossy()
235 }
236 } else {
237 quote! {
238 ::test_log::tracing_subscriber::EnvFilter::builder()
239 .with_default_directive(
240 ::test_log::tracing_subscriber::filter::LevelFilter::INFO.into()
241 )
242 .from_env_lossy()
243 }
244 };
245
246 quote! {
247 {
248 let __internal_event_filter = {
249 use ::test_log::tracing_subscriber::fmt::format::FmtSpan;
250
251 match ::std::env::var_os("RUST_LOG_SPAN_EVENTS") {
252 Some(mut value) => {
253 value.make_ascii_lowercase();
254 let value = value.to_str().expect("test-log: RUST_LOG_SPAN_EVENTS must be valid UTF-8");
255 value
256 .split(",")
257 .map(|filter| match filter.trim() {
258 "new" => FmtSpan::NEW,
259 "enter" => FmtSpan::ENTER,
260 "exit" => FmtSpan::EXIT,
261 "close" => FmtSpan::CLOSE,
262 "active" => FmtSpan::ACTIVE,
263 "full" => FmtSpan::FULL,
264 _ => panic!("test-log: RUST_LOG_SPAN_EVENTS must contain filters separated by `,`.\n\t\
265 For example: `active` or `new,close`\n\t\
266 Supported filters: new, enter, exit, close, active, full\n\t\
267 Got: {}", value),
268 })
269 .fold(FmtSpan::NONE, |acc, filter| filter | acc)
270 },
271 None => FmtSpan::NONE,
272 }
273 };
274
275 let _ = ::test_log::tracing_subscriber::FmtSubscriber::builder()
276 .with_env_filter(#env_filter)
277 .with_span_events(__internal_event_filter)
278 .with_writer(::test_log::tracing_subscriber::fmt::TestWriter::with_stderr)
279 .try_init();
280 }
281 }
282}
283
284#[cfg(not(feature = "trace"))]
285fn expand_tracing_init(_attribute_args: &AttributeArgs) -> Tokens {
286 quote! {}
287}