Skip to main content

scute_core/code_similarity/
language.rs

1use std::collections::HashMap;
2use std::fmt;
3use std::path::Path;
4
5use crate::parser::AstParser;
6
7type TestDetector = fn(&mut dyn AstParser, &Path, &str, usize, usize) -> bool;
8
9pub struct LanguageConfig {
10    language: tree_sitter::Language,
11    roles: HashMap<&'static str, NodeRole>,
12    test_detector: TestDetector,
13}
14
15impl fmt::Debug for LanguageConfig {
16    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
17        f.debug_struct("LanguageConfig")
18            .field("roles", &self.roles)
19            .finish_non_exhaustive()
20    }
21}
22
23impl LanguageConfig {
24    fn new(
25        language: tree_sitter::Language,
26        table: &[(NodeRole, &[&'static str])],
27        test_detector: TestDetector,
28    ) -> Self {
29        let mut roles = HashMap::new();
30        for &(role, kinds) in table {
31            for &kind in kinds {
32                roles.insert(kind, role);
33            }
34        }
35        Self {
36            language,
37            roles,
38            test_detector,
39        }
40    }
41
42    #[must_use]
43    pub fn language(&self) -> &tree_sitter::Language {
44        &self.language
45    }
46
47    #[must_use]
48    pub fn classify(&self, kind: &str) -> NodeRole {
49        self.roles.get(kind).copied().unwrap_or(NodeRole::Other)
50    }
51
52    /// Returns `true` if the given line range is inside test code.
53    ///
54    /// Detection is language-specific: some languages need to parse
55    /// the source, others rely on file path conventions alone.
56    #[must_use]
57    pub fn is_test_context(
58        &self,
59        parser: &mut dyn AstParser,
60        path: &Path,
61        content: &str,
62        start_line: usize,
63        end_line: usize,
64    ) -> bool {
65        (self.test_detector)(parser, path, content, start_line, end_line)
66    }
67}
68
69#[derive(Debug, Clone, Copy, PartialEq)]
70pub enum NodeRole {
71    Identifier,
72    Literal,
73    Comment,
74    Decoration,
75    Other,
76}
77
78fn rust_is_test(
79    parser: &mut dyn AstParser,
80    path: &Path,
81    content: &str,
82    start_line: usize,
83    end_line: usize,
84) -> bool {
85    if path.components().any(|c| c.as_os_str() == "tests") {
86        return true;
87    }
88    let ranges = rust_test_ranges(parser, content);
89    ranges
90        .iter()
91        .any(|&(range_start, range_end)| start_line >= range_start && end_line <= range_end)
92}
93
94/// Finds line ranges of Rust test code: `#[cfg(test)] mod` blocks and `#[test]` functions.
95fn rust_test_ranges(parser: &mut dyn AstParser, content: &str) -> Vec<(usize, usize)> {
96    let Ok(tree) = parser.parse(content, &tree_sitter_rust::LANGUAGE.into()) else {
97        return vec![];
98    };
99
100    let src = content.as_bytes();
101    let mut ranges = vec![];
102    collect_test_ranges(tree.root_node(), src, &mut ranges);
103    ranges
104}
105
106fn collect_test_ranges(parent: tree_sitter::Node, src: &[u8], ranges: &mut Vec<(usize, usize)>) {
107    let mut cursor = parent.walk();
108    for node in parent.children(&mut cursor) {
109        match node.kind() {
110            "mod_item" if has_preceding_attr(&node, src, is_cfg_test_attr) => {
111                push_range_with_attrs(&node, ranges);
112            }
113            "mod_item" => recurse_into_mod_body(node, src, ranges),
114            "function_item" if has_preceding_attr(&node, src, |t| t == "#[test]") => {
115                push_range_with_attrs(&node, ranges);
116            }
117            _ => {}
118        }
119    }
120}
121
122fn push_range_with_attrs(node: &tree_sitter::Node, ranges: &mut Vec<(usize, usize)>) {
123    let start = first_preceding_attr_row(node).unwrap_or(node.start_position().row);
124    ranges.push((start + 1, node.end_position().row + 1));
125}
126
127fn recurse_into_mod_body(node: tree_sitter::Node, src: &[u8], ranges: &mut Vec<(usize, usize)>) {
128    if let Some(body) = node.child_by_field_name("body") {
129        collect_test_ranges(body, src, ranges);
130    }
131}
132
133/// Matches `#[cfg(test)]` and compound forms like `#[cfg(all(test, ...))]`,
134/// but not `#[cfg(not(test))]`.
135fn is_cfg_test_attr(attr_text: &str) -> bool {
136    attr_text.starts_with("#[cfg(")
137        && !attr_text.contains("not(test)")
138        && (attr_text == "#[cfg(test)]"
139            || attr_text.contains("(test,")
140            || attr_text.contains("(test)")
141            || attr_text.contains(", test)")
142            || attr_text.contains(", test,"))
143}
144
145fn has_preceding_attr(node: &tree_sitter::Node, src: &[u8], pred: impl Fn(&str) -> bool) -> bool {
146    let mut sibling = node.prev_sibling();
147    while let Some(s) = sibling {
148        if s.kind() != "attribute_item" {
149            break;
150        }
151        if s.utf8_text(src).is_ok_and(&pred) {
152            return true;
153        }
154        sibling = s.prev_sibling();
155    }
156    false
157}
158
159fn first_preceding_attr_row(node: &tree_sitter::Node) -> Option<usize> {
160    let mut first_row = None;
161    let mut sibling = node.prev_sibling();
162    while let Some(s) = sibling {
163        if s.kind() != "attribute_item" {
164            break;
165        }
166        first_row = Some(s.start_position().row);
167        sibling = s.prev_sibling();
168    }
169    first_row
170}
171
172/// Detects test files by JS/TS conventions: `*.test.*`, `*.spec.*`, or `__tests__/` directory.
173fn js_is_test(
174    _parser: &mut dyn AstParser,
175    path: &Path,
176    _content: &str,
177    _start_line: usize,
178    _end_line: usize,
179) -> bool {
180    let stem = path.file_stem().and_then(|s| s.to_str()).unwrap_or("");
181    Path::new(stem)
182        .extension()
183        .is_some_and(|ext| ext == "test" || ext == "spec")
184        || path.components().any(|c| c.as_os_str() == "__tests__")
185}
186
187#[must_use]
188pub fn rust() -> LanguageConfig {
189    LanguageConfig::new(
190        tree_sitter_rust::LANGUAGE.into(),
191        &[
192            (
193                NodeRole::Identifier,
194                &[
195                    "identifier",
196                    "type_identifier",
197                    "field_identifier",
198                    "shorthand_field_identifier",
199                    "primitive_type",
200                    "lifetime",
201                    "self",
202                    "metavariable",
203                    "crate",
204                    "super",
205                ],
206            ),
207            (
208                NodeRole::Literal,
209                &[
210                    "string_literal",
211                    "raw_string_literal",
212                    "char_literal",
213                    "integer_literal",
214                    "float_literal",
215                    "boolean_literal",
216                ],
217            ),
218            (NodeRole::Comment, &["line_comment", "block_comment"]),
219            (
220                NodeRole::Decoration,
221                &["attribute_item", "inner_attribute_item"],
222            ),
223        ],
224        rust_is_test,
225    )
226}
227
228const TS_ROLES: &[(NodeRole, &[&str])] = &[
229    (
230        NodeRole::Identifier,
231        &[
232            "identifier",
233            "shorthand_property_identifier",
234            "shorthand_property_identifier_pattern",
235            "property_identifier",
236            "type_identifier",
237            "predefined_type",
238        ],
239    ),
240    (
241        NodeRole::Literal,
242        &[
243            "string",
244            "template_string",
245            "number",
246            "true",
247            "false",
248            "null",
249            "undefined",
250            "regex",
251        ],
252    ),
253    (NodeRole::Comment, &["comment"]),
254    (NodeRole::Decoration, &["decorator"]),
255];
256
257#[must_use]
258pub fn javascript() -> LanguageConfig {
259    LanguageConfig::new(
260        tree_sitter_javascript::LANGUAGE.into(),
261        &[
262            (
263                NodeRole::Identifier,
264                &[
265                    "identifier",
266                    "shorthand_property_identifier",
267                    "shorthand_property_identifier_pattern",
268                    "property_identifier",
269                ],
270            ),
271            (
272                NodeRole::Literal,
273                &[
274                    "string",
275                    "template_string",
276                    "number",
277                    "true",
278                    "false",
279                    "null",
280                    "undefined",
281                    "regex",
282                ],
283            ),
284            (NodeRole::Comment, &["comment"]),
285            (NodeRole::Decoration, &["decorator"]),
286        ],
287        js_is_test,
288    )
289}
290
291#[must_use]
292pub fn typescript() -> LanguageConfig {
293    LanguageConfig::new(
294        tree_sitter_typescript::LANGUAGE_TYPESCRIPT.into(),
295        TS_ROLES,
296        js_is_test,
297    )
298}
299
300#[must_use]
301pub fn typescript_tsx() -> LanguageConfig {
302    LanguageConfig::new(
303        tree_sitter_typescript::LANGUAGE_TSX.into(),
304        TS_ROLES,
305        js_is_test,
306    )
307}
308
309#[cfg(test)]
310mod tests {
311    use super::*;
312    use crate::parser::TreeSitterParser;
313
314    fn parse_rust_test_ranges(src: &str) -> Vec<(usize, usize)> {
315        let mut parser = TreeSitterParser::new();
316        rust_test_ranges(&mut parser, src)
317    }
318
319    #[test]
320    fn rust_test_ranges_finds_cfg_test_module() {
321        let src = "\
322fn production() -> i32 { 42 }
323
324#[cfg(test)]
325mod tests {
326    fn helper(x: i32) -> i32 { x + 1 }
327}
328";
329        assert_eq!(parse_rust_test_ranges(src), vec![(3, 6)]);
330    }
331
332    #[test]
333    fn detects_naked_test_fn_as_test_context() {
334        let src = "\
335fn production() -> i32 { 42 }
336
337#[test]
338fn test_something() {
339    let x = production();
340    assert_eq!(x, 42);
341}
342";
343        assert_eq!(parse_rust_test_ranges(src), vec![(3, 7)]);
344    }
345
346    #[test]
347    fn walks_past_multiple_attributes_to_find_test() {
348        let src = "\
349#[test]
350#[should_panic]
351fn test_something() {
352    panic!(\"expected\");
353}
354";
355        assert_eq!(parse_rust_test_ranges(src), vec![(1, 5)]);
356    }
357
358    #[test]
359    fn rejects_cfg_not_test_module() {
360        let src = "\
361#[cfg(not(test))]
362mod prod_only {
363    fn helper() -> i32 { 42 }
364}
365";
366        assert!(parse_rust_test_ranges(src).is_empty());
367    }
368
369    #[test]
370    fn detects_compound_cfg_test_as_test_context() {
371        let src = "\
372#[cfg(all(test, feature = \"integration\"))]
373mod integration_tests {
374    fn helper(x: i32) -> i32 { x + 1 }
375}
376";
377        assert_eq!(parse_rust_test_ranges(src), vec![(1, 4)]);
378    }
379
380    #[test]
381    fn finds_test_fn_nested_in_non_test_module() {
382        let src = "\
383mod integration {
384    #[test]
385    fn test_flow() {
386        assert!(true);
387    }
388}
389";
390        assert_eq!(parse_rust_test_ranges(src), vec![(2, 5)]);
391    }
392}