1use std::path::Path;
2
3use proc_macro::TokenStream;
4use proc_macro2::{Ident, Span};
5use quote::{quote, ToTokens};
6use syn::parse::Result;
7use syn::{meta::ParseNestedMeta, parse_macro_input, FnArg, ItemFn, LitStr, Pat};
8
9const SUPPORTED_KINDS: [&str; 6] = ["csv", "json", "yaml", "ron", "toml", "list"];
10
11#[allow(clippy::test_attr_in_doctest)]
12#[proc_macro_attribute]
35pub fn test_data_file(args: TokenStream, item: TokenStream) -> TokenStream {
36 let mut func = parse_macro_input!(item as ItemFn);
37 let mut attrs = TestFileDataAttributes::default();
38
39 let test_file_dat_parser = syn::meta::parser(|meta| attrs.parse(meta));
40 parse_macro_input!(args with test_file_dat_parser);
41
42 let path = attrs
43 .path
44 .unwrap_or_else(|| panic!("'path' attribute is required"));
45 let kind = attrs
46 .kind
47 .unwrap_or_else(|| panic!("'kind' attribute is required"));
48
49 let generated = impl_test_data_file(&func, path, kind);
50
51 let mut input = proc_macro2::TokenStream::from(generated);
52 func.attrs.retain(|attr| {
53 !(attr.path().is_ident("test")
54 || attr.path().is_ident("should_panic")
55 || attr
56 .path()
57 .segments
58 .first()
59 .map(|s| s.ident == "tokio")
60 .unwrap_or(false))
61 });
62 func.sig.ident = Ident::new(&format!("_{}", &func.sig.ident), func.sig.ident.span());
63 func.to_tokens(&mut input);
64 input.into()
65}
66
67#[derive(Default)]
68struct TestFileDataAttributes {
69 kind: Option<LitStr>,
70 path: Option<LitStr>,
71}
72
73impl TestFileDataAttributes {
74 fn parse(&mut self, meta: ParseNestedMeta) -> Result<()> {
75 if meta.path.is_ident("kind") {
76 let kind: LitStr = meta.value()?.parse()?;
77 if !SUPPORTED_KINDS.contains(&kind.value().as_str()) {
78 return Err(meta.error("unsupported kind"));
79 }
80 self.kind = kind.into();
81 } else if meta.path.is_ident("path") {
82 let path: LitStr = meta.value()?.parse()?;
83 let path_str = path.value();
84 let file_path = Path::new(&path_str);
85 if !file_path.exists() {
86 return Err(meta.error("file does not exist"));
87 }
88 if !file_path.is_file() {
89 return Err(meta.error("path must be a file"));
90 }
91 if let (true, Some(ext)) = (
92 self.kind.is_none(),
93 file_path.extension().and_then(|s| s.to_str()),
94 ) {
95 if SUPPORTED_KINDS.contains(&ext) {
96 self.kind = LitStr::new(ext, path.span()).into();
97 }
98 }
99 self.path = path.into();
100 } else {
101 return Err(meta.error("unsupported property"));
102 }
103 Ok(())
104 }
105}
106
107fn impl_test_data_file(item: &ItemFn, path: LitStr, kind: LitStr) -> TokenStream {
108 let name = &item.sig.ident;
109 let call_ident = Ident::new(&format!("_{}", &item.sig.ident), Span::call_site());
110
111 let (field_names, field_types): (Vec<_>, Vec<_>) = item
112 .sig
113 .inputs
114 .iter()
115 .filter_map(|field| match field {
116 FnArg::Receiver(_) => None,
117 FnArg::Typed(pat_type) => {
118 if let Pat::Ident(pat_ident) = &*pat_type.pat {
119 Some((&pat_ident.ident, &pat_type.ty))
120 } else {
121 None
122 }
123 }
124 })
125 .unzip();
126
127 let kind_str = kind.value();
128 let func_attrs: Vec<_> = item.attrs.iter().collect();
129 let func_async = item.sig.asyncness;
130 let func_await = if func_async.is_some() {
131 Some(quote! {.await})
132 } else {
133 None
134 };
135
136 let body = if kind_str == "csv" {
137 quote! {
138 #[derive(Debug, serde::Deserialize)]
139 struct _Data {
140 #(#field_names: #field_types,)*
141 }
142 let file_path = #path;
143
144 let mut rdr = csv::ReaderBuilder::new()
145 .from_path(file_path)
146 .unwrap();
147 let mut executed = false;
148 for result in rdr.deserialize() {
149 let record: _Data = result.unwrap();
150 executed = true;
151 let _Data { #(#field_names,)* } = record;
152 #call_ident(#(#field_names,)*)#func_await;
153 }
154 if !executed {
155 panic!("Empty test data provided in {file_path}");
156 }
157 }
158 } else if kind_str == "list" {
159 quote! {
160 use std::io::BufRead;
161 let file_path = #path;
162 let f = std::fs::File::open(file_path).unwrap();
163 let lines = std::io::BufReader::new(f).lines();
164 let mut executed = false;
165
166 for (n, line) in lines.enumerate() {
167 if n == 0 {
168 continue;
169 }
170 executed = true;
171 let line = line.unwrap();
172 let mut iter = line.split(' ').filter(|f| !f.is_empty());
173 let mut column = 0;
174 #(
175 let field = iter.next().unwrap();
176 let #field_names = field.parse().map_err(|e| format!("Invalid value in row={n} column={column} {file_path} {e}")).unwrap();
177 column += 1;
178 )*
179 #call_ident(#(#field_names,)*)#func_await;
180 }
181 if !executed {
182 panic!("Empty test data provided in {file_path}");
183 }
184 }
185 } else {
186 let kind = Ident::new(&kind_str, kind.span());
187 let serde_read = match kind_str.as_str() {
188 "yaml" | "json" => {
189 let kind = Ident::new(&format!("serde_{kind_str}"), kind.span());
190 quote! {
191 #kind::from_reader(std::fs::File::open(file_path).unwrap()).map_err(|e| format!("Failed to load data in {file_path} {e}")).unwrap()
192 }
193 }
194 "toml" => quote! {
195 #kind::from_str(&std::fs::read_to_string(file_path).unwrap()).map_err(|e| format!("Failed to load data in {file_path} {e}")).unwrap()
196 },
197 _ => quote! {
198 #kind::de::from_reader(std::fs::File::open(file_path).unwrap()).map_err(|e| format!("Failed to load data in {file_path} {e}")).unwrap()
199 },
200 };
201
202 quote! {
203 #[derive(Debug, serde::Deserialize)]
204 struct _Data {
205 #(#field_names: #field_types,)*
206 }
207
208 #[derive(Debug, serde::Deserialize)]
209 #[serde(untagged)]
210 enum Collection {
211 Index(Vec<_Data>),
212 Map(std::collections::HashMap<String, _Data>)
213 }
214
215 let file_path = #path;
216
217 let values: Collection = #serde_read;
218 let values = match values {
219 Collection::Index(v) => v,
220 Collection::Map(m) => m.into_iter().map(|(_, v)| v).collect(),
221 };
222
223 if values.is_empty() {
224 panic!("Empty test data provided in {file_path}");
225 }
226
227 for test_data in values {
228 let _Data { #(#field_names,)* } = test_data;
229 #call_ident(#(#field_names,)*)#func_await;
230 }
231 }
232 };
233
234 let gen = quote! {
235 #(#func_attrs)*
236 #func_async fn #name() {
237 #body
238 }
239 };
240 gen.into()
241}