test_each_codegen/
lib.rs

1use proc_macro::TokenStream;
2use quote::{format_ident, quote, ToTokens};
3use syn::{
4    meta::ParseNestedMeta, parse_macro_input, spanned::Spanned, Error, ItemFn, LitInt, LitStr,
5    Result,
6};
7
8struct Attrs {
9    glob: Option<String>,
10    segments: usize,
11    index: bool,
12    extension: bool,
13}
14
15impl Attrs {
16    pub fn new() -> Self {
17        Self {
18            glob: None,
19            segments: 1,
20            index: false,
21            extension: false,
22        }
23    }
24
25    fn parse(&mut self, meta: ParseNestedMeta) -> Result<()> {
26        if meta.path.is_ident("glob") {
27            let glob: LitStr = meta.value()?.parse()?;
28            self.glob = Some(glob.value());
29        } else if meta.path.is_ident("name") {
30            meta.parse_nested_meta(|nested| {
31                if nested.path.is_ident("segments") {
32                    let path_segments: LitInt = nested.value()?.parse()?;
33                    self.segments = path_segments.base10_parse()?;
34                } else if nested.path.is_ident("index") {
35                    self.index = true
36                } else if nested.path.is_ident("extension") {
37                    self.extension = true
38                } else {
39                    return Err(nested.error(
40                        "unsupported property, specify `segments = <num>`, `index` or `extension`",
41                    ));
42                }
43
44                Ok(())
45            })?;
46        } else {
47            return Err(meta.error("unsupported property, specify `glob` or `name`"));
48        }
49
50        Ok(())
51    }
52}
53
54#[proc_macro_attribute]
55pub fn test_each_file(args: TokenStream, input: TokenStream) -> TokenStream {
56    let mut attrs = Attrs::new();
57    let attr_parser = syn::meta::parser(|meta| attrs.parse(meta));
58    let input = parse_macro_input!(input as ItemFn);
59    parse_macro_input!(args with attr_parser);
60
61    match test_each(attrs, input, Kind::File) {
62        Ok(output) => output,
63        Err(err) => err.into_compile_error().into(),
64    }
65}
66
67#[proc_macro_attribute]
68pub fn test_each_blob(args: TokenStream, input: TokenStream) -> TokenStream {
69    let mut attrs = Attrs::new();
70    let attr_parser = syn::meta::parser(|meta| attrs.parse(meta));
71    let input = parse_macro_input!(input as ItemFn);
72    parse_macro_input!(args with attr_parser);
73
74    match test_each(attrs, input, Kind::Blob) {
75        Ok(output) => output,
76        Err(err) => err.into_compile_error().into(),
77    }
78}
79
80#[proc_macro_attribute]
81pub fn test_each_path(args: TokenStream, input: TokenStream) -> TokenStream {
82    let mut attrs = Attrs::new();
83    let attr_parser = syn::meta::parser(|meta| attrs.parse(meta));
84    let input = parse_macro_input!(input as ItemFn);
85    parse_macro_input!(args with attr_parser);
86
87    match test_each(attrs, input, Kind::Path) {
88        Ok(output) => output,
89        Err(err) => err.into_compile_error().into(),
90    }
91}
92
93enum Kind {
94    File,
95    Blob,
96    Path,
97}
98
99fn test_each(attrs: Attrs, input: ItemFn, kind: Kind) -> Result<TokenStream> {
100    let mut functions = vec![input.to_token_stream()];
101
102    let name = &input.sig.ident;
103    let vis = &input.vis;
104    let ret = &input.sig.output;
105    let n_args = input.sig.inputs.len();
106
107    let pattern = attrs
108        .glob
109        .as_ref()
110        .ok_or_else(|| Error::new(input.span(), "missing `glob` attribute"))?;
111
112    let files = glob::glob(pattern)
113        .map_err(|err| Error::new(input.span(), format!("invalid path glob pattern: {}", err)))?;
114
115    for (i, file) in files.enumerate() {
116        let mut file = file
117            .map_err(|err| Error::new(input.span(), format!("could not read directory: {}", err)))?
118            .canonicalize()
119            .map_err(|err| Error::new(input.span(), format!("could not read file: {}", err)))?;
120
121        match kind {
122            Kind::File | Kind::Blob if file.is_dir() => continue,
123            _ => {}
124        };
125
126        let path = file.to_string_lossy().to_string();
127
128        if !attrs.extension {
129            file.set_extension("");
130        }
131
132        let mut path_segments = file
133            .iter()
134            .rev()
135            .take(attrs.segments)
136            .map(|s| make_safe_ident(&s.to_string_lossy()))
137            .collect::<Vec<_>>();
138
139        path_segments.reverse();
140
141        let path_name = path_segments.join("_");
142
143        let test_name = if attrs.index {
144            format_ident!("{}_{}_{}", name, path_name, i)
145        } else {
146            format_ident!("{}_{}", name, path_name)
147        };
148
149        let into_path = quote!(&::std::path::PathBuf::from(#path));
150
151        let call = match kind {
152            Kind::File if n_args < 2 => quote!(#name(include_str!(#path))),
153            Kind::File => quote!(#name(include_str!(#path), #into_path)),
154            Kind::Blob if n_args < 2 => quote!(#name(include_bytes!(#path))),
155            Kind::Blob => quote!(#name(include_bytes!(#path), #into_path)),
156            Kind::Path => quote!(#name(#into_path)),
157        };
158
159        functions.push(quote! {
160
161            #[test]
162            #[allow(non_snake_case)]
163            #vis fn #test_name() #ret {
164                #call
165            }
166        });
167    }
168
169    Ok(quote!( #(#functions)* ).into())
170}
171
172fn make_safe_ident(value: &str) -> String {
173    let mut result = String::with_capacity(value.len());
174
175    for c in value.chars() {
176        if c.is_alphanumeric() {
177            result.push(c);
178        } else {
179            result.push('_');
180        }
181    }
182
183    result.trim_matches('_').to_string()
184}