1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
use proc_macro::TokenStream;
use quote::{format_ident, quote, ToTokens};
use syn::{
    parse::Parser, parse_macro_input, punctuated::Punctuated, Error, ItemFn, LitStr, Result, Token,
};

#[proc_macro_attribute]
pub fn test_each_file(attrs: TokenStream, input: TokenStream) -> TokenStream {
    let input = parse_macro_input!(input as ItemFn);
    match test_each(attrs, input, Kind::File) {
        Ok(output) => output,
        Err(err) => err.into_compile_error().into(),
    }
}

#[proc_macro_attribute]
pub fn test_each_blob(attrs: TokenStream, input: TokenStream) -> TokenStream {
    let input = parse_macro_input!(input as ItemFn);
    match test_each(attrs, input, Kind::Blob) {
        Ok(output) => output,
        Err(err) => err.into_compile_error().into(),
    }
}

#[proc_macro_attribute]
pub fn test_each_path(attrs: TokenStream, input: TokenStream) -> TokenStream {
    let input = parse_macro_input!(input as ItemFn);
    match test_each(attrs, input, Kind::Path) {
        Ok(output) => output,
        Err(err) => err.into_compile_error().into(),
    }
}

enum Kind {
    File,
    Blob,
    Path,
}

fn test_each(attrs: TokenStream, input: ItemFn, kind: Kind) -> Result<TokenStream> {
    let lits = Punctuated::<LitStr, Token![,]>::parse_terminated.parse(attrs)?;
    let mut functions = vec![input.clone().to_token_stream()];

    let name = input.sig.ident;
    let vis = input.vis;
    let ret = input.sig.output;
    let n_args = input.sig.inputs.len();

    if lits.len() != 1 {
        return Err(Error::new(
            name.span(),
            "expected a single path glob literal",
        ));
    }

    let pattern = &lits[0].value();

    let files = glob::glob(pattern).map_err(|err| {
        Error::new(
            lits[0].span(),
            format!("invalid path glob pattern: {}", err),
        )
    })?;

    for (i, file) in files.enumerate() {
        let file = file.map_err(|err| {
            Error::new(lits[0].span(), format!("could not read directory: {}", err))
        })?;

        match kind {
            Kind::File | Kind::Blob if file.is_dir() => continue,
            _ => {}
        };

        let file_name = file
            .file_name()
            .map(|name| format!("{}_", make_safe_ident(&name.to_string_lossy())))
            .unwrap_or_default();

        let test_name = format_ident!("{}_{}{}", name, file_name, i);

        let path = file
            .canonicalize()
            .map_err(|err| Error::new(lits[0].span(), format!("could not read file: {}", err)))?
            .to_string_lossy()
            .to_string();

        let into_path = quote!(::std::path::PathBuf::from(#path));

        let call = match kind {
            Kind::File if n_args < 2 => quote!(#name(include_str!(#path))),
            Kind::File => quote!(#name(include_str!(#path), #into_path)),
            Kind::Blob if n_args < 2 => quote!(#name(include_bytes!(#path))),
            Kind::Blob => quote!(#name(include_bytes!(#path), #into_path)),
            Kind::Path => quote!(#name(#into_path)),
        };

        functions.push(quote! {
            #[test]
            #vis fn #test_name() #ret {
                #call
            }
        });
    }

    Ok(quote!( #(#functions)* ).into())
}

fn make_safe_ident(value: &str) -> String {
    let mut result = String::with_capacity(value.len());

    for c in value.chars() {
        if c.is_alphanumeric() {
            result.push(c);
        } else {
            result.push('_');
        }
    }

    result
}