tree_parser/search.rs
1//! Search functionality for finding code constructs
2
3use crate::{languages::get_tree_sitter_language, CodeConstruct, Error, Language, ParsedFile};
4use regex::Regex;
5use tree_sitter::{Query, QueryCursor};
6use streaming_iterator::StreamingIterator;
7
8/// Search for code constructs by their tree-sitter node type
9///
10/// This function searches through all code constructs in a parsed file
11/// and returns those that match the specified node type. Optionally,
12/// results can be filtered by a regex pattern applied to construct names.
13///
14/// # Arguments
15///
16/// * `parsed_file` - The parsed file to search within
17/// * `node_type` - The tree-sitter node type to search for (e.g., "function_definition")
18/// * `name_pattern` - Optional regex pattern to filter results by construct name
19///
20/// # Returns
21///
22/// A vector of `CodeConstruct` objects that match the search criteria.
23///
24/// # Examples
25///
26/// ```rust
27/// use tree_parser::{parse_file, search_by_node_type, Language};
28///
29/// #[tokio::main]
30/// async fn main() -> Result<(), Box<dyn std::error::Error>> {
31/// let parsed = parse_file("example.py", Language::Python).await?;
32///
33/// // Find all function definitions
34/// let functions = search_by_node_type(&parsed, "function_definition", None);
35///
36/// // Find functions with names starting with "test_"
37/// let test_functions = search_by_node_type(&parsed, "function_definition", Some(r"^test_"));
38///
39/// println!("Found {} functions, {} are tests", functions.len(), test_functions.len());
40/// Ok(())
41/// }
42/// ```
43pub fn search_by_node_type(
44 parsed_file: &ParsedFile,
45 node_type: &str,
46 name_pattern: Option<&str>,
47) -> Vec<CodeConstruct> {
48 let mut results = Vec::new();
49
50 // Compile regex pattern if provided
51 let regex = if let Some(pattern) = name_pattern {
52 match Regex::new(pattern) {
53 Ok(r) => Some(r),
54 Err(_) => return results, // Invalid regex, return empty results
55 }
56 } else {
57 None
58 };
59
60 // Search through all constructs (already flattened, no need for recursive search)
61 for construct in &parsed_file.constructs {
62 if construct.node_type == node_type {
63 // Check name pattern if provided
64 if let Some(ref regex) = regex {
65 if let Some(ref name) = construct.name {
66 if regex.is_match(name) {
67 results.push(construct.clone());
68 }
69 }
70 } else {
71 results.push(construct.clone());
72 }
73 }
74 }
75
76 results
77}
78
79/// Search for code constructs matching any of the specified node types
80///
81/// This function extends `search_by_node_type` to search for multiple node types
82/// simultaneously. This is useful when looking for related constructs that may
83/// have different node types in different languages.
84///
85/// # Arguments
86///
87/// * `parsed_file` - The parsed file to search within
88/// * `node_types` - Array of tree-sitter node types to search for
89/// * `name_pattern` - Optional regex pattern to filter results by construct name
90///
91/// # Returns
92///
93/// A vector of `CodeConstruct` objects that match any of the specified node types
94/// and the optional name pattern.
95///
96/// # Examples
97///
98/// ```rust
99/// use tree_parser::{parse_file, search_by_multiple_node_types, Language};
100///
101/// #[tokio::main]
102/// async fn main() -> Result<(), Box<dyn std::error::Error>> {
103/// let parsed = parse_file("example.js", Language::JavaScript).await?;
104///
105/// // Find all function-like constructs
106/// let functions = search_by_multiple_node_types(
107/// &parsed,
108/// &["function_declaration", "function_expression", "arrow_function"],
109/// None
110/// );
111///
112/// println!("Found {} function-like constructs", functions.len());
113/// Ok(())
114/// }
115/// ```
116pub fn search_by_multiple_node_types(
117 parsed_file: &ParsedFile,
118 node_types: &[&str],
119 name_pattern: Option<&str>,
120) -> Vec<CodeConstruct> {
121 let mut results = Vec::new();
122
123 // Compile regex pattern if provided
124 let regex = if let Some(pattern) = name_pattern {
125 match Regex::new(pattern) {
126 Ok(r) => Some(r),
127 Err(_) => return results, // Invalid regex, return empty results
128 }
129 } else {
130 None
131 };
132
133 // Search through all constructs (already flattened, no need for recursive search)
134 for construct in &parsed_file.constructs {
135 if node_types.contains(&construct.node_type.as_str()) {
136 // Check name pattern if provided
137 if let Some(ref regex) = regex {
138 if let Some(ref name) = construct.name {
139 if regex.is_match(name) {
140 results.push(construct.clone());
141 }
142 }
143 } else {
144 results.push(construct.clone());
145 }
146 }
147 }
148
149 results
150}
151
152/// Execute a custom tree-sitter query for advanced searching
153///
154/// This function allows you to use tree-sitter's powerful query language
155/// to perform complex searches on the syntax tree. This provides the most
156/// flexibility for finding specific code patterns.
157///
158/// # Arguments
159///
160/// * `parsed_file` - The parsed file to search within
161/// * `tree_sitter_query` - A tree-sitter query string
162///
163/// # Returns
164///
165/// A `Result` containing a vector of `CodeConstruct` objects that match
166/// the query, or an `Error` if the query is invalid or execution fails.
167///
168/// # Examples
169///
170/// ```rust
171/// use tree_parser::{parse_file, search_by_query, Language};
172///
173/// #[tokio::main]
174/// async fn main() -> Result<(), Box<dyn std::error::Error>> {
175/// let parsed = parse_file("example.py", Language::Python).await?;
176///
177/// // Find all function definitions with decorators
178/// let query = r#"
179/// (decorated_definition
180/// (function_definition
181/// name: (identifier) @func_name))
182/// "#;
183///
184/// let decorated_functions = search_by_query(&parsed, query)?;
185/// println!("Found {} decorated functions", decorated_functions.len());
186/// Ok(())
187/// }
188/// ```
189///
190/// # Errors
191///
192/// This function will return an error if:
193/// - The query syntax is invalid
194/// - The syntax tree is not available
195/// - File I/O operations fail
196pub fn search_by_query(
197 parsed_file: &ParsedFile,
198 tree_sitter_query: &str,
199) -> Result<Vec<CodeConstruct>, Error> {
200 let mut results = Vec::new();
201
202 // Get the syntax tree
203 let tree = parsed_file.syntax_tree.as_ref()
204 .ok_or_else(|| Error::Parse("No syntax tree available".to_string()))?;
205
206 // Get the tree-sitter language
207 let ts_language = get_tree_sitter_language(&parsed_file.language)?;
208
209 // Create and execute query
210 let query = Query::new(&ts_language, tree_sitter_query)
211 .map_err(|e| Error::InvalidQuery(e.to_string()))?;
212
213 let mut cursor = QueryCursor::new();
214
215 // Read the source code to extract text
216 let source = std::fs::read_to_string(&parsed_file.file_path)
217 .map_err(|e| Error::Io(e.to_string()))?;
218
219 let mut matches = cursor.matches(&query, tree.root_node(), source.as_bytes());
220 while let Some(query_match) = matches.next() {
221 for capture in query_match.captures {
222 let node = capture.node;
223 let construct = create_code_construct_from_node(node, &source, &parsed_file.language);
224 results.push(construct);
225 }
226 }
227
228 Ok(results)
229}
230
231/// Create a CodeConstruct from a tree-sitter node (used in query search)
232fn create_code_construct_from_node(
233 node: tree_sitter::Node,
234 source: &str,
235 _language: &Language,
236) -> CodeConstruct {
237 let start_byte = node.start_byte();
238 let end_byte = node.end_byte();
239 let source_code = source[start_byte..end_byte].to_string();
240
241 let start_point = node.start_position();
242 let end_point = node.end_position();
243
244 // Extract name if possible
245 let name = extract_node_name(node, source);
246
247 CodeConstruct {
248 node_type: node.kind().to_string(),
249 name,
250 source_code,
251 start_line: start_point.row + 1, // Convert to 1-based
252 end_line: end_point.row + 1,
253 start_byte,
254 end_byte,
255 parent: None,
256 children: Vec::new(),
257 metadata: crate::ConstructMetadata {
258 visibility: None,
259 modifiers: Vec::new(),
260 parameters: Vec::new(),
261 return_type: None,
262 inheritance: Vec::new(),
263 annotations: Vec::new(),
264 documentation: None,
265 },
266 }
267}
268
269/// Extract name from a tree-sitter node
270fn extract_node_name(node: tree_sitter::Node, source: &str) -> Option<String> {
271 // Try to find identifier child
272 for i in 0..node.child_count() {
273 if let Some(child) = node.child(i) {
274 if child.kind() == "identifier" || child.kind() == "name" {
275 let start = child.start_byte();
276 let end = child.end_byte();
277 return Some(source[start..end].to_string());
278 }
279 }
280 }
281 None
282}
283
284#[cfg(test)]
285mod tests {
286 use super::*;
287 use crate::{parse_file, Language};
288 use std::fs;
289 use tokio;
290
291 #[tokio::test]
292 async fn test_no_duplicate_results() {
293 // Create a test Python file with nested functions
294 let test_content = r#"
295class CacheEngine:
296 def __init__(self):
297 pass
298
299 def _allocate_kv_cache(self):
300 return "cache allocated"
301
302 class InnerClass:
303 def _allocate_kv_cache(self):
304 return "inner cache"
305"#;
306
307 // Write test file
308 let test_file = "test_cache_duplication.py";
309 fs::write(test_file, test_content).expect("Failed to write test file");
310
311 // Parse the file
312 let parsed = parse_file(test_file, Language::Python).await.expect("Failed to parse file");
313
314 // Search for function definitions with specific name
315 let functions = search_by_node_type(&parsed, "function_definition", Some("_allocate_kv_cache"));
316
317 // Should find exactly 2 functions (one in CacheEngine, one in InnerClass)
318 // Before the fix, this would return 4 (each function counted twice due to duplication)
319 assert_eq!(functions.len(), 2, "Expected exactly 2 functions, but found {}", functions.len());
320
321 // Verify the functions have different parents
322 let mut parent_names = Vec::new();
323 for func in &functions {
324 if let Some(parent) = &func.parent {
325 if let Some(parent_name) = &parent.name {
326 parent_names.push(parent_name.clone());
327 }
328 }
329 }
330
331 // Should have 2 different parent classes
332 parent_names.sort();
333 parent_names.dedup();
334 assert_eq!(parent_names.len(), 2, "Expected 2 different parent classes");
335
336 // Clean up
337 fs::remove_file(test_file).ok();
338 }
339}