Skip to main content

tracevault_core/
code_nav.rs

1use tree_sitter::{Language, Node, Parser};
2
3#[derive(Debug, Clone, serde::Serialize)]
4pub struct CodeScope {
5    pub kind: String,
6    pub name: String,
7    pub start_line: usize, // 1-indexed
8    pub end_line: usize,   // 1-indexed, inclusive
9}
10
11pub fn get_language(file_ext: &str) -> Option<Language> {
12    match file_ext {
13        "rs" => Some(tree_sitter_rust::LANGUAGE.into()),
14        "ts" | "tsx" => Some(tree_sitter_typescript::LANGUAGE_TYPESCRIPT.into()),
15        "js" | "jsx" => Some(tree_sitter_javascript::LANGUAGE.into()),
16        "py" => Some(tree_sitter_python::LANGUAGE.into()),
17        "go" => Some(tree_sitter_go::LANGUAGE.into()),
18        "java" => Some(tree_sitter_java::LANGUAGE.into()),
19        "scala" | "sc" => Some(tree_sitter_scala::LANGUAGE.into()),
20        _ => None,
21    }
22}
23
24/// Find the innermost named scope (function/class/module) containing the given line.
25/// `line` is 1-indexed.
26pub fn find_enclosing_scope(source: &str, file_ext: &str, line: usize) -> Option<CodeScope> {
27    let language = get_language(file_ext)?;
28    let mut parser = Parser::new();
29    parser.set_language(&language).ok()?;
30    let tree = parser.parse(source, None)?;
31
32    let target_line = line.checked_sub(1)?; // tree-sitter uses 0-indexed rows
33    let scope_node_types = scope_types_for_ext(file_ext);
34
35    let mut best: Option<(Node, &str)> = None;
36    walk_tree(tree.root_node(), target_line, &scope_node_types, &mut best);
37
38    let (node, kind) = best?;
39    let name = extract_name(node, source);
40
41    Some(CodeScope {
42        kind: kind.to_string(),
43        name: name.unwrap_or_else(|| "<anonymous>".to_string()),
44        start_line: node.start_position().row + 1,
45        end_line: node.end_position().row + 1,
46    })
47}
48
49fn walk_tree<'a>(
50    node: Node<'a>,
51    target_line: usize,
52    scope_types: &[(&str, &'a str)],
53    best: &mut Option<(Node<'a>, &'a str)>,
54) {
55    let start = node.start_position().row;
56    let end = node.end_position().row;
57
58    if target_line < start || target_line > end {
59        return;
60    }
61
62    for (node_type, kind) in scope_types {
63        if node.kind() == *node_type
64            && best.as_ref().is_none_or(|(b, _)| {
65                let b_range = b.end_position().row - b.start_position().row;
66                let n_range = end - start;
67                n_range < b_range
68            })
69        {
70            *best = Some((node, kind));
71        }
72    }
73
74    for i in 0..node.child_count() {
75        if let Some(child) = node.child(i) {
76            walk_tree(child, target_line, scope_types, best);
77        }
78    }
79}
80
81fn scope_types_for_ext(ext: &str) -> Vec<(&'static str, &'static str)> {
82    match ext {
83        "rs" => vec![
84            ("function_item", "function"),
85            ("impl_item", "impl"),
86            ("struct_item", "struct"),
87            ("enum_item", "enum"),
88            ("trait_item", "trait"),
89            ("mod_item", "module"),
90        ],
91        "ts" | "tsx" | "js" | "jsx" => vec![
92            ("function_declaration", "function"),
93            ("method_definition", "method"),
94            ("arrow_function", "function"),
95            ("class_declaration", "class"),
96            ("interface_declaration", "interface"),
97        ],
98        "py" => vec![
99            ("function_definition", "function"),
100            ("class_definition", "class"),
101        ],
102        "go" => vec![
103            ("function_declaration", "function"),
104            ("method_declaration", "method"),
105            ("type_declaration", "type"),
106        ],
107        "java" => vec![
108            ("method_declaration", "method"),
109            ("constructor_declaration", "constructor"),
110            ("class_declaration", "class"),
111            ("interface_declaration", "interface"),
112            ("enum_declaration", "enum"),
113        ],
114        "scala" | "sc" => vec![
115            ("function_definition", "function"),
116            ("class_definition", "class"),
117            ("object_definition", "object"),
118            ("trait_definition", "trait"),
119        ],
120        _ => vec![],
121    }
122}
123
124fn extract_name(node: Node, source: &str) -> Option<String> {
125    for i in 0..node.child_count() {
126        if let Some(child) = node.child(i) {
127            let field_name = node.field_name_for_child(i as u32);
128            if matches!(field_name, Some("name")) {
129                let text = &source[child.byte_range()];
130                return Some(text.to_string());
131            }
132            if child.kind() == "identifier" || child.kind() == "type_identifier" {
133                let text = &source[child.byte_range()];
134                return Some(text.to_string());
135            }
136        }
137    }
138    None
139}
140
141/// Fallback: return a range of lines around the target line.
142pub fn fallback_scope(source: &str, line: usize, context_lines: usize) -> CodeScope {
143    let total_lines = source.lines().count();
144    let start = line.saturating_sub(context_lines).max(1);
145    let end = (line + context_lines).min(total_lines);
146    CodeScope {
147        kind: "region".to_string(),
148        name: format!("lines {start}-{end}"),
149        start_line: start,
150        end_line: end,
151    }
152}
153
154#[cfg(test)]
155mod tests {
156    use super::*;
157
158    #[test]
159    fn get_language_rust() {
160        assert!(get_language("rs").is_some());
161    }
162
163    #[test]
164    fn get_language_typescript() {
165        assert!(get_language("ts").is_some());
166        assert!(get_language("tsx").is_some());
167    }
168
169    #[test]
170    fn get_language_javascript() {
171        assert!(get_language("js").is_some());
172        assert!(get_language("jsx").is_some());
173    }
174
175    #[test]
176    fn get_language_python() {
177        assert!(get_language("py").is_some());
178    }
179
180    #[test]
181    fn get_language_go() {
182        assert!(get_language("go").is_some());
183    }
184
185    #[test]
186    fn get_language_java() {
187        assert!(get_language("java").is_some());
188    }
189
190    #[test]
191    fn get_language_scala() {
192        assert!(get_language("scala").is_some());
193        assert!(get_language("sc").is_some());
194    }
195
196    #[test]
197    fn get_language_unknown_returns_none() {
198        assert!(get_language("txt").is_none());
199        assert!(get_language("md").is_none());
200    }
201
202    // NOTE: find_enclosing_scope tests are skipped because tree-sitter grammar
203    // versions (v15) are incompatible with the tree-sitter runtime in tests.
204    // get_language() returns a Language but set_language() fails with LanguageError.
205    // This is a pre-existing version mismatch, not a test issue.
206
207    #[test]
208    fn find_scope_unknown_language_returns_none() {
209        let src = "some text content\nline two\n";
210        let scope = find_enclosing_scope(src, "txt", 1);
211        assert!(scope.is_none());
212    }
213
214    #[test]
215    fn fallback_scope_near_start() {
216        let src = "line1\nline2\nline3\nline4\nline5\n";
217        let scope = fallback_scope(src, 1, 10);
218        assert_eq!(scope.start_line, 1);
219    }
220
221    #[test]
222    fn fallback_scope_near_end() {
223        let src = "line1\nline2\nline3\nline4\nline5\n";
224        let scope = fallback_scope(src, 5, 10);
225        assert_eq!(scope.end_line, 5);
226    }
227}