Skip to main content

similarity_py/
python_parser.rs

1#![allow(clippy::io_other_error)]
2
3use similarity_core::language_parser::{
4    GenericFunctionDef, GenericTypeDef, Language, LanguageParser,
5};
6use similarity_core::tree::TreeNode;
7use std::error::Error;
8use std::rc::Rc;
9use tree_sitter::{Node, Parser};
10
11pub struct PythonParser {
12    parser: Parser,
13}
14
15impl PythonParser {
16    pub fn new() -> Result<Self, Box<dyn Error + Send + Sync>> {
17        let mut parser = Parser::new();
18        parser.set_language(&tree_sitter_python::LANGUAGE.into()).map_err(|e| {
19            Box::new(std::io::Error::new(
20                std::io::ErrorKind::Other,
21                format!("Failed to set Python language: {e:?}"),
22            )) as Box<dyn Error + Send + Sync>
23        })?;
24
25        Ok(Self { parser })
26    }
27
28    #[allow(clippy::only_used_in_recursion)]
29    fn convert_node(&self, node: Node, source: &str, id_counter: &mut usize) -> TreeNode {
30        let current_id = *id_counter;
31        *id_counter += 1;
32
33        let label = node.kind().to_string();
34        let value = match node.kind() {
35            "identifier" | "string" | "integer" | "float" | "true" | "false" | "none" => {
36                node.utf8_text(source.as_bytes()).unwrap_or("").to_string()
37            }
38            _ => "".to_string(),
39        };
40
41        let mut tree_node = TreeNode::new(label, value, current_id);
42
43        for child in node.children(&mut node.walk()) {
44            let child_node = self.convert_node(child, source, id_counter);
45            tree_node.add_child(Rc::new(child_node));
46        }
47
48        tree_node
49    }
50
51    fn extract_functions_from_node(
52        &self,
53        node: Node,
54        source: &str,
55        class_name: Option<&str>,
56    ) -> Vec<GenericFunctionDef> {
57        let mut functions = Vec::new();
58
59        // Visit all nodes
60        fn visit_node(
61            node: Node,
62            source: &str,
63            functions: &mut Vec<GenericFunctionDef>,
64            class_name: Option<&str>,
65        ) {
66            match node.kind() {
67                "function_definition" => {
68                    if let Some(name_node) = node.child_by_field_name("name") {
69                        if let Ok(name) = name_node.utf8_text(source.as_bytes()) {
70                            let params_node = node.child_by_field_name("parameters");
71                            let body_node = node.child_by_field_name("body");
72
73                            let params = extract_params(params_node, source);
74
75                            functions.push(GenericFunctionDef {
76                                name: name.to_string(),
77                                start_line: node.start_position().row as u32 + 1,
78                                end_line: node.end_position().row as u32 + 1,
79                                body_start_line: body_node
80                                    .map(|n| n.start_position().row as u32 + 1)
81                                    .unwrap_or(0),
82                                body_end_line: body_node
83                                    .map(|n| n.end_position().row as u32 + 1)
84                                    .unwrap_or(0),
85                                parameters: params,
86                                is_method: class_name.is_some(),
87                                class_name: class_name.map(|s| s.to_string()),
88                                is_async: is_async_def(node, source),
89                                is_generator: is_generator_def(node, source),
90                                decorators: extract_decorators(node, source),
91                            });
92                        }
93                    }
94                }
95                "decorated_definition" => {
96                    // Check if it decorates a function
97                    if let Some(child) = node.child(node.child_count().saturating_sub(1)) {
98                        if child.kind() == "function_definition" {
99                            if let Some(name_node) = child.child_by_field_name("name") {
100                                if let Ok(name) = name_node.utf8_text(source.as_bytes()) {
101                                    let params_node = child.child_by_field_name("parameters");
102                                    let body_node = child.child_by_field_name("body");
103
104                                    let params = extract_params(params_node, source);
105
106                                    functions.push(GenericFunctionDef {
107                                        name: name.to_string(),
108                                        start_line: node.start_position().row as u32 + 1,
109                                        end_line: node.end_position().row as u32 + 1,
110                                        body_start_line: body_node
111                                            .map(|n| n.start_position().row as u32 + 1)
112                                            .unwrap_or(0),
113                                        body_end_line: body_node
114                                            .map(|n| n.end_position().row as u32 + 1)
115                                            .unwrap_or(0),
116                                        parameters: params,
117                                        is_method: class_name.is_some(),
118                                        class_name: class_name.map(|s| s.to_string()),
119                                        is_async: is_async_def(child, source),
120                                        is_generator: is_generator_def(child, source),
121                                        decorators: extract_decorators(child, source),
122                                    });
123                                }
124                            }
125                        }
126                    }
127                }
128                "class_definition" => {
129                    // Don't recurse into nested classes when we're already in a class
130                    if class_name.is_none() {
131                        if let Some(name_node) = node.child_by_field_name("name") {
132                            if let Ok(name) = name_node.utf8_text(source.as_bytes()) {
133                                // Recursively extract methods from this class
134                                let mut subcursor = node.walk();
135                                for child in node.children(&mut subcursor) {
136                                    visit_node(child, source, functions, Some(name));
137                                }
138                            }
139                        }
140                    }
141                }
142                _ => {
143                    // Continue traversing for other node types
144                    let mut subcursor = node.walk();
145                    for child in node.children(&mut subcursor) {
146                        visit_node(child, source, functions, class_name);
147                    }
148                }
149            }
150        }
151
152        fn is_async_def(node: Node, source: &str) -> bool {
153            if let Ok(text) = node.utf8_text(source.as_bytes()) {
154                text.starts_with("async ")
155            } else {
156                false
157            }
158        }
159
160        fn is_generator_def(node: Node, source: &str) -> bool {
161            // Python generators are functions that contain yield statements
162            // For simplicity, we'll just check if the function body contains "yield"
163            if let Some(body) = node.child_by_field_name("body") {
164                if let Ok(body_text) = body.utf8_text(source.as_bytes()) {
165                    return body_text.contains("yield");
166                }
167            }
168            false
169        }
170
171        fn extract_decorators(node: Node, source: &str) -> Vec<String> {
172            let mut decorators = Vec::new();
173            let mut cursor = node.walk();
174
175            // Look for decorator nodes before the function definition
176            if let Some(parent) = node.parent() {
177                for child in parent.children(&mut cursor) {
178                    if child.kind() == "decorator"
179                        && child.end_position().row < node.start_position().row
180                    {
181                        if let Ok(decorator_text) = child.utf8_text(source.as_bytes()) {
182                            decorators.push(decorator_text.trim_start_matches('@').to_string());
183                        }
184                    }
185                }
186            }
187
188            decorators
189        }
190
191        fn extract_params(params_node: Option<Node>, source: &str) -> Vec<String> {
192            if let Some(node) = params_node {
193                let mut params = Vec::new();
194                let mut cursor = node.walk();
195
196                for child in node.children(&mut cursor) {
197                    match child.kind() {
198                        "identifier" => {
199                            if let Ok(param_text) = child.utf8_text(source.as_bytes()) {
200                                params.push(param_text.to_string());
201                            }
202                        }
203                        "typed_parameter" | "default_parameter" => {
204                            if let Some(ident) = child.child_by_field_name("name") {
205                                if let Ok(param_text) = ident.utf8_text(source.as_bytes()) {
206                                    params.push(param_text.to_string());
207                                }
208                            }
209                        }
210                        _ => {}
211                    }
212                }
213
214                params
215            } else {
216                Vec::new()
217            }
218        }
219
220        visit_node(node, source, &mut functions, class_name);
221        functions
222    }
223}
224
225impl LanguageParser for PythonParser {
226    fn parse(
227        &mut self,
228        source: &str,
229        _filename: &str,
230    ) -> Result<Rc<TreeNode>, Box<dyn Error + Send + Sync>> {
231        let tree = self.parser.parse(source, None).ok_or_else(|| {
232            Box::new(std::io::Error::new(
233                std::io::ErrorKind::InvalidData,
234                "Failed to parse Python source",
235            )) as Box<dyn Error + Send + Sync>
236        })?;
237
238        let root_node = tree.root_node();
239        let mut id_counter = 0;
240        Ok(Rc::new(self.convert_node(root_node, source, &mut id_counter)))
241    }
242
243    fn extract_functions(
244        &mut self,
245        source: &str,
246        _filename: &str,
247    ) -> Result<Vec<GenericFunctionDef>, Box<dyn Error + Send + Sync>> {
248        let tree = self.parser.parse(source, None).ok_or_else(|| {
249            Box::new(std::io::Error::new(
250                std::io::ErrorKind::InvalidData,
251                "Failed to parse Python source",
252            )) as Box<dyn Error + Send + Sync>
253        })?;
254
255        let root_node = tree.root_node();
256        Ok(self.extract_functions_from_node(root_node, source, None))
257    }
258
259    fn extract_types(
260        &mut self,
261        source: &str,
262        _filename: &str,
263    ) -> Result<Vec<GenericTypeDef>, Box<dyn Error + Send + Sync>> {
264        let tree = self.parser.parse(source, None).ok_or_else(|| {
265            Box::new(std::io::Error::new(
266                std::io::ErrorKind::InvalidData,
267                "Failed to parse Python source",
268            )) as Box<dyn Error + Send + Sync>
269        })?;
270
271        let root_node = tree.root_node();
272        let mut types = Vec::new();
273
274        fn visit_node_for_types(node: Node, source: &str, types: &mut Vec<GenericTypeDef>) {
275            if node.kind() == "class_definition" {
276                if let Some(name_node) = node.child_by_field_name("name") {
277                    if let Ok(name) = name_node.utf8_text(source.as_bytes()) {
278                        types.push(GenericTypeDef {
279                            name: name.to_string(),
280                            kind: "class".to_string(),
281                            start_line: node.start_position().row as u32 + 1,
282                            end_line: node.end_position().row as u32 + 1,
283                            fields: extract_class_fields(node, source),
284                        });
285                    }
286                }
287            }
288
289            // Continue traversing
290            let mut cursor = node.walk();
291            for child in node.children(&mut cursor) {
292                visit_node_for_types(child, source, types);
293            }
294        }
295
296        fn extract_class_fields(node: Node, source: &str) -> Vec<String> {
297            let mut fields = Vec::new();
298
299            if let Some(body) = node.child_by_field_name("body") {
300                let mut cursor = body.walk();
301                for child in body.children(&mut cursor) {
302                    // Look for instance variable assignments in __init__ method
303                    if child.kind() == "function_definition" {
304                        if let Some(name_node) = child.child_by_field_name("name") {
305                            if let Ok(name) = name_node.utf8_text(source.as_bytes()) {
306                                if name == "__init__" {
307                                    // Extract self.field assignments from __init__
308                                    if let Some(func_body) = child.child_by_field_name("body") {
309                                        extract_self_assignments(func_body, source, &mut fields);
310                                    }
311                                }
312                            }
313                        }
314                    }
315                }
316            }
317
318            fields
319        }
320
321        fn extract_self_assignments(node: Node, source: &str, fields: &mut Vec<String>) {
322            let mut cursor = node.walk();
323            for child in node.children(&mut cursor) {
324                if child.kind() == "assignment" {
325                    if let Some(left) = child.child(0) {
326                        if left.kind() == "attribute" {
327                            if let Ok(text) = left.utf8_text(source.as_bytes()) {
328                                if text.starts_with("self.") {
329                                    let field_name = text.trim_start_matches("self.");
330                                    if !fields.contains(&field_name.to_string()) {
331                                        fields.push(field_name.to_string());
332                                    }
333                                }
334                            }
335                        }
336                    }
337                }
338                // Recursively check nested nodes
339                extract_self_assignments(child, source, fields);
340            }
341        }
342
343        visit_node_for_types(root_node, source, &mut types);
344        Ok(types)
345    }
346
347    fn language(&self) -> Language {
348        Language::Python
349    }
350}
351
352#[cfg(test)]
353mod tests {
354    use super::*;
355
356    #[test]
357    fn test_python_functions() {
358        let mut parser = PythonParser::new().unwrap();
359        let source = r#"
360def hello(name):
361    return f"Hello, {name}!"
362
363def add(a, b=0):
364    return a + b
365
366class Calculator:
367    def __init__(self):
368        self.result = 0
369    
370    def add(self, x):
371        self.result += x
372        return self.result
373"#;
374
375        let functions = parser.extract_functions(source, "test.py").unwrap();
376        assert_eq!(functions.len(), 4);
377        assert_eq!(functions[0].name, "hello");
378        assert_eq!(functions[1].name, "add");
379        assert!(!functions[1].is_method);
380        assert_eq!(functions[2].name, "__init__");
381        assert!(functions[2].is_method);
382        assert_eq!(functions[2].class_name, Some("Calculator".to_string()));
383        assert_eq!(functions[3].name, "add");
384        assert!(functions[3].is_method);
385    }
386
387    #[test]
388    fn test_python_classes() {
389        let mut parser = PythonParser::new().unwrap();
390        let source = r#"
391class User:
392    def __init__(self, name):
393        self.name = name
394
395class Admin(User):
396    def __init__(self, name, level):
397        super().__init__(name)
398        self.level = level
399"#;
400
401        let types = parser.extract_types(source, "test.py").unwrap();
402        assert_eq!(types.len(), 2);
403        assert_eq!(types[0].name, "User");
404        assert_eq!(types[0].kind, "class");
405        assert_eq!(types[1].name, "Admin");
406    }
407}