test_data_file/
lib.rs

1use std::path::Path;
2
3use proc_macro::TokenStream;
4use proc_macro2::{Ident, Span};
5use quote::{quote, ToTokens};
6use syn::parse::Result;
7use syn::{meta::ParseNestedMeta, parse_macro_input, FnArg, ItemFn, LitStr, Pat};
8
9const SUPPORTED_KINDS: [&str; 6] = ["csv", "json", "yaml", "ron", "toml", "list"];
10
11#[allow(clippy::test_attr_in_doctest)]
12/// Provide sample data from a file to your test function
13///
14/// # Arguments
15///
16/// * path - path to the sample
17/// * kind - optional file format (if extension is not specified)
18///
19/// # Example
20///
21/// ```
22/// use test_data_file::test_data_file;
23/// #[test_data_file(path = "tests/samples/test_me.yaml")]
24/// #[test]
25/// fn test_is_name_above_max_size(name: Option<String>, max_size: usize, is_above: bool) {
26///     assert_eq!(
27///         name.map(|n| n.len()) > Some(max_size),
28///         is_above,
29///         "failed for {max_size}"
30///     );
31/// }
32/// ```
33///
34#[proc_macro_attribute]
35pub fn test_data_file(args: TokenStream, item: TokenStream) -> TokenStream {
36    let mut func = parse_macro_input!(item as ItemFn);
37    let mut attrs = TestFileDataAttributes::default();
38
39    let test_file_dat_parser = syn::meta::parser(|meta| attrs.parse(meta));
40    parse_macro_input!(args with test_file_dat_parser);
41
42    let path = attrs
43        .path
44        .unwrap_or_else(|| panic!("'path' attribute is required"));
45    let kind = attrs
46        .kind
47        .unwrap_or_else(|| panic!("'kind' attribute is required"));
48
49    let generated = impl_test_data_file(&func, path, kind);
50
51    let mut input = proc_macro2::TokenStream::from(generated);
52    func.attrs.retain(|attr| {
53        !(attr.path().is_ident("test")
54            || attr.path().is_ident("should_panic")
55            || attr
56                .path()
57                .segments
58                .first()
59                .map(|s| s.ident == "tokio")
60                .unwrap_or(false))
61    });
62    func.sig.ident = Ident::new(&format!("_{}", &func.sig.ident), func.sig.ident.span());
63    func.to_tokens(&mut input);
64    input.into()
65}
66
67#[derive(Default)]
68struct TestFileDataAttributes {
69    kind: Option<LitStr>,
70    path: Option<LitStr>,
71}
72
73impl TestFileDataAttributes {
74    fn parse(&mut self, meta: ParseNestedMeta) -> Result<()> {
75        if meta.path.is_ident("kind") {
76            let kind: LitStr = meta.value()?.parse()?;
77            if !SUPPORTED_KINDS.contains(&kind.value().as_str()) {
78                return Err(meta.error("unsupported kind"));
79            }
80            self.kind = kind.into();
81        } else if meta.path.is_ident("path") {
82            let path: LitStr = meta.value()?.parse()?;
83            let path_str = path.value();
84            let file_path = Path::new(&path_str);
85            if !file_path.exists() {
86                return Err(meta.error("file does not exist"));
87            }
88            if !file_path.is_file() {
89                return Err(meta.error("path must be a file"));
90            }
91            if let (true, Some(ext)) = (
92                self.kind.is_none(),
93                file_path.extension().and_then(|s| s.to_str()),
94            ) {
95                if SUPPORTED_KINDS.contains(&ext) {
96                    self.kind = LitStr::new(ext, path.span()).into();
97                }
98            }
99            self.path = path.into();
100        } else {
101            return Err(meta.error("unsupported property"));
102        }
103        Ok(())
104    }
105}
106
107fn impl_test_data_file(item: &ItemFn, path: LitStr, kind: LitStr) -> TokenStream {
108    let name = &item.sig.ident;
109    let call_ident = Ident::new(&format!("_{}", &item.sig.ident), Span::call_site());
110
111    let (field_names, field_types): (Vec<_>, Vec<_>) = item
112        .sig
113        .inputs
114        .iter()
115        .filter_map(|field| match field {
116            FnArg::Receiver(_) => None,
117            FnArg::Typed(pat_type) => {
118                if let Pat::Ident(pat_ident) = &*pat_type.pat {
119                    Some((&pat_ident.ident, &pat_type.ty))
120                } else {
121                    None
122                }
123            }
124        })
125        .unzip();
126
127    let kind_str = kind.value();
128    let func_attrs: Vec<_> = item.attrs.iter().collect();
129    let func_async = item.sig.asyncness;
130    let func_await = if func_async.is_some() {
131        Some(quote! {.await})
132    } else {
133        None
134    };
135
136    let body = if kind_str == "csv" {
137        quote! {
138            #[derive(Debug, serde::Deserialize)]
139            struct _Data {
140                #(#field_names: #field_types,)*
141            }
142            let file_path = #path;
143
144            let mut rdr = csv::ReaderBuilder::new()
145                .from_path(file_path)
146                .unwrap();
147            let mut executed = false;
148            for result in rdr.deserialize() {
149                let record: _Data = result.unwrap();
150                executed = true;
151                let _Data { #(#field_names,)* } = record;
152                #call_ident(#(#field_names,)*)#func_await;
153            }
154            if !executed {
155                panic!("Empty test data provided in {file_path}");
156            }
157        }
158    } else if kind_str == "list" {
159        quote! {
160            use std::io::BufRead;
161            let file_path = #path;
162            let f = std::fs::File::open(file_path).unwrap();
163            let lines = std::io::BufReader::new(f).lines();
164            let mut executed = false;
165
166            for (n, line) in lines.enumerate() {
167                if n == 0 {
168                    continue;
169                }
170                executed = true;
171                let line = line.unwrap();
172                let mut iter = line.split(' ').filter(|f| !f.is_empty());
173                let mut column = 0;
174                #(
175                    let field = iter.next().unwrap();
176                    let #field_names = field.parse().map_err(|e| format!("Invalid value in row={n} column={column} {file_path} {e}")).unwrap();
177                    column += 1;
178                )*
179                #call_ident(#(#field_names,)*)#func_await;
180            }
181            if !executed {
182                panic!("Empty test data provided in {file_path}");
183            }
184        }
185    } else {
186        let kind = Ident::new(&kind_str, kind.span());
187        let serde_read = match kind_str.as_str() {
188            "yaml" | "json" => {
189                let kind = Ident::new(&format!("serde_{kind_str}"), kind.span());
190                quote! {
191                    #kind::from_reader(std::fs::File::open(file_path).unwrap()).map_err(|e| format!("Failed to load data in {file_path} {e}")).unwrap()
192                }
193            }
194            "toml" => quote! {
195                #kind::from_str(&std::fs::read_to_string(file_path).unwrap()).map_err(|e| format!("Failed to load data in {file_path} {e}")).unwrap()
196            },
197            _ => quote! {
198                #kind::de::from_reader(std::fs::File::open(file_path).unwrap()).map_err(|e| format!("Failed to load data in {file_path} {e}")).unwrap()
199            },
200        };
201
202        quote! {
203            #[derive(Debug, serde::Deserialize)]
204            struct _Data {
205                #(#field_names: #field_types,)*
206            }
207
208            #[derive(Debug, serde::Deserialize)]
209            #[serde(untagged)]
210            enum Collection {
211                Index(Vec<_Data>),
212                Map(std::collections::HashMap<String, _Data>)
213            }
214
215            let file_path = #path;
216
217            let values: Collection = #serde_read;
218            let values = match values {
219                Collection::Index(v) => v,
220                Collection::Map(m) => m.into_iter().map(|(_, v)| v).collect(),
221            };
222
223            if values.is_empty() {
224                panic!("Empty test data provided in {file_path}");
225            }
226
227            for test_data in values {
228                let _Data { #(#field_names,)* } = test_data;
229                #call_ident(#(#field_names,)*)#func_await;
230            }
231        }
232    };
233
234    let gen = quote! {
235        #(#func_attrs)*
236        #func_async fn #name() {
237            #body
238        }
239    };
240    gen.into()
241}