scribe_analysis/language_support/
function_extraction.rs1use super::ast_language::AstLanguage;
7use scribe_core::{Result, ScribeError};
8use serde::{Deserialize, Serialize};
9use tree_sitter::{Language, Node, Parser, Query, QueryCursor, Tree};
10
11#[derive(Debug, Clone, Serialize, Deserialize)]
13pub struct FunctionInfo {
14 pub name: String,
16 pub start_line: usize,
18 pub end_line: usize,
20 pub parameters: Vec<String>,
22 pub return_type: Option<String>,
24 pub documentation: Option<String>,
26 pub visibility: Option<String>,
28 pub is_method: bool,
30 pub parent_class: Option<String>,
32}
33
34#[derive(Debug, Clone, Serialize, Deserialize)]
36pub struct ClassInfo {
37 pub name: String,
39 pub start_line: usize,
41 pub end_line: usize,
43 pub parents: Vec<String>,
45 pub documentation: Option<String>,
47 pub visibility: Option<String>,
49 pub methods: Vec<FunctionInfo>,
51}
52
53pub struct FunctionExtractor {
55 language: AstLanguage,
56 parser: Parser,
57 function_query: Option<Query>,
58 class_query: Option<Query>,
59}
60
61impl FunctionExtractor {
62 pub fn new(language: AstLanguage) -> Result<Self> {
64 let mut parser = Parser::new();
65
66 let (function_query, class_query) =
68 if let Some(ts_language) = language.tree_sitter_language() {
69 parser
70 .set_language(ts_language)
71 .map_err(|e| ScribeError::Analysis {
72 message: format!("Failed to set tree-sitter language: {}", e),
73 source: None,
74 file: std::path::PathBuf::from("<unknown>"),
75 })?;
76
77 let function_query = Self::create_function_query(language, ts_language)?;
78 let class_query = Self::create_class_query(language, ts_language)?;
79 (function_query, class_query)
80 } else {
81 (None, None)
82 };
83
84 Ok(Self {
85 language,
86 parser,
87 function_query,
88 class_query,
89 })
90 }
91
92 fn create_function_query(
94 language: AstLanguage,
95 ts_language: Language,
96 ) -> Result<Option<Query>> {
97 let query_string = match language {
98 AstLanguage::Python => {
99 r#"
100 (function_definition) @function.definition
101 "#
102 }
103 AstLanguage::JavaScript | AstLanguage::TypeScript => {
104 r#"
105 (function_declaration) @function.definition
106 (method_definition) @function.definition
107 "#
108 }
109 AstLanguage::Rust => {
110 r#"
111 (function_item) @function.definition
112 "#
113 }
114 AstLanguage::Go => {
115 r#"
116 (function_declaration) @function.definition
117 (method_declaration) @function.definition
118 "#
119 }
120 AstLanguage::Java => {
122 r#"
123 (method_declaration) @function.definition
124 "#
125 }
126 AstLanguage::C | AstLanguage::Cpp => {
127 r#"
128 (function_definition) @function.definition
129 "#
130 }
131 AstLanguage::Ruby => {
132 r#"
133 (method) @function.definition
134 "#
135 }
136 AstLanguage::CSharp => {
137 r#"
138 (method_declaration) @function.definition
139 "#
140 }
141 _ => return Ok(None),
142 };
143
144 Query::new(ts_language, query_string)
145 .map(Some)
146 .map_err(|e| ScribeError::Analysis {
147 message: format!("Failed to create function query: {}", e),
148 source: None,
149 file: std::path::PathBuf::from("<unknown>"),
150 })
151 }
152
153 fn create_class_query(language: AstLanguage, ts_language: Language) -> Result<Option<Query>> {
155 let query_string = match language {
156 AstLanguage::Python => {
157 r#"
158 (class_definition) @class.definition
159 "#
160 }
161 AstLanguage::JavaScript | AstLanguage::TypeScript => {
162 r#"
163 (class_declaration) @class.definition
164 "#
165 }
166 AstLanguage::Rust => {
167 r#"
168 (struct_item) @class.definition
169 "#
170 }
171 AstLanguage::Go => {
172 r#"
173 (type_declaration) @class.definition
174 "#
175 }
176 AstLanguage::Java => {
178 r#"
179 (class_declaration) @class.definition
180 "#
181 }
182 AstLanguage::Cpp => {
183 r#"
184 (class_specifier) @class.definition
185 "#
186 }
187 AstLanguage::Ruby => {
188 r#"
189 (class) @class.definition
190 "#
191 }
192 AstLanguage::CSharp => {
193 r#"
194 (class_declaration) @class.definition
195 "#
196 }
197 _ => return Ok(None),
198 };
199
200 Query::new(ts_language, query_string)
201 .map(Some)
202 .map_err(|e| ScribeError::Analysis {
203 message: format!("Failed to create class query: {}", e),
204 source: None,
205 file: std::path::PathBuf::from("<unknown>"),
206 })
207 }
208
209 pub fn extract_functions(&mut self, content: &str) -> Result<Vec<FunctionInfo>> {
211 let tree = self
212 .parser
213 .parse(content, None)
214 .ok_or_else(|| ScribeError::Analysis {
215 message: "Failed to parse source code".to_string(),
216 source: None,
217 file: std::path::PathBuf::from("<unknown>"),
218 })?;
219
220 let mut functions = Vec::new();
221
222 if let Some(query) = &self.function_query {
223 let mut query_cursor = QueryCursor::new();
224 let matches = query_cursor.matches(query, tree.root_node(), content.as_bytes());
225
226 for query_match in matches {
227 if let Some(function_info) =
228 self.extract_function_from_match(&query_match, content, &tree)?
229 {
230 functions.push(function_info);
231 }
232 }
233 }
234
235 Ok(functions)
236 }
237
238 pub fn extract_classes(&mut self, content: &str) -> Result<Vec<ClassInfo>> {
240 let tree = self
241 .parser
242 .parse(content, None)
243 .ok_or_else(|| ScribeError::Analysis {
244 message: "Failed to parse source code".to_string(),
245 source: None,
246 file: std::path::PathBuf::from("<unknown>"),
247 })?;
248
249 let mut classes = Vec::new();
250
251 if let Some(query) = &self.class_query {
252 let mut query_cursor = QueryCursor::new();
253 let matches = query_cursor.matches(query, tree.root_node(), content.as_bytes());
254
255 for query_match in matches {
256 if let Some(class_info) =
257 self.extract_class_from_match(&query_match, content, &tree)?
258 {
259 classes.push(class_info);
260 }
261 }
262 }
263
264 Ok(classes)
265 }
266
267 fn extract_function_from_match(
269 &self,
270 query_match: &tree_sitter::QueryMatch,
271 content: &str,
272 tree: &Tree,
273 ) -> Result<Option<FunctionInfo>> {
274 for capture in query_match.captures {
275 let node = capture.node;
276 let start_line = node.start_position().row + 1;
277 let end_line = node.end_position().row + 1;
278
279 let name = self.extract_function_name(node, content);
281 let parameters = self.extract_function_parameters(node, content);
282
283 if let Some(function_name) = name {
284 return Ok(Some(FunctionInfo {
285 name: function_name,
286 start_line,
287 end_line,
288 parameters,
289 return_type: None, documentation: None, visibility: None, is_method: false, parent_class: None, }));
295 }
296 }
297 Ok(None)
298 }
299
300 fn extract_class_from_match(
302 &self,
303 query_match: &tree_sitter::QueryMatch,
304 content: &str,
305 tree: &Tree,
306 ) -> Result<Option<ClassInfo>> {
307 for capture in query_match.captures {
308 let node = capture.node;
309 let start_line = node.start_position().row + 1;
310 let end_line = node.end_position().row + 1;
311
312 let name = self.extract_class_name(node, content);
314 let parents = self.extract_class_parents(node, content);
315
316 if let Some(class_name) = name {
317 return Ok(Some(ClassInfo {
318 name: class_name,
319 start_line,
320 end_line,
321 parents,
322 documentation: None, visibility: None, methods: Vec::new(), }));
326 }
327 }
328 Ok(None)
329 }
330
331 fn extract_function_name(&self, node: Node, content: &str) -> Option<String> {
333 let mut cursor = node.walk();
335 cursor.goto_first_child();
336
337 loop {
338 let child = cursor.node();
339 match child.kind() {
340 "identifier" => {
341 if let Ok(name) = child.utf8_text(content.as_bytes()) {
342 return Some(name.to_string());
343 }
344 }
345 _ => {}
346 }
347
348 if !cursor.goto_next_sibling() {
349 break;
350 }
351 }
352 None
353 }
354
355 fn extract_function_parameters(&self, node: Node, content: &str) -> Vec<String> {
357 let mut parameters = Vec::new();
358 let mut cursor = node.walk();
359 cursor.goto_first_child();
360
361 loop {
362 let child = cursor.node();
363 match child.kind() {
364 "parameters" | "parameter_list" => {
365 let mut param_cursor = child.walk();
367 param_cursor.goto_first_child();
368
369 loop {
370 let param_node = param_cursor.node();
371 if param_node.kind() == "identifier" {
372 if let Ok(param_name) = param_node.utf8_text(content.as_bytes()) {
373 if param_name != "self" {
374 parameters.push(param_name.to_string());
375 }
376 }
377 }
378
379 if !param_cursor.goto_next_sibling() {
380 break;
381 }
382 }
383 break;
384 }
385 _ => {}
386 }
387
388 if !cursor.goto_next_sibling() {
389 break;
390 }
391 }
392 parameters
393 }
394
395 fn extract_class_name(&self, node: Node, content: &str) -> Option<String> {
397 let mut cursor = node.walk();
399 cursor.goto_first_child();
400
401 loop {
402 let child = cursor.node();
403 match child.kind() {
404 "identifier" | "type_identifier" => {
405 if let Ok(name) = child.utf8_text(content.as_bytes()) {
406 return Some(name.to_string());
407 }
408 }
409 _ => {}
410 }
411
412 if !cursor.goto_next_sibling() {
413 break;
414 }
415 }
416 None
417 }
418
419 fn extract_class_parents(&self, node: Node, content: &str) -> Vec<String> {
421 let mut parents = Vec::new();
422 let mut cursor = node.walk();
423 cursor.goto_first_child();
424
425 loop {
426 let child = cursor.node();
427 match child.kind() {
428 "argument_list" | "superclass" | "inheritance" => {
429 let mut parent_cursor = child.walk();
431 parent_cursor.goto_first_child();
432
433 loop {
434 let parent_node = parent_cursor.node();
435 if parent_node.kind() == "identifier"
436 || parent_node.kind() == "type_identifier"
437 {
438 if let Ok(parent_name) = parent_node.utf8_text(content.as_bytes()) {
439 parents.push(parent_name.to_string());
440 }
441 }
442
443 if !parent_cursor.goto_next_sibling() {
444 break;
445 }
446 }
447 }
448 _ => {}
449 }
450
451 if !cursor.goto_next_sibling() {
452 break;
453 }
454 }
455 parents
456 }
457
458 fn extract_parameters(&self, params_text: &str, _node: Node) -> Vec<String> {
460 params_text
462 .split(',')
463 .filter_map(|param| {
464 let param = param.trim();
465 if param.is_empty() || param == "self" {
466 None
467 } else {
468 let name = param.split(':').next().unwrap_or(param).trim();
470 if name.is_empty() {
471 None
472 } else {
473 Some(name.to_string())
474 }
475 }
476 })
477 .collect()
478 }
479
480 fn extract_parent_classes(&self, parents_text: &str) -> Vec<String> {
482 parents_text
484 .split(',')
485 .filter_map(|parent| {
486 let parent = parent.trim();
487 if parent.is_empty() {
488 None
489 } else {
490 Some(parent.to_string())
491 }
492 })
493 .collect()
494 }
495}
496
497#[cfg(test)]
498mod tests {
499 use super::*;
500
501 #[test]
502 fn test_function_extractor_creation() {
503 let extractor = FunctionExtractor::new(AstLanguage::Python);
504 assert!(extractor.is_ok());
505 }
506
507 #[test]
508 fn test_python_function_extraction() {
509 let mut extractor = FunctionExtractor::new(AstLanguage::Python).unwrap();
510 let python_code = r#"
511def hello_world():
512 """A simple function."""
513 print("Hello, World!")
514
515def add_numbers(a, b):
516 """Add two numbers together."""
517 return a + b
518
519class Calculator:
520 """A simple calculator."""
521
522 def multiply(self, x, y):
523 """Multiply two numbers."""
524 return x * y
525"#;
526
527 let functions = extractor.extract_functions(python_code).unwrap();
528 assert!(!functions.is_empty());
529
530 let function_names: Vec<&String> = functions.iter().map(|f| &f.name).collect();
532 assert!(function_names.contains(&&"hello_world".to_string()));
533 assert!(function_names.contains(&&"add_numbers".to_string()));
534 }
535
536 #[test]
537 fn test_python_class_extraction() {
538 let mut extractor = FunctionExtractor::new(AstLanguage::Python).unwrap();
539 let python_code = r#"
540class Calculator:
541 """A simple calculator."""
542 pass
543
544class AdvancedCalculator(Calculator):
545 """An advanced calculator that inherits from Calculator."""
546 pass
547"#;
548
549 let classes = extractor.extract_classes(python_code).unwrap();
550 assert!(!classes.is_empty());
551
552 let class_names: Vec<&String> = classes.iter().map(|c| &c.name).collect();
553 assert!(class_names.contains(&&"Calculator".to_string()));
554 assert!(class_names.contains(&&"AdvancedCalculator".to_string()));
555 }
556
557 #[test]
558 fn test_javascript_function_extraction() {
559 let mut extractor = FunctionExtractor::new(AstLanguage::JavaScript).unwrap();
560 let js_code = r#"
561function greetUser(name) {
562 return `Hello, ${name}!`;
563}
564
565class UserManager {
566 constructor() {
567 this.users = [];
568 }
569
570 addUser(user) {
571 this.users.push(user);
572 }
573}
574"#;
575
576 let functions = extractor.extract_functions(js_code).unwrap();
577 assert!(!functions.is_empty());
578 }
579}