1#![allow(clippy::io_other_error)]
2
3use crate::generic_parser_config::GenericParserConfig;
4use crate::language_parser::{GenericFunctionDef, GenericTypeDef, Language, LanguageParser};
5use crate::tree::TreeNode;
6use std::error::Error;
7use std::rc::Rc;
8use tree_sitter::{Node, Parser};
9
10pub struct GenericTreeSitterParser {
11 parser: Parser,
12 config: GenericParserConfig,
13}
14
15impl GenericTreeSitterParser {
16 pub fn new(
18 language: tree_sitter::Language,
19 config: GenericParserConfig,
20 ) -> Result<Self, Box<dyn Error + Send + Sync>> {
21 let mut parser = Parser::new();
22 parser.set_language(&language).map_err(|e| {
23 Box::new(std::io::Error::new(
24 std::io::ErrorKind::Other,
25 format!("Failed to set language: {:?}", e),
26 )) as Box<dyn Error + Send + Sync>
27 })?;
28
29 Ok(Self { parser, config })
30 }
31
32 pub fn from_language_name(language_name: &str) -> Result<Self, Box<dyn Error + Send + Sync>> {
34 let (language, config) = match language_name {
35 "go" => (tree_sitter_go::LANGUAGE.into(), GenericParserConfig::go()),
36 "java" => (tree_sitter_java::LANGUAGE.into(), GenericParserConfig::java()),
37 "c" => (tree_sitter_c::LANGUAGE.into(), GenericParserConfig::c()),
38 "cpp" | "c++" => (tree_sitter_cpp::LANGUAGE.into(), GenericParserConfig::cpp()),
39 "csharp" | "cs" => {
40 (tree_sitter_c_sharp::LANGUAGE.into(), GenericParserConfig::csharp())
41 }
42 "ruby" | "rb" => (tree_sitter_ruby::LANGUAGE.into(), GenericParserConfig::ruby()),
43 _ => {
44 return Err(Box::new(std::io::Error::new(
45 std::io::ErrorKind::InvalidInput,
46 format!("Unsupported language: {}", language_name),
47 )) as Box<dyn Error + Send + Sync>)
48 }
49 };
50
51 Self::new(language, config)
52 }
53
54 fn convert_node(&self, node: Node, source: &str, id_counter: &mut usize) -> TreeNode {
55 let current_id = *id_counter;
56 *id_counter += 1;
57
58 let label = node.kind().to_string();
59 let value = if self.config.value_nodes.contains(&node.kind().to_string()) {
60 node.utf8_text(source.as_bytes()).unwrap_or("").to_string()
61 } else {
62 "".to_string()
63 };
64
65 let mut tree_node = TreeNode::new(label, value, current_id);
66
67 for child in node.children(&mut node.walk()) {
68 let child_node = self.convert_node(child, source, id_counter);
69 tree_node.add_child(Rc::new(child_node));
70 }
71
72 tree_node
73 }
74
75 fn extract_functions_from_node(
76 &self,
77 node: Node,
78 source: &str,
79 functions: &mut Vec<GenericFunctionDef>,
80 class_name: Option<&str>,
81 ) {
82 let node_kind = node.kind();
83
84 if self.config.language == "java" && node_kind == "object_creation_expression" {
86 return;
88 }
89
90 if self.config.function_nodes.contains(&node_kind.to_string()) {
92 if let Some(func_def) = self.extract_function_definition(node, source, class_name) {
93 functions.push(func_def);
94 }
95 }
96
97 if self.config.type_nodes.contains(&node_kind.to_string()) {
99 let new_class_name = node
101 .child_by_field_name(&self.config.field_mappings.name_field)
102 .and_then(|n| n.utf8_text(source.as_bytes()).ok())
103 .unwrap_or("");
104
105 for child in node.children(&mut node.walk()) {
107 self.extract_functions_from_node(child, source, functions, Some(new_class_name));
108 }
109 return; }
111
112 for child in node.children(&mut node.walk()) {
114 self.extract_functions_from_node(child, source, functions, class_name);
115 }
116 }
117
118 fn extract_function_definition(
119 &self,
120 node: Node,
121 source: &str,
122 class_name: Option<&str>,
123 ) -> Option<GenericFunctionDef> {
124 let name_string = if (self.config.language == "c" || self.config.language == "cpp")
126 && node.kind() == "function_definition"
127 {
128 let declarator = node.child_by_field_name("declarator")?;
130
131 match declarator.kind() {
132 "function_declarator" => declarator
133 .child_by_field_name("declarator")
134 .and_then(|n| n.utf8_text(source.as_bytes()).ok())
135 .map(String::from)?,
136 "pointer_declarator" => {
137 let func_decl = declarator
139 .children(&mut declarator.walk())
140 .find(|n| n.kind() == "function_declarator")?;
141 func_decl
142 .child_by_field_name("declarator")
143 .and_then(|n| n.utf8_text(source.as_bytes()).ok())
144 .map(String::from)?
145 }
146 _ => {
147 declarator.utf8_text(source.as_bytes()).ok().map(String::from)?
149 }
150 }
151 } else if self.config.language == "csharp" {
152 match node.kind() {
154 "operator_declaration" => {
155 let operator_symbol = node
157 .child_by_field_name("operator")
158 .and_then(|n| n.utf8_text(source.as_bytes()).ok())?;
159 format!("operator {}", operator_symbol)
160 }
161 "destructor_declaration" => {
162 let class_name = node
164 .child_by_field_name("name")
165 .and_then(|n| n.utf8_text(source.as_bytes()).ok())?;
166 format!("~{}", class_name)
167 }
168 _ => {
169 let name_node =
171 node.child_by_field_name(&self.config.field_mappings.name_field)?;
172 name_node.utf8_text(source.as_bytes()).ok().map(String::from)?
173 }
174 }
175 } else if self.config.language == "elixir" && node.kind() == "call" {
176 let name_result = node
180 .child(1)
181 .filter(|n| n.kind() == "arguments")
182 .and_then(|args| args.child(0))
183 .and_then(|call_node| {
184 if call_node.kind() == "call" {
185 call_node.child_by_field_name("target")
187 } else {
188 None
189 }
190 })
191 .and_then(|n| n.utf8_text(source.as_bytes()).ok())
192 .map(String::from);
193 name_result?
194 } else {
195 let name_node = node.child_by_field_name(&self.config.field_mappings.name_field)?;
197 name_node.utf8_text(source.as_bytes()).ok().map(String::from)?
198 };
199
200 let params_node = if (self.config.language == "c" || self.config.language == "cpp")
202 && node.kind() == "function_definition"
203 {
204 let declarator = node.child_by_field_name("declarator")?;
205 match declarator.kind() {
206 "function_declarator" => declarator.child_by_field_name("parameters"),
207 "pointer_declarator" => declarator
208 .children(&mut declarator.walk())
209 .find(|n| n.kind() == "function_declarator")
210 .and_then(|n| n.child_by_field_name("parameters")),
211 _ => None,
212 }
213 } else {
214 node.child_by_field_name(&self.config.field_mappings.params_field)
215 };
216
217 let body_node = node.child_by_field_name(&self.config.field_mappings.body_field);
218
219 let params = self.extract_parameters(params_node, source);
220 let decorators = self.extract_decorators(node, source);
221 let is_async = self.is_async_function(node, source);
222 let is_generator = self.is_generator_function(node, source);
223
224 Some(GenericFunctionDef {
225 name: name_string,
226 start_line: node.start_position().row as u32 + 1,
227 end_line: node.end_position().row as u32 + 1,
228 body_start_line: body_node.map(|n| n.start_position().row as u32 + 1).unwrap_or(0),
229 body_end_line: body_node.map(|n| n.end_position().row as u32 + 1).unwrap_or(0),
230 parameters: params,
231 is_method: class_name.is_some(),
232 class_name: class_name.map(String::from),
233 is_async,
234 is_generator,
235 decorators,
236 })
237 }
238
239 fn extract_parameters(&self, params_node: Option<Node>, source: &str) -> Vec<String> {
240 let Some(node) = params_node else {
241 return Vec::new();
242 };
243
244 let mut params = Vec::new();
245 let mut cursor = node.walk();
246
247 for child in node.children(&mut cursor) {
248 if self.config.value_nodes.contains(&child.kind().to_string()) {
249 if let Ok(param_text) = child.utf8_text(source.as_bytes()) {
250 params.push(param_text.to_string());
251 }
252 } else if let Some(name_child) =
253 child.child_by_field_name(&self.config.field_mappings.name_field)
254 {
255 if let Ok(param_text) = name_child.utf8_text(source.as_bytes()) {
256 params.push(param_text.to_string());
257 }
258 }
259 }
260
261 params
262 }
263
264 fn extract_decorators(&self, node: Node, source: &str) -> Vec<String> {
265 let mut decorators = Vec::new();
266
267 if let Some(decorator_field) = &self.config.field_mappings.decorator_field {
268 if let Some(parent) = node.parent() {
270 let mut cursor = parent.walk();
271 for child in parent.children(&mut cursor) {
272 if child.kind() == decorator_field
273 && child.end_position().row < node.start_position().row
274 {
275 if let Ok(decorator_text) = child.utf8_text(source.as_bytes()) {
276 decorators.push(decorator_text.trim_start_matches('@').to_string());
277 }
278 }
279 }
280 }
281 }
282
283 decorators
284 }
285
286 fn is_async_function(&self, node: Node, source: &str) -> bool {
287 if let Ok(text) = node.utf8_text(source.as_bytes()) {
289 return text.starts_with("async ");
290 }
291 false
292 }
293
294 fn is_generator_function(&self, node: Node, source: &str) -> bool {
295 if let Some(body) = node.child_by_field_name(&self.config.field_mappings.body_field) {
297 if let Ok(body_text) = body.utf8_text(source.as_bytes()) {
298 return body_text.contains("yield");
299 }
300 }
301 false
302 }
303
304 fn extract_types_from_node(&self, node: Node, source: &str, types: &mut Vec<GenericTypeDef>) {
305 let node_kind = node.kind();
306
307 if self.config.type_nodes.contains(&node_kind.to_string()) {
309 if let Some(type_def) = self.extract_type_definition(node, source) {
310 types.push(type_def);
311 }
312 }
313
314 for child in node.children(&mut node.walk()) {
316 self.extract_types_from_node(child, source, types);
317 }
318 }
319
320 fn extract_type_definition(&self, node: Node, source: &str) -> Option<GenericTypeDef> {
321 let (name, actual_type_node) = if node.kind() == "type_declaration"
323 && self.config.language == "go"
324 {
325 let type_spec = node
327 .child_by_field_name("spec")
328 .or_else(|| node.children(&mut node.walk()).find(|n| n.kind() == "type_spec"))?;
329
330 let name_node = type_spec.child_by_field_name("name").or_else(|| {
331 type_spec.children(&mut type_spec.walk()).find(|n| n.kind() == "type_identifier")
332 })?;
333 let name = name_node.utf8_text(source.as_bytes()).ok()?;
334
335 let actual_type = type_spec
337 .child_by_field_name("type")
338 .or_else(|| type_spec.children(&mut type_spec.walk()).nth(1))?;
339
340 (name, actual_type)
341 } else if node.kind() == "type_definition" && self.config.language == "c" {
342 let declarator = node.child_by_field_name("declarator")?;
344 let name = declarator.utf8_text(source.as_bytes()).ok()?;
345
346 let actual_type = node.child_by_field_name("type").unwrap_or(node);
348
349 (name, actual_type)
350 } else {
351 let name_node = node.child_by_field_name(&self.config.field_mappings.name_field)?;
353 let name = name_node.utf8_text(source.as_bytes()).ok()?;
354 (name, node)
355 };
356
357 Some(GenericTypeDef {
358 name: name.to_string(),
359 kind: actual_type_node.kind().to_string(),
360 start_line: node.start_position().row as u32 + 1,
361 end_line: node.end_position().row as u32 + 1,
362 fields: Vec::new(), })
364 }
365}
366
367impl LanguageParser for GenericTreeSitterParser {
368 fn parse(
369 &mut self,
370 source: &str,
371 _filename: &str,
372 ) -> Result<Rc<TreeNode>, Box<dyn Error + Send + Sync>> {
373 let tree = self.parser.parse(source, None).ok_or_else(|| {
374 Box::new(std::io::Error::new(std::io::ErrorKind::InvalidData, "Failed to parse source"))
375 as Box<dyn Error + Send + Sync>
376 })?;
377
378 let root_node = tree.root_node();
379 let mut id_counter = 0;
380 Ok(Rc::new(self.convert_node(root_node, source, &mut id_counter)))
381 }
382
383 fn extract_functions(
384 &mut self,
385 source: &str,
386 _filename: &str,
387 ) -> Result<Vec<GenericFunctionDef>, Box<dyn Error + Send + Sync>> {
388 let tree = self.parser.parse(source, None).ok_or_else(|| {
389 Box::new(std::io::Error::new(std::io::ErrorKind::InvalidData, "Failed to parse source"))
390 as Box<dyn Error + Send + Sync>
391 })?;
392
393 let root_node = tree.root_node();
394 let mut functions = Vec::new();
395 self.extract_functions_from_node(root_node, source, &mut functions, None);
396 Ok(functions)
397 }
398
399 fn extract_types(
400 &mut self,
401 source: &str,
402 _filename: &str,
403 ) -> Result<Vec<GenericTypeDef>, Box<dyn Error + Send + Sync>> {
404 let tree = self.parser.parse(source, None).ok_or_else(|| {
405 Box::new(std::io::Error::new(std::io::ErrorKind::InvalidData, "Failed to parse source"))
406 as Box<dyn Error + Send + Sync>
407 })?;
408
409 let root_node = tree.root_node();
410 let mut types = Vec::new();
411 self.extract_types_from_node(root_node, source, &mut types);
412 Ok(types)
413 }
414
415 fn language(&self) -> Language {
416 match self.config.language.as_str() {
417 "python" => Language::Python,
418 "rust" => Language::Rust,
419 "javascript" | "typescript" => Language::TypeScript,
420 "go" => Language::Go,
421 "java" => Language::Java,
422 "c" => Language::C,
423 "cpp" => Language::Cpp,
424 "csharp" => Language::CSharp,
425 "ruby" => Language::Ruby,
426 "php" => Language::Php,
427 _ => Language::Unknown,
428 }
429 }
430}
431
432#[cfg(test)]
433mod tests {
434 use super::*;
435
436 #[test]
437 fn test_generic_parser_with_go() {
438 let mut parser = GenericTreeSitterParser::from_language_name("go").unwrap();
439
440 let source = r#"
441package main
442
443func hello(name string) string {
444 return "Hello, " + name + "!"
445}
446
447type Greeter struct {}
448
449func (g *Greeter) greet(name string) string {
450 return "Hi, " + name + "!"
451}
452"#;
453
454 let functions = parser.extract_functions(source, "test.go").unwrap();
455 assert_eq!(functions.len(), 2);
456 assert_eq!(functions[0].name, "hello");
457 assert_eq!(functions[1].name, "greet");
458 }
459
460 #[test]
461 fn test_generic_parser_with_java() {
462 let mut parser = GenericTreeSitterParser::from_language_name("java").unwrap();
463
464 let source = r#"
465public class Calculator {
466 public int add(int a, int b) {
467 return a + b;
468 }
469
470 public int multiply(int x, int y) {
471 return x * y;
472 }
473}
474"#;
475
476 let functions = parser.extract_functions(source, "Test.java").unwrap();
477 assert_eq!(functions.len(), 2);
478 assert_eq!(functions[0].name, "add");
479 assert_eq!(functions[1].name, "multiply");
480 }
481}