scribe_analysis/language_support/
function_extraction.rs

1//! # Function and Class Extraction from AST
2//!
3//! Extracts function definitions, class definitions, and methods from source code
4//! using tree-sitter AST parsing for accurate analysis.
5
6use serde::{Deserialize, Serialize};
7use tree_sitter::{Parser, Language, Node, Tree, Query, QueryCursor};
8use scribe_core::{Result, ScribeError};
9use super::ast_language::AstLanguage;
10
11/// Information about a function extracted from source code
12#[derive(Debug, Clone, Serialize, Deserialize)]
13pub struct FunctionInfo {
14    /// Function name
15    pub name: String,
16    /// Line number where function starts
17    pub start_line: usize,
18    /// Line number where function ends
19    pub end_line: usize,
20    /// Function parameters
21    pub parameters: Vec<String>,
22    /// Return type (if available)
23    pub return_type: Option<String>,
24    /// Documentation/docstring
25    pub documentation: Option<String>,
26    /// Function visibility (public, private, etc.)
27    pub visibility: Option<String>,
28    /// Whether this is a method (inside a class)
29    pub is_method: bool,
30    /// Parent class name (if this is a method)
31    pub parent_class: Option<String>,
32}
33
34/// Information about a class extracted from source code
35#[derive(Debug, Clone, Serialize, Deserialize)]
36pub struct ClassInfo {
37    /// Class name
38    pub name: String,
39    /// Line number where class starts
40    pub start_line: usize,
41    /// Line number where class ends
42    pub end_line: usize,
43    /// Parent classes/interfaces
44    pub parents: Vec<String>,
45    /// Documentation/docstring
46    pub documentation: Option<String>,
47    /// Class visibility
48    pub visibility: Option<String>,
49    /// Methods in this class
50    pub methods: Vec<FunctionInfo>,
51}
52
53/// Extracts functions and classes from source code using tree-sitter
54pub struct FunctionExtractor {
55    language: AstLanguage,
56    parser: Parser,
57    function_query: Option<Query>,
58    class_query: Option<Query>,
59}
60
61impl FunctionExtractor {
62    /// Create a new function extractor for the given language
63    pub fn new(language: AstLanguage) -> Result<Self> {
64        let mut parser = Parser::new();
65        
66        // Set up tree-sitter language if available
67        let (function_query, class_query) = if let Some(ts_language) = language.tree_sitter_language() {
68            parser.set_language(ts_language)
69                .map_err(|e| ScribeError::Analysis {
70                    message: format!("Failed to set tree-sitter language: {}", e),
71                    source: None,
72                    file: std::path::PathBuf::from("<unknown>"),
73                })?;
74            
75            let function_query = Self::create_function_query(language, ts_language)?;
76            let class_query = Self::create_class_query(language, ts_language)?;
77            (function_query, class_query)
78        } else {
79            (None, None)
80        };
81        
82        Ok(Self {
83            language,
84            parser,
85            function_query,
86            class_query,
87        })
88    }
89    
90    /// Create tree-sitter query for finding functions
91    fn create_function_query(language: AstLanguage, ts_language: Language) -> Result<Option<Query>> {
92        let query_string = match language {
93            AstLanguage::Python => r#"
94                (function_definition) @function.definition
95            "#,
96            AstLanguage::JavaScript | AstLanguage::TypeScript => r#"
97                (function_declaration) @function.definition
98                (method_definition) @function.definition
99            "#,
100            AstLanguage::Rust => r#"
101                (function_item) @function.definition
102            "#,
103            AstLanguage::Go => r#"
104                (function_declaration) @function.definition
105                (method_declaration) @function.definition
106            "#,
107            // Future languages - placeholder queries
108            AstLanguage::Java => r#"
109                (method_declaration) @function.definition
110            "#,
111            AstLanguage::C | AstLanguage::Cpp => r#"
112                (function_definition) @function.definition
113            "#,
114            AstLanguage::Ruby => r#"
115                (method) @function.definition
116            "#,
117            AstLanguage::CSharp => r#"
118                (method_declaration) @function.definition
119            "#,
120            _ => return Ok(None),
121        };
122        
123        Query::new(ts_language, query_string)
124            .map(Some)
125            .map_err(|e| ScribeError::Analysis {
126                message: format!("Failed to create function query: {}", e),
127                source: None,
128                file: std::path::PathBuf::from("<unknown>"),
129            })
130    }
131    
132    /// Create tree-sitter query for finding classes
133    fn create_class_query(language: AstLanguage, ts_language: Language) -> Result<Option<Query>> {
134        let query_string = match language {
135            AstLanguage::Python => r#"
136                (class_definition) @class.definition
137            "#,
138            AstLanguage::JavaScript | AstLanguage::TypeScript => r#"
139                (class_declaration) @class.definition
140            "#,
141            AstLanguage::Rust => r#"
142                (struct_item) @class.definition
143            "#,
144            AstLanguage::Go => r#"
145                (type_declaration) @class.definition
146            "#,
147            // Future languages - placeholder queries
148            AstLanguage::Java => r#"
149                (class_declaration) @class.definition
150            "#,
151            AstLanguage::Cpp => r#"
152                (class_specifier) @class.definition
153            "#,
154            AstLanguage::Ruby => r#"
155                (class) @class.definition
156            "#,
157            AstLanguage::CSharp => r#"
158                (class_declaration) @class.definition
159            "#,
160            _ => return Ok(None),
161        };
162        
163        Query::new(ts_language, query_string)
164            .map(Some)
165            .map_err(|e| ScribeError::Analysis {
166                message: format!("Failed to create class query: {}", e),
167                source: None,
168                file: std::path::PathBuf::from("<unknown>"),
169            })
170    }
171    
172    /// Extract all functions from source code
173    pub fn extract_functions(&mut self, content: &str) -> Result<Vec<FunctionInfo>> {
174        let tree = self.parser.parse(content, None)
175            .ok_or_else(|| ScribeError::Analysis {
176                message: "Failed to parse source code".to_string(),
177                source: None,
178                file: std::path::PathBuf::from("<unknown>"),
179            })?;
180        
181        let mut functions = Vec::new();
182        
183        if let Some(query) = &self.function_query {
184            let mut query_cursor = QueryCursor::new();
185            let matches = query_cursor.matches(query, tree.root_node(), content.as_bytes());
186            
187            for query_match in matches {
188                if let Some(function_info) = self.extract_function_from_match(&query_match, content, &tree)? {
189                    functions.push(function_info);
190                }
191            }
192        }
193        
194        Ok(functions)
195    }
196    
197    /// Extract all classes from source code
198    pub fn extract_classes(&mut self, content: &str) -> Result<Vec<ClassInfo>> {
199        let tree = self.parser.parse(content, None)
200            .ok_or_else(|| ScribeError::Analysis {
201                message: "Failed to parse source code".to_string(),
202                source: None,
203                file: std::path::PathBuf::from("<unknown>"),
204            })?;
205        
206        let mut classes = Vec::new();
207        
208        if let Some(query) = &self.class_query {
209            let mut query_cursor = QueryCursor::new();
210            let matches = query_cursor.matches(query, tree.root_node(), content.as_bytes());
211            
212            for query_match in matches {
213                if let Some(class_info) = self.extract_class_from_match(&query_match, content, &tree)? {
214                    classes.push(class_info);
215                }
216            }
217        }
218        
219        Ok(classes)
220    }
221    
222    /// Extract function information from a query match
223    fn extract_function_from_match(
224        &self,
225        query_match: &tree_sitter::QueryMatch,
226        content: &str,
227        tree: &Tree,
228    ) -> Result<Option<FunctionInfo>> {
229        for capture in query_match.captures {
230            let node = capture.node;
231            let start_line = node.start_position().row + 1;
232            let end_line = node.end_position().row + 1;
233            
234            // Extract function name from the AST node structure
235            let name = self.extract_function_name(node, content);
236            let parameters = self.extract_function_parameters(node, content);
237            
238            if let Some(function_name) = name {
239                return Ok(Some(FunctionInfo {
240                    name: function_name,
241                    start_line,
242                    end_line,
243                    parameters,
244                    return_type: None, // TODO: Extract return type
245                    documentation: None, // TODO: Extract documentation
246                    visibility: None, // TODO: Extract visibility
247                    is_method: false, // TODO: Determine if method
248                    parent_class: None, // TODO: Find parent class
249                }));
250            }
251        }
252        Ok(None)
253    }
254    
255    /// Extract class information from a query match
256    fn extract_class_from_match(
257        &self,
258        query_match: &tree_sitter::QueryMatch,
259        content: &str,
260        tree: &Tree,
261    ) -> Result<Option<ClassInfo>> {
262        for capture in query_match.captures {
263            let node = capture.node;
264            let start_line = node.start_position().row + 1;
265            let end_line = node.end_position().row + 1;
266            
267            // Extract class name from the AST node structure
268            let name = self.extract_class_name(node, content);
269            let parents = self.extract_class_parents(node, content);
270            
271            if let Some(class_name) = name {
272                return Ok(Some(ClassInfo {
273                    name: class_name,
274                    start_line,
275                    end_line,
276                    parents,
277                    documentation: None, // TODO: Extract documentation
278                    visibility: None, // TODO: Extract visibility
279                    methods: Vec::new(), // TODO: Extract methods
280                }));
281            }
282        }
283        Ok(None)
284    }
285    
286    /// Extract function name from AST node
287    fn extract_function_name(&self, node: Node, content: &str) -> Option<String> {
288        // Look for identifier child nodes that represent the function name
289        let mut cursor = node.walk();
290        cursor.goto_first_child();
291        
292        loop {
293            let child = cursor.node();
294            match child.kind() {
295                "identifier" => {
296                    if let Ok(name) = child.utf8_text(content.as_bytes()) {
297                        return Some(name.to_string());
298                    }
299                }
300                _ => {}
301            }
302            
303            if !cursor.goto_next_sibling() {
304                break;
305            }
306        }
307        None
308    }
309    
310    /// Extract function parameters from AST node
311    fn extract_function_parameters(&self, node: Node, content: &str) -> Vec<String> {
312        let mut parameters = Vec::new();
313        let mut cursor = node.walk();
314        cursor.goto_first_child();
315        
316        loop {
317            let child = cursor.node();
318            match child.kind() {
319                "parameters" | "parameter_list" => {
320                    // Extract parameter names from parameter list
321                    let mut param_cursor = child.walk();
322                    param_cursor.goto_first_child();
323                    
324                    loop {
325                        let param_node = param_cursor.node();
326                        if param_node.kind() == "identifier" {
327                            if let Ok(param_name) = param_node.utf8_text(content.as_bytes()) {
328                                if param_name != "self" {
329                                    parameters.push(param_name.to_string());
330                                }
331                            }
332                        }
333                        
334                        if !param_cursor.goto_next_sibling() {
335                            break;
336                        }
337                    }
338                    break;
339                }
340                _ => {}
341            }
342            
343            if !cursor.goto_next_sibling() {
344                break;
345            }
346        }
347        parameters
348    }
349    
350    /// Extract class name from AST node
351    fn extract_class_name(&self, node: Node, content: &str) -> Option<String> {
352        // Look for identifier child nodes that represent the class name
353        let mut cursor = node.walk();
354        cursor.goto_first_child();
355        
356        loop {
357            let child = cursor.node();
358            match child.kind() {
359                "identifier" | "type_identifier" => {
360                    if let Ok(name) = child.utf8_text(content.as_bytes()) {
361                        return Some(name.to_string());
362                    }
363                }
364                _ => {}
365            }
366            
367            if !cursor.goto_next_sibling() {
368                break;
369            }
370        }
371        None
372    }
373    
374    /// Extract class parent classes from AST node
375    fn extract_class_parents(&self, node: Node, content: &str) -> Vec<String> {
376        let mut parents = Vec::new();
377        let mut cursor = node.walk();
378        cursor.goto_first_child();
379        
380        loop {
381            let child = cursor.node();
382            match child.kind() {
383                "argument_list" | "superclass" | "inheritance" => {
384                    // Extract parent class names
385                    let mut parent_cursor = child.walk();
386                    parent_cursor.goto_first_child();
387                    
388                    loop {
389                        let parent_node = parent_cursor.node();
390                        if parent_node.kind() == "identifier" || parent_node.kind() == "type_identifier" {
391                            if let Ok(parent_name) = parent_node.utf8_text(content.as_bytes()) {
392                                parents.push(parent_name.to_string());
393                            }
394                        }
395                        
396                        if !parent_cursor.goto_next_sibling() {
397                            break;
398                        }
399                    }
400                }
401                _ => {}
402            }
403            
404            if !cursor.goto_next_sibling() {
405                break;
406            }
407        }
408        parents
409    }
410    
411    /// Extract parameter names from parameter list text
412    fn extract_parameters(&self, params_text: &str, _node: Node) -> Vec<String> {
413        // Simple parameter extraction - can be improved per language
414        params_text
415            .split(',')
416            .filter_map(|param| {
417                let param = param.trim();
418                if param.is_empty() || param == "self" {
419                    None
420                } else {
421                    // Extract just the parameter name (before type annotations)
422                    let name = param.split(':').next().unwrap_or(param).trim();
423                    if name.is_empty() {
424                        None
425                    } else {
426                        Some(name.to_string())
427                    }
428                }
429            })
430            .collect()
431    }
432    
433    /// Extract parent class names from inheritance clause
434    fn extract_parent_classes(&self, parents_text: &str) -> Vec<String> {
435        // Simple parent class extraction - can be improved per language
436        parents_text
437            .split(',')
438            .filter_map(|parent| {
439                let parent = parent.trim();
440                if parent.is_empty() {
441                    None
442                } else {
443                    Some(parent.to_string())
444                }
445            })
446            .collect()
447    }
448}
449
450impl AstLanguage {
451    /// Get tree-sitter language for supported languages
452    pub fn tree_sitter_language(&self) -> Option<tree_sitter::Language> {
453        match self {
454            AstLanguage::Python => Some(tree_sitter_python::language()),
455            AstLanguage::JavaScript => Some(tree_sitter_javascript::language()),
456            AstLanguage::TypeScript => Some(tree_sitter_typescript::language_typescript()),
457            AstLanguage::Go => Some(tree_sitter_go::language()),
458            AstLanguage::Rust => Some(tree_sitter_rust::language()),
459            AstLanguage::Html => Some(tree_sitter_html::language()),
460            // TODO: Add when dependencies are available
461            // AstLanguage::Java => Some(tree_sitter_java::language()),
462            // AstLanguage::C => Some(tree_sitter_c::language()),
463            // AstLanguage::Cpp => Some(tree_sitter_cpp::language()),
464            // AstLanguage::Ruby => Some(tree_sitter_ruby::language()),
465            // AstLanguage::CSharp => Some(tree_sitter_c_sharp::language()),
466            _ => None,
467        }
468    }
469}
470
471#[cfg(test)]
472mod tests {
473    use super::*;
474    
475    #[test]
476    fn test_function_extractor_creation() {
477        let extractor = FunctionExtractor::new(AstLanguage::Python);
478        assert!(extractor.is_ok());
479    }
480    
481    #[test]
482    fn test_python_function_extraction() {
483        let mut extractor = FunctionExtractor::new(AstLanguage::Python).unwrap();
484        let python_code = r#"
485def hello_world():
486    """A simple function."""
487    print("Hello, World!")
488
489def add_numbers(a, b):
490    """Add two numbers together."""
491    return a + b
492
493class Calculator:
494    """A simple calculator."""
495    
496    def multiply(self, x, y):
497        """Multiply two numbers."""
498        return x * y
499"#;
500        
501        let functions = extractor.extract_functions(python_code).unwrap();
502        assert!(!functions.is_empty());
503        
504        // Should find at least the standalone functions
505        let function_names: Vec<&String> = functions.iter().map(|f| &f.name).collect();
506        assert!(function_names.contains(&&"hello_world".to_string()));
507        assert!(function_names.contains(&&"add_numbers".to_string()));
508    }
509    
510    #[test]
511    fn test_python_class_extraction() {
512        let mut extractor = FunctionExtractor::new(AstLanguage::Python).unwrap();
513        let python_code = r#"
514class Calculator:
515    """A simple calculator."""
516    pass
517
518class AdvancedCalculator(Calculator):
519    """An advanced calculator that inherits from Calculator."""
520    pass
521"#;
522        
523        let classes = extractor.extract_classes(python_code).unwrap();
524        assert!(!classes.is_empty());
525        
526        let class_names: Vec<&String> = classes.iter().map(|c| &c.name).collect();
527        assert!(class_names.contains(&&"Calculator".to_string()));
528        assert!(class_names.contains(&&"AdvancedCalculator".to_string()));
529    }
530    
531    #[test]
532    fn test_javascript_function_extraction() {
533        let mut extractor = FunctionExtractor::new(AstLanguage::JavaScript).unwrap();
534        let js_code = r#"
535function greetUser(name) {
536    return `Hello, ${name}!`;
537}
538
539class UserManager {
540    constructor() {
541        this.users = [];
542    }
543    
544    addUser(user) {
545        this.users.push(user);
546    }
547}
548"#;
549        
550        let functions = extractor.extract_functions(js_code).unwrap();
551        assert!(!functions.is_empty());
552    }
553}