scribe_analysis/language_support/
function_extraction.rs1use serde::{Deserialize, Serialize};
7use tree_sitter::{Parser, Language, Node, Tree, Query, QueryCursor};
8use scribe_core::{Result, ScribeError};
9use super::ast_language::AstLanguage;
10
11#[derive(Debug, Clone, Serialize, Deserialize)]
13pub struct FunctionInfo {
14 pub name: String,
16 pub start_line: usize,
18 pub end_line: usize,
20 pub parameters: Vec<String>,
22 pub return_type: Option<String>,
24 pub documentation: Option<String>,
26 pub visibility: Option<String>,
28 pub is_method: bool,
30 pub parent_class: Option<String>,
32}
33
34#[derive(Debug, Clone, Serialize, Deserialize)]
36pub struct ClassInfo {
37 pub name: String,
39 pub start_line: usize,
41 pub end_line: usize,
43 pub parents: Vec<String>,
45 pub documentation: Option<String>,
47 pub visibility: Option<String>,
49 pub methods: Vec<FunctionInfo>,
51}
52
53pub struct FunctionExtractor {
55 language: AstLanguage,
56 parser: Parser,
57 function_query: Option<Query>,
58 class_query: Option<Query>,
59}
60
61impl FunctionExtractor {
62 pub fn new(language: AstLanguage) -> Result<Self> {
64 let mut parser = Parser::new();
65
66 let (function_query, class_query) = if let Some(ts_language) = language.tree_sitter_language() {
68 parser.set_language(ts_language)
69 .map_err(|e| ScribeError::Analysis {
70 message: format!("Failed to set tree-sitter language: {}", e),
71 source: None,
72 file: std::path::PathBuf::from("<unknown>"),
73 })?;
74
75 let function_query = Self::create_function_query(language, ts_language)?;
76 let class_query = Self::create_class_query(language, ts_language)?;
77 (function_query, class_query)
78 } else {
79 (None, None)
80 };
81
82 Ok(Self {
83 language,
84 parser,
85 function_query,
86 class_query,
87 })
88 }
89
90 fn create_function_query(language: AstLanguage, ts_language: Language) -> Result<Option<Query>> {
92 let query_string = match language {
93 AstLanguage::Python => r#"
94 (function_definition) @function.definition
95 "#,
96 AstLanguage::JavaScript | AstLanguage::TypeScript => r#"
97 (function_declaration) @function.definition
98 (method_definition) @function.definition
99 "#,
100 AstLanguage::Rust => r#"
101 (function_item) @function.definition
102 "#,
103 AstLanguage::Go => r#"
104 (function_declaration) @function.definition
105 (method_declaration) @function.definition
106 "#,
107 AstLanguage::Java => r#"
109 (method_declaration) @function.definition
110 "#,
111 AstLanguage::C | AstLanguage::Cpp => r#"
112 (function_definition) @function.definition
113 "#,
114 AstLanguage::Ruby => r#"
115 (method) @function.definition
116 "#,
117 AstLanguage::CSharp => r#"
118 (method_declaration) @function.definition
119 "#,
120 _ => return Ok(None),
121 };
122
123 Query::new(ts_language, query_string)
124 .map(Some)
125 .map_err(|e| ScribeError::Analysis {
126 message: format!("Failed to create function query: {}", e),
127 source: None,
128 file: std::path::PathBuf::from("<unknown>"),
129 })
130 }
131
132 fn create_class_query(language: AstLanguage, ts_language: Language) -> Result<Option<Query>> {
134 let query_string = match language {
135 AstLanguage::Python => r#"
136 (class_definition) @class.definition
137 "#,
138 AstLanguage::JavaScript | AstLanguage::TypeScript => r#"
139 (class_declaration) @class.definition
140 "#,
141 AstLanguage::Rust => r#"
142 (struct_item) @class.definition
143 "#,
144 AstLanguage::Go => r#"
145 (type_declaration) @class.definition
146 "#,
147 AstLanguage::Java => r#"
149 (class_declaration) @class.definition
150 "#,
151 AstLanguage::Cpp => r#"
152 (class_specifier) @class.definition
153 "#,
154 AstLanguage::Ruby => r#"
155 (class) @class.definition
156 "#,
157 AstLanguage::CSharp => r#"
158 (class_declaration) @class.definition
159 "#,
160 _ => return Ok(None),
161 };
162
163 Query::new(ts_language, query_string)
164 .map(Some)
165 .map_err(|e| ScribeError::Analysis {
166 message: format!("Failed to create class query: {}", e),
167 source: None,
168 file: std::path::PathBuf::from("<unknown>"),
169 })
170 }
171
172 pub fn extract_functions(&mut self, content: &str) -> Result<Vec<FunctionInfo>> {
174 let tree = self.parser.parse(content, None)
175 .ok_or_else(|| ScribeError::Analysis {
176 message: "Failed to parse source code".to_string(),
177 source: None,
178 file: std::path::PathBuf::from("<unknown>"),
179 })?;
180
181 let mut functions = Vec::new();
182
183 if let Some(query) = &self.function_query {
184 let mut query_cursor = QueryCursor::new();
185 let matches = query_cursor.matches(query, tree.root_node(), content.as_bytes());
186
187 for query_match in matches {
188 if let Some(function_info) = self.extract_function_from_match(&query_match, content, &tree)? {
189 functions.push(function_info);
190 }
191 }
192 }
193
194 Ok(functions)
195 }
196
197 pub fn extract_classes(&mut self, content: &str) -> Result<Vec<ClassInfo>> {
199 let tree = self.parser.parse(content, None)
200 .ok_or_else(|| ScribeError::Analysis {
201 message: "Failed to parse source code".to_string(),
202 source: None,
203 file: std::path::PathBuf::from("<unknown>"),
204 })?;
205
206 let mut classes = Vec::new();
207
208 if let Some(query) = &self.class_query {
209 let mut query_cursor = QueryCursor::new();
210 let matches = query_cursor.matches(query, tree.root_node(), content.as_bytes());
211
212 for query_match in matches {
213 if let Some(class_info) = self.extract_class_from_match(&query_match, content, &tree)? {
214 classes.push(class_info);
215 }
216 }
217 }
218
219 Ok(classes)
220 }
221
222 fn extract_function_from_match(
224 &self,
225 query_match: &tree_sitter::QueryMatch,
226 content: &str,
227 tree: &Tree,
228 ) -> Result<Option<FunctionInfo>> {
229 for capture in query_match.captures {
230 let node = capture.node;
231 let start_line = node.start_position().row + 1;
232 let end_line = node.end_position().row + 1;
233
234 let name = self.extract_function_name(node, content);
236 let parameters = self.extract_function_parameters(node, content);
237
238 if let Some(function_name) = name {
239 return Ok(Some(FunctionInfo {
240 name: function_name,
241 start_line,
242 end_line,
243 parameters,
244 return_type: None, documentation: None, visibility: None, is_method: false, parent_class: None, }));
250 }
251 }
252 Ok(None)
253 }
254
255 fn extract_class_from_match(
257 &self,
258 query_match: &tree_sitter::QueryMatch,
259 content: &str,
260 tree: &Tree,
261 ) -> Result<Option<ClassInfo>> {
262 for capture in query_match.captures {
263 let node = capture.node;
264 let start_line = node.start_position().row + 1;
265 let end_line = node.end_position().row + 1;
266
267 let name = self.extract_class_name(node, content);
269 let parents = self.extract_class_parents(node, content);
270
271 if let Some(class_name) = name {
272 return Ok(Some(ClassInfo {
273 name: class_name,
274 start_line,
275 end_line,
276 parents,
277 documentation: None, visibility: None, methods: Vec::new(), }));
281 }
282 }
283 Ok(None)
284 }
285
286 fn extract_function_name(&self, node: Node, content: &str) -> Option<String> {
288 let mut cursor = node.walk();
290 cursor.goto_first_child();
291
292 loop {
293 let child = cursor.node();
294 match child.kind() {
295 "identifier" => {
296 if let Ok(name) = child.utf8_text(content.as_bytes()) {
297 return Some(name.to_string());
298 }
299 }
300 _ => {}
301 }
302
303 if !cursor.goto_next_sibling() {
304 break;
305 }
306 }
307 None
308 }
309
310 fn extract_function_parameters(&self, node: Node, content: &str) -> Vec<String> {
312 let mut parameters = Vec::new();
313 let mut cursor = node.walk();
314 cursor.goto_first_child();
315
316 loop {
317 let child = cursor.node();
318 match child.kind() {
319 "parameters" | "parameter_list" => {
320 let mut param_cursor = child.walk();
322 param_cursor.goto_first_child();
323
324 loop {
325 let param_node = param_cursor.node();
326 if param_node.kind() == "identifier" {
327 if let Ok(param_name) = param_node.utf8_text(content.as_bytes()) {
328 if param_name != "self" {
329 parameters.push(param_name.to_string());
330 }
331 }
332 }
333
334 if !param_cursor.goto_next_sibling() {
335 break;
336 }
337 }
338 break;
339 }
340 _ => {}
341 }
342
343 if !cursor.goto_next_sibling() {
344 break;
345 }
346 }
347 parameters
348 }
349
350 fn extract_class_name(&self, node: Node, content: &str) -> Option<String> {
352 let mut cursor = node.walk();
354 cursor.goto_first_child();
355
356 loop {
357 let child = cursor.node();
358 match child.kind() {
359 "identifier" | "type_identifier" => {
360 if let Ok(name) = child.utf8_text(content.as_bytes()) {
361 return Some(name.to_string());
362 }
363 }
364 _ => {}
365 }
366
367 if !cursor.goto_next_sibling() {
368 break;
369 }
370 }
371 None
372 }
373
374 fn extract_class_parents(&self, node: Node, content: &str) -> Vec<String> {
376 let mut parents = Vec::new();
377 let mut cursor = node.walk();
378 cursor.goto_first_child();
379
380 loop {
381 let child = cursor.node();
382 match child.kind() {
383 "argument_list" | "superclass" | "inheritance" => {
384 let mut parent_cursor = child.walk();
386 parent_cursor.goto_first_child();
387
388 loop {
389 let parent_node = parent_cursor.node();
390 if parent_node.kind() == "identifier" || parent_node.kind() == "type_identifier" {
391 if let Ok(parent_name) = parent_node.utf8_text(content.as_bytes()) {
392 parents.push(parent_name.to_string());
393 }
394 }
395
396 if !parent_cursor.goto_next_sibling() {
397 break;
398 }
399 }
400 }
401 _ => {}
402 }
403
404 if !cursor.goto_next_sibling() {
405 break;
406 }
407 }
408 parents
409 }
410
411 fn extract_parameters(&self, params_text: &str, _node: Node) -> Vec<String> {
413 params_text
415 .split(',')
416 .filter_map(|param| {
417 let param = param.trim();
418 if param.is_empty() || param == "self" {
419 None
420 } else {
421 let name = param.split(':').next().unwrap_or(param).trim();
423 if name.is_empty() {
424 None
425 } else {
426 Some(name.to_string())
427 }
428 }
429 })
430 .collect()
431 }
432
433 fn extract_parent_classes(&self, parents_text: &str) -> Vec<String> {
435 parents_text
437 .split(',')
438 .filter_map(|parent| {
439 let parent = parent.trim();
440 if parent.is_empty() {
441 None
442 } else {
443 Some(parent.to_string())
444 }
445 })
446 .collect()
447 }
448}
449
450impl AstLanguage {
451 pub fn tree_sitter_language(&self) -> Option<tree_sitter::Language> {
453 match self {
454 AstLanguage::Python => Some(tree_sitter_python::language()),
455 AstLanguage::JavaScript => Some(tree_sitter_javascript::language()),
456 AstLanguage::TypeScript => Some(tree_sitter_typescript::language_typescript()),
457 AstLanguage::Go => Some(tree_sitter_go::language()),
458 AstLanguage::Rust => Some(tree_sitter_rust::language()),
459 AstLanguage::Html => Some(tree_sitter_html::language()),
460 _ => None,
467 }
468 }
469}
470
471#[cfg(test)]
472mod tests {
473 use super::*;
474
475 #[test]
476 fn test_function_extractor_creation() {
477 let extractor = FunctionExtractor::new(AstLanguage::Python);
478 assert!(extractor.is_ok());
479 }
480
481 #[test]
482 fn test_python_function_extraction() {
483 let mut extractor = FunctionExtractor::new(AstLanguage::Python).unwrap();
484 let python_code = r#"
485def hello_world():
486 """A simple function."""
487 print("Hello, World!")
488
489def add_numbers(a, b):
490 """Add two numbers together."""
491 return a + b
492
493class Calculator:
494 """A simple calculator."""
495
496 def multiply(self, x, y):
497 """Multiply two numbers."""
498 return x * y
499"#;
500
501 let functions = extractor.extract_functions(python_code).unwrap();
502 assert!(!functions.is_empty());
503
504 let function_names: Vec<&String> = functions.iter().map(|f| &f.name).collect();
506 assert!(function_names.contains(&&"hello_world".to_string()));
507 assert!(function_names.contains(&&"add_numbers".to_string()));
508 }
509
510 #[test]
511 fn test_python_class_extraction() {
512 let mut extractor = FunctionExtractor::new(AstLanguage::Python).unwrap();
513 let python_code = r#"
514class Calculator:
515 """A simple calculator."""
516 pass
517
518class AdvancedCalculator(Calculator):
519 """An advanced calculator that inherits from Calculator."""
520 pass
521"#;
522
523 let classes = extractor.extract_classes(python_code).unwrap();
524 assert!(!classes.is_empty());
525
526 let class_names: Vec<&String> = classes.iter().map(|c| &c.name).collect();
527 assert!(class_names.contains(&&"Calculator".to_string()));
528 assert!(class_names.contains(&&"AdvancedCalculator".to_string()));
529 }
530
531 #[test]
532 fn test_javascript_function_extraction() {
533 let mut extractor = FunctionExtractor::new(AstLanguage::JavaScript).unwrap();
534 let js_code = r#"
535function greetUser(name) {
536 return `Hello, ${name}!`;
537}
538
539class UserManager {
540 constructor() {
541 this.users = [];
542 }
543
544 addUser(user) {
545 this.users.push(user);
546 }
547}
548"#;
549
550 let functions = extractor.extract_functions(js_code).unwrap();
551 assert!(!functions.is_empty());
552 }
553}