scribe_analysis/language_support/
function_extraction.rs

1//! # Function and Class Extraction from AST
2//!
3//! Extracts function definitions, class definitions, and methods from source code
4//! using tree-sitter AST parsing for accurate analysis.
5
6use super::ast_language::AstLanguage;
7use scribe_core::{Result, ScribeError};
8use serde::{Deserialize, Serialize};
9use tree_sitter::{Language, Node, Parser, Query, QueryCursor, Tree};
10
11/// Information about a function extracted from source code
12#[derive(Debug, Clone, Serialize, Deserialize)]
13pub struct FunctionInfo {
14    /// Function name
15    pub name: String,
16    /// Line number where function starts
17    pub start_line: usize,
18    /// Line number where function ends
19    pub end_line: usize,
20    /// Function parameters
21    pub parameters: Vec<String>,
22    /// Return type (if available)
23    pub return_type: Option<String>,
24    /// Documentation/docstring
25    pub documentation: Option<String>,
26    /// Function visibility (public, private, etc.)
27    pub visibility: Option<String>,
28    /// Whether this is a method (inside a class)
29    pub is_method: bool,
30    /// Parent class name (if this is a method)
31    pub parent_class: Option<String>,
32}
33
34/// Information about a class extracted from source code
35#[derive(Debug, Clone, Serialize, Deserialize)]
36pub struct ClassInfo {
37    /// Class name
38    pub name: String,
39    /// Line number where class starts
40    pub start_line: usize,
41    /// Line number where class ends
42    pub end_line: usize,
43    /// Parent classes/interfaces
44    pub parents: Vec<String>,
45    /// Documentation/docstring
46    pub documentation: Option<String>,
47    /// Class visibility
48    pub visibility: Option<String>,
49    /// Methods in this class
50    pub methods: Vec<FunctionInfo>,
51}
52
53/// Extracts functions and classes from source code using tree-sitter
54pub struct FunctionExtractor {
55    language: AstLanguage,
56    parser: Parser,
57    function_query: Option<Query>,
58    class_query: Option<Query>,
59}
60
61impl FunctionExtractor {
62    /// Create a new function extractor for the given language
63    pub fn new(language: AstLanguage) -> Result<Self> {
64        let mut parser = Parser::new();
65
66        // Set up tree-sitter language if available
67        let (function_query, class_query) =
68            if let Some(ts_language) = language.tree_sitter_language() {
69                parser
70                    .set_language(ts_language)
71                    .map_err(|e| ScribeError::Analysis {
72                        message: format!("Failed to set tree-sitter language: {}", e),
73                        source: None,
74                        file: std::path::PathBuf::from("<unknown>"),
75                    })?;
76
77                let function_query = Self::create_function_query(language, ts_language)?;
78                let class_query = Self::create_class_query(language, ts_language)?;
79                (function_query, class_query)
80            } else {
81                (None, None)
82            };
83
84        Ok(Self {
85            language,
86            parser,
87            function_query,
88            class_query,
89        })
90    }
91
92    /// Create tree-sitter query for finding functions
93    fn create_function_query(
94        language: AstLanguage,
95        ts_language: Language,
96    ) -> Result<Option<Query>> {
97        let query_string = match language {
98            AstLanguage::Python => {
99                r#"
100                (function_definition) @function.definition
101            "#
102            }
103            AstLanguage::JavaScript | AstLanguage::TypeScript => {
104                r#"
105                (function_declaration) @function.definition
106                (method_definition) @function.definition
107            "#
108            }
109            AstLanguage::Rust => {
110                r#"
111                (function_item) @function.definition
112            "#
113            }
114            AstLanguage::Go => {
115                r#"
116                (function_declaration) @function.definition
117                (method_declaration) @function.definition
118            "#
119            }
120            // Future languages - placeholder queries
121            AstLanguage::Java => {
122                r#"
123                (method_declaration) @function.definition
124            "#
125            }
126            AstLanguage::C | AstLanguage::Cpp => {
127                r#"
128                (function_definition) @function.definition
129            "#
130            }
131            AstLanguage::Ruby => {
132                r#"
133                (method) @function.definition
134            "#
135            }
136            AstLanguage::CSharp => {
137                r#"
138                (method_declaration) @function.definition
139            "#
140            }
141            _ => return Ok(None),
142        };
143
144        Query::new(ts_language, query_string)
145            .map(Some)
146            .map_err(|e| ScribeError::Analysis {
147                message: format!("Failed to create function query: {}", e),
148                source: None,
149                file: std::path::PathBuf::from("<unknown>"),
150            })
151    }
152
153    /// Create tree-sitter query for finding classes
154    fn create_class_query(language: AstLanguage, ts_language: Language) -> Result<Option<Query>> {
155        let query_string = match language {
156            AstLanguage::Python => {
157                r#"
158                (class_definition) @class.definition
159            "#
160            }
161            AstLanguage::JavaScript | AstLanguage::TypeScript => {
162                r#"
163                (class_declaration) @class.definition
164            "#
165            }
166            AstLanguage::Rust => {
167                r#"
168                (struct_item) @class.definition
169            "#
170            }
171            AstLanguage::Go => {
172                r#"
173                (type_declaration) @class.definition
174            "#
175            }
176            // Future languages - placeholder queries
177            AstLanguage::Java => {
178                r#"
179                (class_declaration) @class.definition
180            "#
181            }
182            AstLanguage::Cpp => {
183                r#"
184                (class_specifier) @class.definition
185            "#
186            }
187            AstLanguage::Ruby => {
188                r#"
189                (class) @class.definition
190            "#
191            }
192            AstLanguage::CSharp => {
193                r#"
194                (class_declaration) @class.definition
195            "#
196            }
197            _ => return Ok(None),
198        };
199
200        Query::new(ts_language, query_string)
201            .map(Some)
202            .map_err(|e| ScribeError::Analysis {
203                message: format!("Failed to create class query: {}", e),
204                source: None,
205                file: std::path::PathBuf::from("<unknown>"),
206            })
207    }
208
209    /// Extract all functions from source code
210    pub fn extract_functions(&mut self, content: &str) -> Result<Vec<FunctionInfo>> {
211        let tree = self
212            .parser
213            .parse(content, None)
214            .ok_or_else(|| ScribeError::Analysis {
215                message: "Failed to parse source code".to_string(),
216                source: None,
217                file: std::path::PathBuf::from("<unknown>"),
218            })?;
219
220        let mut functions = Vec::new();
221
222        if let Some(query) = &self.function_query {
223            let mut query_cursor = QueryCursor::new();
224            let matches = query_cursor.matches(query, tree.root_node(), content.as_bytes());
225
226            for query_match in matches {
227                if let Some(function_info) =
228                    self.extract_function_from_match(&query_match, content, &tree)?
229                {
230                    functions.push(function_info);
231                }
232            }
233        }
234
235        Ok(functions)
236    }
237
238    /// Extract all classes from source code
239    pub fn extract_classes(&mut self, content: &str) -> Result<Vec<ClassInfo>> {
240        let tree = self
241            .parser
242            .parse(content, None)
243            .ok_or_else(|| ScribeError::Analysis {
244                message: "Failed to parse source code".to_string(),
245                source: None,
246                file: std::path::PathBuf::from("<unknown>"),
247            })?;
248
249        let mut classes = Vec::new();
250
251        if let Some(query) = &self.class_query {
252            let mut query_cursor = QueryCursor::new();
253            let matches = query_cursor.matches(query, tree.root_node(), content.as_bytes());
254
255            for query_match in matches {
256                if let Some(class_info) =
257                    self.extract_class_from_match(&query_match, content, &tree)?
258                {
259                    classes.push(class_info);
260                }
261            }
262        }
263
264        Ok(classes)
265    }
266
267    /// Extract function information from a query match
268    fn extract_function_from_match(
269        &self,
270        query_match: &tree_sitter::QueryMatch,
271        content: &str,
272        tree: &Tree,
273    ) -> Result<Option<FunctionInfo>> {
274        for capture in query_match.captures {
275            let node = capture.node;
276            let start_line = node.start_position().row + 1;
277            let end_line = node.end_position().row + 1;
278
279            // Extract function name from the AST node structure
280            let name = self.extract_function_name(node, content);
281            let parameters = self.extract_function_parameters(node, content);
282
283            if let Some(function_name) = name {
284                return Ok(Some(FunctionInfo {
285                    name: function_name,
286                    start_line,
287                    end_line,
288                    parameters,
289                    return_type: None,   // TODO: Extract return type
290                    documentation: None, // TODO: Extract documentation
291                    visibility: None,    // TODO: Extract visibility
292                    is_method: false,    // TODO: Determine if method
293                    parent_class: None,  // TODO: Find parent class
294                }));
295            }
296        }
297        Ok(None)
298    }
299
300    /// Extract class information from a query match
301    fn extract_class_from_match(
302        &self,
303        query_match: &tree_sitter::QueryMatch,
304        content: &str,
305        tree: &Tree,
306    ) -> Result<Option<ClassInfo>> {
307        for capture in query_match.captures {
308            let node = capture.node;
309            let start_line = node.start_position().row + 1;
310            let end_line = node.end_position().row + 1;
311
312            // Extract class name from the AST node structure
313            let name = self.extract_class_name(node, content);
314            let parents = self.extract_class_parents(node, content);
315
316            if let Some(class_name) = name {
317                return Ok(Some(ClassInfo {
318                    name: class_name,
319                    start_line,
320                    end_line,
321                    parents,
322                    documentation: None, // TODO: Extract documentation
323                    visibility: None,    // TODO: Extract visibility
324                    methods: Vec::new(), // TODO: Extract methods
325                }));
326            }
327        }
328        Ok(None)
329    }
330
331    /// Extract function name from AST node
332    fn extract_function_name(&self, node: Node, content: &str) -> Option<String> {
333        // Look for identifier child nodes that represent the function name
334        let mut cursor = node.walk();
335        cursor.goto_first_child();
336
337        loop {
338            let child = cursor.node();
339            match child.kind() {
340                "identifier" => {
341                    if let Ok(name) = child.utf8_text(content.as_bytes()) {
342                        return Some(name.to_string());
343                    }
344                }
345                _ => {}
346            }
347
348            if !cursor.goto_next_sibling() {
349                break;
350            }
351        }
352        None
353    }
354
355    /// Extract function parameters from AST node
356    fn extract_function_parameters(&self, node: Node, content: &str) -> Vec<String> {
357        let mut parameters = Vec::new();
358        let mut cursor = node.walk();
359        cursor.goto_first_child();
360
361        loop {
362            let child = cursor.node();
363            match child.kind() {
364                "parameters" | "parameter_list" => {
365                    // Extract parameter names from parameter list
366                    let mut param_cursor = child.walk();
367                    param_cursor.goto_first_child();
368
369                    loop {
370                        let param_node = param_cursor.node();
371                        if param_node.kind() == "identifier" {
372                            if let Ok(param_name) = param_node.utf8_text(content.as_bytes()) {
373                                if param_name != "self" {
374                                    parameters.push(param_name.to_string());
375                                }
376                            }
377                        }
378
379                        if !param_cursor.goto_next_sibling() {
380                            break;
381                        }
382                    }
383                    break;
384                }
385                _ => {}
386            }
387
388            if !cursor.goto_next_sibling() {
389                break;
390            }
391        }
392        parameters
393    }
394
395    /// Extract class name from AST node
396    fn extract_class_name(&self, node: Node, content: &str) -> Option<String> {
397        // Look for identifier child nodes that represent the class name
398        let mut cursor = node.walk();
399        cursor.goto_first_child();
400
401        loop {
402            let child = cursor.node();
403            match child.kind() {
404                "identifier" | "type_identifier" => {
405                    if let Ok(name) = child.utf8_text(content.as_bytes()) {
406                        return Some(name.to_string());
407                    }
408                }
409                _ => {}
410            }
411
412            if !cursor.goto_next_sibling() {
413                break;
414            }
415        }
416        None
417    }
418
419    /// Extract class parent classes from AST node
420    fn extract_class_parents(&self, node: Node, content: &str) -> Vec<String> {
421        let mut parents = Vec::new();
422        let mut cursor = node.walk();
423        cursor.goto_first_child();
424
425        loop {
426            let child = cursor.node();
427            match child.kind() {
428                "argument_list" | "superclass" | "inheritance" => {
429                    // Extract parent class names
430                    let mut parent_cursor = child.walk();
431                    parent_cursor.goto_first_child();
432
433                    loop {
434                        let parent_node = parent_cursor.node();
435                        if parent_node.kind() == "identifier"
436                            || parent_node.kind() == "type_identifier"
437                        {
438                            if let Ok(parent_name) = parent_node.utf8_text(content.as_bytes()) {
439                                parents.push(parent_name.to_string());
440                            }
441                        }
442
443                        if !parent_cursor.goto_next_sibling() {
444                            break;
445                        }
446                    }
447                }
448                _ => {}
449            }
450
451            if !cursor.goto_next_sibling() {
452                break;
453            }
454        }
455        parents
456    }
457
458    /// Extract parameter names from parameter list text
459    fn extract_parameters(&self, params_text: &str, _node: Node) -> Vec<String> {
460        // Simple parameter extraction - can be improved per language
461        params_text
462            .split(',')
463            .filter_map(|param| {
464                let param = param.trim();
465                if param.is_empty() || param == "self" {
466                    None
467                } else {
468                    // Extract just the parameter name (before type annotations)
469                    let name = param.split(':').next().unwrap_or(param).trim();
470                    if name.is_empty() {
471                        None
472                    } else {
473                        Some(name.to_string())
474                    }
475                }
476            })
477            .collect()
478    }
479
480    /// Extract parent class names from inheritance clause
481    fn extract_parent_classes(&self, parents_text: &str) -> Vec<String> {
482        // Simple parent class extraction - can be improved per language
483        parents_text
484            .split(',')
485            .filter_map(|parent| {
486                let parent = parent.trim();
487                if parent.is_empty() {
488                    None
489                } else {
490                    Some(parent.to_string())
491                }
492            })
493            .collect()
494    }
495}
496
497#[cfg(test)]
498mod tests {
499    use super::*;
500
501    #[test]
502    fn test_function_extractor_creation() {
503        let extractor = FunctionExtractor::new(AstLanguage::Python);
504        assert!(extractor.is_ok());
505    }
506
507    #[test]
508    fn test_python_function_extraction() {
509        let mut extractor = FunctionExtractor::new(AstLanguage::Python).unwrap();
510        let python_code = r#"
511def hello_world():
512    """A simple function."""
513    print("Hello, World!")
514
515def add_numbers(a, b):
516    """Add two numbers together."""
517    return a + b
518
519class Calculator:
520    """A simple calculator."""
521    
522    def multiply(self, x, y):
523        """Multiply two numbers."""
524        return x * y
525"#;
526
527        let functions = extractor.extract_functions(python_code).unwrap();
528        assert!(!functions.is_empty());
529
530        // Should find at least the standalone functions
531        let function_names: Vec<&String> = functions.iter().map(|f| &f.name).collect();
532        assert!(function_names.contains(&&"hello_world".to_string()));
533        assert!(function_names.contains(&&"add_numbers".to_string()));
534    }
535
536    #[test]
537    fn test_python_class_extraction() {
538        let mut extractor = FunctionExtractor::new(AstLanguage::Python).unwrap();
539        let python_code = r#"
540class Calculator:
541    """A simple calculator."""
542    pass
543
544class AdvancedCalculator(Calculator):
545    """An advanced calculator that inherits from Calculator."""
546    pass
547"#;
548
549        let classes = extractor.extract_classes(python_code).unwrap();
550        assert!(!classes.is_empty());
551
552        let class_names: Vec<&String> = classes.iter().map(|c| &c.name).collect();
553        assert!(class_names.contains(&&"Calculator".to_string()));
554        assert!(class_names.contains(&&"AdvancedCalculator".to_string()));
555    }
556
557    #[test]
558    fn test_javascript_function_extraction() {
559        let mut extractor = FunctionExtractor::new(AstLanguage::JavaScript).unwrap();
560        let js_code = r#"
561function greetUser(name) {
562    return `Hello, ${name}!`;
563}
564
565class UserManager {
566    constructor() {
567        this.users = [];
568    }
569    
570    addUser(user) {
571        this.users.push(user);
572    }
573}
574"#;
575
576        let functions = extractor.extract_functions(js_code).unwrap();
577        assert!(!functions.is_empty());
578    }
579}