test_pretty_log_macros/
lib.rs1extern crate proc_macro;
5
6use std::iter::Peekable;
7
8use proc_macro::TokenStream;
9use proc_macro2::TokenStream as Tokens;
10
11use quote::{quote, ToTokens};
12
13use syn::Attribute;
14use syn::ExprLit;
15use syn::LitBool;
16use syn::LitStr;
17use syn::parse_macro_input;
18use syn::Expr;
19use syn::ItemFn;
20use syn::Lit;
21use syn::Meta;
22use syn::punctuated::Punctuated;
23use syn::token::Comma;
24
25#[derive(Debug, Default)]
26struct MacroArgs {
27 inner_test: Option<Tokens>,
28 default_log_filter: Option<String>,
29 color: Option<bool>
30}
31
32impl MacroArgs {
33 fn from_punctuated(punctuated: Punctuated<Meta, Comma>) -> syn::Result<Self> {
34 let mut new_self = Self::default();
35 let mut punctuated_iter = punctuated.into_iter().peekable();
36
37 new_self.parse_inner_test(&mut punctuated_iter);
38 new_self.parse_name_value_args(&mut punctuated_iter)?;
39
40 Ok(new_self)
41 }
42
43 fn parse_inner_test<I: Iterator<Item = Meta>>(&mut self, punctuated: &mut Peekable<I>) {
44 if let Some(Meta::Path(_)) = punctuated.peek() {
45 self.inner_test = punctuated.next().map(|path| path.into_token_stream());
46 } else {
47 self.inner_test = None
48 }
49 }
50
51 fn parse_name_value_args<I: Iterator<Item = Meta>>(&mut self, punctuated: &mut I) -> syn::Result<()> {
52 for m in punctuated {
53 let name_value = m.require_name_value().map_err(map_name_value_error)?;
54 let ident = name_value.path.require_ident().map_err(map_name_value_error)?;
55 match ident.to_string().as_str() {
56 "default_log_filter" => self.default_log_filter = Some(require_lit_str(&name_value.value)?.value()),
57 "color" => self.color = Some(require_lit_bool(&name_value.value)?.value()),
58 _ => return Err(syn::Error::new_spanned(
59 &name_value.path,
60 "Unrecognized attribute, see documentation for details.",
61 ))
62 };
63 }
64
65 Ok(())
66 }
67}
68
69
70fn map_name_value_error(err: syn::Error) -> syn::Error {
71 syn::Error::new(err.span(), "Expected NameValue syntax, e.g. 'default_log_filter = \"debug\"'.")
72}
73
74fn require_lit_str(expr: &Expr) -> syn::Result<&LitStr> {
75 match expr {
76 Expr::Lit(ExprLit { lit: Lit::Str(lit_str), .. }) => Ok(lit_str),
77 _ => Err(syn::Error::new_spanned(
78 &expr,
79 "Failed to parse value, expected a string",
80 ))
81 }
82}
83
84fn require_lit_bool(expr: &Expr) -> syn::Result<&LitBool> {
85 match expr {
86 Expr::Lit(ExprLit { lit: Lit::Bool(lit_bool), .. }) => Ok(lit_bool),
87 _ => Err(syn::Error::new_spanned(
88 &expr,
89 "Failed to parse value, expected a bool",
90 ))
91 }
92}
93
94
95#[proc_macro_attribute]
96pub fn test(args: TokenStream, item: TokenStream) -> TokenStream {
97 let punctuated_args = parse_macro_input!(args with Punctuated::<Meta, syn::Token![,]>::parse_terminated);
98 let item = parse_macro_input!(item as ItemFn);
99
100 try_test(punctuated_args, item)
101 .unwrap_or_else(syn::Error::into_compile_error)
102 .into()
103}
104
105fn try_test(punctuated_args: Punctuated<Meta, Comma>, input: ItemFn) -> syn::Result<Tokens> {
106 let macro_args = MacroArgs::from_punctuated(punctuated_args)?;
107
108 let ItemFn {
109 attrs,
110 vis,
111 sig,
112 block,
113 } = input;
114
115 let test_attr = extract_test_attribute(¯o_args, &attrs);
116
117 let env_filter = ¯o_args.default_log_filter.map_or(quote! { None }, |s| quote! { Some(#s) });
118 let enable_ansi = ¯o_args.color.map_or(quote! { None }, |b| quote! { Some(#b) });
119
120 let result = quote! {
121 #test_attr
122 #(#attrs)*
123 #vis #sig {
124 let __default_tracing_subscriber_guard = ::test_pretty_log::runtime::init(#env_filter, #enable_ansi);
125
126 #block
127 }
128 };
129 Ok(result)
130}
131
132
133fn extract_test_attribute(macro_args: &MacroArgs, attrs: &Vec<Attribute>) -> Option<Tokens> {
136 if let Some(inner_test_arg) = ¯o_args.inner_test {
137 Some(quote! { #[#inner_test_arg] })
138 } else if attrs.iter().find(|&attr| is_test_attribute(attr)).is_none() {
139 Some(quote! { #[::core::prelude::v1::test] })
140 } else {
141 None
142 }
143}
144
145fn is_test_attribute(attribute: &Attribute) -> bool {
146 attribute.meta.path().segments.last().is_some_and(|seg| seg.ident == "test")
147}