1use std::borrow::Cow;
7
8use proc_macro2::TokenStream as Tokens;
9
10use quote::quote;
11
12use syn::parse::Parse;
13use syn::Attribute;
14use syn::Expr;
15use syn::ItemFn;
16use syn::Lit;
17use syn::Meta;
18
19
20fn parse_attrs(attrs: Vec<Attribute>) -> syn::Result<(AttributeArgs, Vec<Attribute>)> {
23 let mut attribute_args = AttributeArgs::default();
24 if cfg!(feature = "unstable") {
25 let mut ignored_attrs = vec![];
26 for attr in attrs {
27 let matched = attribute_args.try_parse_attr_single(&attr)?;
28 if !matched {
30 ignored_attrs.push(attr);
31 }
32 }
33
34 Ok((attribute_args, ignored_attrs))
35 } else {
36 Ok((attribute_args, attrs))
37 }
38}
39
40fn is_test_attribute(attr: &Attribute) -> bool {
45 let path = match &attr.meta {
46 syn::Meta::Path(path) => path,
47 _ => return false,
48 };
49 let candidates = [
50 ["core", "prelude", "*", "test"],
51 ["std", "prelude", "*", "test"],
52 ];
53 if path.leading_colon.is_none()
54 && path.segments.len() == 1
55 && path.segments[0].arguments.is_none()
56 && path.segments[0].ident == "test"
57 {
58 return true;
59 } else if path.segments.len() != candidates[0].len() {
60 return false;
61 }
62 candidates.into_iter().any(|segments| {
63 path
64 .segments
65 .iter()
66 .zip(segments)
67 .all(|(segment, path)| segment.arguments.is_none() && (path == "*" || segment.ident == path))
68 })
69}
70
71
72pub fn try_test(attr: Tokens, input: ItemFn) -> syn::Result<Tokens> {
74 let ItemFn {
75 attrs,
76 vis,
77 sig,
78 mut block,
79 } = input;
80
81 let (attribute_args, ignored_attrs) = parse_attrs(attrs)?;
82 let logging_init = expand_logging_init(&attribute_args);
83 let tracing_init = expand_tracing_init(&attribute_args);
84
85 let (inner_test, generated_test) = if attr.is_empty() {
86 let has_test = ignored_attrs.iter().any(is_test_attribute);
87 let generated_test = if has_test {
88 quote! {}
89 } else {
90 quote! { #[::core::prelude::v1::test]}
91 };
92 (quote! {}, generated_test)
93 } else {
94 (quote! { #[#attr] }, quote! {})
95 };
96
97 let prologue = quote! {
104 {
105 mod init {
116 pub fn init() {
117 #logging_init
118 #tracing_init
119 }
120 }
121
122 init::init();
123 }
124 };
125 let prologue_block: syn::Stmt = syn::parse2(prologue)?;
126 block.stmts.insert(0, prologue_block);
127
128 let result = quote! {
129 #inner_test
130 #(#ignored_attrs)*
131 #generated_test
132 #vis #sig #block
133 };
134 Ok(result)
135}
136
137
138#[derive(Debug, Default)]
140struct AttributeArgs {
141 default_log_filter: Option<Cow<'static, str>>,
143}
144
145impl AttributeArgs {
146 fn try_parse_attr_single(&mut self, attr: &Attribute) -> syn::Result<bool> {
148 if !attr.path().is_ident("test_log") {
149 return Ok(false)
150 }
151
152 let nested_meta = attr.parse_args_with(Meta::parse)?;
153 let name_value = if let Meta::NameValue(name_value) = nested_meta {
154 name_value
155 } else {
156 return Err(syn::Error::new_spanned(
157 &nested_meta,
158 "Expected NameValue syntax, e.g. 'default_log_filter = \"debug\"'.",
159 ))
160 };
161
162 let ident = if let Some(ident) = name_value.path.get_ident() {
163 ident
164 } else {
165 return Err(syn::Error::new_spanned(
166 &name_value.path,
167 "Expected NameValue syntax, e.g. 'default_log_filter = \"debug\"'.",
168 ))
169 };
170
171 let arg_ref = if ident == "default_log_filter" {
172 &mut self.default_log_filter
173 } else {
174 return Err(syn::Error::new_spanned(
175 &name_value.path,
176 "Unrecognized attribute, see documentation for details.",
177 ))
178 };
179
180 if let Expr::Lit(lit) = &name_value.value {
181 if let Lit::Str(lit_str) = &lit.lit {
182 *arg_ref = Some(Cow::from(lit_str.value()));
183 }
184 }
185
186 if arg_ref.is_none() {
189 return Err(syn::Error::new_spanned(
190 &name_value.value,
191 "Failed to parse value, expected a string",
192 ))
193 }
194
195 Ok(true)
196 }
197}
198
199
200#[cfg(all(feature = "log", not(feature = "trace")))]
202fn expand_logging_init(attribute_args: &AttributeArgs) -> Tokens {
203 let default_filter = attribute_args
204 .default_log_filter
205 .as_ref()
206 .unwrap_or(&Cow::Borrowed("info"));
207
208 quote! {
209 {
210 let _result = ::test_log::env_logger::builder()
211 .parse_env(
212 ::test_log::env_logger::Env::default()
213 .default_filter_or(#default_filter)
214 )
215 .target(::test_log::env_logger::Target::Stderr)
216 .is_test(true)
217 .try_init();
218 }
219 }
220}
221
222#[cfg(not(all(feature = "log", not(feature = "trace"))))]
223fn expand_logging_init(_attribute_args: &AttributeArgs) -> Tokens {
224 quote! {}
225}
226
227#[cfg(feature = "trace")]
229fn expand_tracing_init(attribute_args: &AttributeArgs) -> Tokens {
230 let env_filter = if let Some(default_log_filter) = &attribute_args.default_log_filter {
231 quote! {
232 ::test_log::tracing_subscriber::EnvFilter::builder()
233 .with_default_directive(
234 #default_log_filter
235 .parse()
236 .expect("test-log: default_log_filter must be valid")
237 )
238 .from_env_lossy()
239 }
240 } else {
241 quote! {
242 ::test_log::tracing_subscriber::EnvFilter::builder()
243 .with_default_directive(
244 ::test_log::tracing_subscriber::filter::LevelFilter::INFO.into()
245 )
246 .from_env_lossy()
247 }
248 };
249
250 quote! {
251 {
252 let __internal_event_filter = {
253 use ::test_log::tracing_subscriber::fmt::format::FmtSpan;
254
255 match ::std::env::var_os("RUST_LOG_SPAN_EVENTS") {
256 Some(mut value) => {
257 value.make_ascii_lowercase();
258 let value = value.to_str().expect("test-log: RUST_LOG_SPAN_EVENTS must be valid UTF-8");
259 value
260 .split(",")
261 .map(|filter| match filter.trim() {
262 "new" => FmtSpan::NEW,
263 "enter" => FmtSpan::ENTER,
264 "exit" => FmtSpan::EXIT,
265 "close" => FmtSpan::CLOSE,
266 "active" => FmtSpan::ACTIVE,
267 "full" => FmtSpan::FULL,
268 _ => panic!("test-log: RUST_LOG_SPAN_EVENTS must contain filters separated by `,`.\n\t\
269 For example: `active` or `new,close`\n\t\
270 Supported filters: new, enter, exit, close, active, full\n\t\
271 Got: {}", value),
272 })
273 .fold(FmtSpan::NONE, |acc, filter| filter | acc)
274 },
275 None => FmtSpan::NONE,
276 }
277 };
278
279 let _ = ::test_log::tracing_subscriber::FmtSubscriber::builder()
280 .with_env_filter(#env_filter)
281 .with_span_events(__internal_event_filter)
282 .with_writer(::test_log::tracing_subscriber::fmt::TestWriter::with_stderr)
283 .try_init();
284 }
285 }
286}
287
288#[cfg(not(feature = "trace"))]
289fn expand_tracing_init(_attribute_args: &AttributeArgs) -> Tokens {
290 quote! {}
291}