1#![allow(clippy::io_other_error)]
2
3use similarity_core::language_parser::{
4 GenericFunctionDef, GenericTypeDef, Language, LanguageParser,
5};
6use similarity_core::tree::TreeNode;
7use std::error::Error;
8use std::rc::Rc;
9use tree_sitter::{Node, Parser};
10
11pub struct PythonParser {
12 parser: Parser,
13}
14
15impl PythonParser {
16 pub fn new() -> Result<Self, Box<dyn Error + Send + Sync>> {
17 let mut parser = Parser::new();
18 parser.set_language(&tree_sitter_python::LANGUAGE.into()).map_err(|e| {
19 Box::new(std::io::Error::new(
20 std::io::ErrorKind::Other,
21 format!("Failed to set Python language: {e:?}"),
22 )) as Box<dyn Error + Send + Sync>
23 })?;
24
25 Ok(Self { parser })
26 }
27
28 #[allow(clippy::only_used_in_recursion)]
29 fn convert_node(&self, node: Node, source: &str, id_counter: &mut usize) -> TreeNode {
30 let current_id = *id_counter;
31 *id_counter += 1;
32
33 let label = node.kind().to_string();
34 let value = match node.kind() {
35 "identifier" | "string" | "integer" | "float" | "true" | "false" | "none" => {
36 node.utf8_text(source.as_bytes()).unwrap_or("").to_string()
37 }
38 _ => "".to_string(),
39 };
40
41 let mut tree_node = TreeNode::new(label, value, current_id);
42
43 for child in node.children(&mut node.walk()) {
44 let child_node = self.convert_node(child, source, id_counter);
45 tree_node.add_child(Rc::new(child_node));
46 }
47
48 tree_node
49 }
50
51 fn extract_functions_from_node(
52 &self,
53 node: Node,
54 source: &str,
55 class_name: Option<&str>,
56 ) -> Vec<GenericFunctionDef> {
57 let mut functions = Vec::new();
58
59 fn visit_node(
61 node: Node,
62 source: &str,
63 functions: &mut Vec<GenericFunctionDef>,
64 class_name: Option<&str>,
65 ) {
66 match node.kind() {
67 "function_definition" => {
68 if let Some(name_node) = node.child_by_field_name("name") {
69 if let Ok(name) = name_node.utf8_text(source.as_bytes()) {
70 let params_node = node.child_by_field_name("parameters");
71 let body_node = node.child_by_field_name("body");
72
73 let params = extract_params(params_node, source);
74
75 functions.push(GenericFunctionDef {
76 name: name.to_string(),
77 start_line: node.start_position().row as u32 + 1,
78 end_line: node.end_position().row as u32 + 1,
79 body_start_line: body_node
80 .map(|n| n.start_position().row as u32 + 1)
81 .unwrap_or(0),
82 body_end_line: body_node
83 .map(|n| n.end_position().row as u32 + 1)
84 .unwrap_or(0),
85 parameters: params,
86 is_method: class_name.is_some(),
87 class_name: class_name.map(|s| s.to_string()),
88 is_async: is_async_def(node, source),
89 is_generator: is_generator_def(node, source),
90 decorators: extract_decorators(node, source),
91 });
92 }
93 }
94 }
95 "decorated_definition" => {
96 if let Some(child) = node.child(node.child_count().saturating_sub(1)) {
98 if child.kind() == "function_definition" {
99 if let Some(name_node) = child.child_by_field_name("name") {
100 if let Ok(name) = name_node.utf8_text(source.as_bytes()) {
101 let params_node = child.child_by_field_name("parameters");
102 let body_node = child.child_by_field_name("body");
103
104 let params = extract_params(params_node, source);
105
106 functions.push(GenericFunctionDef {
107 name: name.to_string(),
108 start_line: node.start_position().row as u32 + 1,
109 end_line: node.end_position().row as u32 + 1,
110 body_start_line: body_node
111 .map(|n| n.start_position().row as u32 + 1)
112 .unwrap_or(0),
113 body_end_line: body_node
114 .map(|n| n.end_position().row as u32 + 1)
115 .unwrap_or(0),
116 parameters: params,
117 is_method: class_name.is_some(),
118 class_name: class_name.map(|s| s.to_string()),
119 is_async: is_async_def(child, source),
120 is_generator: is_generator_def(child, source),
121 decorators: extract_decorators(child, source),
122 });
123 }
124 }
125 }
126 }
127 }
128 "class_definition" => {
129 if class_name.is_none() {
131 if let Some(name_node) = node.child_by_field_name("name") {
132 if let Ok(name) = name_node.utf8_text(source.as_bytes()) {
133 let mut subcursor = node.walk();
135 for child in node.children(&mut subcursor) {
136 visit_node(child, source, functions, Some(name));
137 }
138 }
139 }
140 }
141 }
142 _ => {
143 let mut subcursor = node.walk();
145 for child in node.children(&mut subcursor) {
146 visit_node(child, source, functions, class_name);
147 }
148 }
149 }
150 }
151
152 fn is_async_def(node: Node, source: &str) -> bool {
153 if let Ok(text) = node.utf8_text(source.as_bytes()) {
154 text.starts_with("async ")
155 } else {
156 false
157 }
158 }
159
160 fn is_generator_def(node: Node, source: &str) -> bool {
161 if let Some(body) = node.child_by_field_name("body") {
164 if let Ok(body_text) = body.utf8_text(source.as_bytes()) {
165 return body_text.contains("yield");
166 }
167 }
168 false
169 }
170
171 fn extract_decorators(node: Node, source: &str) -> Vec<String> {
172 let mut decorators = Vec::new();
173 let mut cursor = node.walk();
174
175 if let Some(parent) = node.parent() {
177 for child in parent.children(&mut cursor) {
178 if child.kind() == "decorator"
179 && child.end_position().row < node.start_position().row
180 {
181 if let Ok(decorator_text) = child.utf8_text(source.as_bytes()) {
182 decorators.push(decorator_text.trim_start_matches('@').to_string());
183 }
184 }
185 }
186 }
187
188 decorators
189 }
190
191 fn extract_params(params_node: Option<Node>, source: &str) -> Vec<String> {
192 if let Some(node) = params_node {
193 let mut params = Vec::new();
194 let mut cursor = node.walk();
195
196 for child in node.children(&mut cursor) {
197 match child.kind() {
198 "identifier" => {
199 if let Ok(param_text) = child.utf8_text(source.as_bytes()) {
200 params.push(param_text.to_string());
201 }
202 }
203 "typed_parameter" | "default_parameter" => {
204 if let Some(ident) = child.child_by_field_name("name") {
205 if let Ok(param_text) = ident.utf8_text(source.as_bytes()) {
206 params.push(param_text.to_string());
207 }
208 }
209 }
210 _ => {}
211 }
212 }
213
214 params
215 } else {
216 Vec::new()
217 }
218 }
219
220 visit_node(node, source, &mut functions, class_name);
221 functions
222 }
223}
224
225impl LanguageParser for PythonParser {
226 fn parse(
227 &mut self,
228 source: &str,
229 _filename: &str,
230 ) -> Result<Rc<TreeNode>, Box<dyn Error + Send + Sync>> {
231 let tree = self.parser.parse(source, None).ok_or_else(|| {
232 Box::new(std::io::Error::new(
233 std::io::ErrorKind::InvalidData,
234 "Failed to parse Python source",
235 )) as Box<dyn Error + Send + Sync>
236 })?;
237
238 let root_node = tree.root_node();
239 let mut id_counter = 0;
240 Ok(Rc::new(self.convert_node(root_node, source, &mut id_counter)))
241 }
242
243 fn extract_functions(
244 &mut self,
245 source: &str,
246 _filename: &str,
247 ) -> Result<Vec<GenericFunctionDef>, Box<dyn Error + Send + Sync>> {
248 let tree = self.parser.parse(source, None).ok_or_else(|| {
249 Box::new(std::io::Error::new(
250 std::io::ErrorKind::InvalidData,
251 "Failed to parse Python source",
252 )) as Box<dyn Error + Send + Sync>
253 })?;
254
255 let root_node = tree.root_node();
256 Ok(self.extract_functions_from_node(root_node, source, None))
257 }
258
259 fn extract_types(
260 &mut self,
261 source: &str,
262 _filename: &str,
263 ) -> Result<Vec<GenericTypeDef>, Box<dyn Error + Send + Sync>> {
264 let tree = self.parser.parse(source, None).ok_or_else(|| {
265 Box::new(std::io::Error::new(
266 std::io::ErrorKind::InvalidData,
267 "Failed to parse Python source",
268 )) as Box<dyn Error + Send + Sync>
269 })?;
270
271 let root_node = tree.root_node();
272 let mut types = Vec::new();
273
274 fn visit_node_for_types(node: Node, source: &str, types: &mut Vec<GenericTypeDef>) {
275 if node.kind() == "class_definition" {
276 if let Some(name_node) = node.child_by_field_name("name") {
277 if let Ok(name) = name_node.utf8_text(source.as_bytes()) {
278 types.push(GenericTypeDef {
279 name: name.to_string(),
280 kind: "class".to_string(),
281 start_line: node.start_position().row as u32 + 1,
282 end_line: node.end_position().row as u32 + 1,
283 fields: extract_class_fields(node, source),
284 });
285 }
286 }
287 }
288
289 let mut cursor = node.walk();
291 for child in node.children(&mut cursor) {
292 visit_node_for_types(child, source, types);
293 }
294 }
295
296 fn extract_class_fields(node: Node, source: &str) -> Vec<String> {
297 let mut fields = Vec::new();
298
299 if let Some(body) = node.child_by_field_name("body") {
300 let mut cursor = body.walk();
301 for child in body.children(&mut cursor) {
302 if child.kind() == "function_definition" {
304 if let Some(name_node) = child.child_by_field_name("name") {
305 if let Ok(name) = name_node.utf8_text(source.as_bytes()) {
306 if name == "__init__" {
307 if let Some(func_body) = child.child_by_field_name("body") {
309 extract_self_assignments(func_body, source, &mut fields);
310 }
311 }
312 }
313 }
314 }
315 }
316 }
317
318 fields
319 }
320
321 fn extract_self_assignments(node: Node, source: &str, fields: &mut Vec<String>) {
322 let mut cursor = node.walk();
323 for child in node.children(&mut cursor) {
324 if child.kind() == "assignment" {
325 if let Some(left) = child.child(0) {
326 if left.kind() == "attribute" {
327 if let Ok(text) = left.utf8_text(source.as_bytes()) {
328 if text.starts_with("self.") {
329 let field_name = text.trim_start_matches("self.");
330 if !fields.contains(&field_name.to_string()) {
331 fields.push(field_name.to_string());
332 }
333 }
334 }
335 }
336 }
337 }
338 extract_self_assignments(child, source, fields);
340 }
341 }
342
343 visit_node_for_types(root_node, source, &mut types);
344 Ok(types)
345 }
346
347 fn language(&self) -> Language {
348 Language::Python
349 }
350}
351
352#[cfg(test)]
353mod tests {
354 use super::*;
355
356 #[test]
357 fn test_python_functions() {
358 let mut parser = PythonParser::new().unwrap();
359 let source = r#"
360def hello(name):
361 return f"Hello, {name}!"
362
363def add(a, b=0):
364 return a + b
365
366class Calculator:
367 def __init__(self):
368 self.result = 0
369
370 def add(self, x):
371 self.result += x
372 return self.result
373"#;
374
375 let functions = parser.extract_functions(source, "test.py").unwrap();
376 assert_eq!(functions.len(), 4);
377 assert_eq!(functions[0].name, "hello");
378 assert_eq!(functions[1].name, "add");
379 assert!(!functions[1].is_method);
380 assert_eq!(functions[2].name, "__init__");
381 assert!(functions[2].is_method);
382 assert_eq!(functions[2].class_name, Some("Calculator".to_string()));
383 assert_eq!(functions[3].name, "add");
384 assert!(functions[3].is_method);
385 }
386
387 #[test]
388 fn test_python_classes() {
389 let mut parser = PythonParser::new().unwrap();
390 let source = r#"
391class User:
392 def __init__(self, name):
393 self.name = name
394
395class Admin(User):
396 def __init__(self, name, level):
397 super().__init__(name)
398 self.level = level
399"#;
400
401 let types = parser.extract_types(source, "test.py").unwrap();
402 assert_eq!(types.len(), 2);
403 assert_eq!(types[0].name, "User");
404 assert_eq!(types[0].kind, "class");
405 assert_eq!(types[1].name, "Admin");
406 }
407}