1use std::path::Path;
44
45use tldr_core::ast::ParserPool;
46use tldr_core::Language;
47use tree_sitter::{Node, Tree};
48
49#[derive(Debug, Clone, Default)]
51pub struct TestFileInfo {
52 pub is_test_file: bool,
55 pub test_function_count: u32,
57}
58
59pub fn recognize(path: &Path, source: &str, language: Language) -> TestFileInfo {
66 if !is_candidate_test_file(path, language) {
67 return TestFileInfo::default();
68 }
69
70 if matches!(language, Language::Rust) && !source.contains("#[test]") {
78 return TestFileInfo::default();
79 }
80
81 if source.trim().is_empty() {
83 return TestFileInfo {
84 is_test_file: true,
85 test_function_count: 0,
86 };
87 }
88
89 let pool = ParserPool::new();
90 let tree = match pool.parse(source, language).ok() {
91 Some(t) => t,
92 None => {
93 return TestFileInfo {
94 is_test_file: true,
95 test_function_count: 0,
96 };
97 }
98 };
99
100 let count = count_test_functions(&tree, source.as_bytes(), language);
101
102 TestFileInfo {
103 is_test_file: true,
104 test_function_count: count,
105 }
106}
107
108fn is_candidate_test_file(path: &Path, language: Language) -> bool {
113 let file_name = match path.file_name().and_then(|n| n.to_str()) {
114 Some(n) => n,
115 None => return false,
116 };
117 let stem = path
118 .file_stem()
119 .and_then(|s| s.to_str())
120 .unwrap_or(file_name);
121 let lower = file_name.to_ascii_lowercase();
122
123 match language {
124 Language::Python => {
126 file_name.starts_with("test_") && file_name.ends_with(".py")
127 || file_name.ends_with("_test.py")
128 }
129 Language::JavaScript | Language::TypeScript => {
133 let in_tests_dir = path
134 .components()
135 .any(|c| c.as_os_str() == "__tests__" || c.as_os_str() == "test"
136 || c.as_os_str() == "tests" || c.as_os_str() == "spec");
137 let has_test_marker = stem.ends_with(".test")
138 || stem.ends_with(".spec")
139 || stem.ends_with("_test")
140 || stem.ends_with("_spec")
141 || stem.ends_with("Test")
142 || stem.ends_with("Spec");
143 (has_test_marker || in_tests_dir)
144 && (lower.ends_with(".js")
145 || lower.ends_with(".jsx")
146 || lower.ends_with(".mjs")
147 || lower.ends_with(".cjs")
148 || lower.ends_with(".ts")
149 || lower.ends_with(".tsx"))
150 }
151 Language::Java => {
154 if !lower.ends_with(".java") {
155 return false;
156 }
157 stem.ends_with("Test")
158 || stem.ends_with("Tests")
159 || stem.ends_with("IT")
160 || stem.ends_with("ITCase")
161 || path.components().any(|c| c.as_os_str() == "test")
162 }
163 Language::Kotlin => {
165 (lower.ends_with(".kt") || lower.ends_with(".kts"))
166 && (stem.ends_with("Test")
167 || stem.ends_with("Tests")
168 || path.components().any(|c| c.as_os_str() == "test"))
169 }
170 Language::Php => lower.ends_with(".php") && (stem.ends_with("Test") || stem.ends_with("Tests")),
172 Language::Swift => {
174 lower.ends_with(".swift")
175 && (stem.ends_with("Tests") || stem.ends_with("Test") || stem.ends_with("Spec"))
176 }
177 Language::Ruby => {
179 lower.ends_with(".rb")
180 && (file_name.starts_with("test_")
181 || stem.ends_with("_test")
182 || stem.ends_with("_spec"))
183 }
184 Language::Go => lower.ends_with("_test.go"),
186 Language::Scala => {
188 lower.ends_with(".scala")
189 && (stem.ends_with("Test")
190 || stem.ends_with("Tests")
191 || stem.ends_with("Spec")
192 || stem.ends_with("Suite"))
193 }
194 Language::Elixir => {
196 (lower.ends_with(".exs") || lower.ends_with(".ex"))
197 && (stem.ends_with("_test") || file_name.starts_with("test_"))
198 }
199 Language::Lua | Language::Luau => {
201 (lower.ends_with(".lua") || lower.ends_with(".luau"))
202 && (stem.ends_with("_spec")
203 || stem.ends_with("_test")
204 || file_name.starts_with("test_"))
205 }
206 Language::Rust => {
224 lower.ends_with(".rs")
225 }
226 Language::CSharp => {
230 lower.ends_with(".cs")
231 && (stem.ends_with("Test")
232 || stem.ends_with("Tests")
233 || path.components().any(|c| {
234 let s = c.as_os_str().to_string_lossy().to_ascii_lowercase();
235 s == "test" || s == "tests"
236 }))
237 }
238 Language::C | Language::Cpp | Language::Ocaml => {
244 lower.contains("test") || lower.contains("spec")
245 }
246 }
247}
248
249fn count_test_functions(tree: &Tree, source: &[u8], language: Language) -> u32 {
251 let root = tree.root_node();
252 let mut count = 0u32;
253 walk_count(&root, source, language, &mut count);
254 count
255}
256
257fn walk_count(node: &Node, source: &[u8], language: Language, count: &mut u32) {
258 if matches_test_function(node, source, language) {
259 *count += 1;
260 return;
265 }
266
267 let mut cursor = node.walk();
268 for child in node.children(&mut cursor) {
269 walk_count(&child, source, language, count);
270 }
271}
272
273pub fn is_test_function_node(node: &Node, source: &[u8], language: Language) -> bool {
279 matches_test_function(node, source, language)
280}
281
282fn matches_test_function(node: &Node, source: &[u8], language: Language) -> bool {
284 match language {
285 Language::Python => python_is_test_function(node, source),
286 Language::JavaScript | Language::TypeScript => js_is_test_call(node, source),
287 Language::Java | Language::Kotlin => jvm_has_test_annotation(node, source),
288 Language::Php => php_is_test_method(node, source),
289 Language::Swift => swift_is_test_method(node, source),
290 Language::Ruby => ruby_is_test_def_or_block(node, source),
291 Language::Go => go_is_top_level_test_function(node, source),
292 Language::Scala => scala_is_test_call(node, source),
293 Language::Elixir => elixir_is_test_macro(node, source),
294 Language::Lua | Language::Luau => lua_is_test_call(node, source),
295 Language::Rust => rust_is_test_function(node, source),
296 Language::CSharp => csharp_has_test_attribute(node, source),
297 Language::C | Language::Cpp | Language::Ocaml => false,
298 }
299}
300
301fn rust_is_test_function(node: &Node, source: &[u8]) -> bool {
311 if node.kind() != "function_item" {
312 return false;
313 }
314 let mut prev = node.prev_sibling();
315 while let Some(p) = prev {
316 match p.kind() {
317 "attribute_item" => {
318 if rust_attribute_is_test(&p, source) {
319 return true;
320 }
321 prev = p.prev_sibling();
322 }
323 "line_comment" | "block_comment" => {
324 prev = p.prev_sibling();
325 }
326 _ => break,
327 }
328 }
329 false
330}
331
332fn rust_attribute_is_test(attr_item: &Node, source: &[u8]) -> bool {
333 let mut cursor = attr_item.walk();
335 for child in attr_item.children(&mut cursor) {
336 if child.kind() == "attribute" {
337 let text = node_text(child, source);
340 let head = text
342 .split(|c: char| c == '(' || c.is_whitespace())
343 .next()
344 .unwrap_or("");
345 let tail = head.rsplit("::").next().unwrap_or("");
346 if matches!(
347 tail,
348 "test" | "tokio_test" | "async_test" | "rstest" | "test_case"
349 ) {
350 return true;
351 }
352 }
353 }
354 false
355}
356
357fn csharp_has_test_attribute(node: &Node, source: &[u8]) -> bool {
364 if node.kind() != "method_declaration" {
365 return false;
366 }
367 let mut cursor = node.walk();
368 for child in node.children(&mut cursor) {
369 if child.kind() == "attribute_list" {
370 let mut inner = child.walk();
371 for attr in child.children(&mut inner) {
372 if attr.kind() == "attribute" && csharp_attribute_is_test(&attr, source) {
373 return true;
374 }
375 }
376 }
377 }
378 false
379}
380
381fn csharp_attribute_is_test(attribute: &Node, source: &[u8]) -> bool {
382 let text = node_text(*attribute, source);
384 let head = text
385 .split(|c: char| c == '(' || c.is_whitespace())
386 .next()
387 .unwrap_or("");
388 let tail = head.rsplit('.').next().unwrap_or("");
389 matches!(
390 tail,
391 "Test"
392 | "TestAttribute"
393 | "Fact"
394 | "FactAttribute"
395 | "Theory"
396 | "TheoryAttribute"
397 | "TestMethod"
398 | "TestMethodAttribute"
399 | "TestCase"
400 | "TestCaseAttribute"
401 | "DataTestMethod"
402 | "DataTestMethodAttribute"
403 )
404}
405
406fn python_is_test_function(node: &Node, source: &[u8]) -> bool {
408 if node.kind() != "function_definition" {
409 return false;
410 }
411 let name = node
412 .child_by_field_name("name")
413 .map(|n| node_text(n, source))
414 .unwrap_or_default();
415 name.starts_with("test_")
416}
417
418fn js_is_test_call(node: &Node, source: &[u8]) -> bool {
420 if node.kind() != "call_expression" {
423 return false;
424 }
425 let func_node = match node.child_by_field_name("function") {
426 Some(n) => n,
427 None => return false,
428 };
429 if func_node.kind() != "identifier" {
432 return false;
433 }
434 let name = node_text(func_node, source);
435 matches!(name.as_str(), "it" | "test" | "fit" | "xit" | "xtest")
436}
437
438fn jvm_has_test_annotation(node: &Node, source: &[u8]) -> bool {
440 let kind = node.kind();
446 if kind != "method_declaration" && kind != "function_declaration" {
447 return false;
448 }
449
450 let mut cursor = node.walk();
454 for child in node.children(&mut cursor) {
455 if child.kind() == "modifiers" {
456 if subtree_contains_annotation_named(&child, source, "Test") {
457 return true;
458 }
459 } else if child.kind() == "annotation" || child.kind() == "marker_annotation" {
460 if annotation_has_name(&child, source, "Test") {
461 return true;
462 }
463 }
464 }
465 false
466}
467
468fn subtree_contains_annotation_named(node: &Node, source: &[u8], target: &str) -> bool {
469 let mut cursor = node.walk();
470 for child in node.children(&mut cursor) {
471 let kind = child.kind();
472 if (kind == "annotation" || kind == "marker_annotation")
473 && annotation_has_name(&child, source, target)
474 {
475 return true;
476 }
477 if subtree_contains_annotation_named(&child, source, target) {
478 return true;
479 }
480 }
481 false
482}
483
484fn annotation_has_name(annotation_node: &Node, source: &[u8], target: &str) -> bool {
487 let text = node_text(*annotation_node, source);
488 let trimmed = text.trim_start_matches('@');
491 let head = trimmed.split(|c: char| c == '(' || c.is_whitespace()).next().unwrap_or("");
492 let last = head.rsplit('.').next().unwrap_or("");
493 last == target
494}
495
496fn php_is_test_method(node: &Node, source: &[u8]) -> bool {
498 if node.kind() != "method_declaration" {
499 return false;
500 }
501 let name = node
502 .child_by_field_name("name")
503 .map(|n| node_text(n, source))
504 .unwrap_or_default();
505 name.starts_with("test")
506}
507
508fn swift_is_test_method(node: &Node, source: &[u8]) -> bool {
510 let kind = node.kind();
514 if !(kind == "function_declaration" || kind == "protocol_function_declaration") {
515 return false;
516 }
517 let mut cursor = node.walk();
520 for child in node.children(&mut cursor) {
521 if child.kind() == "simple_identifier" {
522 return node_text(child, source).starts_with("test");
523 }
524 }
525 false
526}
527
528fn ruby_is_test_def_or_block(node: &Node, source: &[u8]) -> bool {
530 match node.kind() {
531 "method" => {
532 let name = node
534 .child_by_field_name("name")
535 .map(|n| node_text(n, source))
536 .unwrap_or_default();
537 name.starts_with("test_")
538 }
539 "call" => {
540 let method = node
542 .child_by_field_name("method")
543 .map(|n| node_text(n, source))
544 .unwrap_or_default();
545 matches!(method.as_str(), "it" | "specify")
546 }
547 _ => false,
548 }
549}
550
551fn go_is_top_level_test_function(node: &Node, source: &[u8]) -> bool {
553 if node.kind() != "function_declaration" {
554 return false;
555 }
556 let name = node
557 .child_by_field_name("name")
558 .map(|n| node_text(n, source))
559 .unwrap_or_default();
560 if !name.starts_with("Test") {
561 return false;
562 }
563 let after = name.strip_prefix("Test").unwrap_or("");
566 let starts_upper = after.chars().next().map(|c| c.is_ascii_uppercase()).unwrap_or(false);
567 starts_upper
568}
569
570fn scala_is_test_call(node: &Node, source: &[u8]) -> bool {
572 if node.kind() != "call_expression" {
578 return false;
579 }
580 let func_node = match node.child_by_field_name("function") {
581 Some(n) => n,
582 None => return false,
583 };
584 let name = node_text(func_node, source);
585 matches!(name.as_str(), "test")
586}
587
588fn elixir_is_test_macro(node: &Node, source: &[u8]) -> bool {
590 if node.kind() != "call" {
593 return false;
594 }
595 let target = match node.child_by_field_name("target") {
596 Some(n) => n,
597 None => return false,
598 };
599 node_text(target, source) == "test"
600}
601
602fn lua_is_test_call(node: &Node, source: &[u8]) -> bool {
604 if node.kind() != "function_call" && node.kind() != "function_call_statement" {
607 return false;
608 }
609 let mut cursor = node.walk();
610 for child in node.children(&mut cursor) {
611 if child.kind() == "identifier" {
612 return matches!(node_text(child, source).as_str(), "it" | "test");
613 }
614 }
615 false
616}
617
618fn node_text(node: Node, source: &[u8]) -> String {
620 let start = node.start_byte();
621 let end = node.end_byte();
622 if end <= source.len() {
623 std::str::from_utf8(&source[start..end])
624 .unwrap_or("")
625 .to_string()
626 } else {
627 String::new()
628 }
629}
630
631pub fn detect_language(path: &Path) -> Option<Language> {
634 Language::from_path(path)
635}
636
637#[cfg(test)]
638mod tests {
639 use super::*;
640 use std::fs;
641 use tempfile::tempdir;
642
643 fn write(dir: &Path, name: &str, body: &str) -> std::path::PathBuf {
644 let p = dir.join(name);
645 if let Some(parent) = p.parent() {
646 fs::create_dir_all(parent).ok();
647 }
648 fs::write(&p, body).unwrap();
649 p
650 }
651
652 #[test]
653 fn python_test_function_counted() {
654 let tmp = tempdir().unwrap();
655 let p = write(
656 tmp.path(),
657 "test_x.py",
658 "def test_one():\n pass\n\ndef test_two():\n pass\n",
659 );
660 let src = fs::read_to_string(&p).unwrap();
661 let info = recognize(&p, &src, Language::Python);
662 assert!(info.is_test_file);
663 assert_eq!(info.test_function_count, 2);
664 }
665
666 #[test]
667 fn javascript_describe_it_counted() {
668 let tmp = tempdir().unwrap();
669 let p = write(
670 tmp.path(),
671 "foo.test.js",
672 "describe('s', () => { it('a', () => {}); it('b', () => {}); });",
673 );
674 let src = fs::read_to_string(&p).unwrap();
675 let info = recognize(&p, &src, Language::JavaScript);
676 assert!(info.is_test_file);
677 assert_eq!(info.test_function_count, 2);
678 }
679
680 #[test]
681 fn java_test_annotation_counted() {
682 let tmp = tempdir().unwrap();
683 let p = write(
684 tmp.path(),
685 "FooTest.java",
686 "import org.junit.Test;\nclass FooTest {\n @Test public void shouldFoo() {}\n @Test public void shouldBar() {}\n}\n",
687 );
688 let src = fs::read_to_string(&p).unwrap();
689 let info = recognize(&p, &src, Language::Java);
690 assert!(info.is_test_file);
691 assert_eq!(info.test_function_count, 2);
692 }
693
694 #[test]
695 fn php_phpunit_counted() {
696 let tmp = tempdir().unwrap();
697 let p = write(
698 tmp.path(),
699 "FooTest.php",
700 "<?php\nclass FooTest {\n public function testBar() {}\n public function testBaz() {}\n}\n",
701 );
702 let src = fs::read_to_string(&p).unwrap();
703 let info = recognize(&p, &src, Language::Php);
704 assert!(info.is_test_file);
705 assert_eq!(info.test_function_count, 2);
706 }
707
708 #[test]
709 fn swift_xctest_counted() {
710 let tmp = tempdir().unwrap();
711 let p = write(
712 tmp.path(),
713 "FooTests.swift",
714 "import XCTest\nclass FooTests: XCTestCase {\n func testBar() {}\n func testBaz() {}\n}\n",
715 );
716 let src = fs::read_to_string(&p).unwrap();
717 let info = recognize(&p, &src, Language::Swift);
718 assert!(info.is_test_file);
719 assert!(info.test_function_count >= 2);
720 }
721
722 #[test]
723 fn go_testing_counted() {
724 let tmp = tempdir().unwrap();
725 let p = write(
726 tmp.path(),
727 "foo_test.go",
728 "package foo\nimport \"testing\"\nfunc TestFoo(t *testing.T) {}\nfunc TestBar(t *testing.T) {}\nfunc helper() {}\n",
729 );
730 let src = fs::read_to_string(&p).unwrap();
731 let info = recognize(&p, &src, Language::Go);
732 assert!(info.is_test_file);
733 assert_eq!(info.test_function_count, 2);
734 }
735
736 #[test]
737 fn ruby_minitest_counted() {
738 let tmp = tempdir().unwrap();
739 let p = write(
740 tmp.path(),
741 "foo_test.rb",
742 "class FooTest\n def test_one; end\n def test_two; end\nend\n",
743 );
744 let src = fs::read_to_string(&p).unwrap();
745 let info = recognize(&p, &src, Language::Ruby);
746 assert!(info.is_test_file);
747 assert_eq!(info.test_function_count, 2);
748 }
749}