tree_parser/
search.rs

1//! Search functionality for finding code constructs
2
3use crate::{languages::get_tree_sitter_language, CodeConstruct, Error, Language, ParsedFile};
4use regex::Regex;
5use tree_sitter::{Query, QueryCursor};
6use streaming_iterator::StreamingIterator;
7
8/// Search for code constructs by their tree-sitter node type
9/// 
10/// This function searches through all code constructs in a parsed file
11/// and returns those that match the specified node type. Optionally,
12/// results can be filtered by a regex pattern applied to construct names.
13/// 
14/// # Arguments
15/// 
16/// * `parsed_file` - The parsed file to search within
17/// * `node_type` - The tree-sitter node type to search for (e.g., "function_definition")
18/// * `name_pattern` - Optional regex pattern to filter results by construct name
19/// 
20/// # Returns
21/// 
22/// A vector of `CodeConstruct` objects that match the search criteria.
23/// 
24/// # Examples
25/// 
26/// ```rust
27/// use tree_parser::{parse_file, search_by_node_type, Language};
28/// 
29/// #[tokio::main]
30/// async fn main() -> Result<(), Box<dyn std::error::Error>> {
31///     let parsed = parse_file("example.py", Language::Python).await?;
32///     
33///     // Find all function definitions
34///     let functions = search_by_node_type(&parsed, "function_definition", None);
35///     
36///     // Find functions with names starting with "test_"
37///     let test_functions = search_by_node_type(&parsed, "function_definition", Some(r"^test_"));
38///     
39///     println!("Found {} functions, {} are tests", functions.len(), test_functions.len());
40///     Ok(())
41/// }
42/// ```
43pub fn search_by_node_type(
44    parsed_file: &ParsedFile,
45    node_type: &str,
46    name_pattern: Option<&str>,
47) -> Vec<CodeConstruct> {
48    let mut results = Vec::new();
49    
50    // Compile regex pattern if provided
51    let regex = if let Some(pattern) = name_pattern {
52        match Regex::new(pattern) {
53            Ok(r) => Some(r),
54            Err(_) => return results, // Invalid regex, return empty results
55        }
56    } else {
57        None
58    };
59    
60    // Search through all constructs (already flattened, no need for recursive search)
61    for construct in &parsed_file.constructs {
62        if construct.node_type == node_type {
63            // Check name pattern if provided
64            if let Some(ref regex) = regex {
65                if let Some(ref name) = construct.name {
66                    if regex.is_match(name) {
67                        results.push(construct.clone());
68                    }
69                }
70            } else {
71                results.push(construct.clone());
72            }
73        }
74    }
75    
76    results
77}
78
79/// Search for code constructs matching any of the specified node types
80/// 
81/// This function extends `search_by_node_type` to search for multiple node types
82/// simultaneously. This is useful when looking for related constructs that may
83/// have different node types in different languages.
84/// 
85/// # Arguments
86/// 
87/// * `parsed_file` - The parsed file to search within
88/// * `node_types` - Array of tree-sitter node types to search for
89/// * `name_pattern` - Optional regex pattern to filter results by construct name
90/// 
91/// # Returns
92/// 
93/// A vector of `CodeConstruct` objects that match any of the specified node types
94/// and the optional name pattern.
95/// 
96/// # Examples
97/// 
98/// ```rust
99/// use tree_parser::{parse_file, search_by_multiple_node_types, Language};
100/// 
101/// #[tokio::main]
102/// async fn main() -> Result<(), Box<dyn std::error::Error>> {
103///     let parsed = parse_file("example.js", Language::JavaScript).await?;
104///     
105///     // Find all function-like constructs
106///     let functions = search_by_multiple_node_types(
107///         &parsed,
108///         &["function_declaration", "function_expression", "arrow_function"],
109///         None
110///     );
111///     
112///     println!("Found {} function-like constructs", functions.len());
113///     Ok(())
114/// }
115/// ```
116pub fn search_by_multiple_node_types(
117    parsed_file: &ParsedFile,
118    node_types: &[&str],
119    name_pattern: Option<&str>,
120) -> Vec<CodeConstruct> {
121    let mut results = Vec::new();
122    
123    // Compile regex pattern if provided
124    let regex = if let Some(pattern) = name_pattern {
125        match Regex::new(pattern) {
126            Ok(r) => Some(r),
127            Err(_) => return results, // Invalid regex, return empty results
128        }
129    } else {
130        None
131    };
132    
133    // Search through all constructs (already flattened, no need for recursive search)
134    for construct in &parsed_file.constructs {
135        if node_types.contains(&construct.node_type.as_str()) {
136            // Check name pattern if provided
137            if let Some(ref regex) = regex {
138                if let Some(ref name) = construct.name {
139                    if regex.is_match(name) {
140                        results.push(construct.clone());
141                    }
142                }
143            } else {
144                results.push(construct.clone());
145            }
146        }
147    }
148    
149    results
150}
151
152/// Execute a custom tree-sitter query for advanced searching
153/// 
154/// This function allows you to use tree-sitter's powerful query language
155/// to perform complex searches on the syntax tree. This provides the most
156/// flexibility for finding specific code patterns.
157/// 
158/// # Arguments
159/// 
160/// * `parsed_file` - The parsed file to search within
161/// * `tree_sitter_query` - A tree-sitter query string
162/// 
163/// # Returns
164/// 
165/// A `Result` containing a vector of `CodeConstruct` objects that match
166/// the query, or an `Error` if the query is invalid or execution fails.
167/// 
168/// # Examples
169/// 
170/// ```rust
171/// use tree_parser::{parse_file, search_by_query, Language};
172/// 
173/// #[tokio::main]
174/// async fn main() -> Result<(), Box<dyn std::error::Error>> {
175///     let parsed = parse_file("example.py", Language::Python).await?;
176///     
177///     // Find all function definitions with decorators
178///     let query = r#"
179///         (decorated_definition
180///           (function_definition
181///             name: (identifier) @func_name))
182///     "#;
183///     
184///     let decorated_functions = search_by_query(&parsed, query)?;
185///     println!("Found {} decorated functions", decorated_functions.len());
186///     Ok(())
187/// }
188/// ```
189/// 
190/// # Errors
191/// 
192/// This function will return an error if:
193/// - The query syntax is invalid
194/// - The syntax tree is not available
195/// - File I/O operations fail
196pub fn search_by_query(
197    parsed_file: &ParsedFile,
198    tree_sitter_query: &str,
199) -> Result<Vec<CodeConstruct>, Error> {
200    let mut results = Vec::new();
201    
202    // Get the syntax tree
203    let tree = parsed_file.syntax_tree.as_ref()
204        .ok_or_else(|| Error::Parse("No syntax tree available".to_string()))?;
205    
206    // Get the tree-sitter language
207    let ts_language = get_tree_sitter_language(&parsed_file.language)?;
208    
209    // Create and execute query
210    let query = Query::new(&ts_language, tree_sitter_query)
211        .map_err(|e| Error::InvalidQuery(e.to_string()))?;
212    
213    let mut cursor = QueryCursor::new();
214    
215    // Read the source code to extract text
216    let source = std::fs::read_to_string(&parsed_file.file_path)
217        .map_err(|e| Error::Io(e.to_string()))?;
218    
219    let mut matches = cursor.matches(&query, tree.root_node(), source.as_bytes());
220    while let Some(query_match) = matches.next() {
221        for capture in query_match.captures {
222            let node = capture.node;
223            let construct = create_code_construct_from_node(node, &source, &parsed_file.language);
224            results.push(construct);
225        }
226    }
227    
228    Ok(results)
229}
230
231/// Create a CodeConstruct from a tree-sitter node (used in query search)
232fn create_code_construct_from_node(
233    node: tree_sitter::Node,
234    source: &str,
235    _language: &Language,
236) -> CodeConstruct {
237    let start_byte = node.start_byte();
238    let end_byte = node.end_byte();
239    let source_code = source[start_byte..end_byte].to_string();
240    
241    let start_point = node.start_position();
242    let end_point = node.end_position();
243    
244    // Extract name if possible
245    let name = extract_node_name(node, source);
246    
247    CodeConstruct {
248        node_type: node.kind().to_string(),
249        name,
250        source_code,
251        start_line: start_point.row + 1, // Convert to 1-based
252        end_line: end_point.row + 1,
253        start_byte,
254        end_byte,
255        parent: None,
256        children: Vec::new(),
257        metadata: crate::ConstructMetadata {
258            visibility: None,
259            modifiers: Vec::new(),
260            parameters: Vec::new(),
261            return_type: None,
262            inheritance: Vec::new(),
263            annotations: Vec::new(),
264            documentation: None,
265        },
266    }
267}
268
269/// Extract name from a tree-sitter node
270fn extract_node_name(node: tree_sitter::Node, source: &str) -> Option<String> {
271    // Try to find identifier child
272    for i in 0..node.child_count() {
273        if let Some(child) = node.child(i) {
274            if child.kind() == "identifier" || child.kind() == "name" {
275                let start = child.start_byte();
276                let end = child.end_byte();
277                return Some(source[start..end].to_string());
278            }
279        }
280    }
281    None
282}
283
284#[cfg(test)]
285mod tests {
286    use super::*;
287    use crate::{parse_file, Language};
288    use std::fs;
289    use tokio;
290
291    #[tokio::test]
292    async fn test_no_duplicate_results() {
293        // Create a test Python file with nested functions
294        let test_content = r#"
295class CacheEngine:
296    def __init__(self):
297        pass
298    
299    def _allocate_kv_cache(self):
300        return "cache allocated"
301
302    class InnerClass:
303        def _allocate_kv_cache(self):
304            return "inner cache"
305"#;
306        
307        // Write test file
308        let test_file = "test_cache_duplication.py";
309        fs::write(test_file, test_content).expect("Failed to write test file");
310        
311        // Parse the file
312        let parsed = parse_file(test_file, Language::Python).await.expect("Failed to parse file");
313        
314        // Search for function definitions with specific name
315        let functions = search_by_node_type(&parsed, "function_definition", Some("_allocate_kv_cache"));
316        
317        // Should find exactly 2 functions (one in CacheEngine, one in InnerClass)
318        // Before the fix, this would return 4 (each function counted twice due to duplication)
319        assert_eq!(functions.len(), 2, "Expected exactly 2 functions, but found {}", functions.len());
320        
321        // Verify the functions have different parents
322        let mut parent_names = Vec::new();
323        for func in &functions {
324            if let Some(parent) = &func.parent {
325                if let Some(parent_name) = &parent.name {
326                    parent_names.push(parent_name.clone());
327                }
328            }
329        }
330        
331        // Should have 2 different parent classes
332        parent_names.sort();
333        parent_names.dedup();
334        assert_eq!(parent_names.len(), 2, "Expected 2 different parent classes");
335        
336        // Clean up
337        fs::remove_file(test_file).ok();
338    }
339}