1use proc_macro::TokenStream;
2use quote::{format_ident, quote, ToTokens};
3use syn::{
4 meta::ParseNestedMeta, parse_macro_input, spanned::Spanned, Error, ItemFn, LitInt, LitStr,
5 Result,
6};
7
8struct Attrs {
9 glob: Option<String>,
10 segments: usize,
11 index: bool,
12 extension: bool,
13}
14
15impl Attrs {
16 pub fn new() -> Self {
17 Self {
18 glob: None,
19 segments: 1,
20 index: false,
21 extension: false,
22 }
23 }
24
25 fn parse(&mut self, meta: ParseNestedMeta) -> Result<()> {
26 if meta.path.is_ident("glob") {
27 let glob: LitStr = meta.value()?.parse()?;
28 self.glob = Some(glob.value());
29 } else if meta.path.is_ident("name") {
30 meta.parse_nested_meta(|nested| {
31 if nested.path.is_ident("segments") {
32 let path_segments: LitInt = nested.value()?.parse()?;
33 self.segments = path_segments.base10_parse()?;
34 } else if nested.path.is_ident("index") {
35 self.index = true
36 } else if nested.path.is_ident("extension") {
37 self.extension = true
38 } else {
39 return Err(nested.error(
40 "unsupported property, specify `segments = <num>`, `index` or `extension`",
41 ));
42 }
43
44 Ok(())
45 })?;
46 } else {
47 return Err(meta.error("unsupported property, specify `glob` or `name`"));
48 }
49
50 Ok(())
51 }
52}
53
54#[proc_macro_attribute]
55pub fn test_each_file(args: TokenStream, input: TokenStream) -> TokenStream {
56 let mut attrs = Attrs::new();
57 let attr_parser = syn::meta::parser(|meta| attrs.parse(meta));
58 let input = parse_macro_input!(input as ItemFn);
59 parse_macro_input!(args with attr_parser);
60
61 match test_each(attrs, input, Kind::File) {
62 Ok(output) => output,
63 Err(err) => err.into_compile_error().into(),
64 }
65}
66
67#[proc_macro_attribute]
68pub fn test_each_blob(args: TokenStream, input: TokenStream) -> TokenStream {
69 let mut attrs = Attrs::new();
70 let attr_parser = syn::meta::parser(|meta| attrs.parse(meta));
71 let input = parse_macro_input!(input as ItemFn);
72 parse_macro_input!(args with attr_parser);
73
74 match test_each(attrs, input, Kind::Blob) {
75 Ok(output) => output,
76 Err(err) => err.into_compile_error().into(),
77 }
78}
79
80#[proc_macro_attribute]
81pub fn test_each_path(args: TokenStream, input: TokenStream) -> TokenStream {
82 let mut attrs = Attrs::new();
83 let attr_parser = syn::meta::parser(|meta| attrs.parse(meta));
84 let input = parse_macro_input!(input as ItemFn);
85 parse_macro_input!(args with attr_parser);
86
87 match test_each(attrs, input, Kind::Path) {
88 Ok(output) => output,
89 Err(err) => err.into_compile_error().into(),
90 }
91}
92
93enum Kind {
94 File,
95 Blob,
96 Path,
97}
98
99fn test_each(attrs: Attrs, input: ItemFn, kind: Kind) -> Result<TokenStream> {
100 let mut functions = vec![input.to_token_stream()];
101
102 let name = &input.sig.ident;
103 let vis = &input.vis;
104 let ret = &input.sig.output;
105 let n_args = input.sig.inputs.len();
106
107 let pattern = attrs
108 .glob
109 .as_ref()
110 .ok_or_else(|| Error::new(input.span(), "missing `glob` attribute"))?;
111
112 let files = glob::glob(pattern)
113 .map_err(|err| Error::new(input.span(), format!("invalid path glob pattern: {}", err)))?;
114
115 for (i, file) in files.enumerate() {
116 let mut file = file
117 .map_err(|err| Error::new(input.span(), format!("could not read directory: {}", err)))?
118 .canonicalize()
119 .map_err(|err| Error::new(input.span(), format!("could not read file: {}", err)))?;
120
121 match kind {
122 Kind::File | Kind::Blob if file.is_dir() => continue,
123 _ => {}
124 };
125
126 let path = file.to_string_lossy().to_string();
127
128 if !attrs.extension {
129 file.set_extension("");
130 }
131
132 let mut path_segments = file
133 .iter()
134 .rev()
135 .take(attrs.segments)
136 .map(|s| make_safe_ident(&s.to_string_lossy()))
137 .collect::<Vec<_>>();
138
139 path_segments.reverse();
140
141 let path_name = path_segments.join("_");
142
143 let test_name = if attrs.index {
144 format_ident!("{}_{}_{}", name, path_name, i)
145 } else {
146 format_ident!("{}_{}", name, path_name)
147 };
148
149 let into_path = quote!(&::std::path::PathBuf::from(#path));
150
151 let call = match kind {
152 Kind::File if n_args < 2 => quote!(#name(include_str!(#path))),
153 Kind::File => quote!(#name(include_str!(#path), #into_path)),
154 Kind::Blob if n_args < 2 => quote!(#name(include_bytes!(#path))),
155 Kind::Blob => quote!(#name(include_bytes!(#path), #into_path)),
156 Kind::Path => quote!(#name(#into_path)),
157 };
158
159 functions.push(quote! {
160
161 #[test]
162 #[allow(non_snake_case)]
163 #vis fn #test_name() #ret {
164 #call
165 }
166 });
167 }
168
169 Ok(quote!( #(#functions)* ).into())
170}
171
172fn make_safe_ident(value: &str) -> String {
173 let mut result = String::with_capacity(value.len());
174
175 for c in value.chars() {
176 if c.is_alphanumeric() {
177 result.push(c);
178 } else {
179 result.push('_');
180 }
181 }
182
183 result.trim_matches('_').to_string()
184}