1use ast_grep_core::tree_sitter::StrDoc;
2use ast_grep_core::{AstGrep, Pattern};
3use ast_grep_language::{LanguageExt, SupportLang};
4use ignore::WalkBuilder;
5use schemars::JsonSchema;
6use serde::{Deserialize, Serialize};
7use std::fs;
8use std::path::Path;
9use std::str::FromStr;
10use steer_macros::tool;
11use tokio::task;
12
13use crate::result::{AstGrepResult, SearchMatch, SearchResult};
14use crate::{ExecutionContext, ToolError};
15
16#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
17pub struct AstGrepParams {
18 pub pattern: String,
20 pub lang: Option<String>,
22 pub include: Option<String>,
24 pub exclude: Option<String>,
26 pub path: Option<String>,
28}
29
30#[derive(Debug, Serialize, Deserialize)]
31pub struct AstGrepMatch {
32 pub file: String,
33 pub line: usize,
34 pub column: usize,
35 pub matched_code: String,
36 pub context: String,
37}
38
39tool! {
40 AstGrepTool {
41 params: AstGrepParams,
42 output: AstGrepResult,
43 variant: Search,
44 description: r#"Structural code search using abstract syntax trees (AST).
45- Searches code by its syntactic structure, not just text patterns
46- Use $METAVAR placeholders (e.g., $VAR, $FUNC, $ARGS) to match any code element
47- Supports all major languages: rust, javascript, typescript, python, java, go, etc.
48Pattern examples:
49- "console.log($MSG)" - finds all console.log calls regardless of argument
50- "fn $NAME($PARAMS) { $BODY }" - finds all Rust function definitions
51- "if $COND { $THEN } else { $ELSE }" - finds all if-else statements
52- "import $WHAT from '$MODULE'" - finds all ES6 imports from specific modules
53- "$VAR = $VAR + $EXPR" - finds all self-incrementing assignments
54Advanced patterns:
55- "function $FUNC($$$ARGS) { $$$ }" - $$$ matches any number of elements
56- "foo($ARG, ...)" - ellipsis matches remaining arguments
57- Use any valid code as a pattern - ast-grep understands the syntax!
58Automatically respects .gitignore files"#,
59 name: "astgrep",
60 require_approval: false
61 }
62
63 async fn run(
64 _tool: &AstGrepTool,
65 params: AstGrepParams,
66 context: &ExecutionContext,
67 ) -> Result<AstGrepResult, ToolError> {
68 if context.is_cancelled() {
69 return Err(ToolError::Cancelled(AST_GREP_TOOL_NAME.to_string()));
70 }
71
72 let search_path = params.path.as_deref().unwrap_or(".");
73 let base_path = if Path::new(search_path).is_absolute() {
74 Path::new(search_path).to_path_buf()
75 } else {
76 context.working_directory.join(search_path)
77 };
78
79 let pattern = params.pattern.clone();
81 let lang = params.lang.clone();
82 let include = params.include.clone();
83 let exclude = params.exclude.clone();
84 let cancellation_token = context.cancellation_token.clone();
85
86 let result = task::spawn_blocking(move || {
87 astgrep_search_internal(&pattern, lang.as_deref(), include.as_deref(), exclude.as_deref(), &base_path, &cancellation_token)
88 }).await;
89
90 match result {
91 Ok(search_result) => search_result.map_err(|e| ToolError::execution(AST_GREP_TOOL_NAME, e)),
92 Err(e) => Err(ToolError::execution(AST_GREP_TOOL_NAME, format!("Task join error: {e}"))),
93 }
94 }
95}
96
97fn astgrep_search_internal(
98 pattern: &str,
99 lang: Option<&str>,
100 include: Option<&str>,
101 exclude: Option<&str>,
102 base_path: &Path,
103 cancellation_token: &tokio_util::sync::CancellationToken,
104) -> Result<AstGrepResult, String> {
105 if !base_path.exists() {
106 return Err(format!("Path does not exist: {}", base_path.display()));
107 }
108
109 let mut walker = WalkBuilder::new(base_path);
111 walker.hidden(false); walker.git_ignore(true); walker.git_global(true); walker.git_exclude(true); let include_pattern = include
117 .map(|p| glob::Pattern::new(p).map_err(|e| format!("Invalid include glob pattern: {e}")))
118 .transpose()?;
119
120 let exclude_pattern = exclude
121 .map(|p| glob::Pattern::new(p).map_err(|e| format!("Invalid exclude glob pattern: {e}")))
122 .transpose()?;
123
124 let mut all_matches = Vec::new();
125 let mut files_searched = 0;
126
127 for result in walker.build() {
128 if cancellation_token.is_cancelled() {
129 return Ok(AstGrepResult(SearchResult {
130 matches: all_matches,
131 total_files_searched: files_searched,
132 search_completed: false,
133 }));
134 }
135
136 let entry = match result {
137 Ok(e) => e,
138 Err(_) => continue,
139 };
140
141 let path = entry.path();
142 if !path.is_file() {
143 continue;
144 }
145
146 if let Some(ref pattern) = include_pattern {
148 if !path_matches_glob(path, pattern, base_path) {
149 continue;
150 }
151 }
152
153 if let Some(ref pattern) = exclude_pattern {
155 if path_matches_glob(path, pattern, base_path) {
156 continue;
157 }
158 }
159
160 let detected_lang = if let Some(l) = lang {
162 match SupportLang::from_str(l) {
163 Ok(lang) => Some(lang),
164 Err(_) => {
165 continue;
167 }
168 }
169 } else {
170 SupportLang::from_extension(path).or_else(|| {
172 path.extension()
174 .and_then(|ext| ext.to_str())
175 .and_then(|ext| match ext {
176 "jsx" => Some(SupportLang::JavaScript),
177 "mjs" => Some(SupportLang::JavaScript),
178 _ => None,
179 })
180 })
181 };
182
183 let Some(language) = detected_lang else {
185 continue;
186 };
187
188 files_searched += 1;
190 let content = match fs::read_to_string(path) {
191 Ok(c) => c,
192 Err(_) => continue, };
194
195 let ast_grep = language.ast_grep(&content);
197
198 let pattern_matcher = match Pattern::try_new(pattern, language) {
200 Ok(p) => p,
201 Err(e) => return Err(format!("Invalid pattern: {e}")),
202 };
203
204 let relative_path = path.strip_prefix(base_path).unwrap_or(path);
206 let file_matches = find_matches(&ast_grep, &pattern_matcher, relative_path, &content);
207
208 for m in file_matches {
210 all_matches.push(SearchMatch {
211 file_path: m.file,
212 line_number: m.line,
213 line_content: m.context.trim().to_string(),
214 column_range: Some((m.column, m.column + m.matched_code.len())),
215 });
216 }
217 }
218
219 all_matches.sort_by(|a, b| {
221 a.file_path
222 .cmp(&b.file_path)
223 .then(a.line_number.cmp(&b.line_number))
224 });
225
226 Ok(AstGrepResult(SearchResult {
227 matches: all_matches,
228 total_files_searched: files_searched,
229 search_completed: true,
230 }))
231}
232
233fn find_matches(
234 ast_grep: &AstGrep<StrDoc<SupportLang>>,
235 pattern: &Pattern,
236 path: &Path,
237 content: &str,
238) -> Vec<AstGrepMatch> {
239 let root = ast_grep.root();
240 let matches = root.find_all(pattern);
241
242 let mut results = Vec::new();
243 for node_match in matches {
244 let node = node_match.get_node();
245 let range = node.range();
246 let start_pos = node.start_pos();
247
248 let matched_code = node.text();
250
251 let line_start = content[..range.start]
253 .rfind('\n')
254 .map(|i| i + 1)
255 .unwrap_or(0);
256 let line_end = content[range.end..]
257 .find('\n')
258 .map(|i| range.end + i)
259 .unwrap_or(content.len());
260 let context = &content[line_start..line_end];
261
262 results.push(AstGrepMatch {
263 file: path.display().to_string(),
264 line: start_pos.line() + 1, column: start_pos.column(node) + 1,
266 matched_code: matched_code.to_string(),
267 context: context.to_string(),
268 });
269 }
270
271 results
272}
273
274fn path_matches_glob(path: &Path, pattern: &glob::Pattern, base_path: &Path) -> bool {
275 if pattern.matches_path(path) {
277 return true;
278 }
279
280 if let Ok(relative_path) = path.strip_prefix(base_path) {
282 if pattern.matches_path(relative_path) {
283 return true;
284 }
285 }
286
287 if let Some(filename) = path.file_name() {
289 if pattern.matches(&filename.to_string_lossy()) {
290 return true;
291 }
292 }
293
294 false
295}
296
297trait LanguageHelpers {
299 fn from_extension(path: &Path) -> Option<SupportLang>;
300}
301
302impl LanguageHelpers for SupportLang {
303 fn from_extension(path: &Path) -> Option<SupportLang> {
304 ast_grep_language::Language::from_path(path)
305 }
306}
307
308#[cfg(test)]
309mod tests {
310 use super::*;
311 use crate::{ExecutionContext, Tool};
312 use std::fs;
313 use tempfile::tempdir;
314
315 fn create_test_context(temp_dir: &tempfile::TempDir) -> ExecutionContext {
316 ExecutionContext::new("test-call-id".to_string())
317 .with_working_directory(temp_dir.path().to_path_buf())
318 }
319
320 #[tokio::test]
321 async fn test_astgrep_rust_function() {
322 let temp_dir = tempdir().unwrap();
323
324 fs::write(
326 temp_dir.path().join("test.rs"),
327 r#"fn main() {
328 println!("Hello, world!");
329}
330
331fn add(a: i32, b: i32) -> i32 {
332 a + b
333}
334
335async fn fetch_data() -> Result<String, Error> {
336 Ok("data".to_string())
337}"#,
338 )
339 .unwrap();
340
341 let context = create_test_context(&temp_dir);
342
343 let tool = AstGrepTool;
344 let params = AstGrepParams {
345 pattern: "fn $NAME($$$ARGS) { $$$ }".to_string(),
346 lang: Some("rust".to_string()),
347 include: None,
348 exclude: None,
349 path: None,
350 };
351 let params_json = serde_json::to_value(params).unwrap();
352
353 let result = tool.execute(params_json, &context).await.unwrap();
354
355 assert_eq!(result.0.matches.len(), 1);
357 assert!(result.0.matches[0].file_path.contains("test.rs"));
358 assert_eq!(result.0.matches[0].line_number, 1);
359 assert!(result.0.matches[0].line_content.contains("fn main() {"));
360 assert!(result.0.search_completed);
361 }
362
363 #[tokio::test]
364 async fn test_astgrep_javascript_console_log() {
365 let temp_dir = tempdir().unwrap();
366
367 fs::write(
369 temp_dir.path().join("app.js"),
370 r#"console.log("Starting application");
371
372function processData(data) {
373 console.log("Processing:", data);
374 console.error("An error occurred");
375 return data;
376}
377
378console.log("Application ready");"#,
379 )
380 .unwrap();
381
382 let context = create_test_context(&temp_dir);
383
384 let tool = AstGrepTool;
385 let params = AstGrepParams {
386 pattern: "console.log($ARGS)".to_string(),
387 lang: None, include: None,
389 exclude: None,
390 path: None,
391 };
392 let params_json = serde_json::to_value(params).unwrap();
393
394 let result = tool.execute(params_json, &context).await.unwrap();
395
396 assert_eq!(result.0.matches.len(), 2);
398 assert!(result.0.matches.iter().any(|m| {
400 m.file_path.contains("app.js")
401 && m.line_number == 1
402 && m.line_content
403 .contains("console.log(\"Starting application\")")
404 }));
405 assert!(result.0.matches.iter().any(|m| {
407 m.file_path.contains("app.js")
408 && m.line_number == 9
409 && m.line_content
410 .contains("console.log(\"Application ready\")")
411 }));
412 assert!(result.0.search_completed);
413 }
414
415 #[tokio::test]
416 async fn test_astgrep_with_include_pattern() {
417 let temp_dir = tempdir().unwrap();
418
419 fs::write(
421 temp_dir.path().join("module.ts"),
422 "export function getData() { return fetch('/api/data'); }",
423 )
424 .unwrap();
425
426 fs::write(
427 temp_dir.path().join("test.spec.ts"),
428 "describe('test', () => { it('works', () => {}); });",
429 )
430 .unwrap();
431
432 fs::create_dir(temp_dir.path().join("src")).unwrap();
433 fs::write(
434 temp_dir.path().join("src/utils.ts"),
435 "export function processData() { return []; }",
436 )
437 .unwrap();
438
439 let context = create_test_context(&temp_dir);
440
441 let tool = AstGrepTool;
442 let params = AstGrepParams {
443 pattern: "function $NAME($ARGS) { $BODY }".to_string(),
444 lang: Some("typescript".to_string()),
445 include: Some("src/**/*.ts".to_string()),
446 exclude: None,
447 path: None,
448 };
449 let params_json = serde_json::to_value(params).unwrap();
450
451 let result = tool.execute(params_json, &context).await.unwrap();
452
453 assert_eq!(result.0.matches.len(), 0);
455 assert!(result.0.search_completed);
456 }
457
458 #[tokio::test]
459 async fn test_astgrep_no_matches() {
460 let temp_dir = tempdir().unwrap();
461
462 fs::write(
463 temp_dir.path().join("simple.py"),
464 "x = 1\ny = 2\nprint(x + y)",
465 )
466 .unwrap();
467
468 let context = create_test_context(&temp_dir);
469
470 let tool = AstGrepTool;
471 let params = AstGrepParams {
472 pattern: "class $NAME($BASE): $BODY".to_string(),
473 lang: Some("python".to_string()),
474 include: None,
475 exclude: None,
476 path: None,
477 };
478 let params_json = serde_json::to_value(params).unwrap();
479
480 let result = tool.execute(params_json, &context).await.unwrap();
481
482 assert_eq!(result.0.matches.len(), 0);
483 assert!(result.0.search_completed);
484 }
485
486 #[tokio::test]
487 async fn test_astgrep_invalid_path() {
488 let temp_dir = tempdir().unwrap();
489 let context = create_test_context(&temp_dir);
490
491 let tool = AstGrepTool;
492 let params = AstGrepParams {
493 pattern: "fn $NAME()".to_string(),
494 lang: Some("rust".to_string()),
495 include: None,
496 exclude: None,
497 path: Some("non-existent-dir".to_string()),
498 };
499 let params_json = serde_json::to_value(params).unwrap();
500
501 let result = tool.execute(params_json, &context).await;
502
503 assert!(result.is_err());
504 if let Err(e) = result {
505 assert!(e.to_string().contains("Path does not exist"));
506 }
507 }
508}