1use tree_sitter::{Language, Node, Parser};
2
3#[derive(Debug, Clone, serde::Serialize)]
4pub struct CodeScope {
5 pub kind: String,
6 pub name: String,
7 pub start_line: usize, pub end_line: usize, }
10
11pub fn get_language(file_ext: &str) -> Option<Language> {
12 match file_ext {
13 "rs" => Some(tree_sitter_rust::LANGUAGE.into()),
14 "ts" | "tsx" => Some(tree_sitter_typescript::LANGUAGE_TYPESCRIPT.into()),
15 "js" | "jsx" => Some(tree_sitter_javascript::LANGUAGE.into()),
16 "py" => Some(tree_sitter_python::LANGUAGE.into()),
17 "go" => Some(tree_sitter_go::LANGUAGE.into()),
18 "java" => Some(tree_sitter_java::LANGUAGE.into()),
19 "scala" | "sc" => Some(tree_sitter_scala::LANGUAGE.into()),
20 _ => None,
21 }
22}
23
24pub fn find_enclosing_scope(source: &str, file_ext: &str, line: usize) -> Option<CodeScope> {
27 let language = get_language(file_ext)?;
28 let mut parser = Parser::new();
29 parser.set_language(&language).ok()?;
30 let tree = parser.parse(source, None)?;
31
32 let target_line = line.checked_sub(1)?; let scope_node_types = scope_types_for_ext(file_ext);
34
35 let mut best: Option<(Node, &str)> = None;
36 walk_tree(tree.root_node(), target_line, &scope_node_types, &mut best);
37
38 let (node, kind) = best?;
39 let name = extract_name(node, source);
40
41 Some(CodeScope {
42 kind: kind.to_string(),
43 name: name.unwrap_or_else(|| "<anonymous>".to_string()),
44 start_line: node.start_position().row + 1,
45 end_line: node.end_position().row + 1,
46 })
47}
48
49fn walk_tree<'a>(
50 node: Node<'a>,
51 target_line: usize,
52 scope_types: &[(&str, &'a str)],
53 best: &mut Option<(Node<'a>, &'a str)>,
54) {
55 let start = node.start_position().row;
56 let end = node.end_position().row;
57
58 if target_line < start || target_line > end {
59 return;
60 }
61
62 for (node_type, kind) in scope_types {
63 if node.kind() == *node_type
64 && best.as_ref().is_none_or(|(b, _)| {
65 let b_range = b.end_position().row - b.start_position().row;
66 let n_range = end - start;
67 n_range < b_range
68 })
69 {
70 *best = Some((node, kind));
71 }
72 }
73
74 for i in 0..node.child_count() {
75 if let Some(child) = node.child(i) {
76 walk_tree(child, target_line, scope_types, best);
77 }
78 }
79}
80
81fn scope_types_for_ext(ext: &str) -> Vec<(&'static str, &'static str)> {
82 match ext {
83 "rs" => vec![
84 ("function_item", "function"),
85 ("impl_item", "impl"),
86 ("struct_item", "struct"),
87 ("enum_item", "enum"),
88 ("trait_item", "trait"),
89 ("mod_item", "module"),
90 ],
91 "ts" | "tsx" | "js" | "jsx" => vec![
92 ("function_declaration", "function"),
93 ("method_definition", "method"),
94 ("arrow_function", "function"),
95 ("class_declaration", "class"),
96 ("interface_declaration", "interface"),
97 ],
98 "py" => vec![
99 ("function_definition", "function"),
100 ("class_definition", "class"),
101 ],
102 "go" => vec![
103 ("function_declaration", "function"),
104 ("method_declaration", "method"),
105 ("type_declaration", "type"),
106 ],
107 "java" => vec![
108 ("method_declaration", "method"),
109 ("constructor_declaration", "constructor"),
110 ("class_declaration", "class"),
111 ("interface_declaration", "interface"),
112 ("enum_declaration", "enum"),
113 ],
114 "scala" | "sc" => vec![
115 ("function_definition", "function"),
116 ("class_definition", "class"),
117 ("object_definition", "object"),
118 ("trait_definition", "trait"),
119 ],
120 _ => vec![],
121 }
122}
123
124fn extract_name(node: Node, source: &str) -> Option<String> {
125 for i in 0..node.child_count() {
126 if let Some(child) = node.child(i) {
127 let field_name = node.field_name_for_child(i as u32);
128 if matches!(field_name, Some("name")) {
129 let text = &source[child.byte_range()];
130 return Some(text.to_string());
131 }
132 if child.kind() == "identifier" || child.kind() == "type_identifier" {
133 let text = &source[child.byte_range()];
134 return Some(text.to_string());
135 }
136 }
137 }
138 None
139}
140
141pub fn fallback_scope(source: &str, line: usize, context_lines: usize) -> CodeScope {
143 let total_lines = source.lines().count();
144 let start = line.saturating_sub(context_lines).max(1);
145 let end = (line + context_lines).min(total_lines);
146 CodeScope {
147 kind: "region".to_string(),
148 name: format!("lines {start}-{end}"),
149 start_line: start,
150 end_line: end,
151 }
152}
153
154#[cfg(test)]
155mod tests {
156 use super::*;
157
158 #[test]
159 fn get_language_rust() {
160 assert!(get_language("rs").is_some());
161 }
162
163 #[test]
164 fn get_language_typescript() {
165 assert!(get_language("ts").is_some());
166 assert!(get_language("tsx").is_some());
167 }
168
169 #[test]
170 fn get_language_javascript() {
171 assert!(get_language("js").is_some());
172 assert!(get_language("jsx").is_some());
173 }
174
175 #[test]
176 fn get_language_python() {
177 assert!(get_language("py").is_some());
178 }
179
180 #[test]
181 fn get_language_go() {
182 assert!(get_language("go").is_some());
183 }
184
185 #[test]
186 fn get_language_java() {
187 assert!(get_language("java").is_some());
188 }
189
190 #[test]
191 fn get_language_scala() {
192 assert!(get_language("scala").is_some());
193 assert!(get_language("sc").is_some());
194 }
195
196 #[test]
197 fn get_language_unknown_returns_none() {
198 assert!(get_language("txt").is_none());
199 assert!(get_language("md").is_none());
200 }
201
202 #[test]
208 fn find_scope_unknown_language_returns_none() {
209 let src = "some text content\nline two\n";
210 let scope = find_enclosing_scope(src, "txt", 1);
211 assert!(scope.is_none());
212 }
213
214 #[test]
215 fn fallback_scope_near_start() {
216 let src = "line1\nline2\nline3\nline4\nline5\n";
217 let scope = fallback_scope(src, 1, 10);
218 assert_eq!(scope.start_line, 1);
219 }
220
221 #[test]
222 fn fallback_scope_near_end() {
223 let src = "line1\nline2\nline3\nline4\nline5\n";
224 let scope = fallback_scope(src, 5, 10);
225 assert_eq!(scope.end_line, 5);
226 }
227}