spec_ai_core/tools/builtin/
grep.rs

1use crate::tools::{Tool, ToolResult};
2use anyhow::{anyhow, Context, Result};
3use async_trait::async_trait;
4use regex::{Regex, RegexBuilder};
5use serde::{Deserialize, Serialize};
6use serde_json::Value;
7use std::fs;
8use std::path::{Path, PathBuf};
9use walkdir::WalkDir;
10
11/// Default maximum number of matching lines to return
12const DEFAULT_MAX_MATCHES: usize = 50;
13/// Hard maximum to prevent context overload
14const HARD_MAX_MATCHES: usize = 200;
15/// Default context lines before/after match
16const DEFAULT_CONTEXT_LINES: usize = 0;
17/// Maximum file size to read (1 MiB)
18const DEFAULT_MAX_FILE_BYTES: usize = 1024 * 1024;
19/// Maximum line length to include (truncate longer lines)
20const MAX_LINE_LENGTH: usize = 500;
21
22#[derive(Debug, Deserialize)]
23struct GrepArgs {
24    /// Pattern to search for (regex or literal)
25    pattern: String,
26    /// Root directory or file to search in
27    path: Option<String>,
28    /// Glob pattern to filter files (e.g., "*.rs", "**/*.py")
29    #[serde(default)]
30    glob: Option<String>,
31    /// Interpret pattern as regex (default: true)
32    #[serde(default = "default_true")]
33    regex: bool,
34    /// Case insensitive search
35    #[serde(default)]
36    case_insensitive: bool,
37    /// Lines of context before match
38    #[serde(rename = "before_context")]
39    before: Option<usize>,
40    /// Lines of context after match
41    #[serde(rename = "after_context")]
42    after: Option<usize>,
43    /// Lines of context before and after (overrides before/after if set)
44    context: Option<usize>,
45    /// Maximum number of matches to return
46    max_matches: Option<usize>,
47}
48
49fn default_true() -> bool {
50    true
51}
52
53#[derive(Debug, Serialize, Deserialize)]
54struct GrepMatch {
55    /// File path containing the match
56    file: String,
57    /// Line number (1-indexed)
58    line_number: usize,
59    /// The matching line content
60    content: String,
61    /// Context lines before the match (if requested)
62    #[serde(skip_serializing_if = "Vec::is_empty", default)]
63    before_context: Vec<ContextLine>,
64    /// Context lines after the match (if requested)
65    #[serde(skip_serializing_if = "Vec::is_empty", default)]
66    after_context: Vec<ContextLine>,
67}
68
69#[derive(Debug, Serialize, Deserialize)]
70struct ContextLine {
71    line_number: usize,
72    content: String,
73}
74
75#[derive(Debug, Serialize, Deserialize)]
76struct GrepResponse {
77    pattern: String,
78    total_matches: usize,
79    matches: Vec<GrepMatch>,
80    truncated: bool,
81}
82
83/// Tool that uses grep-like pattern matching to read specific parts of files.
84///
85/// This tool is designed to help avoid context overload by returning only
86/// the relevant portions of files that match a given pattern, with optional
87/// surrounding context lines.
88pub struct GrepTool {
89    root: PathBuf,
90    max_file_bytes: usize,
91}
92
93impl GrepTool {
94    pub fn new() -> Self {
95        let root = std::env::current_dir().unwrap_or_else(|_| PathBuf::from("."));
96        Self {
97            root,
98            max_file_bytes: DEFAULT_MAX_FILE_BYTES,
99        }
100    }
101
102    pub fn with_root(mut self, root: impl Into<PathBuf>) -> Self {
103        self.root = root.into();
104        self
105    }
106
107    pub fn with_max_file_bytes(mut self, max_file_bytes: usize) -> Self {
108        self.max_file_bytes = max_file_bytes;
109        self
110    }
111
112    fn resolve_path(&self, override_path: &Option<String>) -> PathBuf {
113        override_path
114            .as_ref()
115            .map(PathBuf::from)
116            .unwrap_or_else(|| self.root.clone())
117    }
118
119    fn matches_glob(&self, path: &Path, glob_regex: &Option<Regex>) -> bool {
120        match glob_regex {
121            None => true,
122            Some(regex) => {
123                // Try matching against the full path and just the filename
124                let path_str = path.to_string_lossy();
125                let filename = path.file_name().map(|s| s.to_string_lossy());
126
127                regex.is_match(&path_str)
128                    || filename.map(|f| regex.is_match(&f)).unwrap_or(false)
129            }
130        }
131    }
132
133    /// Convert a glob pattern to a regex pattern
134    fn glob_to_regex(glob: &str) -> Result<Regex> {
135        let mut regex = String::with_capacity(glob.len() * 2);
136        regex.push('^');
137
138        let mut chars = glob.chars().peekable();
139        while let Some(c) = chars.next() {
140            match c {
141                '*' => {
142                    if chars.peek() == Some(&'*') {
143                        chars.next(); // consume second *
144                        // Skip path separator if present after **
145                        if chars.peek() == Some(&'/') {
146                            chars.next();
147                        }
148                        // ** matches any path including separators
149                        regex.push_str(".*");
150                    } else {
151                        // * matches anything except path separator
152                        regex.push_str("[^/]*");
153                    }
154                }
155                '?' => regex.push_str("[^/]"),
156                '.' | '+' | '^' | '$' | '(' | ')' | '{' | '}' | '|' | '\\' => {
157                    regex.push('\\');
158                    regex.push(c);
159                }
160                '[' => {
161                    // Character class - pass through but escape special regex chars inside
162                    regex.push('[');
163                    while let Some(c) = chars.next() {
164                        if c == ']' {
165                            regex.push(']');
166                            break;
167                        }
168                        regex.push(c);
169                    }
170                }
171                _ => regex.push(c),
172            }
173        }
174
175        regex.push('$');
176        Regex::new(&regex).context("Failed to compile glob pattern as regex")
177    }
178
179    fn truncate_line(line: &str) -> String {
180        if line.len() > MAX_LINE_LENGTH {
181            format!("{}...", &line[..MAX_LINE_LENGTH])
182        } else {
183            line.to_string()
184        }
185    }
186
187    fn collect_matches(
188        &self,
189        path: &Path,
190        regex: &regex::Regex,
191        args: &GrepArgs,
192        max_matches: usize,
193        current_count: &mut usize,
194    ) -> Result<Vec<GrepMatch>> {
195        let metadata = fs::metadata(path).context("Failed to read file metadata")?;
196
197        if metadata.len() as usize > self.max_file_bytes {
198            return Ok(Vec::new());
199        }
200
201        let data = fs::read(path).context("Failed to read file")?;
202        let content = match String::from_utf8(data) {
203            Ok(text) => text,
204            Err(_) => return Ok(Vec::new()), // Skip binary files
205        };
206
207        let lines: Vec<&str> = content.lines().collect();
208        let mut matches = Vec::new();
209
210        // Determine context sizes
211        let (before_ctx, after_ctx) = match args.context {
212            Some(c) => (c, c),
213            None => (
214                args.before.unwrap_or(DEFAULT_CONTEXT_LINES),
215                args.after.unwrap_or(DEFAULT_CONTEXT_LINES),
216            ),
217        };
218
219        for (idx, line) in lines.iter().enumerate() {
220            if *current_count >= max_matches {
221                break;
222            }
223
224            if regex.is_match(line) {
225                let line_number = idx + 1;
226
227                // Collect before context
228                let before_context: Vec<ContextLine> = if before_ctx > 0 {
229                    let start = idx.saturating_sub(before_ctx);
230                    (start..idx)
231                        .map(|i| ContextLine {
232                            line_number: i + 1,
233                            content: Self::truncate_line(lines[i]),
234                        })
235                        .collect()
236                } else {
237                    Vec::new()
238                };
239
240                // Collect after context
241                let after_context: Vec<ContextLine> = if after_ctx > 0 {
242                    let end = (idx + 1 + after_ctx).min(lines.len());
243                    ((idx + 1)..end)
244                        .map(|i| ContextLine {
245                            line_number: i + 1,
246                            content: Self::truncate_line(lines[i]),
247                        })
248                        .collect()
249                } else {
250                    Vec::new()
251                };
252
253                matches.push(GrepMatch {
254                    file: path.display().to_string(),
255                    line_number,
256                    content: Self::truncate_line(line),
257                    before_context,
258                    after_context,
259                });
260
261                *current_count += 1;
262            }
263        }
264
265        Ok(matches)
266    }
267}
268
269impl Default for GrepTool {
270    fn default() -> Self {
271        Self::new()
272    }
273}
274
275#[async_trait]
276impl Tool for GrepTool {
277    fn name(&self) -> &str {
278        "grep"
279    }
280
281    fn description(&self) -> &str {
282        "Search for patterns in files using grep-like matching. Returns matching lines with optional context to avoid loading entire files into context."
283    }
284
285    fn parameters(&self) -> Value {
286        serde_json::json!({
287            "type": "object",
288            "properties": {
289                "pattern": {
290                    "type": "string",
291                    "description": "Pattern to search for (regex by default, or literal if regex=false)"
292                },
293                "path": {
294                    "type": "string",
295                    "description": "File or directory to search in (defaults to current workspace)"
296                },
297                "glob": {
298                    "type": "string",
299                    "description": "Glob pattern to filter files (e.g., '*.rs', '**/*.py', 'src/**/*.ts')"
300                },
301                "regex": {
302                    "type": "boolean",
303                    "description": "Interpret pattern as regex (default: true)",
304                    "default": true
305                },
306                "case_insensitive": {
307                    "type": "boolean",
308                    "description": "Case insensitive search (default: false)",
309                    "default": false
310                },
311                "before_context": {
312                    "type": "integer",
313                    "description": "Number of lines to show before each match (like grep -B)"
314                },
315                "after_context": {
316                    "type": "integer",
317                    "description": "Number of lines to show after each match (like grep -A)"
318                },
319                "context": {
320                    "type": "integer",
321                    "description": "Number of lines to show before and after each match (like grep -C, overrides before/after)"
322                },
323                "max_matches": {
324                    "type": "integer",
325                    "description": "Maximum number of matches to return (default: 50, max: 200)"
326                }
327            },
328            "required": ["pattern"]
329        })
330    }
331
332    async fn execute(&self, args: Value) -> Result<ToolResult> {
333        let args: GrepArgs =
334            serde_json::from_value(args).context("Failed to parse grep arguments")?;
335
336        if args.pattern.trim().is_empty() {
337            return Err(anyhow!("grep pattern cannot be empty"));
338        }
339
340        let search_path = self.resolve_path(&args.path);
341        if !search_path.exists() {
342            return Err(anyhow!(
343                "Search path {} does not exist",
344                search_path.display()
345            ));
346        }
347
348        let max_matches = args
349            .max_matches
350            .unwrap_or(DEFAULT_MAX_MATCHES)
351            .clamp(1, HARD_MAX_MATCHES);
352
353        // Build the regex pattern
354        let regex = if args.regex {
355            RegexBuilder::new(&args.pattern)
356                .case_insensitive(args.case_insensitive)
357                .build()
358                .context("Invalid regular expression pattern")?
359        } else {
360            // Escape the pattern for literal matching
361            let escaped = regex::escape(&args.pattern);
362            RegexBuilder::new(&escaped)
363                .case_insensitive(args.case_insensitive)
364                .build()
365                .context("Failed to build literal pattern")?
366        };
367
368        // Parse glob pattern if provided
369        let glob_regex = args
370            .glob
371            .as_ref()
372            .map(|g| Self::glob_to_regex(g))
373            .transpose()?;
374
375        let mut all_matches = Vec::new();
376        let mut match_count = 0;
377
378        if search_path.is_file() {
379            // Search single file
380            if self.matches_glob(&search_path, &glob_regex) {
381                let file_matches =
382                    self.collect_matches(&search_path, &regex, &args, max_matches, &mut match_count)?;
383                all_matches.extend(file_matches);
384            }
385        } else {
386            // Walk directory
387            for entry in WalkDir::new(&search_path)
388                .follow_links(false)
389                .into_iter()
390                .filter_map(|e| e.ok())
391            {
392                if match_count >= max_matches {
393                    break;
394                }
395
396                let path = entry.path();
397                if !entry.file_type().is_file() {
398                    continue;
399                }
400
401                // Skip hidden files and common non-text directories
402                let path_str = path.to_string_lossy();
403                if path_str.contains("/.git/")
404                    || path_str.contains("/node_modules/")
405                    || path_str.contains("/target/")
406                    || path_str.contains("/.venv/")
407                    || path_str.contains("/__pycache__/")
408                {
409                    continue;
410                }
411
412                if !self.matches_glob(path, &glob_regex) {
413                    continue;
414                }
415
416                match self.collect_matches(path, &regex, &args, max_matches, &mut match_count) {
417                    Ok(file_matches) => all_matches.extend(file_matches),
418                    Err(_) => continue, // Skip files we can't read
419                }
420            }
421        }
422
423        let truncated = match_count >= max_matches;
424        let response = GrepResponse {
425            pattern: args.pattern,
426            total_matches: all_matches.len(),
427            matches: all_matches,
428            truncated,
429        };
430
431        Ok(ToolResult::success(
432            serde_json::to_string(&response).context("Failed to serialize grep results")?,
433        ))
434    }
435}
436
437#[cfg(test)]
438mod tests {
439    use super::*;
440    use std::fs;
441    use tempfile::tempdir;
442
443    #[tokio::test]
444    async fn test_basic_grep() {
445        let dir = tempdir().unwrap();
446        let file_path = dir.path().join("test.rs");
447        fs::write(
448            &file_path,
449            "fn main() {\n    println!(\"Hello\");\n}\n\nfn other() {\n    println!(\"World\");\n}",
450        )
451        .unwrap();
452
453        let tool = GrepTool::new().with_root(dir.path());
454        let args = serde_json::json!({
455            "pattern": "fn \\w+",
456            "path": dir.path().to_string_lossy()
457        });
458
459        let result = tool.execute(args).await.unwrap();
460        assert!(result.success);
461        let payload: GrepResponse = serde_json::from_str(&result.output).unwrap();
462        assert_eq!(payload.total_matches, 2);
463    }
464
465    #[tokio::test]
466    async fn test_grep_with_context() {
467        let dir = tempdir().unwrap();
468        let file_path = dir.path().join("test.txt");
469        fs::write(
470            &file_path,
471            "line 1\nline 2\nMATCH\nline 4\nline 5",
472        )
473        .unwrap();
474
475        let tool = GrepTool::new().with_root(dir.path());
476        let args = serde_json::json!({
477            "pattern": "MATCH",
478            "path": file_path.to_string_lossy(),
479            "context": 1,
480            "regex": false
481        });
482
483        let result = tool.execute(args).await.unwrap();
484        assert!(result.success);
485        let payload: GrepResponse = serde_json::from_str(&result.output).unwrap();
486        assert_eq!(payload.total_matches, 1);
487        assert_eq!(payload.matches[0].before_context.len(), 1);
488        assert_eq!(payload.matches[0].after_context.len(), 1);
489        assert_eq!(payload.matches[0].before_context[0].content, "line 2");
490        assert_eq!(payload.matches[0].after_context[0].content, "line 4");
491    }
492
493    #[tokio::test]
494    async fn test_grep_case_insensitive() {
495        let dir = tempdir().unwrap();
496        let file_path = dir.path().join("test.txt");
497        fs::write(&file_path, "Hello World\nhello world\nHELLO WORLD").unwrap();
498
499        let tool = GrepTool::new().with_root(dir.path());
500        let args = serde_json::json!({
501            "pattern": "hello",
502            "path": file_path.to_string_lossy(),
503            "case_insensitive": true,
504            "regex": false
505        });
506
507        let result = tool.execute(args).await.unwrap();
508        assert!(result.success);
509        let payload: GrepResponse = serde_json::from_str(&result.output).unwrap();
510        assert_eq!(payload.total_matches, 3);
511    }
512
513    #[tokio::test]
514    async fn test_grep_with_glob() {
515        let dir = tempdir().unwrap();
516        fs::write(dir.path().join("test.rs"), "fn main()").unwrap();
517        fs::write(dir.path().join("test.py"), "def main()").unwrap();
518        fs::write(dir.path().join("test.txt"), "main function").unwrap();
519
520        let tool = GrepTool::new().with_root(dir.path());
521        let args = serde_json::json!({
522            "pattern": "main",
523            "path": dir.path().to_string_lossy(),
524            "glob": "*.rs",
525            "regex": false
526        });
527
528        let result = tool.execute(args).await.unwrap();
529        assert!(result.success);
530        let payload: GrepResponse = serde_json::from_str(&result.output).unwrap();
531        assert_eq!(payload.total_matches, 1);
532        assert!(payload.matches[0].file.ends_with("test.rs"));
533    }
534
535    #[tokio::test]
536    async fn test_grep_max_matches() {
537        let dir = tempdir().unwrap();
538        let file_path = dir.path().join("test.txt");
539        let content = (0..100).map(|i| format!("line {}", i)).collect::<Vec<_>>().join("\n");
540        fs::write(&file_path, content).unwrap();
541
542        let tool = GrepTool::new().with_root(dir.path());
543        let args = serde_json::json!({
544            "pattern": "line",
545            "path": file_path.to_string_lossy(),
546            "max_matches": 5,
547            "regex": false
548        });
549
550        let result = tool.execute(args).await.unwrap();
551        assert!(result.success);
552        let payload: GrepResponse = serde_json::from_str(&result.output).unwrap();
553        assert_eq!(payload.total_matches, 5);
554        assert!(payload.truncated);
555    }
556}