1extern crate proc_macro;
2
3use eyre::{OptionExt, WrapErr};
4use itertools::Itertools;
5use once_cell::sync::Lazy;
6use proc_macro::TokenStream;
7use proc_macro2::Span;
8use quote::{quote, ToTokens};
9use std::{
10 collections::{HashMap, HashSet},
11 io::Read,
12 path::{Path, PathBuf},
13 sync::{Arc, Mutex},
14};
15use syn::{
16 parse::Parse, parse_macro_input, punctuated::Punctuated, Expr, ExprCall, ExprLit, ExprPath,
17 File, Ident, Item, ItemFn, Lit, LitStr, Meta, ReturnType, Signature, Token, Type,
18};
19use walkdir::WalkDir;
20
21static TEST_CASES: Lazy<Arc<Mutex<HashSet<TestCase>>>> =
22 Lazy::new(|| Arc::new(Mutex::new(HashSet::new())));
23
24#[derive(Debug, Clone, Hash, Eq, PartialEq)]
26struct TestCase {
27 func_name: String,
29 test_name: String,
31}
32
33impl TestCase {
34 fn from_func_name(input: &Input, org_func_name: &str) -> TestCase {
36 let test_name = generate_test_name(org_func_name, input);
37 let func_name = format!("tanu_{test_name}").replace("::", "_");
38 if syn::parse_str::<Ident>(&func_name).is_err() {
39 panic!(
40 r#"Test case generation error! The provided test parameters contain
41 invalid characters that cannot be used in a function name (function name: {func_name}).
42 Please specify a valid test name using only letters, numbers, and underscores."#
43 );
44 }
45 TestCase {
46 func_name,
47 test_name,
48 }
49 }
50}
51
52#[derive(Debug)]
53#[allow(dead_code)]
54struct TestModule {
55 module: String,
57 func_name: String,
59 test_cases: Vec<TestCase>,
61}
62
63struct Input {
65 args: Punctuated<Expr, Token![,]>,
67 name: Option<LitStr>,
69}
70
71impl Parse for Input {
72 fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
73 if input.is_empty() {
74 Ok(Input {
75 args: Default::default(),
76 name: None,
77 })
78 } else {
79 let args: Punctuated<Expr, Token![,]> =
80 Punctuated::parse_separated_nonempty_with(input, Expr::parse)?;
81
82 let name = if input.parse::<Token![;]>().is_ok() {
83 input.parse::<LitStr>().ok()
84 } else {
85 None
86 };
87
88 Ok(Input { args, name })
89 }
90 }
91}
92
93fn generate_test_name(org_func_name: &str, input: &Input) -> String {
96 let func_name = org_func_name.to_string();
97
98 if input.args.is_empty() {
99 return func_name.to_string();
100 }
101
102 let stringified_args = match &input.name {
103 Some(name_argument) => name_argument.value(),
104 _ => input
105 .args
106 .iter()
107 .filter_map(|expr| match expr {
108 Expr::Lit(ExprLit { lit, .. }) => match lit {
109 Lit::Str(lit_str) => Some(lit_str.value()),
110 other_literal => Some(quote!(#other_literal).to_string()),
111 },
112 expr @ Expr::Path(_) | expr @ Expr::Call(_) => extract_and_stringify_option(expr),
113 other_expr => Some(quote!(#other_expr).to_string()),
114 })
115 .map(|s| {
116 s.replace("+=", "_add_")
117 .replace("+", "_add_")
118 .replace("-=", "_sub_")
119 .replace("-", "_sub_")
120 .replace("/=", "_div_")
121 .replace("/", "_div_")
122 .replace("*=", "_mul_")
123 .replace("*", "_mul_")
124 .replace("%=", "_mod_")
125 .replace("%", "_mod_")
126 .replace("==", "_eq_")
127 .replace("!=", "_nq_")
128 .replace("&&", "_and_")
129 .replace("||", "_or_")
130 .replace("!", "not_")
131 .replace("&=", "_and_")
132 .replace("&", "_and_")
133 .replace("|=", "_or_")
134 .replace("|", "_or_")
135 .replace("^=", "_xor_")
136 .replace("^", "_xor_")
137 .replace("<<=", "_lshift_")
138 .replace("<<", "_lshift_")
139 .replace("<=", "_le_")
140 .replace("<", "_lt_")
141 .replace(">>=", "_rshift_")
142 .replace(">>", "_rshift_")
143 .replace(">=", "_ge_")
144 .replace(">", "_gt_")
145 .replace("&mut ", "")
146 .replace("*mut ", "")
147 .replace("&", "")
148 .replace("*", "")
149 .replace(" :: ", "_")
150 .replace("\\", "")
151 .replace("/", "")
152 .replace("\"", "")
153 .replace("(", "")
154 .replace(")", "")
155 .replace("{", "")
156 .replace("}", "")
157 .replace("[", "")
158 .replace("]", "")
159 .replace(" ", "")
160 .replace(",", "_")
161 .replace(".", "_")
162 .to_lowercase()
163 })
164 .collect::<Vec<_>>()
165 .join("_"),
166 };
167
168 format!("{func_name}::{stringified_args}")
169}
170
171#[derive(Debug, Eq, PartialEq)]
172enum ErrorCrate {
173 Eyre,
174 AnythingElse,
175}
176
177fn inspect_error_crate(sig: &Signature) -> ErrorCrate {
190 match &sig.output {
191 ReturnType::Default => panic!("return type needs to be other than ()"),
192 ReturnType::Type(_, ty) => {
193 let Type::Path(type_path) = ty.as_ref() else {
194 panic!("failed to get return type path");
195 };
196
197 let path = &type_path.path;
198 match (path.segments.first(), path.segments.last()) {
199 (Some(first), Some(last)) => {
200 if first.ident == "eyre" && last.ident == "Result" {
201 ErrorCrate::Eyre
202 } else {
203 ErrorCrate::AnythingElse
204 }
205 }
206 _ => {
207 panic!("unexpected return type");
208 }
209 }
210 }
211 }
212}
213
214#[allow(dead_code)]
215fn get_expr_variant_name(expr: &Expr) -> &'static str {
217 match expr {
218 Expr::Array(_) => "Array",
219 Expr::Assign(_) => "Assign",
220 Expr::Async(_) => "Async",
221 Expr::Await(_) => "Await",
222 Expr::Binary(_) => "Binary",
223 Expr::Block(_) => "Block",
224 Expr::Break(_) => "Break",
225 Expr::Call(_) => "Call",
226 Expr::Cast(_) => "Cast",
227 Expr::Closure(_) => "Closure",
228 Expr::Continue(_) => "Continue",
229 Expr::Field(_) => "Field",
230 Expr::ForLoop(_) => "ForLoop",
231 Expr::Group(_) => "Group",
232 Expr::If(_) => "If",
233 Expr::Index(_) => "Index",
234 Expr::Let(_) => "Let",
235 Expr::Lit(_) => "Lit",
236 Expr::Loop(_) => "Loop",
237 Expr::Macro(_) => "Macro",
238 Expr::Match(_) => "Match",
239 Expr::MethodCall(_) => "MethodCall",
240 Expr::Paren(_) => "Paren",
241 Expr::Path(_) => "Path",
242 Expr::Range(_) => "Range",
243 Expr::Reference(_) => "Reference",
244 Expr::Repeat(_) => "Repeat",
245 Expr::Return(_) => "Return",
246 Expr::Struct(_) => "Struct",
247 Expr::Try(_) => "Try",
248 Expr::TryBlock(_) => "TryBlock",
249 Expr::Tuple(_) => "Tuple",
250 Expr::Unary(_) => "Unary",
251 Expr::Unsafe(_) => "Unsafe",
252 Expr::Verbatim(_) => "Verbatim",
253 Expr::While(_) => "While",
254 Expr::Yield(_) => "Yield",
255 _ => "Unknown",
256 }
257}
258
259fn extract_and_stringify_option(expr: &Expr) -> Option<String> {
260 match expr {
261 Expr::Call(ExprCall { func, args, .. }) => {
262 if let Expr::Path(ExprPath { path, .. }) = &**func {
263 let segment = path.segments.last()?;
264 if segment.ident == "Some" {
265 match args.first()? {
266 Expr::Lit(ExprLit { lit, .. }) => match lit {
267 Lit::Str(lit_str) => {
268 return Some(lit_str.value());
269 }
270 other_type_of_literal => {
271 return Some(other_type_of_literal.to_token_stream().to_string());
272 }
273 },
274 first_arg => {
275 return Some(quote!(#first_arg).to_string());
276 }
277 }
278 }
279 }
280 }
281 Expr::Path(ExprPath { path, .. }) => {
282 if path.get_ident()? == "None" {
283 return Some("None".into());
284 }
285 }
286 _ => {}
287 }
288
289 None
290}
291
292#[proc_macro_attribute]
295pub fn test(args: TokenStream, input: TokenStream) -> TokenStream {
296 let input_args = parse_macro_input!(args as Input);
297 let input_fn = parse_macro_input!(input as ItemFn);
298
299 let func_name_inner = &input_fn.sig.ident;
300 let test_case = TestCase::from_func_name(&input_args, &func_name_inner.to_string());
301
302 match TEST_CASES.lock() {
303 Ok(mut lock) => lock.insert(test_case.clone()),
304 Err(e) => {
305 eprintln!("Failed to acquire test case lock: {}", e);
306 return quote! { #input_fn }.into();
307 }
308 };
309
310 let func_name = Ident::new(&test_case.func_name, Span::call_site());
311 let args = input_args.args.to_token_stream();
312
313 let error_crate = inspect_error_crate(&input_fn.sig);
323 let output = if error_crate == ErrorCrate::Eyre {
324 quote! {
325 #input_fn
326 pub(crate) async fn #func_name() -> tanu::eyre::Result<()> {
327 #func_name_inner(#args).await
328 }
329 }
330 } else {
331 quote! {
332 #input_fn
333 pub(crate) async fn #func_name() -> tanu::eyre::Result<()> {
334 #func_name_inner(#args).await.map_err(|e| tanu::eyre::eyre!(Box::new(e)))
335 }
336 }
337 };
338
339 output.into()
340}
341
342fn find_crate_root() -> eyre::Result<PathBuf> {
343 let dir = std::env::var("CARGO_MANIFEST_DIR")?;
344 Ok(dir.into())
345}
346
347fn discover_tests() -> eyre::Result<Vec<TestModule>> {
348 let root = find_crate_root()?;
349
350 let source_paths: Vec<_> = WalkDir::new(root)
352 .into_iter()
353 .filter_map(|entry| {
354 let path = entry.ok()?.into_path();
355 let ext = path.extension()?;
356 if ext.eq_ignore_ascii_case("rs") {
357 Some(path)
358 } else {
359 None
360 }
361 })
362 .collect();
363
364 let mut test_modules = Vec::<TestModule>::new();
365 for source_path in source_paths {
366 let mut source_file = std::fs::File::open(&source_path)
367 .wrap_err_with(|| format!("could not open file: {}", source_path.display()))?;
368 let mut code = String::new();
369 source_file.read_to_string(&mut code)?;
370
371 let file = syn::parse_file(&code)?;
372 test_modules.extend(extract_module_and_test(
373 &extract_module_path(&source_path).ok_or_eyre("malformed module path")?,
374 file,
375 ));
376 }
377
378 Ok(test_modules)
379}
380
381fn extract_module_path(path: &Path) -> Option<String> {
383 let src_index = path.iter().position(|p| p == "src")?;
384 let module_path: Vec<_> = path.iter().skip(src_index + 1).collect();
385 let module_path_str = module_path
386 .iter()
387 .filter_map(|p| p.to_str())
388 .map(|s| s.strip_suffix(".rs").unwrap_or(s)) .filter(|s| *s != "mod")
390 .collect::<Vec<_>>()
391 .join("::");
392 Some(module_path_str)
393}
394
395fn has_test_attribute(path: &syn::Path) -> bool {
397 let has_test = path.is_ident("test");
399 let has_tanu_test = match (path.segments.first(), path.segments.last()) {
401 (Some(first), Some(last)) => {
402 path.segments.len() == 2 && first.ident == "tanu" && last.ident == "test"
403 }
404 _ => false,
405 };
406
407 has_test || has_tanu_test
408}
409
410fn extract_module_and_test(module: &str, input: File) -> Vec<TestModule> {
411 let mut test_modules = Vec::new();
412 for item in input.items {
413 if let Item::Fn(item_fn) = item {
414 let mut is_test = false;
415 let mut test_cases = Vec::new();
416 for attr in item_fn.attrs {
417 if has_test_attribute(attr.path()) {
418 is_test = true;
419
420 match &attr.meta {
421 Meta::Path(_path) => {
423 let test_case = TestCase {
424 func_name: format!("tanu_{}", item_fn.sig.ident),
425 test_name: format!("tanu_{}", item_fn.sig.ident),
426 };
427 test_cases.push(test_case);
428 }
429 Meta::List(_list) => match attr.parse_args_with(Input::parse) {
431 Ok(test_case_token) => {
432 let test_case = TestCase::from_func_name(
433 &test_case_token,
434 &item_fn.sig.ident.to_string(),
435 );
436 test_cases.push(test_case);
437 }
438 Err(e) => {
439 eprintln!("failed to parse attributes in #[test]: {e:#}");
440 }
441 },
442 _ => {}
443 }
444 }
445 }
446 if is_test {
447 test_modules.push(TestModule {
448 module: module.to_owned(),
449 func_name: item_fn.sig.ident.to_string(),
450 test_cases,
451 });
452 }
453 }
454 }
455
456 test_modules
457}
458
459#[proc_macro_attribute]
460pub fn main(_args: TokenStream, input: TokenStream) -> TokenStream {
461 let main_fn = parse_macro_input!(input as ItemFn);
462
463 let test_modules = discover_tests().expect("failed to discover test cases");
464 let test_modules: HashMap<String, Vec<String>> = test_modules
468 .iter()
469 .flat_map(|module| {
470 let module_name = module.module.clone();
471 module.test_cases.iter().map(move |test_case| {
472 (
473 test_case.func_name.clone(),
474 if module_name == "main" {
475 "crate".into()
476 } else {
477 module_name.clone()
478 },
479 )
480 })
481 })
482 .fold(HashMap::new(), |mut acc, (func_name, module_name)| {
483 acc.entry(func_name).or_default().push(module_name);
484 acc
485 });
486
487 let (test_mods, test_names, func_names): (Vec<_>, Vec<_>, Vec<_>) = match TEST_CASES.lock() {
488 Ok(lock) => lock
489 .iter()
490 .flat_map(|f| {
491 test_modules
492 .get(&f.func_name)
493 .into_iter() .flatten() .filter_map(|module_name| {
496 let test_module_path = match syn::parse_str::<syn::Path>(module_name) {
497 Ok(path) => path,
498 Err(e) => {
499 eprintln!("failed to parse module path '{module_name}': {e}");
500 return None;
501 }
502 };
503
504 Some((
505 test_module_path,
506 f.test_name.clone(),
507 Ident::new(&f.func_name, Span::call_site()),
508 ))
509 })
510 })
511 .multiunzip(),
512 Err(e) => {
513 eprintln!("failed to acquire test case lock: {}", e);
514 (Vec::new(), Vec::new(), Vec::new())
515 }
516 };
517
518 let output = quote! {
519 fn run() -> tanu::Runner {
520 let mut runner = tanu::Runner::new();
521 #(
522 runner.add_test(
523 #test_names,
524 &stringify!(#test_mods).replace(" ", ""),
525 std::sync::Arc::new(|| Box::pin(#test_mods::#func_names()))
526 );
527 )*
528 runner
529 }
530
531 #main_fn
532 };
533
534 output.into()
535}
536
537#[cfg(test)]
538mod test {
539 use crate::Input;
540
541 use super::{ErrorCrate, Expr, Path, TestCase};
542 use test_case::test_case;
543
544 #[test_case("test" => true; "test")]
545 #[test_case("tanu::test" => true; "tanu_test")]
546 #[test_case("tanu::foo::test" => false; "not_tanu_test")]
547 #[test_case("foo::test" => false; "also_not_tanu_test")]
548 fn has_test_attribute(s: &str) -> bool {
549 let path: syn::Path = syn::parse_str(s).expect("Failed to parse path");
550 super::has_test_attribute(&path)
551 }
552
553 #[test_case("/home/yukinari/tanu/src/main.rs", "main"; "main")]
554 #[test_case("/home/yukinari/tanu/src/foo.rs", "foo"; "foo")]
555 #[test_case("/home/yukinari/tanu/src/foo/bar.rs", "foo::bar"; "foo::bar")]
556 #[test_case("/home/yukinari/tanu/src/foo/bar/baz.rs", "foo::bar::baz"; "foo::bar::baz")]
557 #[test_case("/home/yukinari/tanu/src/foo/bar/mod.rs", "foo::bar"; "foo::bar::mod")]
558 fn test_extract_module_path(path: &str, module_path: &str) {
559 let path = Path::new(path);
560 let extracted_module = super::extract_module_path(path);
561 assert_eq!(extracted_module, Some(module_path.to_string()));
562 }
563
564 #[test_case("fn foo() -> eyre::Result" => ErrorCrate::Eyre; "eyre")]
565 #[test_case("fn foo() -> anyhow::Result" => ErrorCrate::AnythingElse; "anyhow")]
566 #[test_case("fn foo() -> miette::Result" => ErrorCrate::AnythingElse; "miette")]
567 #[test_case("fn foo() -> Result" => ErrorCrate::AnythingElse; "std_result")]
568 fn inspect_error_crate(s: &str) -> ErrorCrate {
569 let sig: syn::Signature = syn::parse_str(s).expect("failed to parse function signature");
570 super::inspect_error_crate(&sig)
571 }
572
573 #[test_case("Some(1)" => Some("1".into()); "Some with int")]
574 #[test_case("Some(\"test\")" => Some("test".into()); "Some with string")]
575 #[test_case("Some(true)" => Some("true".into()); "Some with boolean")]
576 #[test_case("Some(1.0)" => Some("1.0".into()); "Some with float")]
577 #[test_case("Some(StatusCode::OK)" => Some("StatusCode :: OK".into()); "Some third party type")]
578 #[test_case("Some(\"foo\".to_string())" => Some("\"foo\" . to_string ()".into()); "Some expression")]
579 #[test_case("None" => Some("None".into()); "None")]
580 fn extract_and_stringify_option(s: &str) -> Option<String> {
581 let expr: Expr = syn::parse_str(s).expect("failed to parse expression");
582 super::extract_and_stringify_option(&expr)
583 }
584
585 #[allow(clippy::erasing_op)]
586 #[test_case("a, b; \"test_name\"" => "foo::test_name"; "with test name")]
587 #[test_case("1+1" => "foo::1_add_1"; "with add expression")]
588 #[test_case("1+=1" => "foo::1_add_1"; "with add assignment expression")]
589 #[test_case("1-1" => "foo::1_sub_1"; "with sub expression")]
590 #[test_case("1-=1" => "foo::1_sub_1"; "with sub assignment expression")]
591 #[test_case("1/1" => "foo::1_div_1"; "with div expression")]
592 #[test_case("1/=1" => "foo::1_div_1"; "with div assignment expression")]
593 #[test_case("1*1" => "foo::1_mul_1"; "with mul expression")]
594 #[test_case("1*=1" => "foo::1_mul_1"; "with mul assignment expression")]
595 #[test_case("1%1" => "foo::1_mod_1"; "with mod expression")]
596 #[test_case("1%=1" => "foo::1_mod_1"; "with mod assignment expression")]
597 #[test_case("1==1" => "foo::1_eq_1"; "with eq expression")]
598 #[test_case("1!=1" => "foo::1_nq_1"; "with neq expression")]
599 #[test_case("1<1" => "foo::1_lt_1"; "with lt expression")]
600 #[test_case("1>1" => "foo::1_gt_1"; "with gt expression")]
601 #[test_case("1<=1" => "foo::1_le_1"; "with le expression")]
602 #[test_case("1>=1" => "foo::1_ge_1"; "with ge expression")]
603 #[test_case("true&&false" => "foo::true_and_false"; "with and expression")]
604 #[test_case("true||false" => "foo::true_or_false"; "with or expression")]
605 #[test_case("!true" => "foo::not_true"; "with not expression")]
606 #[test_case("1&1" => "foo::1_and_1"; "with bitwise and expression")]
607 #[test_case("1&=1" => "foo::1_and_1"; "with bitwise and assignment expression")]
608 #[test_case("1|1" => "foo::1_or_1"; "with bitwise or expression")]
609 #[test_case("1|=1" => "foo::1_or_1"; "with bitwise or assignment expression")]
610 #[test_case("1^1" => "foo::1_xor_1"; "with xor expression")]
611 #[test_case("1^=1" => "foo::1_xor_1"; "with xor assignment expression")]
612 #[test_case("1<<1" => "foo::1_lshift_1"; "with left shift expression")]
613 #[test_case("1<<=1" => "foo::1_lshift_1"; "with left shift assignment expression")]
614 #[test_case("1>>1" => "foo::1_rshift_1"; "with right shift expression")]
615 #[test_case("1>>=1" => "foo::1_rshift_1"; "with right shift assignment expression")]
616 #[test_case("\"bar\".to_string()" => "foo::bar_to_string"; "to_string")]
617 #[test_case("1+1*2" => "foo::1_add_1_mul_2"; "with add and mul expression")]
618 #[test_case("1*(2+3)" => "foo::1_mul_2_add_3"; "with mul and add expression")]
619 #[test_case("1+2-3" => "foo::1_add_2_sub_3"; "with add and sub expression")]
620 #[test_case("1/2*3" => "foo::1_div_2_mul_3"; "with div and mul expression")]
621 #[test_case("1%2+3" => "foo::1_mod_2_add_3"; "with mod and add expression")]
622 #[test_case("1==2&&3!=4" => "foo::1_eq_2_and_3_nq_4"; "with eq and and expression")]
623 #[test_case("true||false&&true" => "foo::true_or_false_and_true"; "with or and and expression")]
624 #[test_case("!(1+2)" => "foo::not_1_add_2"; "with not and add expression")]
625 #[test_case("1&2|3^4" => "foo::1_and_2_or_3_xor_4"; "with bitwise and, or, xor expression")]
626 #[test_case("1<<2>>3" => "foo::1_lshift_2_rshift_3"; "with left shift and right shift expression")]
627 #[test_case("Some(1+2)" => "foo::1_add_2"; "with Some and add expression")]
628 #[test_case("None" => "foo::none"; "with None")]
629 #[test_case("[1, 2]" => "foo::1_2"; "with array")]
630 #[test_case("vec![1, 2]" => "foo::vecnot_1_2"; "with macro")] #[test_case("\"foo\".to_string().len()" => "foo::foo_to_string_len"; "with function call chain")]
632 #[test_case("0.5+0.3" => "foo::0_5_add_0_3"; "with floating point add")] #[test_case("-10" => "foo::_sub_10"; "with negative number")] #[test_case("1.0e10" => "foo::1_0e10"; "with scientific notation")] #[test_case("0xff" => "foo::0xff"; "with hex literal")]
636 #[test_case("0o777" => "foo::0o777"; "with octal literal")]
637 #[test_case("0b1010" => "foo::0b1010"; "with binary literal")]
638 #[test_case("\"hello\" + \"world\"" => "foo::hello_add_world"; "with string concatenation")]
639 #[test_case("format!(\"{}{}\", 1, 2)" => "foo::formatnot__1_2"; "with format macro")] #[test_case("r#\"raw string\"#" => "foo::rawstring"; "with raw string")]
641 #[test_case("(1, \"hello\", true)" => "foo::1_hello_true"; "with mixed tuple")]
643 #[test_case("vec![1..5]" => "foo::vecnot_1__5"; "with range in macro")]
647 #[test_case("x.map(|v| v+1)" => "foo::x_map_or_v_or_v_add_1"; "with closure")]
649 #[test_case("a.into()" => "foo::a_into"; "with method call no args")]
650 #[test_case("a.parse::<i32>().unwrap()" => "foo::a_parse__lt_i32_gt__unwrap"; "with turbofish syntax")]
652 #[test_case("1..10" => "foo::1__10"; "with range")]
659 fn generate_test_name(args: &str) -> String {
663 let input_args: Input = syn::parse_str(args).expect("failed to parse input args");
664 let test_case = TestCase::from_func_name(&input_args, "foo");
665 test_case.test_name
666 }
667}