Skip to main content

tldr_cli/commands/contracts/
test_recognizer.rs

1//! Per-language test framework recognizers.
2//!
3//! Closes phase-11 BUG-AGG-3 (HIGH): `tldr specs --from-tests` and
4//! `tldr invariants --from-tests` previously hard-coded Python `test_*`
5//! pytest-style discovery, so JavaScript/Java/PHP/Swift/Go/Kotlin/Scala/
6//! Ruby/Elixir/Lua test directories returned `test_files_scanned = 0`
7//! despite containing real test functions.
8//!
9//! Each recognizer answers two questions per file:
10//!
11//! 1. **Is this a test file?** (per the language's discovery convention)
12//! 2. **How many test functions does it contain?** (counted via tree-sitter
13//!    AST walks rather than text-level heuristics, so comments and string
14//!    literals can't false-match.)
15//!
16//! Languages handled:
17//!
18//! | Language    | Convention                                                 |
19//! |-------------|------------------------------------------------------------|
20//! | Python      | `def test_*` (pytest) or methods inside `class Test*`      |
21//! | JavaScript  | `it(...)` / `test(...)` calls (Mocha/Jest/Jasmine)         |
22//! | TypeScript  | `it(...)` / `test(...)` calls                              |
23//! | Java        | Methods annotated with `@Test`                             |
24//! | Kotlin      | Methods annotated with `@Test`                             |
25//! | PHP         | `public function test*` inside class whose name ends `Test`|
26//! | Swift       | `func test*()` inside class extending `XCTestCase`         |
27//! | Ruby        | `def test_*` (Minitest) or `it/describe` blocks (RSpec)    |
28//! | Go          | Top-level `func TestXxx(t *testing.T)`                     |
29//! | Scala       | `test("...")` calls (Munit/ScalaTest FunSuite)             |
30//! | Elixir      | `test "..." do ... end` blocks (ExUnit)                    |
31//! | Lua / Luau  | `it(...)` / `describe(...)` blocks (busted)                |
32//! | Rust        | `fn` items immediately preceded by `#[test]`               |
33//! | C#          | Methods with `[Test]` / `[Fact]` / `[TestMethod]`          |
34//! | C / C++ /   | (No widely standard test framework — fall back to file     |
35//! | OCaml       |  count when name suggests test, function count = 0; the   |
36//! |             |  framework adapter can be wired later.)                    |
37//!
38//! For languages without a clear convention, the recognizer treats files
39//! whose name contains `test` (case-insensitive) as test files but reports
40//! `0` functions — strictly better than the previous behaviour where the
41//! file count itself was always `0` for non-Python.
42
43use std::path::Path;
44
45use tldr_core::ast::ParserPool;
46use tldr_core::Language;
47use tree_sitter::{Node, Tree};
48
49/// Result of inspecting a single candidate file for test functions.
50#[derive(Debug, Clone, Default)]
51pub struct TestFileInfo {
52    /// True if this file participates in a test suite (i.e. should bump
53    /// `test_files_scanned`).
54    pub is_test_file: bool,
55    /// Number of test functions detected by walking the AST.
56    pub test_function_count: u32,
57}
58
59/// Public entry point: classify a candidate file and count its tests.
60///
61/// `language` is the language the caller has already detected for the
62/// file (typically via `Language::from_path` in `run_specs` /
63/// `collect_observations`). Returns a default zero-info value if the
64/// file is not parseable in this language.
65pub fn recognize(path: &Path, source: &str, language: Language) -> TestFileInfo {
66    if !is_candidate_test_file(path, language) {
67        return TestFileInfo::default();
68    }
69
70    // language-specific-bugs-v1 (P14.AGG14-9): for Rust, every `.rs` is
71    // a path-level candidate so `tldr specs --from-tests` can cover
72    // inline `#[cfg(test)] mod tests { ... }` blocks inside production
73    // source files (e.g. ripgrep `crates/globset/src/lib.rs`). To keep
74    // directory walks cheap, gate on a fast `#[test]` substring check
75    // before parsing — any `.rs` without `#[test]` cannot contribute
76    // and parsing the entire file would just be wasted work.
77    if matches!(language, Language::Rust) && !source.contains("#[test]") {
78        return TestFileInfo::default();
79    }
80
81    // Empty or whitespace-only files: nothing to count.
82    if source.trim().is_empty() {
83        return TestFileInfo {
84            is_test_file: true,
85            test_function_count: 0,
86        };
87    }
88
89    let pool = ParserPool::new();
90    let tree = match pool.parse(source, language).ok() {
91        Some(t) => t,
92        None => {
93            return TestFileInfo {
94                is_test_file: true,
95                test_function_count: 0,
96            };
97        }
98    };
99
100    let count = count_test_functions(&tree, source.as_bytes(), language);
101
102    TestFileInfo {
103        is_test_file: true,
104        test_function_count: count,
105    }
106}
107
108/// Decide whether `path` matches the language's test-file naming convention.
109///
110/// Each language has its own convention; we centralise them here so
111/// `run_specs` can early-skip non-tests cheaply (without parsing).
112fn is_candidate_test_file(path: &Path, language: Language) -> bool {
113    let file_name = match path.file_name().and_then(|n| n.to_str()) {
114        Some(n) => n,
115        None => return false,
116    };
117    let stem = path
118        .file_stem()
119        .and_then(|s| s.to_str())
120        .unwrap_or(file_name);
121    let lower = file_name.to_ascii_lowercase();
122
123    match language {
124        // Python: `test_*.py` (pytest) or `*_test.py` (unittest-style).
125        Language::Python => {
126            file_name.starts_with("test_") && file_name.ends_with(".py")
127                || file_name.ends_with("_test.py")
128        }
129        // JavaScript / TypeScript: Jest/Mocha conventions.
130        // `*.test.js` / `*.spec.js` (and tsx/jsx variants), or any file
131        // inside a directory literally named `__tests__`.
132        Language::JavaScript | Language::TypeScript => {
133            let in_tests_dir = path
134                .components()
135                .any(|c| c.as_os_str() == "__tests__" || c.as_os_str() == "test"
136                    || c.as_os_str() == "tests" || c.as_os_str() == "spec");
137            let has_test_marker = stem.ends_with(".test")
138                || stem.ends_with(".spec")
139                || stem.ends_with("_test")
140                || stem.ends_with("_spec")
141                || stem.ends_with("Test")
142                || stem.ends_with("Spec");
143            (has_test_marker || in_tests_dir)
144                && (lower.ends_with(".js")
145                    || lower.ends_with(".jsx")
146                    || lower.ends_with(".mjs")
147                    || lower.ends_with(".cjs")
148                    || lower.ends_with(".ts")
149                    || lower.ends_with(".tsx"))
150        }
151        // Java: Maven/Gradle convention — files under `src/test/java` are
152        // tests, or any class whose name ends with `Test`/`Tests`.
153        Language::Java => {
154            if !lower.ends_with(".java") {
155                return false;
156            }
157            stem.ends_with("Test")
158                || stem.ends_with("Tests")
159                || stem.ends_with("IT")
160                || stem.ends_with("ITCase")
161                || path.components().any(|c| c.as_os_str() == "test")
162        }
163        // Kotlin: same convention as Java.
164        Language::Kotlin => {
165            (lower.ends_with(".kt") || lower.ends_with(".kts"))
166                && (stem.ends_with("Test")
167                    || stem.ends_with("Tests")
168                    || path.components().any(|c| c.as_os_str() == "test"))
169        }
170        // PHP: PHPUnit convention — class FooTest in FooTest.php.
171        Language::Php => lower.ends_with(".php") && (stem.ends_with("Test") || stem.ends_with("Tests")),
172        // Swift: XCTest convention — files named `*Tests.swift`.
173        Language::Swift => {
174            lower.ends_with(".swift")
175                && (stem.ends_with("Tests") || stem.ends_with("Test") || stem.ends_with("Spec"))
176        }
177        // Ruby: Minitest `test_*.rb` / `*_test.rb`; RSpec `*_spec.rb`.
178        Language::Ruby => {
179            lower.ends_with(".rb")
180                && (file_name.starts_with("test_")
181                    || stem.ends_with("_test")
182                    || stem.ends_with("_spec"))
183        }
184        // Go: convention is `*_test.go`.
185        Language::Go => lower.ends_with("_test.go"),
186        // Scala: Munit / ScalaTest convention — `*Suite.scala`/`*Spec.scala`/`*Test.scala`.
187        Language::Scala => {
188            lower.ends_with(".scala")
189                && (stem.ends_with("Test")
190                    || stem.ends_with("Tests")
191                    || stem.ends_with("Spec")
192                    || stem.ends_with("Suite"))
193        }
194        // Elixir: ExUnit convention — `*_test.exs`.
195        Language::Elixir => {
196            (lower.ends_with(".exs") || lower.ends_with(".ex"))
197                && (stem.ends_with("_test") || file_name.starts_with("test_"))
198        }
199        // Lua / Luau: busted convention — `*_spec.lua`/`*_test.lua`.
200        Language::Lua | Language::Luau => {
201            (lower.ends_with(".lua") || lower.ends_with(".luau"))
202                && (stem.ends_with("_spec")
203                    || stem.ends_with("_test")
204                    || file_name.starts_with("test_"))
205        }
206        // Rust: built-in `#[test]` framework — files under `tests/` are
207        // integration tests, and any source file may contain `#[cfg(test)]`
208        // mod blocks. Treat any `.rs` whose path contains `test` (a tests/
209        // directory or a *_test.rs filename), `bench`, OR contains a
210        // `#[test]` substring as a candidate;
211        // matches_test_function then filters down to actual `#[test]` items.
212        //
213        // language-specific-bugs-v1 (P14.AGG14-9): the path-only filter
214        // missed the canonical Rust convention of inline
215        // `#[cfg(test)] mod tests { ... }` blocks inside a regular
216        // `lib.rs` / module file (every cargo crate has these). The
217        // additional substring check at recognise-time
218        // (`source.contains("#[test]")`) is cheap relative to parsing and
219        // turns single-file invocations like `tldr specs --from-tests
220        // crates/globset/src/lib.rs` into yielding the inline tests they
221        // contain. Directory walks accept any `.rs` here and still rely on
222        // `matches_test_function` for per-fn filtering.
223        Language::Rust => {
224            lower.ends_with(".rs")
225        }
226        // C#: NUnit / xUnit / MSTest — files named `*Tests.cs` or under
227        // a Tests directory. matches_test_function filters down to methods
228        // carrying `[Test]` / `[Fact]` / `[TestMethod]` etc.
229        Language::CSharp => {
230            lower.ends_with(".cs")
231                && (stem.ends_with("Test")
232                    || stem.ends_with("Tests")
233                    || path.components().any(|c| {
234                        let s = c.as_os_str().to_string_lossy().to_ascii_lowercase();
235                        s == "test" || s == "tests"
236                    }))
237        }
238        // Languages without a single dominant convention. Fall back to the
239        // weak heuristic of "filename contains 'test'" so directories laid
240        // out as `tests/` still count their files. The function-count
241        // walker still returns 0 for these — wiring grammar-specific
242        // recognisers is left as TODO.
243        Language::C | Language::Cpp | Language::Ocaml => {
244            lower.contains("test") || lower.contains("spec")
245        }
246    }
247}
248
249/// AST-walk a parsed test file and count test functions per language convention.
250fn count_test_functions(tree: &Tree, source: &[u8], language: Language) -> u32 {
251    let root = tree.root_node();
252    let mut count = 0u32;
253    walk_count(&root, source, language, &mut count);
254    count
255}
256
257fn walk_count(node: &Node, source: &[u8], language: Language, count: &mut u32) {
258    if matches_test_function(node, source, language) {
259        *count += 1;
260        // Don't recurse into a function body — nested calls inside a
261        // matched test function shouldn't be double-counted (e.g. an
262        // `it(...)` inside a `describe(...)` block both match the JS
263        // recogniser; only the leaf `it` counts).
264        return;
265    }
266
267    let mut cursor = node.walk();
268    for child in node.children(&mut cursor) {
269        walk_count(&child, source, language, count);
270    }
271}
272
273/// Public wrapper around the per-language test-function predicate.
274///
275/// `tldr specs --from-tests` re-uses this from the generic spec extractor
276/// so the same definition of "test function" used to count tests is used
277/// to harvest assertions inside them.
278pub fn is_test_function_node(node: &Node, source: &[u8], language: Language) -> bool {
279    matches_test_function(node, source, language)
280}
281
282/// Per-language predicate: is this AST node a test function declaration?
283fn matches_test_function(node: &Node, source: &[u8], language: Language) -> bool {
284    match language {
285        Language::Python => python_is_test_function(node, source),
286        Language::JavaScript | Language::TypeScript => js_is_test_call(node, source),
287        Language::Java | Language::Kotlin => jvm_has_test_annotation(node, source),
288        Language::Php => php_is_test_method(node, source),
289        Language::Swift => swift_is_test_method(node, source),
290        Language::Ruby => ruby_is_test_def_or_block(node, source),
291        Language::Go => go_is_top_level_test_function(node, source),
292        Language::Scala => scala_is_test_call(node, source),
293        Language::Elixir => elixir_is_test_macro(node, source),
294        Language::Lua | Language::Luau => lua_is_test_call(node, source),
295        Language::Rust => rust_is_test_function(node, source),
296        Language::CSharp => csharp_has_test_attribute(node, source),
297        Language::C | Language::Cpp | Language::Ocaml => false,
298    }
299}
300
301// -- Rust: `#[test]` attribute precedes a `fn` item ---------------------------
302//
303// In tree-sitter-rust, `#[test]` is parsed as an `attribute_item` that is a
304// SIBLING (preceding) of the `function_item`, not a child. So at every
305// `function_item` we walk back to the previous siblings collecting any
306// `attribute_item` nodes; if any contains an `attribute` whose head
307// identifier is `test`, this is a unit test. We also accept aliases
308// commonly used in async/integration setups: `tokio::test`, `async_std::test`,
309// `rstest`, `proptest`, plus the common `test_case` macro.
310fn rust_is_test_function(node: &Node, source: &[u8]) -> bool {
311    if node.kind() != "function_item" {
312        return false;
313    }
314    let mut prev = node.prev_sibling();
315    while let Some(p) = prev {
316        match p.kind() {
317            "attribute_item" => {
318                if rust_attribute_is_test(&p, source) {
319                    return true;
320                }
321                prev = p.prev_sibling();
322            }
323            "line_comment" | "block_comment" => {
324                prev = p.prev_sibling();
325            }
326            _ => break,
327        }
328    }
329    false
330}
331
332fn rust_attribute_is_test(attr_item: &Node, source: &[u8]) -> bool {
333    // attribute_item -> [#, [, attribute(...), ]]
334    let mut cursor = attr_item.walk();
335    for child in attr_item.children(&mut cursor) {
336        if child.kind() == "attribute" {
337            // attribute can be `test`, `tokio::test`, `test_case::test_case`,
338            // etc. Walk and collect the tail identifier(s).
339            let text = node_text(child, source);
340            // Strip any argument list `(...)` and whitespace; take the path tail.
341            let head = text
342                .split(|c: char| c == '(' || c.is_whitespace())
343                .next()
344                .unwrap_or("");
345            let tail = head.rsplit("::").next().unwrap_or("");
346            if matches!(
347                tail,
348                "test" | "tokio_test" | "async_test" | "rstest" | "test_case"
349            ) {
350                return true;
351            }
352        }
353    }
354    false
355}
356
357// -- C#: methods with `[Test]` / `[Fact]` / `[TestMethod]` etc. --------------
358//
359// tree-sitter-c-sharp uses `method_declaration` whose direct children include
360// one or more `attribute_list` nodes. Each `attribute_list` contains
361// `attribute` children whose first identifier names the attribute (e.g.
362// "Test", "Fact", "TestMethod", "TestCase", "Theory").
363fn csharp_has_test_attribute(node: &Node, source: &[u8]) -> bool {
364    if node.kind() != "method_declaration" {
365        return false;
366    }
367    let mut cursor = node.walk();
368    for child in node.children(&mut cursor) {
369        if child.kind() == "attribute_list" {
370            let mut inner = child.walk();
371            for attr in child.children(&mut inner) {
372                if attr.kind() == "attribute" && csharp_attribute_is_test(&attr, source) {
373                    return true;
374                }
375            }
376        }
377    }
378    false
379}
380
381fn csharp_attribute_is_test(attribute: &Node, source: &[u8]) -> bool {
382    // Read the tail identifier of the attribute name.
383    let text = node_text(*attribute, source);
384    let head = text
385        .split(|c: char| c == '(' || c.is_whitespace())
386        .next()
387        .unwrap_or("");
388    let tail = head.rsplit('.').next().unwrap_or("");
389    matches!(
390        tail,
391        "Test"
392            | "TestAttribute"
393            | "Fact"
394            | "FactAttribute"
395            | "Theory"
396            | "TheoryAttribute"
397            | "TestMethod"
398            | "TestMethodAttribute"
399            | "TestCase"
400            | "TestCaseAttribute"
401            | "DataTestMethod"
402            | "DataTestMethodAttribute"
403    )
404}
405
406// -- Python: `def test_*` -----------------------------------------------------
407fn python_is_test_function(node: &Node, source: &[u8]) -> bool {
408    if node.kind() != "function_definition" {
409        return false;
410    }
411    let name = node
412        .child_by_field_name("name")
413        .map(|n| node_text(n, source))
414        .unwrap_or_default();
415    name.starts_with("test_")
416}
417
418// -- JS/TS: `it(...)` / `test(...)` -------------------------------------------
419fn js_is_test_call(node: &Node, source: &[u8]) -> bool {
420    // tree-sitter-typescript / -javascript both expose `call_expression`
421    // with a `function` child that's an identifier for top-level calls.
422    if node.kind() != "call_expression" {
423        return false;
424    }
425    let func_node = match node.child_by_field_name("function") {
426        Some(n) => n,
427        None => return false,
428    };
429    // We only want unqualified identifiers (`it("...")`, `test("...")`),
430    // not member calls like `obj.it(...)` which are unrelated.
431    if func_node.kind() != "identifier" {
432        return false;
433    }
434    let name = node_text(func_node, source);
435    matches!(name.as_str(), "it" | "test" | "fit" | "xit" | "xtest")
436}
437
438// -- Java/Kotlin: methods with @Test annotation -------------------------------
439fn jvm_has_test_annotation(node: &Node, source: &[u8]) -> bool {
440    // Java tree-sitter: `method_declaration` with a sibling `modifiers`
441    // child containing `marker_annotation` / `annotation` whose name is
442    // `Test`. Kotlin (kotlin-ng): `function_declaration` with a `modifiers`
443    // child containing `annotation` -> `user_type` -> `type_identifier`
444    // text "Test".
445    let kind = node.kind();
446    if kind != "method_declaration" && kind != "function_declaration" {
447        return false;
448    }
449
450    // Walk this node's modifier subtree looking for an annotation whose
451    // last identifier component is "Test". Spans both Java and Kotlin AST
452    // shapes, since both use roughly the same node names.
453    let mut cursor = node.walk();
454    for child in node.children(&mut cursor) {
455        if child.kind() == "modifiers" {
456            if subtree_contains_annotation_named(&child, source, "Test") {
457                return true;
458            }
459        } else if child.kind() == "annotation" || child.kind() == "marker_annotation" {
460            if annotation_has_name(&child, source, "Test") {
461                return true;
462            }
463        }
464    }
465    false
466}
467
468fn subtree_contains_annotation_named(node: &Node, source: &[u8], target: &str) -> bool {
469    let mut cursor = node.walk();
470    for child in node.children(&mut cursor) {
471        let kind = child.kind();
472        if (kind == "annotation" || kind == "marker_annotation")
473            && annotation_has_name(&child, source, target)
474        {
475            return true;
476        }
477        if subtree_contains_annotation_named(&child, source, target) {
478            return true;
479        }
480    }
481    false
482}
483
484/// True if `annotation_node` has its tail identifier equal to `target`
485/// (e.g. `@Test`, `@org.junit.Test`, `@org.junit.jupiter.api.Test`).
486fn annotation_has_name(annotation_node: &Node, source: &[u8], target: &str) -> bool {
487    let text = node_text(*annotation_node, source);
488    // Strip leading `@` and any argument list, then compare the tail
489    // identifier (after the last `.`).
490    let trimmed = text.trim_start_matches('@');
491    let head = trimmed.split(|c: char| c == '(' || c.is_whitespace()).next().unwrap_or("");
492    let last = head.rsplit('.').next().unwrap_or("");
493    last == target
494}
495
496// -- PHP: PHPUnit `public function test*` -------------------------------------
497fn php_is_test_method(node: &Node, source: &[u8]) -> bool {
498    if node.kind() != "method_declaration" {
499        return false;
500    }
501    let name = node
502        .child_by_field_name("name")
503        .map(|n| node_text(n, source))
504        .unwrap_or_default();
505    name.starts_with("test")
506}
507
508// -- Swift: `func test*()` ----------------------------------------------------
509fn swift_is_test_method(node: &Node, source: &[u8]) -> bool {
510    // tree-sitter-swift uses `function_declaration` (top-level) and
511    // `protocol_function_declaration` (inside class body); we accept any
512    // declaration whose `name` child starts with "test".
513    let kind = node.kind();
514    if !(kind == "function_declaration" || kind == "protocol_function_declaration") {
515        return false;
516    }
517    // Swift grammar exposes the method name as the first `simple_identifier`
518    // child after the `func` keyword. Walk children to find it.
519    let mut cursor = node.walk();
520    for child in node.children(&mut cursor) {
521        if child.kind() == "simple_identifier" {
522            return node_text(child, source).starts_with("test");
523        }
524    }
525    false
526}
527
528// -- Ruby: `def test_*` (Minitest) or `it/describe` blocks (RSpec) -----------
529fn ruby_is_test_def_or_block(node: &Node, source: &[u8]) -> bool {
530    match node.kind() {
531        "method" => {
532            // Minitest: `def test_<name>`.
533            let name = node
534                .child_by_field_name("name")
535                .map(|n| node_text(n, source))
536                .unwrap_or_default();
537            name.starts_with("test_")
538        }
539        "call" => {
540            // RSpec: `it "..." do ... end` / `specify "..." do ... end`.
541            let method = node
542                .child_by_field_name("method")
543                .map(|n| node_text(n, source))
544                .unwrap_or_default();
545            matches!(method.as_str(), "it" | "specify")
546        }
547        _ => false,
548    }
549}
550
551// -- Go: top-level `func TestXxx(t *testing.T)` -------------------------------
552fn go_is_top_level_test_function(node: &Node, source: &[u8]) -> bool {
553    if node.kind() != "function_declaration" {
554        return false;
555    }
556    let name = node
557        .child_by_field_name("name")
558        .map(|n| node_text(n, source))
559        .unwrap_or_default();
560    if !name.starts_with("Test") {
561        return false;
562    }
563    // Filter out `Test` exactly (no following uppercase). The Go testing
564    // convention is `TestXxx` where `X` is upper.
565    let after = name.strip_prefix("Test").unwrap_or("");
566    let starts_upper = after.chars().next().map(|c| c.is_ascii_uppercase()).unwrap_or(false);
567    starts_upper
568}
569
570// -- Scala: `test("...") { ... }` calls ---------------------------------------
571fn scala_is_test_call(node: &Node, source: &[u8]) -> bool {
572    // tree-sitter-scala uses `call_expression`; the function child is a
573    // `simple_identifier`/`identifier` named "test". We also accept the
574    // FunSuite naming convention where a class extends `FunSuite` and
575    // calls `test(...)` at class scope. Pattern-matching just the call
576    // shape is sufficient for the common case.
577    if node.kind() != "call_expression" {
578        return false;
579    }
580    let func_node = match node.child_by_field_name("function") {
581        Some(n) => n,
582        None => return false,
583    };
584    let name = node_text(func_node, source);
585    matches!(name.as_str(), "test")
586}
587
588// -- Elixir: `test "..." do ... end` ------------------------------------------
589fn elixir_is_test_macro(node: &Node, source: &[u8]) -> bool {
590    // tree-sitter-elixir parses macros as `call` nodes; the head is the
591    // `target` child (an `identifier`), and the body is a `do_block`.
592    if node.kind() != "call" {
593        return false;
594    }
595    let target = match node.child_by_field_name("target") {
596        Some(n) => n,
597        None => return false,
598    };
599    node_text(target, source) == "test"
600}
601
602// -- Lua/Luau: `it(...)` / `describe(...)` (busted) ---------------------------
603fn lua_is_test_call(node: &Node, source: &[u8]) -> bool {
604    // tree-sitter-lua / -luau represent calls as `function_call`. The
605    // function name lives in the `name` field (an `identifier`).
606    if node.kind() != "function_call" && node.kind() != "function_call_statement" {
607        return false;
608    }
609    let mut cursor = node.walk();
610    for child in node.children(&mut cursor) {
611        if child.kind() == "identifier" {
612            return matches!(node_text(child, source).as_str(), "it" | "test");
613        }
614    }
615    false
616}
617
618// -- Helpers ------------------------------------------------------------------
619fn node_text(node: Node, source: &[u8]) -> String {
620    let start = node.start_byte();
621    let end = node.end_byte();
622    if end <= source.len() {
623        std::str::from_utf8(&source[start..end])
624            .unwrap_or("")
625            .to_string()
626    } else {
627        String::new()
628    }
629}
630
631/// Detect language for a candidate test file. Returns `None` if the
632/// extension isn't supported.
633pub fn detect_language(path: &Path) -> Option<Language> {
634    Language::from_path(path)
635}
636
637#[cfg(test)]
638mod tests {
639    use super::*;
640    use std::fs;
641    use tempfile::tempdir;
642
643    fn write(dir: &Path, name: &str, body: &str) -> std::path::PathBuf {
644        let p = dir.join(name);
645        if let Some(parent) = p.parent() {
646            fs::create_dir_all(parent).ok();
647        }
648        fs::write(&p, body).unwrap();
649        p
650    }
651
652    #[test]
653    fn python_test_function_counted() {
654        let tmp = tempdir().unwrap();
655        let p = write(
656            tmp.path(),
657            "test_x.py",
658            "def test_one():\n    pass\n\ndef test_two():\n    pass\n",
659        );
660        let src = fs::read_to_string(&p).unwrap();
661        let info = recognize(&p, &src, Language::Python);
662        assert!(info.is_test_file);
663        assert_eq!(info.test_function_count, 2);
664    }
665
666    #[test]
667    fn javascript_describe_it_counted() {
668        let tmp = tempdir().unwrap();
669        let p = write(
670            tmp.path(),
671            "foo.test.js",
672            "describe('s', () => { it('a', () => {}); it('b', () => {}); });",
673        );
674        let src = fs::read_to_string(&p).unwrap();
675        let info = recognize(&p, &src, Language::JavaScript);
676        assert!(info.is_test_file);
677        assert_eq!(info.test_function_count, 2);
678    }
679
680    #[test]
681    fn java_test_annotation_counted() {
682        let tmp = tempdir().unwrap();
683        let p = write(
684            tmp.path(),
685            "FooTest.java",
686            "import org.junit.Test;\nclass FooTest {\n  @Test public void shouldFoo() {}\n  @Test public void shouldBar() {}\n}\n",
687        );
688        let src = fs::read_to_string(&p).unwrap();
689        let info = recognize(&p, &src, Language::Java);
690        assert!(info.is_test_file);
691        assert_eq!(info.test_function_count, 2);
692    }
693
694    #[test]
695    fn php_phpunit_counted() {
696        let tmp = tempdir().unwrap();
697        let p = write(
698            tmp.path(),
699            "FooTest.php",
700            "<?php\nclass FooTest {\n  public function testBar() {}\n  public function testBaz() {}\n}\n",
701        );
702        let src = fs::read_to_string(&p).unwrap();
703        let info = recognize(&p, &src, Language::Php);
704        assert!(info.is_test_file);
705        assert_eq!(info.test_function_count, 2);
706    }
707
708    #[test]
709    fn swift_xctest_counted() {
710        let tmp = tempdir().unwrap();
711        let p = write(
712            tmp.path(),
713            "FooTests.swift",
714            "import XCTest\nclass FooTests: XCTestCase {\n  func testBar() {}\n  func testBaz() {}\n}\n",
715        );
716        let src = fs::read_to_string(&p).unwrap();
717        let info = recognize(&p, &src, Language::Swift);
718        assert!(info.is_test_file);
719        assert!(info.test_function_count >= 2);
720    }
721
722    #[test]
723    fn go_testing_counted() {
724        let tmp = tempdir().unwrap();
725        let p = write(
726            tmp.path(),
727            "foo_test.go",
728            "package foo\nimport \"testing\"\nfunc TestFoo(t *testing.T) {}\nfunc TestBar(t *testing.T) {}\nfunc helper() {}\n",
729        );
730        let src = fs::read_to_string(&p).unwrap();
731        let info = recognize(&p, &src, Language::Go);
732        assert!(info.is_test_file);
733        assert_eq!(info.test_function_count, 2);
734    }
735
736    #[test]
737    fn ruby_minitest_counted() {
738        let tmp = tempdir().unwrap();
739        let p = write(
740            tmp.path(),
741            "foo_test.rb",
742            "class FooTest\n  def test_one; end\n  def test_two; end\nend\n",
743        );
744        let src = fs::read_to_string(&p).unwrap();
745        let info = recognize(&p, &src, Language::Ruby);
746        assert!(info.is_test_file);
747        assert_eq!(info.test_function_count, 2);
748    }
749}