1extern crate proc_macro;
18
19use eyre::WrapErr;
20use itertools::Itertools;
21use once_cell::sync::Lazy;
22use proc_macro::TokenStream;
23use proc_macro2::Span;
24use quote::{quote, ToTokens};
25use std::{
26 collections::{HashMap, HashSet},
27 io::Read,
28 path::{Path, PathBuf},
29 sync::{Arc, Mutex},
30};
31use syn::{
32 parse::Parse, parse_macro_input, punctuated::Punctuated, Expr, ExprCall, ExprLit, ExprPath,
33 File, Ident, Item, ItemFn, Lit, LitStr, Meta, ReturnType, Signature, Token, Type,
34};
35use walkdir::WalkDir;
36
37static TEST_CASES: Lazy<Arc<Mutex<HashSet<TestCase>>>> =
38 Lazy::new(|| Arc::new(Mutex::new(HashSet::new())));
39
40#[derive(Debug, Clone, Hash, Eq, PartialEq)]
42struct TestCase {
43 func_name: String,
45 test_name: String,
47}
48
49impl TestCase {
50 fn from_func_name(input: &Input, org_func_name: &str) -> TestCase {
52 let test_name = generate_test_name(org_func_name, input);
53 let func_name = format!("tanu_{test_name}").replace("::", "_");
54 if syn::parse_str::<Ident>(&func_name).is_err() {
55 panic!(
56 r#"Test case generation error! The provided test parameters contain
57 invalid characters that cannot be used in a function name (function name: {func_name}).
58 Please specify a valid test name using only letters, numbers, and underscores."#
59 );
60 }
61 TestCase {
62 func_name,
63 test_name,
64 }
65 }
66}
67
68#[derive(Debug)]
69#[allow(dead_code)]
70struct TestModule {
71 module: String,
73 func_name: String,
75 test_cases: Vec<TestCase>,
77}
78
79struct Input {
81 args: Punctuated<Expr, Token![,]>,
83 name: Option<LitStr>,
85}
86
87impl Parse for Input {
88 fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
89 if input.is_empty() {
90 Ok(Input {
91 args: Default::default(),
92 name: None,
93 })
94 } else {
95 let args: Punctuated<Expr, Token![,]> =
96 Punctuated::parse_separated_nonempty_with(input, Expr::parse)?;
97
98 let name = if input.parse::<Token![;]>().is_ok() {
99 input.parse::<LitStr>().ok()
100 } else {
101 None
102 };
103
104 Ok(Input { args, name })
105 }
106 }
107}
108
109fn generate_test_name(org_func_name: &str, input: &Input) -> String {
112 let func_name = org_func_name.to_string();
113
114 if input.args.is_empty() {
115 return func_name.to_string();
116 }
117
118 let stringified_args = match &input.name {
119 Some(name_argument) => name_argument.value(),
120 _ => input
121 .args
122 .iter()
123 .filter_map(|expr| match expr {
124 Expr::Lit(ExprLit { lit, .. }) => match lit {
125 Lit::Str(lit_str) => Some(lit_str.value()),
126 other_literal => Some(quote!(#other_literal).to_string()),
127 },
128 expr @ Expr::Path(_) | expr @ Expr::Call(_) => extract_and_stringify_option(expr),
129 other_expr => Some(quote!(#other_expr).to_string()),
130 })
131 .map(|s| {
132 s.replace("+=", "_add_")
133 .replace("+", "_add_")
134 .replace("-=", "_sub_")
135 .replace("-", "_sub_")
136 .replace("/=", "_div_")
137 .replace("/", "_div_")
138 .replace("*=", "_mul_")
139 .replace("*", "_mul_")
140 .replace("%=", "_mod_")
141 .replace("%", "_mod_")
142 .replace("==", "_eq_")
143 .replace("!=", "_nq_")
144 .replace("&&", "_and_")
145 .replace("||", "_or_")
146 .replace("!", "not_")
147 .replace("&=", "_and_")
148 .replace("&", "_and_")
149 .replace("|=", "_or_")
150 .replace("|", "_or_")
151 .replace("^=", "_xor_")
152 .replace("^", "_xor_")
153 .replace("<<=", "_lshift_")
154 .replace("<<", "_lshift_")
155 .replace("<=", "_le_")
156 .replace("<", "_lt_")
157 .replace(">>=", "_rshift_")
158 .replace(">>", "_rshift_")
159 .replace(">=", "_ge_")
160 .replace(">", "_gt_")
161 .replace("&mut ", "")
162 .replace("*mut ", "")
163 .replace("&", "")
164 .replace("*", "")
165 .replace(" :: ", "_")
166 .replace("\\", "")
167 .replace("/", "")
168 .replace("\"", "")
169 .replace("(", "")
170 .replace(")", "")
171 .replace("{", "")
172 .replace("}", "")
173 .replace("[", "")
174 .replace("]", "")
175 .replace(" ", "")
176 .replace(",", "_")
177 .replace(".", "_")
178 .to_lowercase()
179 })
180 .collect::<Vec<_>>()
181 .join("_"),
182 };
183
184 format!("{func_name}::{stringified_args}")
185}
186
187#[derive(Debug, Eq, PartialEq)]
188enum ErrorCrate {
189 Eyre,
190 AnythingElse,
191}
192
193fn inspect_error_crate(sig: &Signature) -> ErrorCrate {
206 match &sig.output {
207 ReturnType::Default => panic!("return type needs to be other than ()"),
208 ReturnType::Type(_, ty) => {
209 let Type::Path(type_path) = ty.as_ref() else {
210 panic!("failed to get return type path");
211 };
212
213 let path = &type_path.path;
214 match (path.segments.first(), path.segments.last()) {
215 (Some(first), Some(last)) => {
216 if first.ident == "eyre" && last.ident == "Result" {
217 ErrorCrate::Eyre
218 } else {
219 ErrorCrate::AnythingElse
220 }
221 }
222 _ => {
223 panic!("unexpected return type");
224 }
225 }
226 }
227 }
228}
229
230#[allow(dead_code)]
231fn get_expr_variant_name(expr: &Expr) -> &'static str {
233 match expr {
234 Expr::Array(_) => "Array",
235 Expr::Assign(_) => "Assign",
236 Expr::Async(_) => "Async",
237 Expr::Await(_) => "Await",
238 Expr::Binary(_) => "Binary",
239 Expr::Block(_) => "Block",
240 Expr::Break(_) => "Break",
241 Expr::Call(_) => "Call",
242 Expr::Cast(_) => "Cast",
243 Expr::Closure(_) => "Closure",
244 Expr::Continue(_) => "Continue",
245 Expr::Field(_) => "Field",
246 Expr::ForLoop(_) => "ForLoop",
247 Expr::Group(_) => "Group",
248 Expr::If(_) => "If",
249 Expr::Index(_) => "Index",
250 Expr::Let(_) => "Let",
251 Expr::Lit(_) => "Lit",
252 Expr::Loop(_) => "Loop",
253 Expr::Macro(_) => "Macro",
254 Expr::Match(_) => "Match",
255 Expr::MethodCall(_) => "MethodCall",
256 Expr::Paren(_) => "Paren",
257 Expr::Path(_) => "Path",
258 Expr::Range(_) => "Range",
259 Expr::Reference(_) => "Reference",
260 Expr::Repeat(_) => "Repeat",
261 Expr::Return(_) => "Return",
262 Expr::Struct(_) => "Struct",
263 Expr::Try(_) => "Try",
264 Expr::TryBlock(_) => "TryBlock",
265 Expr::Tuple(_) => "Tuple",
266 Expr::Unary(_) => "Unary",
267 Expr::Unsafe(_) => "Unsafe",
268 Expr::Verbatim(_) => "Verbatim",
269 Expr::While(_) => "While",
270 Expr::Yield(_) => "Yield",
271 _ => "Unknown",
272 }
273}
274
275fn extract_and_stringify_option(expr: &Expr) -> Option<String> {
276 match expr {
277 Expr::Call(ExprCall { func, args, .. }) => {
278 if let Expr::Path(ExprPath { path, .. }) = &**func {
279 let segment = path.segments.last()?;
280 if segment.ident == "Some" {
281 match args.first()? {
282 Expr::Lit(ExprLit { lit, .. }) => match lit {
283 Lit::Str(lit_str) => {
284 return Some(lit_str.value());
285 }
286 other_type_of_literal => {
287 return Some(other_type_of_literal.to_token_stream().to_string());
288 }
289 },
290 first_arg => {
291 return Some(quote!(#first_arg).to_string());
292 }
293 }
294 }
295 }
296 }
297 Expr::Path(ExprPath { path, .. }) => {
298 if path.get_ident()? == "None" {
299 return Some("None".into());
300 }
301 }
302 _ => {}
303 }
304
305 None
306}
307
308#[proc_macro_attribute]
348pub fn test(args: TokenStream, input: TokenStream) -> TokenStream {
349 let input_args = parse_macro_input!(args as Input);
350 let input_fn = parse_macro_input!(input as ItemFn);
351
352 let func_name_inner = &input_fn.sig.ident;
353 let test_case = TestCase::from_func_name(&input_args, &func_name_inner.to_string());
354
355 match TEST_CASES.lock() {
356 Ok(mut lock) => lock.insert(test_case.clone()),
357 Err(e) => {
358 eprintln!("Failed to acquire test case lock: {e}");
359 return quote! { #input_fn }.into();
360 }
361 };
362
363 let func_name = Ident::new(&test_case.func_name, Span::call_site());
364 let args = input_args.args.to_token_stream();
365
366 let error_crate = inspect_error_crate(&input_fn.sig);
376 let output = if error_crate == ErrorCrate::Eyre {
377 quote! {
378 #input_fn
379 pub(crate) async fn #func_name() -> tanu::eyre::Result<()> {
380 #func_name_inner(#args).await
381 }
382 }
383 } else {
384 quote! {
385 #input_fn
386 pub(crate) async fn #func_name() -> tanu::eyre::Result<()> {
387 #func_name_inner(#args).await.map_err(|e| tanu::eyre::eyre!(Box::new(e)))
388 }
389 }
390 };
391
392 output.into()
393}
394
395fn find_crate_root() -> eyre::Result<PathBuf> {
396 let dir = std::env::var("CARGO_MANIFEST_DIR")?;
397 Ok(dir.into())
398}
399
400fn discover_tests() -> eyre::Result<Vec<TestModule>> {
401 let root = find_crate_root()?;
402
403 let source_paths: Vec<_> = WalkDir::new(root)
405 .into_iter()
406 .filter_map(|entry| {
407 let path = entry.ok()?.into_path();
408 let ext = path.extension()?;
409 if ext.eq_ignore_ascii_case("rs") {
410 Some(path)
411 } else {
412 None
413 }
414 })
415 .collect();
416
417 let mut test_modules = Vec::<TestModule>::new();
418 for source_path in source_paths {
419 let mut source_file = std::fs::File::open(&source_path)
420 .wrap_err_with(|| format!("could not open file: {}", source_path.display()))?;
421 let mut code = String::new();
422 source_file.read_to_string(&mut code)?;
423
424 let file = syn::parse_file(&code)?;
425 let Some(module) = extract_module_path(&source_path) else {
426 continue;
427 };
428 test_modules.extend(extract_module_and_test(&module, file));
429 }
430
431 Ok(test_modules)
432}
433
434fn extract_module_path(path: &Path) -> Option<String> {
436 let src_index = path.iter().position(|p| p == "src")?;
437 let module_path: Vec<_> = path.iter().skip(src_index + 1).collect();
438 let module_path_str = module_path
439 .iter()
440 .filter_map(|p| p.to_str())
441 .map(|s| s.strip_suffix(".rs").unwrap_or(s)) .filter(|s| *s != "mod")
443 .collect::<Vec<_>>()
444 .join("::");
445 Some(module_path_str)
446}
447
448fn has_test_attribute(path: &syn::Path) -> bool {
450 let has_test = path.is_ident("test");
452 let has_tanu_test = match (path.segments.first(), path.segments.last()) {
454 (Some(first), Some(last)) => {
455 path.segments.len() == 2 && first.ident == "tanu" && last.ident == "test"
456 }
457 _ => false,
458 };
459
460 has_test || has_tanu_test
461}
462
463fn extract_module_and_test(module: &str, input: File) -> Vec<TestModule> {
464 let mut test_modules = Vec::new();
465 for item in input.items {
466 if let Item::Fn(item_fn) = item {
467 let mut is_test = false;
468 let mut test_cases = Vec::new();
469 for attr in item_fn.attrs {
470 if has_test_attribute(attr.path()) {
471 is_test = true;
472
473 match &attr.meta {
474 Meta::Path(_path) => {
476 let test_case = TestCase {
477 func_name: format!("tanu_{}", item_fn.sig.ident),
478 test_name: format!("tanu_{}", item_fn.sig.ident),
479 };
480 test_cases.push(test_case);
481 }
482 Meta::List(_list) => match attr.parse_args_with(Input::parse) {
484 Ok(test_case_token) => {
485 let test_case = TestCase::from_func_name(
486 &test_case_token,
487 &item_fn.sig.ident.to_string(),
488 );
489 test_cases.push(test_case);
490 }
491 Err(e) => {
492 eprintln!("failed to parse attributes in #[test]: {e:#}");
493 }
494 },
495 _ => {}
496 }
497 }
498 }
499 if is_test {
500 test_modules.push(TestModule {
501 module: module.to_owned(),
502 func_name: item_fn.sig.ident.to_string(),
503 test_cases,
504 });
505 }
506 }
507 }
508
509 test_modules
510}
511
512#[proc_macro_attribute]
550pub fn main(_args: TokenStream, input: TokenStream) -> TokenStream {
551 let main_fn = parse_macro_input!(input as ItemFn);
552
553 let test_modules = discover_tests().expect("failed to discover test cases");
554 let test_modules: HashMap<String, Vec<String>> = test_modules
558 .iter()
559 .flat_map(|module| {
560 let module_name = module.module.clone();
561 module.test_cases.iter().map(move |test_case| {
562 (
563 test_case.func_name.clone(),
564 if module_name == "main" {
565 "crate".into()
566 } else {
567 module_name.clone()
568 },
569 )
570 })
571 })
572 .fold(HashMap::new(), |mut acc, (func_name, module_name)| {
573 acc.entry(func_name).or_default().push(module_name);
574 acc
575 });
576
577 let (test_mods, test_names, func_names): (Vec<_>, Vec<_>, Vec<_>) = match TEST_CASES.lock() {
578 Ok(lock) => lock
579 .iter()
580 .flat_map(|f| {
581 test_modules
582 .get(&f.func_name)
583 .into_iter() .flatten() .filter_map(|module_name| {
586 let test_module_path = match syn::parse_str::<syn::Path>(module_name) {
587 Ok(path) => path,
588 Err(e) => {
589 eprintln!("failed to parse module path '{module_name}': {e}");
590 return None;
591 }
592 };
593
594 Some((
595 test_module_path,
596 f.test_name.clone(),
597 Ident::new(&f.func_name, Span::call_site()),
598 ))
599 })
600 })
601 .multiunzip(),
602 Err(e) => {
603 eprintln!("failed to acquire test case lock: {e}");
604 (Vec::new(), Vec::new(), Vec::new())
605 }
606 };
607
608 let output = quote! {
609 fn run() -> tanu::Runner {
610 let mut runner = tanu::Runner::new();
611 #(
612 runner.add_test(
613 #test_names,
614 &stringify!(#test_mods).replace(" ", ""),
615 std::sync::Arc::new(|| Box::pin(#test_mods::#func_names()))
616 );
617 )*
618 runner
619 }
620
621 #main_fn
622 };
623
624 output.into()
625}
626
627#[cfg(test)]
628mod test {
629 use crate::Input;
630
631 use super::{ErrorCrate, Expr, Path, TestCase};
632 use test_case::test_case;
633
634 #[test_case("test" => true; "test")]
635 #[test_case("tanu::test" => true; "tanu_test")]
636 #[test_case("tanu::foo::test" => false; "not_tanu_test")]
637 #[test_case("foo::test" => false; "also_not_tanu_test")]
638 fn has_test_attribute(s: &str) -> bool {
639 let path: syn::Path = syn::parse_str(s).expect("Failed to parse path");
640 super::has_test_attribute(&path)
641 }
642
643 #[test_case("/home/yukinari/tanu/src/main.rs", "main"; "main")]
644 #[test_case("/home/yukinari/tanu/src/foo.rs", "foo"; "foo")]
645 #[test_case("/home/yukinari/tanu/src/foo/bar.rs", "foo::bar"; "foo::bar")]
646 #[test_case("/home/yukinari/tanu/src/foo/bar/baz.rs", "foo::bar::baz"; "foo::bar::baz")]
647 #[test_case("/home/yukinari/tanu/src/foo/bar/mod.rs", "foo::bar"; "foo::bar::mod")]
648 fn test_extract_module_path(path: &str, module_path: &str) {
649 let path = Path::new(path);
650 let extracted_module = super::extract_module_path(path);
651 assert_eq!(extracted_module, Some(module_path.to_string()));
652 }
653
654 #[test_case("fn foo() -> eyre::Result" => ErrorCrate::Eyre; "eyre")]
655 #[test_case("fn foo() -> anyhow::Result" => ErrorCrate::AnythingElse; "anyhow")]
656 #[test_case("fn foo() -> miette::Result" => ErrorCrate::AnythingElse; "miette")]
657 #[test_case("fn foo() -> Result" => ErrorCrate::AnythingElse; "std_result")]
658 fn inspect_error_crate(s: &str) -> ErrorCrate {
659 let sig: syn::Signature = syn::parse_str(s).expect("failed to parse function signature");
660 super::inspect_error_crate(&sig)
661 }
662
663 #[test_case("Some(1)" => Some("1".into()); "Some with int")]
664 #[test_case("Some(\"test\")" => Some("test".into()); "Some with string")]
665 #[test_case("Some(true)" => Some("true".into()); "Some with boolean")]
666 #[test_case("Some(1.0)" => Some("1.0".into()); "Some with float")]
667 #[test_case("Some(StatusCode::OK)" => Some("StatusCode :: OK".into()); "Some third party type")]
668 #[test_case("Some(\"foo\".to_string())" => Some("\"foo\" . to_string ()".into()); "Some expression")]
669 #[test_case("None" => Some("None".into()); "None")]
670 fn extract_and_stringify_option(s: &str) -> Option<String> {
671 let expr: Expr = syn::parse_str(s).expect("failed to parse expression");
672 super::extract_and_stringify_option(&expr)
673 }
674
675 #[allow(clippy::erasing_op)]
676 #[test_case("a, b; \"test_name\"" => "foo::test_name"; "with test name")]
677 #[test_case("1+1" => "foo::1_add_1"; "with add expression")]
678 #[test_case("1+=1" => "foo::1_add_1"; "with add assignment expression")]
679 #[test_case("1-1" => "foo::1_sub_1"; "with sub expression")]
680 #[test_case("1-=1" => "foo::1_sub_1"; "with sub assignment expression")]
681 #[test_case("1/1" => "foo::1_div_1"; "with div expression")]
682 #[test_case("1/=1" => "foo::1_div_1"; "with div assignment expression")]
683 #[test_case("1*1" => "foo::1_mul_1"; "with mul expression")]
684 #[test_case("1*=1" => "foo::1_mul_1"; "with mul assignment expression")]
685 #[test_case("1%1" => "foo::1_mod_1"; "with mod expression")]
686 #[test_case("1%=1" => "foo::1_mod_1"; "with mod assignment expression")]
687 #[test_case("1==1" => "foo::1_eq_1"; "with eq expression")]
688 #[test_case("1!=1" => "foo::1_nq_1"; "with neq expression")]
689 #[test_case("1<1" => "foo::1_lt_1"; "with lt expression")]
690 #[test_case("1>1" => "foo::1_gt_1"; "with gt expression")]
691 #[test_case("1<=1" => "foo::1_le_1"; "with le expression")]
692 #[test_case("1>=1" => "foo::1_ge_1"; "with ge expression")]
693 #[test_case("true&&false" => "foo::true_and_false"; "with and expression")]
694 #[test_case("true||false" => "foo::true_or_false"; "with or expression")]
695 #[test_case("!true" => "foo::not_true"; "with not expression")]
696 #[test_case("1&1" => "foo::1_and_1"; "with bitwise and expression")]
697 #[test_case("1&=1" => "foo::1_and_1"; "with bitwise and assignment expression")]
698 #[test_case("1|1" => "foo::1_or_1"; "with bitwise or expression")]
699 #[test_case("1|=1" => "foo::1_or_1"; "with bitwise or assignment expression")]
700 #[test_case("1^1" => "foo::1_xor_1"; "with xor expression")]
701 #[test_case("1^=1" => "foo::1_xor_1"; "with xor assignment expression")]
702 #[test_case("1<<1" => "foo::1_lshift_1"; "with left shift expression")]
703 #[test_case("1<<=1" => "foo::1_lshift_1"; "with left shift assignment expression")]
704 #[test_case("1>>1" => "foo::1_rshift_1"; "with right shift expression")]
705 #[test_case("1>>=1" => "foo::1_rshift_1"; "with right shift assignment expression")]
706 #[test_case("\"bar\".to_string()" => "foo::bar_to_string"; "to_string")]
707 #[test_case("1+1*2" => "foo::1_add_1_mul_2"; "with add and mul expression")]
708 #[test_case("1*(2+3)" => "foo::1_mul_2_add_3"; "with mul and add expression")]
709 #[test_case("1+2-3" => "foo::1_add_2_sub_3"; "with add and sub expression")]
710 #[test_case("1/2*3" => "foo::1_div_2_mul_3"; "with div and mul expression")]
711 #[test_case("1%2+3" => "foo::1_mod_2_add_3"; "with mod and add expression")]
712 #[test_case("1==2&&3!=4" => "foo::1_eq_2_and_3_nq_4"; "with eq and and expression")]
713 #[test_case("true||false&&true" => "foo::true_or_false_and_true"; "with or and and expression")]
714 #[test_case("!(1+2)" => "foo::not_1_add_2"; "with not and add expression")]
715 #[test_case("1&2|3^4" => "foo::1_and_2_or_3_xor_4"; "with bitwise and, or, xor expression")]
716 #[test_case("1<<2>>3" => "foo::1_lshift_2_rshift_3"; "with left shift and right shift expression")]
717 #[test_case("Some(1+2)" => "foo::1_add_2"; "with Some and add expression")]
718 #[test_case("None" => "foo::none"; "with None")]
719 #[test_case("[1, 2]" => "foo::1_2"; "with array")]
720 #[test_case("vec![1, 2]" => "foo::vecnot_1_2"; "with macro")] #[test_case("\"foo\".to_string().len()" => "foo::foo_to_string_len"; "with function call chain")]
722 #[test_case("0.5+0.3" => "foo::0_5_add_0_3"; "with floating point add")] #[test_case("-10" => "foo::_sub_10"; "with negative number")] #[test_case("1.0e10" => "foo::1_0e10"; "with scientific notation")] #[test_case("0xff" => "foo::0xff"; "with hex literal")]
726 #[test_case("0o777" => "foo::0o777"; "with octal literal")]
727 #[test_case("0b1010" => "foo::0b1010"; "with binary literal")]
728 #[test_case("\"hello\" + \"world\"" => "foo::hello_add_world"; "with string concatenation")]
729 #[test_case("format!(\"{}{}\", 1, 2)" => "foo::formatnot__1_2"; "with format macro")] #[test_case("r#\"raw string\"#" => "foo::rawstring"; "with raw string")]
731 #[test_case("(1, \"hello\", true)" => "foo::1_hello_true"; "with mixed tuple")]
733 #[test_case("vec![1..5]" => "foo::vecnot_1__5"; "with range in macro")]
737 #[test_case("x.map(|v| v+1)" => "foo::x_map_or_v_or_v_add_1"; "with closure")]
739 #[test_case("a.into()" => "foo::a_into"; "with method call no args")]
740 #[test_case("a.parse::<i32>().unwrap()" => "foo::a_parse__lt_i32_gt__unwrap"; "with turbofish syntax")]
742 #[test_case("1..10" => "foo::1__10"; "with range")]
749 fn generate_test_name(args: &str) -> String {
753 let input_args: Input = syn::parse_str(args).expect("failed to parse input args");
754 let test_case = TestCase::from_func_name(&input_args, "foo");
755 test_case.test_name
756 }
757}