steer_tools/tools/
astgrep.rs

1use ast_grep_core::tree_sitter::StrDoc;
2use ast_grep_core::{AstGrep, Pattern};
3use ast_grep_language::{LanguageExt, SupportLang};
4use ignore::WalkBuilder;
5use schemars::JsonSchema;
6use serde::{Deserialize, Serialize};
7use std::fs;
8use std::path::Path;
9use std::str::FromStr;
10use steer_macros::tool;
11use tokio::task;
12
13use crate::result::{AstGrepResult, SearchMatch, SearchResult};
14use crate::{ExecutionContext, ToolError};
15
16#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
17pub struct AstGrepParams {
18    /// The search pattern (code pattern with $METAVAR placeholders)
19    pub pattern: String,
20    /// Language (rust, tsx, python, etc.)
21    pub lang: Option<String>,
22    /// Optional glob pattern to filter files by name (e.g., "*.rs", "*.{ts,tsx}")
23    pub include: Option<String>,
24    /// Optional glob pattern to exclude files
25    pub exclude: Option<String>,
26    /// Optional directory to search in (defaults to current working directory)
27    pub path: Option<String>,
28}
29
30#[derive(Debug, Serialize, Deserialize)]
31pub struct AstGrepMatch {
32    pub file: String,
33    pub line: usize,
34    pub column: usize,
35    pub matched_code: String,
36    pub context: String,
37}
38
39tool! {
40    AstGrepTool {
41        params: AstGrepParams,
42        output: AstGrepResult,
43        variant: Search,
44        description: r#"Structural code search using abstract syntax trees (AST).
45- Searches code by its syntactic structure, not just text patterns
46- Use $METAVAR placeholders (e.g., $VAR, $FUNC, $ARGS) to match any code element
47- Supports all major languages: rust, javascript, typescript, python, java, go, etc.
48Pattern examples:
49- "console.log($MSG)" - finds all console.log calls regardless of argument
50- "fn $NAME($PARAMS) { $BODY }" - finds all Rust function definitions
51- "if $COND { $THEN } else { $ELSE }" - finds all if-else statements
52- "import $WHAT from '$MODULE'" - finds all ES6 imports from specific modules
53- "$VAR = $VAR + $EXPR" - finds all self-incrementing assignments
54Advanced patterns:
55- "function $FUNC($$$ARGS) { $$$ }" - $$$ matches any number of elements
56- "foo($ARG, ...)" - ellipsis matches remaining arguments
57- Use any valid code as a pattern - ast-grep understands the syntax!
58Automatically respects .gitignore files"#,
59        name: "astgrep",
60        require_approval: false
61    }
62
63    async fn run(
64        _tool: &AstGrepTool,
65        params: AstGrepParams,
66        context: &ExecutionContext,
67    ) -> Result<AstGrepResult, ToolError> {
68        if context.is_cancelled() {
69            return Err(ToolError::Cancelled(AST_GREP_TOOL_NAME.to_string()));
70        }
71
72        let search_path = params.path.as_deref().unwrap_or(".");
73        let base_path = if Path::new(search_path).is_absolute() {
74            Path::new(search_path).to_path_buf()
75        } else {
76            context.working_directory.join(search_path)
77        };
78
79        // Run the blocking search operation in a separate task
80        let pattern = params.pattern.clone();
81        let lang = params.lang.clone();
82        let include = params.include.clone();
83        let exclude = params.exclude.clone();
84        let cancellation_token = context.cancellation_token.clone();
85
86        let result = task::spawn_blocking(move || {
87            astgrep_search_internal(&pattern, lang.as_deref(), include.as_deref(), exclude.as_deref(), &base_path, &cancellation_token)
88        }).await;
89
90        match result {
91            Ok(search_result) => search_result.map_err(|e| ToolError::execution(AST_GREP_TOOL_NAME, e)),
92            Err(e) => Err(ToolError::execution(AST_GREP_TOOL_NAME, format!("Task join error: {e}"))),
93        }
94    }
95}
96
97fn astgrep_search_internal(
98    pattern: &str,
99    lang: Option<&str>,
100    include: Option<&str>,
101    exclude: Option<&str>,
102    base_path: &Path,
103    cancellation_token: &tokio_util::sync::CancellationToken,
104) -> Result<AstGrepResult, String> {
105    if !base_path.exists() {
106        return Err(format!("Path does not exist: {}", base_path.display()));
107    }
108
109    // Use ignore crate's WalkBuilder for respecting .gitignore
110    let mut walker = WalkBuilder::new(base_path);
111    walker.hidden(false); // Include hidden files by default
112    walker.git_ignore(true); // Respect .gitignore files
113    walker.git_global(true); // Respect global gitignore
114    walker.git_exclude(true); // Respect .git/info/exclude
115
116    let include_pattern = include
117        .map(|p| glob::Pattern::new(p).map_err(|e| format!("Invalid include glob pattern: {e}")))
118        .transpose()?;
119
120    let exclude_pattern = exclude
121        .map(|p| glob::Pattern::new(p).map_err(|e| format!("Invalid exclude glob pattern: {e}")))
122        .transpose()?;
123
124    let mut all_matches = Vec::new();
125    let mut files_searched = 0;
126
127    for result in walker.build() {
128        if cancellation_token.is_cancelled() {
129            return Ok(AstGrepResult(SearchResult {
130                matches: all_matches,
131                total_files_searched: files_searched,
132                search_completed: false,
133            }));
134        }
135
136        let entry = match result {
137            Ok(e) => e,
138            Err(_) => continue,
139        };
140
141        let path = entry.path();
142        if !path.is_file() {
143            continue;
144        }
145
146        // Check include pattern if specified
147        if let Some(ref pattern) = include_pattern {
148            if !path_matches_glob(path, pattern, base_path) {
149                continue;
150            }
151        }
152
153        // Check exclude pattern if specified
154        if let Some(ref pattern) = exclude_pattern {
155            if path_matches_glob(path, pattern, base_path) {
156                continue;
157            }
158        }
159
160        // Determine the language based on file extension or user specification
161        let detected_lang = if let Some(l) = lang {
162            match SupportLang::from_str(l) {
163                Ok(lang) => Some(lang),
164                Err(_) => {
165                    // Skip files with unsupported language
166                    continue;
167                }
168            }
169        } else {
170            // Auto-detect language from file extension
171            SupportLang::from_extension(path).or_else(|| {
172                // Fallback to manual extension matching for common cases
173                path.extension()
174                    .and_then(|ext| ext.to_str())
175                    .and_then(|ext| match ext {
176                        "jsx" => Some(SupportLang::JavaScript),
177                        "mjs" => Some(SupportLang::JavaScript),
178                        _ => None,
179                    })
180            })
181        };
182
183        // Skip files without detectable language
184        let Some(language) = detected_lang else {
185            continue;
186        };
187
188        // Read file content
189        files_searched += 1;
190        let content = match fs::read_to_string(path) {
191            Ok(c) => c,
192            Err(_) => continue, // Skip files that can't be read
193        };
194
195        // Parse the file using ast-grep
196        let ast_grep = language.ast_grep(&content);
197
198        // Create pattern matcher
199        let pattern_matcher = match Pattern::try_new(pattern, language) {
200            Ok(p) => p,
201            Err(e) => return Err(format!("Invalid pattern: {e}")),
202        };
203
204        // Find all matches in the file
205        let relative_path = path.strip_prefix(base_path).unwrap_or(path);
206        let file_matches = find_matches(&ast_grep, &pattern_matcher, relative_path, &content);
207
208        // Convert AstGrepMatch to SearchMatch
209        for m in file_matches {
210            all_matches.push(SearchMatch {
211                file_path: m.file,
212                line_number: m.line,
213                line_content: m.context.trim().to_string(),
214                column_range: Some((m.column, m.column + m.matched_code.len())),
215            });
216        }
217    }
218
219    // Sort by file path for consistent output
220    all_matches.sort_by(|a, b| {
221        a.file_path
222            .cmp(&b.file_path)
223            .then(a.line_number.cmp(&b.line_number))
224    });
225
226    Ok(AstGrepResult(SearchResult {
227        matches: all_matches,
228        total_files_searched: files_searched,
229        search_completed: true,
230    }))
231}
232
233fn find_matches(
234    ast_grep: &AstGrep<StrDoc<SupportLang>>,
235    pattern: &Pattern,
236    path: &Path,
237    content: &str,
238) -> Vec<AstGrepMatch> {
239    let root = ast_grep.root();
240    let matches = root.find_all(pattern);
241
242    let mut results = Vec::new();
243    for node_match in matches {
244        let node = node_match.get_node();
245        let range = node.range();
246        let start_pos = node.start_pos();
247
248        // Get the matched code
249        let matched_code = node.text();
250
251        // Get the line content for context
252        let line_start = content[..range.start]
253            .rfind('\n')
254            .map(|i| i + 1)
255            .unwrap_or(0);
256        let line_end = content[range.end..]
257            .find('\n')
258            .map(|i| range.end + i)
259            .unwrap_or(content.len());
260        let context = &content[line_start..line_end];
261
262        results.push(AstGrepMatch {
263            file: path.display().to_string(),
264            line: start_pos.line() + 1, // Convert 0-based to 1-based
265            column: start_pos.column(node) + 1,
266            matched_code: matched_code.to_string(),
267            context: context.to_string(),
268        });
269    }
270
271    results
272}
273
274fn path_matches_glob(path: &Path, pattern: &glob::Pattern, base_path: &Path) -> bool {
275    // Check if the full path matches
276    if pattern.matches_path(path) {
277        return true;
278    }
279
280    // Check if the relative path from base_path matches
281    if let Ok(relative_path) = path.strip_prefix(base_path) {
282        if pattern.matches_path(relative_path) {
283            return true;
284        }
285    }
286
287    // Also check if just the filename matches (for patterns like "*.rs")
288    if let Some(filename) = path.file_name() {
289        if pattern.matches(&filename.to_string_lossy()) {
290            return true;
291        }
292    }
293
294    false
295}
296
297// Helper trait to make language handling cleaner
298trait LanguageHelpers {
299    fn from_extension(path: &Path) -> Option<SupportLang>;
300}
301
302impl LanguageHelpers for SupportLang {
303    fn from_extension(path: &Path) -> Option<SupportLang> {
304        ast_grep_language::Language::from_path(path)
305    }
306}
307
308#[cfg(test)]
309mod tests {
310    use super::*;
311    use crate::{ExecutionContext, Tool};
312    use std::fs;
313    use tempfile::tempdir;
314
315    fn create_test_context(temp_dir: &tempfile::TempDir) -> ExecutionContext {
316        ExecutionContext::new("test-call-id".to_string())
317            .with_working_directory(temp_dir.path().to_path_buf())
318    }
319
320    #[tokio::test]
321    async fn test_astgrep_rust_function() {
322        let temp_dir = tempdir().unwrap();
323
324        // Create a Rust file with functions
325        fs::write(
326            temp_dir.path().join("test.rs"),
327            r#"fn main() {
328    println!("Hello, world!");
329}
330
331fn add(a: i32, b: i32) -> i32 {
332    a + b
333}
334
335async fn fetch_data() -> Result<String, Error> {
336    Ok("data".to_string())
337}"#,
338        )
339        .unwrap();
340
341        let context = create_test_context(&temp_dir);
342
343        let tool = AstGrepTool;
344        let params = AstGrepParams {
345            pattern: "fn $NAME($$$ARGS) { $$$ }".to_string(),
346            lang: Some("rust".to_string()),
347            include: None,
348            exclude: None,
349            path: None,
350        };
351        let params_json = serde_json::to_value(params).unwrap();
352
353        let result = tool.execute(params_json, &context).await.unwrap();
354
355        // Only fn main() matches the pattern - functions with return types have different AST structure
356        assert_eq!(result.0.matches.len(), 1);
357        assert!(result.0.matches[0].file_path.contains("test.rs"));
358        assert_eq!(result.0.matches[0].line_number, 1);
359        assert!(result.0.matches[0].line_content.contains("fn main() {"));
360        assert!(result.0.search_completed);
361    }
362
363    #[tokio::test]
364    async fn test_astgrep_javascript_console_log() {
365        let temp_dir = tempdir().unwrap();
366
367        // Create a JavaScript file
368        fs::write(
369            temp_dir.path().join("app.js"),
370            r#"console.log("Starting application");
371
372function processData(data) {
373    console.log("Processing:", data);
374    console.error("An error occurred");
375    return data;
376}
377
378console.log("Application ready");"#,
379        )
380        .unwrap();
381
382        let context = create_test_context(&temp_dir);
383
384        let tool = AstGrepTool;
385        let params = AstGrepParams {
386            pattern: "console.log($ARGS)".to_string(),
387            lang: None, // Should auto-detect from .js extension
388            include: None,
389            exclude: None,
390            path: None,
391        };
392        let params_json = serde_json::to_value(params).unwrap();
393
394        let result = tool.execute(params_json, &context).await.unwrap();
395
396        // Only top-level console.log calls are found, not ones inside functions
397        assert_eq!(result.0.matches.len(), 2);
398        // Check first match
399        assert!(result.0.matches.iter().any(|m| {
400            m.file_path.contains("app.js")
401                && m.line_number == 1
402                && m.line_content
403                    .contains("console.log(\"Starting application\")")
404        }));
405        // Check second match
406        assert!(result.0.matches.iter().any(|m| {
407            m.file_path.contains("app.js")
408                && m.line_number == 9
409                && m.line_content
410                    .contains("console.log(\"Application ready\")")
411        }));
412        assert!(result.0.search_completed);
413    }
414
415    #[tokio::test]
416    async fn test_astgrep_with_include_pattern() {
417        let temp_dir = tempdir().unwrap();
418
419        // Create multiple files
420        fs::write(
421            temp_dir.path().join("module.ts"),
422            "export function getData() { return fetch('/api/data'); }",
423        )
424        .unwrap();
425
426        fs::write(
427            temp_dir.path().join("test.spec.ts"),
428            "describe('test', () => { it('works', () => {}); });",
429        )
430        .unwrap();
431
432        fs::create_dir(temp_dir.path().join("src")).unwrap();
433        fs::write(
434            temp_dir.path().join("src/utils.ts"),
435            "export function processData() { return []; }",
436        )
437        .unwrap();
438
439        let context = create_test_context(&temp_dir);
440
441        let tool = AstGrepTool;
442        let params = AstGrepParams {
443            pattern: "function $NAME($ARGS) { $BODY }".to_string(),
444            lang: Some("typescript".to_string()),
445            include: Some("src/**/*.ts".to_string()),
446            exclude: None,
447            path: None,
448        };
449        let params_json = serde_json::to_value(params).unwrap();
450
451        let result = tool.execute(params_json, &context).await.unwrap();
452
453        // Export function syntax doesn't match the pattern
454        assert_eq!(result.0.matches.len(), 0);
455        assert!(result.0.search_completed);
456    }
457
458    #[tokio::test]
459    async fn test_astgrep_no_matches() {
460        let temp_dir = tempdir().unwrap();
461
462        fs::write(
463            temp_dir.path().join("simple.py"),
464            "x = 1\ny = 2\nprint(x + y)",
465        )
466        .unwrap();
467
468        let context = create_test_context(&temp_dir);
469
470        let tool = AstGrepTool;
471        let params = AstGrepParams {
472            pattern: "class $NAME($BASE): $BODY".to_string(),
473            lang: Some("python".to_string()),
474            include: None,
475            exclude: None,
476            path: None,
477        };
478        let params_json = serde_json::to_value(params).unwrap();
479
480        let result = tool.execute(params_json, &context).await.unwrap();
481
482        assert_eq!(result.0.matches.len(), 0);
483        assert!(result.0.search_completed);
484    }
485
486    #[tokio::test]
487    async fn test_astgrep_invalid_path() {
488        let temp_dir = tempdir().unwrap();
489        let context = create_test_context(&temp_dir);
490
491        let tool = AstGrepTool;
492        let params = AstGrepParams {
493            pattern: "fn $NAME()".to_string(),
494            lang: Some("rust".to_string()),
495            include: None,
496            exclude: None,
497            path: Some("non-existent-dir".to_string()),
498        };
499        let params_json = serde_json::to_value(params).unwrap();
500
501        let result = tool.execute(params_json, &context).await;
502
503        assert!(result.is_err());
504        if let Err(e) = result {
505            assert!(e.to_string().contains("Path does not exist"));
506        }
507    }
508}