similarity_core/
generic_tree_sitter_parser.rs

1#![allow(clippy::io_other_error)]
2
3use crate::generic_parser_config::GenericParserConfig;
4use crate::language_parser::{GenericFunctionDef, GenericTypeDef, Language, LanguageParser};
5use crate::tree::TreeNode;
6use std::error::Error;
7use std::rc::Rc;
8use tree_sitter::{Node, Parser};
9
10pub struct GenericTreeSitterParser {
11    parser: Parser,
12    config: GenericParserConfig,
13}
14
15impl GenericTreeSitterParser {
16    /// Create a new generic parser with the given tree-sitter language and configuration
17    pub fn new(
18        language: tree_sitter::Language,
19        config: GenericParserConfig,
20    ) -> Result<Self, Box<dyn Error + Send + Sync>> {
21        let mut parser = Parser::new();
22        parser.set_language(&language).map_err(|e| {
23            Box::new(std::io::Error::new(
24                std::io::ErrorKind::Other,
25                format!("Failed to set language: {:?}", e),
26            )) as Box<dyn Error + Send + Sync>
27        })?;
28
29        Ok(Self { parser, config })
30    }
31
32    /// Create from a pre-configured language
33    pub fn from_language_name(language_name: &str) -> Result<Self, Box<dyn Error + Send + Sync>> {
34        let (language, config) = match language_name {
35            "go" => (tree_sitter_go::LANGUAGE.into(), GenericParserConfig::go()),
36            "java" => (tree_sitter_java::LANGUAGE.into(), GenericParserConfig::java()),
37            "c" => (tree_sitter_c::LANGUAGE.into(), GenericParserConfig::c()),
38            "cpp" | "c++" => (tree_sitter_cpp::LANGUAGE.into(), GenericParserConfig::cpp()),
39            "csharp" | "cs" => {
40                (tree_sitter_c_sharp::LANGUAGE.into(), GenericParserConfig::csharp())
41            }
42            "ruby" | "rb" => (tree_sitter_ruby::LANGUAGE.into(), GenericParserConfig::ruby()),
43            _ => {
44                return Err(Box::new(std::io::Error::new(
45                    std::io::ErrorKind::InvalidInput,
46                    format!("Unsupported language: {}", language_name),
47                )) as Box<dyn Error + Send + Sync>)
48            }
49        };
50
51        Self::new(language, config)
52    }
53
54    fn convert_node(&self, node: Node, source: &str, id_counter: &mut usize) -> TreeNode {
55        let current_id = *id_counter;
56        *id_counter += 1;
57
58        let label = node.kind().to_string();
59        let value = if self.config.value_nodes.contains(&node.kind().to_string()) {
60            node.utf8_text(source.as_bytes()).unwrap_or("").to_string()
61        } else {
62            "".to_string()
63        };
64
65        let mut tree_node = TreeNode::new(label, value, current_id);
66
67        for child in node.children(&mut node.walk()) {
68            let child_node = self.convert_node(child, source, id_counter);
69            tree_node.add_child(Rc::new(child_node));
70        }
71
72        tree_node
73    }
74
75    fn extract_functions_from_node(
76        &self,
77        node: Node,
78        source: &str,
79        functions: &mut Vec<GenericFunctionDef>,
80        class_name: Option<&str>,
81    ) {
82        let node_kind = node.kind();
83
84        // Skip anonymous class bodies in Java
85        if self.config.language == "java" && node_kind == "object_creation_expression" {
86            // Skip the class_body of anonymous classes
87            return;
88        }
89
90        // Check if this is a function node
91        if self.config.function_nodes.contains(&node_kind.to_string()) {
92            if let Some(func_def) = self.extract_function_definition(node, source, class_name) {
93                functions.push(func_def);
94            }
95        }
96
97        // Check if this is a type/class node
98        if self.config.type_nodes.contains(&node_kind.to_string()) {
99            // Extract class name for nested functions
100            let new_class_name = node
101                .child_by_field_name(&self.config.field_mappings.name_field)
102                .and_then(|n| n.utf8_text(source.as_bytes()).ok())
103                .unwrap_or("");
104
105            // Recursively extract methods from class
106            for child in node.children(&mut node.walk()) {
107                self.extract_functions_from_node(child, source, functions, Some(new_class_name));
108            }
109            return; // Don't continue normal traversal for type nodes
110        }
111
112        // Continue searching in children
113        for child in node.children(&mut node.walk()) {
114            self.extract_functions_from_node(child, source, functions, class_name);
115        }
116    }
117
118    fn extract_function_definition(
119        &self,
120        node: Node,
121        source: &str,
122        class_name: Option<&str>,
123    ) -> Option<GenericFunctionDef> {
124        // Extract the function name first, which might require special handling
125        let name_string = if (self.config.language == "c" || self.config.language == "cpp")
126            && node.kind() == "function_definition"
127        {
128            // In C/C++, the declarator contains both name and parameters
129            let declarator = node.child_by_field_name("declarator")?;
130
131            match declarator.kind() {
132                "function_declarator" => declarator
133                    .child_by_field_name("declarator")
134                    .and_then(|n| n.utf8_text(source.as_bytes()).ok())
135                    .map(String::from)?,
136                "pointer_declarator" => {
137                    // Handle functions returning pointers
138                    let func_decl = declarator
139                        .children(&mut declarator.walk())
140                        .find(|n| n.kind() == "function_declarator")?;
141                    func_decl
142                        .child_by_field_name("declarator")
143                        .and_then(|n| n.utf8_text(source.as_bytes()).ok())
144                        .map(String::from)?
145                }
146                _ => {
147                    // Simple function without parameters
148                    declarator.utf8_text(source.as_bytes()).ok().map(String::from)?
149                }
150            }
151        } else if self.config.language == "csharp" {
152            // Special handling for C# methods
153            match node.kind() {
154                "operator_declaration" => {
155                    // For operators, construct name as "operator <symbol>"
156                    let operator_symbol = node
157                        .child_by_field_name("operator")
158                        .and_then(|n| n.utf8_text(source.as_bytes()).ok())?;
159                    format!("operator {}", operator_symbol)
160                }
161                "destructor_declaration" => {
162                    // For destructors, add ~ prefix
163                    let class_name = node
164                        .child_by_field_name("name")
165                        .and_then(|n| n.utf8_text(source.as_bytes()).ok())?;
166                    format!("~{}", class_name)
167                }
168                _ => {
169                    // Standard C# methods
170                    let name_node =
171                        node.child_by_field_name(&self.config.field_mappings.name_field)?;
172                    name_node.utf8_text(source.as_bytes()).ok().map(String::from)?
173                }
174            }
175        } else if self.config.language == "elixir" && node.kind() == "call" {
176            // Special handling for Elixir functions
177            // The function name is in the arguments field (first call node)
178            // For Elixir def/defp, arguments is the second child (index 1)
179            let name_result = node
180                .child(1)
181                .filter(|n| n.kind() == "arguments")
182                .and_then(|args| args.child(0))
183                .and_then(|call_node| {
184                    if call_node.kind() == "call" {
185                        // Get the target of the inner call (the function name)
186                        call_node.child_by_field_name("target")
187                    } else {
188                        None
189                    }
190                })
191                .and_then(|n| n.utf8_text(source.as_bytes()).ok())
192                .map(String::from);
193            name_result?
194        } else {
195            // For other languages, use the standard field mapping
196            let name_node = node.child_by_field_name(&self.config.field_mappings.name_field)?;
197            name_node.utf8_text(source.as_bytes()).ok().map(String::from)?
198        };
199
200        // Extract parameters - special handling for C/C++
201        let params_node = if (self.config.language == "c" || self.config.language == "cpp")
202            && node.kind() == "function_definition"
203        {
204            let declarator = node.child_by_field_name("declarator")?;
205            match declarator.kind() {
206                "function_declarator" => declarator.child_by_field_name("parameters"),
207                "pointer_declarator" => declarator
208                    .children(&mut declarator.walk())
209                    .find(|n| n.kind() == "function_declarator")
210                    .and_then(|n| n.child_by_field_name("parameters")),
211                _ => None,
212            }
213        } else {
214            node.child_by_field_name(&self.config.field_mappings.params_field)
215        };
216
217        let body_node = node.child_by_field_name(&self.config.field_mappings.body_field);
218
219        let params = self.extract_parameters(params_node, source);
220        let decorators = self.extract_decorators(node, source);
221        let is_async = self.is_async_function(node, source);
222        let is_generator = self.is_generator_function(node, source);
223
224        Some(GenericFunctionDef {
225            name: name_string,
226            start_line: node.start_position().row as u32 + 1,
227            end_line: node.end_position().row as u32 + 1,
228            body_start_line: body_node.map(|n| n.start_position().row as u32 + 1).unwrap_or(0),
229            body_end_line: body_node.map(|n| n.end_position().row as u32 + 1).unwrap_or(0),
230            parameters: params,
231            is_method: class_name.is_some(),
232            class_name: class_name.map(String::from),
233            is_async,
234            is_generator,
235            decorators,
236        })
237    }
238
239    fn extract_parameters(&self, params_node: Option<Node>, source: &str) -> Vec<String> {
240        let Some(node) = params_node else {
241            return Vec::new();
242        };
243
244        let mut params = Vec::new();
245        let mut cursor = node.walk();
246
247        for child in node.children(&mut cursor) {
248            if self.config.value_nodes.contains(&child.kind().to_string()) {
249                if let Ok(param_text) = child.utf8_text(source.as_bytes()) {
250                    params.push(param_text.to_string());
251                }
252            } else if let Some(name_child) =
253                child.child_by_field_name(&self.config.field_mappings.name_field)
254            {
255                if let Ok(param_text) = name_child.utf8_text(source.as_bytes()) {
256                    params.push(param_text.to_string());
257                }
258            }
259        }
260
261        params
262    }
263
264    fn extract_decorators(&self, node: Node, source: &str) -> Vec<String> {
265        let mut decorators = Vec::new();
266
267        if let Some(decorator_field) = &self.config.field_mappings.decorator_field {
268            // Look for decorator nodes
269            if let Some(parent) = node.parent() {
270                let mut cursor = parent.walk();
271                for child in parent.children(&mut cursor) {
272                    if child.kind() == decorator_field
273                        && child.end_position().row < node.start_position().row
274                    {
275                        if let Ok(decorator_text) = child.utf8_text(source.as_bytes()) {
276                            decorators.push(decorator_text.trim_start_matches('@').to_string());
277                        }
278                    }
279                }
280            }
281        }
282
283        decorators
284    }
285
286    fn is_async_function(&self, node: Node, source: &str) -> bool {
287        // Check if the function definition contains async keyword
288        if let Ok(text) = node.utf8_text(source.as_bytes()) {
289            return text.starts_with("async ");
290        }
291        false
292    }
293
294    fn is_generator_function(&self, node: Node, source: &str) -> bool {
295        // Check if function body contains yield
296        if let Some(body) = node.child_by_field_name(&self.config.field_mappings.body_field) {
297            if let Ok(body_text) = body.utf8_text(source.as_bytes()) {
298                return body_text.contains("yield");
299            }
300        }
301        false
302    }
303
304    fn extract_types_from_node(&self, node: Node, source: &str, types: &mut Vec<GenericTypeDef>) {
305        let node_kind = node.kind();
306
307        // Check if this is a type node
308        if self.config.type_nodes.contains(&node_kind.to_string()) {
309            if let Some(type_def) = self.extract_type_definition(node, source) {
310                types.push(type_def);
311            }
312        }
313
314        // Continue searching in children
315        for child in node.children(&mut node.walk()) {
316            self.extract_types_from_node(child, source, types);
317        }
318    }
319
320    fn extract_type_definition(&self, node: Node, source: &str) -> Option<GenericTypeDef> {
321        // Special handling for Go's type_declaration
322        let (name, actual_type_node) = if node.kind() == "type_declaration"
323            && self.config.language == "go"
324        {
325            // In Go, type_declaration -> type_spec -> actual type
326            let type_spec = node
327                .child_by_field_name("spec")
328                .or_else(|| node.children(&mut node.walk()).find(|n| n.kind() == "type_spec"))?;
329
330            let name_node = type_spec.child_by_field_name("name").or_else(|| {
331                type_spec.children(&mut type_spec.walk()).find(|n| n.kind() == "type_identifier")
332            })?;
333            let name = name_node.utf8_text(source.as_bytes()).ok()?;
334
335            // Get the actual type (struct_type, interface_type, etc.)
336            let actual_type = type_spec
337                .child_by_field_name("type")
338                .or_else(|| type_spec.children(&mut type_spec.walk()).nth(1))?;
339
340            (name, actual_type)
341        } else if node.kind() == "type_definition" && self.config.language == "c" {
342            // In C, type_definition has a declarator field
343            let declarator = node.child_by_field_name("declarator")?;
344            let name = declarator.utf8_text(source.as_bytes()).ok()?;
345
346            // Get the actual type from the type field
347            let actual_type = node.child_by_field_name("type").unwrap_or(node);
348
349            (name, actual_type)
350        } else {
351            // For other languages, use the standard field mapping
352            let name_node = node.child_by_field_name(&self.config.field_mappings.name_field)?;
353            let name = name_node.utf8_text(source.as_bytes()).ok()?;
354            (name, node)
355        };
356
357        Some(GenericTypeDef {
358            name: name.to_string(),
359            kind: actual_type_node.kind().to_string(),
360            start_line: node.start_position().row as u32 + 1,
361            end_line: node.end_position().row as u32 + 1,
362            fields: Vec::new(), // TODO: Extract fields based on language
363        })
364    }
365}
366
367impl LanguageParser for GenericTreeSitterParser {
368    fn parse(
369        &mut self,
370        source: &str,
371        _filename: &str,
372    ) -> Result<Rc<TreeNode>, Box<dyn Error + Send + Sync>> {
373        let tree = self.parser.parse(source, None).ok_or_else(|| {
374            Box::new(std::io::Error::new(std::io::ErrorKind::InvalidData, "Failed to parse source"))
375                as Box<dyn Error + Send + Sync>
376        })?;
377
378        let root_node = tree.root_node();
379        let mut id_counter = 0;
380        Ok(Rc::new(self.convert_node(root_node, source, &mut id_counter)))
381    }
382
383    fn extract_functions(
384        &mut self,
385        source: &str,
386        _filename: &str,
387    ) -> Result<Vec<GenericFunctionDef>, Box<dyn Error + Send + Sync>> {
388        let tree = self.parser.parse(source, None).ok_or_else(|| {
389            Box::new(std::io::Error::new(std::io::ErrorKind::InvalidData, "Failed to parse source"))
390                as Box<dyn Error + Send + Sync>
391        })?;
392
393        let root_node = tree.root_node();
394        let mut functions = Vec::new();
395        self.extract_functions_from_node(root_node, source, &mut functions, None);
396        Ok(functions)
397    }
398
399    fn extract_types(
400        &mut self,
401        source: &str,
402        _filename: &str,
403    ) -> Result<Vec<GenericTypeDef>, Box<dyn Error + Send + Sync>> {
404        let tree = self.parser.parse(source, None).ok_or_else(|| {
405            Box::new(std::io::Error::new(std::io::ErrorKind::InvalidData, "Failed to parse source"))
406                as Box<dyn Error + Send + Sync>
407        })?;
408
409        let root_node = tree.root_node();
410        let mut types = Vec::new();
411        self.extract_types_from_node(root_node, source, &mut types);
412        Ok(types)
413    }
414
415    fn language(&self) -> Language {
416        match self.config.language.as_str() {
417            "python" => Language::Python,
418            "rust" => Language::Rust,
419            "javascript" | "typescript" => Language::TypeScript,
420            "go" => Language::Go,
421            "java" => Language::Java,
422            "c" => Language::C,
423            "cpp" => Language::Cpp,
424            "csharp" => Language::CSharp,
425            "ruby" => Language::Ruby,
426            "php" => Language::Php,
427            _ => Language::Unknown,
428        }
429    }
430}
431
432#[cfg(test)]
433mod tests {
434    use super::*;
435
436    #[test]
437    fn test_generic_parser_with_go() {
438        let mut parser = GenericTreeSitterParser::from_language_name("go").unwrap();
439
440        let source = r#"
441package main
442
443func hello(name string) string {
444    return "Hello, " + name + "!"
445}
446
447type Greeter struct {}
448
449func (g *Greeter) greet(name string) string {
450    return "Hi, " + name + "!"
451}
452"#;
453
454        let functions = parser.extract_functions(source, "test.go").unwrap();
455        assert_eq!(functions.len(), 2);
456        assert_eq!(functions[0].name, "hello");
457        assert_eq!(functions[1].name, "greet");
458    }
459
460    #[test]
461    fn test_generic_parser_with_java() {
462        let mut parser = GenericTreeSitterParser::from_language_name("java").unwrap();
463
464        let source = r#"
465public class Calculator {
466    public int add(int a, int b) {
467        return a + b;
468    }
469    
470    public int multiply(int x, int y) {
471        return x * y;
472    }
473}
474"#;
475
476        let functions = parser.extract_functions(source, "Test.java").unwrap();
477        assert_eq!(functions.len(), 2);
478        assert_eq!(functions[0].name, "add");
479        assert_eq!(functions[1].name, "multiply");
480    }
481}