1extern crate proc_macro;
2
3use eyre::{OptionExt, WrapErr};
4use itertools::Itertools;
5use once_cell::sync::Lazy;
6use proc_macro::TokenStream;
7use proc_macro2::Span;
8use quote::{quote, ToTokens};
9use std::{
10 collections::{HashMap, HashSet},
11 io::Read,
12 path::{Path, PathBuf},
13 sync::{Arc, Mutex},
14};
15use syn::{
16 parse::Parse, parse_macro_input, punctuated::Punctuated, Expr, ExprCall, ExprLit, ExprPath,
17 File, Ident, Item, ItemFn, Lit, LitStr, Meta, ReturnType, Signature, Token, Type,
18};
19use walkdir::WalkDir;
20
21static TEST_CASES: Lazy<Arc<Mutex<HashSet<TestCase>>>> =
22 Lazy::new(|| Arc::new(Mutex::new(HashSet::new())));
23
24#[derive(Debug, Clone, Hash, Eq, PartialEq)]
26struct TestCase {
27 func_name: String,
29 test_name: String,
31}
32
33impl TestCase {
34 fn from_func_name(input: &Input, org_func_name: &str) -> TestCase {
36 TestCase {
37 func_name: generate_test_name(org_func_name, input, Some("tanu")),
38 test_name: generate_test_name(org_func_name, input, None),
39 }
40 }
41}
42
43#[derive(Debug)]
44#[allow(dead_code)]
45struct TestModule {
46 module: String,
48 func_name: String,
50 test_cases: Vec<TestCase>,
52}
53
54struct Input {
56 args: Punctuated<Expr, Token![,]>,
58 name: Option<LitStr>,
60}
61
62impl Parse for Input {
63 fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
64 if input.is_empty() {
65 Ok(Input {
66 args: Default::default(),
67 name: None,
68 })
69 } else {
70 let args: Punctuated<Expr, Token![,]> =
71 Punctuated::parse_separated_nonempty_with(input, Expr::parse)?;
72
73 let name = if input.parse::<Token![;]>().is_ok() {
74 input.parse::<LitStr>().ok()
75 } else {
76 None
77 };
78
79 Ok(Input { args, name })
80 }
81 }
82}
83
84fn generate_test_name(org_func_name: &str, input: &Input, prefix: Option<&str>) -> String {
87 let func_name = org_func_name.to_string();
88
89 if input.args.is_empty() {
90 if let Some(prefix) = prefix {
91 return format!("{prefix}_{func_name}");
92 } else {
93 return func_name.to_string();
94 }
95 }
96
97 let generated = match &input.name {
98 Some(name_argument) => name_argument.value(),
99 _ => {
100 let args = input
101 .args
102 .iter()
103 .filter_map(|expr| match expr {
104 Expr::Lit(ExprLit { lit, .. }) => match lit {
105 Lit::Str(lit_str) => Some(lit_str.value()),
106 other_literal => Some(quote!(#other_literal).to_string()),
107 },
108 expr @ Expr::Path(_) | expr @ Expr::Call(_) => {
109 extract_and_stringify_option(expr)
110 }
111 other_expr => Some(quote!(#other_expr).to_string()),
112 })
113 .map(|s| {
114 s.replace("+=", "_add_")
115 .replace("+", "_add_")
116 .replace("-=", "_sub_")
117 .replace("-", "_sub_")
118 .replace("/=", "_div_")
119 .replace("/", "_div_")
120 .replace("*=", "_mul_")
121 .replace("*", "_mul_")
122 .replace("%=", "_mod_")
123 .replace("%", "_mod_")
124 .replace("==", "_eq_")
125 .replace("!=", "_nq_")
126 .replace("&&", "_and_")
127 .replace("||", "_or_")
128 .replace("!", "not_")
129 .replace("&=", "_and_")
130 .replace("&", "_and_")
131 .replace("|=", "_or_")
132 .replace("|", "_or_")
133 .replace("^=", "_xor_")
134 .replace("^", "_xor_")
135 .replace("<<=", "_lshift_")
136 .replace("<<", "_lshift_")
137 .replace("<=", "_le_")
138 .replace("<", "_lt_")
139 .replace(">>=", "_rshift_")
140 .replace(">>", "_rshift_")
141 .replace(">=", "_ge_")
142 .replace(">", "_gt_")
143 .replace("&mut ", "")
144 .replace("*mut ", "")
145 .replace("&", "")
146 .replace("*", "")
147 .replace(" :: ", "_")
148 .replace("\"", "")
149 .replace("(", "")
150 .replace(")", "")
151 .replace(" ", "")
152 .replace(".", "_")
153 .to_lowercase()
154 })
155 .collect::<Vec<_>>()
156 .join("_");
157
158 if let Some(prefix) = prefix {
159 return format!("{prefix}_{func_name}_{args}");
160 } else {
161 return format!("{func_name}_{args}");
162 }
163 }
164 };
165
166 if syn::parse_str::<Ident>(&generated).is_err() {
168 panic!(
169 r#"Test case generation error! The provided test parameters contain
170 invalid characters that cannot be used in a function name (function name: {generated}).
171 Please specify a valid test name using only letters, numbers, and underscores."#
172 );
173 }
174
175 generated
176}
177
178#[derive(Debug, Eq, PartialEq)]
179enum ErrorCrate {
180 Eyre,
181 AnythingElse,
182}
183
184fn inspect_error_crate(sig: &Signature) -> ErrorCrate {
197 match &sig.output {
198 ReturnType::Default => panic!("return type needs to be other than ()"),
199 ReturnType::Type(_, ty) => {
200 let Type::Path(type_path) = ty.as_ref() else {
201 panic!("failed to get return type path");
202 };
203
204 let path = &type_path.path;
205 match (path.segments.first(), path.segments.last()) {
206 (Some(first), Some(last)) => {
207 if first.ident == "eyre" && last.ident == "Result" {
208 ErrorCrate::Eyre
209 } else {
210 ErrorCrate::AnythingElse
211 }
212 }
213 _ => {
214 panic!("unexpected return type");
215 }
216 }
217 }
218 }
219}
220
221#[allow(dead_code)]
222fn get_expr_variant_name(expr: &Expr) -> &'static str {
224 match expr {
225 Expr::Array(_) => "Array",
226 Expr::Assign(_) => "Assign",
227 Expr::Async(_) => "Async",
228 Expr::Await(_) => "Await",
229 Expr::Binary(_) => "Binary",
230 Expr::Block(_) => "Block",
231 Expr::Break(_) => "Break",
232 Expr::Call(_) => "Call",
233 Expr::Cast(_) => "Cast",
234 Expr::Closure(_) => "Closure",
235 Expr::Continue(_) => "Continue",
236 Expr::Field(_) => "Field",
237 Expr::ForLoop(_) => "ForLoop",
238 Expr::Group(_) => "Group",
239 Expr::If(_) => "If",
240 Expr::Index(_) => "Index",
241 Expr::Let(_) => "Let",
242 Expr::Lit(_) => "Lit",
243 Expr::Loop(_) => "Loop",
244 Expr::Macro(_) => "Macro",
245 Expr::Match(_) => "Match",
246 Expr::MethodCall(_) => "MethodCall",
247 Expr::Paren(_) => "Paren",
248 Expr::Path(_) => "Path",
249 Expr::Range(_) => "Range",
250 Expr::Reference(_) => "Reference",
251 Expr::Repeat(_) => "Repeat",
252 Expr::Return(_) => "Return",
253 Expr::Struct(_) => "Struct",
254 Expr::Try(_) => "Try",
255 Expr::TryBlock(_) => "TryBlock",
256 Expr::Tuple(_) => "Tuple",
257 Expr::Unary(_) => "Unary",
258 Expr::Unsafe(_) => "Unsafe",
259 Expr::Verbatim(_) => "Verbatim",
260 Expr::While(_) => "While",
261 Expr::Yield(_) => "Yield",
262 _ => "Unknown",
263 }
264}
265
266fn extract_and_stringify_option(expr: &Expr) -> Option<String> {
267 match expr {
268 Expr::Call(ExprCall { func, args, .. }) => {
269 if let Expr::Path(ExprPath { path, .. }) = &**func {
270 let segment = path.segments.last()?;
271 if segment.ident == "Some" {
272 match args.first()? {
273 Expr::Lit(ExprLit { lit, .. }) => match lit {
274 Lit::Str(lit_str) => {
275 return Some(lit_str.value());
276 }
277 other_type_of_literal => {
278 return Some(other_type_of_literal.to_token_stream().to_string());
279 }
280 },
281 first_arg => {
282 return Some(quote!(#first_arg).to_string());
283 }
284 }
285 }
286 }
287 }
288 Expr::Path(ExprPath { path, .. }) => {
289 if path.get_ident()? == "None" {
290 return Some("None".into());
291 }
292 }
293 _ => {}
294 }
295
296 None
297}
298
299#[proc_macro_attribute]
302pub fn test(args: TokenStream, input: TokenStream) -> TokenStream {
303 let input_args = parse_macro_input!(args as Input);
304 let input_fn = parse_macro_input!(input as ItemFn);
305
306 let func_name_inner = &input_fn.sig.ident;
307 let test_case = TestCase::from_func_name(&input_args, &func_name_inner.to_string());
308
309 let not_duplicated = TEST_CASES
310 .lock()
311 .expect("failed to accuire test case lock")
312 .insert(test_case.clone());
313 if !not_duplicated {
314 panic!(
315 r#"tanu does not yet support registering test with the exactly same signature.
316 please check the name of this function "{func_name_inner}" and try again."#
317 );
318 }
319
320 let func_name = Ident::new(&test_case.func_name, Span::call_site());
321 let args = input_args.args.to_token_stream();
322
323 let error_crate = inspect_error_crate(&input_fn.sig);
333 let output = if error_crate == ErrorCrate::Eyre {
334 quote! {
335 #input_fn
336 pub(crate) async fn #func_name() -> tanu::eyre::Result<()> {
337 #func_name_inner(#args).await
338 }
339 }
340 } else {
341 quote! {
342 #input_fn
343 pub(crate) async fn #func_name() -> tanu::eyre::Result<()> {
344 #func_name_inner(#args).await.map_err(|e| tanu::eyre::eyre!(Box::new(e)))
345 }
346 }
347 };
348
349 output.into()
350}
351
352fn find_crate_root() -> eyre::Result<PathBuf> {
353 let dir = std::env::var("CARGO_MANIFEST_DIR")?;
354 Ok(dir.into())
355}
356
357fn discover_tests() -> eyre::Result<Vec<TestModule>> {
358 let root = find_crate_root()?;
359
360 let source_paths: Vec<_> = WalkDir::new(root)
362 .into_iter()
363 .filter_map(|entry| {
364 let path = entry.ok()?.into_path();
365 let ext = path.extension()?;
366 if ext.eq_ignore_ascii_case("rs") {
367 Some(path)
368 } else {
369 None
370 }
371 })
372 .collect();
373
374 let mut test_modules = Vec::<TestModule>::new();
375 for source_path in source_paths {
376 let mut source_file = std::fs::File::open(&source_path)
377 .wrap_err_with(|| format!("could not open file: {}", source_path.display()))?;
378 let mut code = String::new();
379 source_file.read_to_string(&mut code)?;
380
381 let file = syn::parse_file(&code)?;
382 test_modules.extend(extract_module_and_test(
383 &extract_module_path(&source_path).ok_or_eyre("malformed module path")?,
384 file,
385 ));
386 }
387
388 Ok(test_modules)
389}
390
391fn extract_module_path(path: &Path) -> Option<String> {
393 let src_index = path.iter().position(|p| p == "src")?;
394 let module_path: Vec<_> = path.iter().skip(src_index + 1).collect();
395 let module_path_str = module_path
396 .iter()
397 .filter_map(|p| p.to_str())
398 .map(|s| s.strip_suffix(".rs").unwrap_or(s)) .filter(|s| *s != "mod")
400 .collect::<Vec<_>>()
401 .join("::");
402 Some(module_path_str)
403}
404
405fn has_test_attribute(path: &syn::Path) -> bool {
407 let has_test = path.is_ident("test");
409 let has_tanu_test = match (path.segments.first(), path.segments.last()) {
411 (Some(first), Some(last)) => {
412 path.segments.len() == 2 && first.ident == "tanu" && last.ident == "test"
413 }
414 _ => false,
415 };
416
417 has_test || has_tanu_test
418}
419
420fn extract_module_and_test(module: &str, input: File) -> Vec<TestModule> {
421 let mut test_modules = Vec::new();
422 for item in input.items {
423 if let Item::Fn(item_fn) = item {
424 let mut is_test = false;
425 let mut test_cases = Vec::new();
426 for attr in item_fn.attrs {
427 if has_test_attribute(attr.path()) {
428 is_test = true;
429
430 match &attr.meta {
431 Meta::Path(_path) => {
433 let test_case = TestCase {
434 func_name: format!("tanu_{}", item_fn.sig.ident),
435 test_name: format!("tanu_{}", item_fn.sig.ident),
436 };
437 test_cases.push(test_case);
438 }
439 Meta::List(_list) => match attr.parse_args_with(Input::parse) {
441 Ok(test_case_token) => {
442 let test_case = TestCase::from_func_name(
443 &test_case_token,
444 &item_fn.sig.ident.to_string(),
445 );
446 test_cases.push(test_case);
447 }
448 Err(e) => {
449 eprintln!("failed to parse attributes in #[test]: {e:#}");
450 }
451 },
452 _ => {}
453 }
454 }
455 }
456 if is_test {
457 test_modules.push(TestModule {
458 module: module.to_owned(),
459 func_name: item_fn.sig.ident.to_string(),
460 test_cases,
461 });
462 }
463 }
464 }
465
466 test_modules
467}
468
469#[proc_macro_attribute]
470pub fn main(_args: TokenStream, input: TokenStream) -> TokenStream {
471 let main_fn = parse_macro_input!(input as ItemFn);
472
473 let test_modules = discover_tests().expect("failed to discover test cases");
474 let test_modules: HashMap<String, String> = test_modules
475 .iter()
476 .flat_map(|module| {
477 let module_name = module.module.clone();
478 module.test_cases.iter().map(move |test_case| {
479 (
480 test_case.func_name.clone(),
481 if module_name == "main" {
482 "crate".into()
483 } else {
484 module_name.clone()
485 },
486 )
487 })
488 })
489 .collect();
490
491 let (test_mods, test_names, func_names): (Vec<_>, Vec<_>, Vec<_>) = TEST_CASES
492 .lock()
493 .expect("failed to accuire test case lock")
494 .iter()
495 .map(|f| {
496 let test_module = test_modules.get(&f.func_name).expect("module not found");
497 let test_module_path: syn::Path =
498 syn::parse_str(test_module).expect("failed to parse module path");
499 (
500 test_module_path,
501 f.test_name.clone(),
502 Ident::new(&f.func_name, Span::call_site()),
503 )
504 })
505 .multiunzip();
506
507 let output = quote! {
508 fn run() -> tanu::Runner {
509 let mut runner = tanu::Runner::new();
510 #(
511 runner.add_test(
512 #test_names,
513 &stringify!(#test_mods).replace(" ", ""),
514 std::sync::Arc::new(|| Box::pin(#test_mods::#func_names()))
515 );
516 )*
517 runner
518 }
519
520 #main_fn
521 };
522
523 output.into()
524}
525
526#[cfg(test)]
527mod test {
528 use crate::Input;
529
530 use super::{ErrorCrate, Expr, Path, TestCase};
531 use test_case::test_case;
532
533 #[test_case("test" => true; "test")]
534 #[test_case("tanu::test" => true; "tanu_test")]
535 #[test_case("tanu::foo::test" => false; "not_tanu_test")]
536 #[test_case("foo::test" => false; "also_not_tanu_test")]
537 fn has_test_attribute(s: &str) -> bool {
538 let path: syn::Path = syn::parse_str(s).expect("Failed to parse path");
539 super::has_test_attribute(&path)
540 }
541
542 #[test_case("/home/yukinari/tanu/src/main.rs", "main"; "main")]
543 #[test_case("/home/yukinari/tanu/src/foo.rs", "foo"; "foo")]
544 #[test_case("/home/yukinari/tanu/src/foo/bar.rs", "foo::bar"; "foo::bar")]
545 #[test_case("/home/yukinari/tanu/src/foo/bar/baz.rs", "foo::bar::baz"; "foo::bar::baz")]
546 #[test_case("/home/yukinari/tanu/src/foo/bar/mod.rs", "foo::bar"; "foo::bar::mod")]
547 fn test_extract_module_path(path: &str, module_path: &str) {
548 let path = Path::new(path);
549 let extracted_module = super::extract_module_path(path);
550 assert_eq!(extracted_module, Some(module_path.to_string()));
551 }
552
553 #[test_case("fn foo() -> eyre::Result" => ErrorCrate::Eyre; "eyre")]
554 #[test_case("fn foo() -> anyhow::Result" => ErrorCrate::AnythingElse; "anyhow")]
555 #[test_case("fn foo() -> miette::Result" => ErrorCrate::AnythingElse; "miette")]
556 #[test_case("fn foo() -> Result" => ErrorCrate::AnythingElse; "std_result")]
557 fn inspect_error_crate(s: &str) -> ErrorCrate {
558 let sig: syn::Signature = syn::parse_str(s).expect("failed to parse function signature");
559 super::inspect_error_crate(&sig)
560 }
561
562 #[test_case("Some(1)" => Some("1".into()); "Some with int")]
563 #[test_case("Some(\"test\")" => Some("test".into()); "Some with string")]
564 #[test_case("Some(true)" => Some("true".into()); "Some with boolean")]
565 #[test_case("Some(1.0)" => Some("1.0".into()); "Some with float")]
566 #[test_case("Some(StatusCode::OK)" => Some("StatusCode :: OK".into()); "Some third party type")]
567 #[test_case("Some(\"foo\".to_string())" => Some("\"foo\" . to_string ()".into()); "Some expression")]
568 #[test_case("None" => Some("None".into()); "None")]
569 fn extract_and_stringify_option(s: &str) -> Option<String> {
570 let expr: Expr = syn::parse_str(s).expect("failed to parse expression");
571 super::extract_and_stringify_option(&expr)
572 }
573
574 #[allow(clippy::erasing_op)]
575 #[test_case("a, b; \"test_name\"" => "test_name"; "with test name")]
576 #[test_case("1+1" => "foo_1_add_1"; "with add expression")]
577 #[test_case("1+=1" => "foo_1_add_1"; "with add assignment expression")]
578 #[test_case("1-1" => "foo_1_sub_1"; "with sub expression")]
579 #[test_case("1-=1" => "foo_1_sub_1"; "with sub assignment expression")]
580 #[test_case("1/1" => "foo_1_div_1"; "with div expression")]
581 #[test_case("1/=1" => "foo_1_div_1"; "with div assignment expression")]
582 #[test_case("1*1" => "foo_1_mul_1"; "with mul expression")]
583 #[test_case("1*=1" => "foo_1_mul_1"; "with mul assignment expression")]
584 #[test_case("1%1" => "foo_1_mod_1"; "with mod expression")]
585 #[test_case("1%=1" => "foo_1_mod_1"; "with mod assignment expression")]
586 #[test_case("1==1" => "foo_1_eq_1"; "with eq expression")]
587 #[test_case("1!=1" => "foo_1_nq_1"; "with neq expression")]
588 #[test_case("1<1" => "foo_1_lt_1"; "with lt expression")]
589 #[test_case("1>1" => "foo_1_gt_1"; "with gt expression")]
590 #[test_case("1<=1" => "foo_1_le_1"; "with le expression")]
591 #[test_case("1>=1" => "foo_1_ge_1"; "with ge expression")]
592 #[test_case("true&&false" => "foo_true_and_false"; "with and expression")]
593 #[test_case("true||false" => "foo_true_or_false"; "with or expression")]
594 #[test_case("!true" => "foo_not_true"; "with not expression")]
595 #[test_case("1&1" => "foo_1_and_1"; "with bitwise and expression")]
596 #[test_case("1&=1" => "foo_1_and_1"; "with bitwise and assignment expression")]
597 #[test_case("1|1" => "foo_1_or_1"; "with bitwise or expression")]
598 #[test_case("1|=1" => "foo_1_or_1"; "with bitwise or assignment expression")]
599 #[test_case("1^1" => "foo_1_xor_1"; "with xor expression")]
600 #[test_case("1^=1" => "foo_1_xor_1"; "with xor assignment expression")]
601 #[test_case("1<<1" => "foo_1_lshift_1"; "with left shift expression")]
602 #[test_case("1<<=1" => "foo_1_lshift_1"; "with left shift assignment expression")]
603 #[test_case("1>>1" => "foo_1_rshift_1"; "with right shift expression")]
604 #[test_case("1>>=1" => "foo_1_rshift_1"; "with right shift assignment expression")]
605 #[test_case("\"bar\".to_string()" => "foo_bar_to_string"; "to_string")]
606 #[test_case("1+1*2" => "foo_1_add_1_mul_2"; "with add and mul expression")]
607 #[test_case("1*(2+3)" => "foo_1_mul_2_add_3"; "with mul and add expression")]
608 #[test_case("1+2-3" => "foo_1_add_2_sub_3"; "with add and sub expression")]
609 #[test_case("1/2*3" => "foo_1_div_2_mul_3"; "with div and mul expression")]
610 #[test_case("1%2+3" => "foo_1_mod_2_add_3"; "with mod and add expression")]
611 #[test_case("1==2&&3!=4" => "foo_1_eq_2_and_3_nq_4"; "with eq and and expression")]
612 #[test_case("true||false&&true" => "foo_true_or_false_and_true"; "with or and and expression")]
613 #[test_case("!(1+2)" => "foo_not_1_add_2"; "with not and add expression")]
614 #[test_case("1&2|3^4" => "foo_1_and_2_or_3_xor_4"; "with bitwise and, or, xor expression")]
615 #[test_case("1<<2>>3" => "foo_1_lshift_2_rshift_3"; "with left shift and right shift expression")]
616 #[test_case("Some(1+2)" => "foo_1_add_2"; "with Some and add expression")]
617 #[test_case("None" => "foo_none"; "with None")]
618 #[test_case("\"foo\".to_string().len()" => "foo_foo_to_string_len"; "with function call chain")]
619 fn generate_test_name(args: &str) -> String {
620 let input_args: Input = syn::parse_str(args).expect("failed to parse input args");
621 let test_case = TestCase::from_func_name(&input_args, "foo");
622 test_case.test_name
623 }
624}