spec_ai_core/tools/builtin/
grep.rs

1use crate::tools::{Tool, ToolResult};
2use anyhow::{anyhow, Context, Result};
3use async_trait::async_trait;
4use regex::{Regex, RegexBuilder};
5use serde::{Deserialize, Serialize};
6use serde_json::Value;
7use std::fs;
8use std::path::{Path, PathBuf};
9use walkdir::WalkDir;
10
11/// Default maximum number of matching lines to return
12const DEFAULT_MAX_MATCHES: usize = 50;
13/// Hard maximum to prevent context overload
14const HARD_MAX_MATCHES: usize = 200;
15/// Default context lines before/after match
16const DEFAULT_CONTEXT_LINES: usize = 0;
17/// Maximum file size to read (1 MiB)
18const DEFAULT_MAX_FILE_BYTES: usize = 1024 * 1024;
19/// Maximum line length to include (truncate longer lines)
20const MAX_LINE_LENGTH: usize = 500;
21
22#[derive(Debug, Deserialize)]
23struct GrepArgs {
24    /// Pattern to search for (regex or literal)
25    pattern: String,
26    /// Root directory or file to search in
27    path: Option<String>,
28    /// Glob pattern to filter files (e.g., "*.rs", "**/*.py")
29    #[serde(default)]
30    glob: Option<String>,
31    /// Interpret pattern as regex (default: true)
32    #[serde(default = "default_true")]
33    regex: bool,
34    /// Case insensitive search
35    #[serde(default)]
36    case_insensitive: bool,
37    /// Lines of context before match
38    #[serde(rename = "before_context")]
39    before: Option<usize>,
40    /// Lines of context after match
41    #[serde(rename = "after_context")]
42    after: Option<usize>,
43    /// Lines of context before and after (overrides before/after if set)
44    context: Option<usize>,
45    /// Maximum number of matches to return
46    max_matches: Option<usize>,
47}
48
49fn default_true() -> bool {
50    true
51}
52
53#[derive(Debug, Serialize, Deserialize)]
54struct GrepMatch {
55    /// File path containing the match
56    file: String,
57    /// Line number (1-indexed)
58    line_number: usize,
59    /// The matching line content
60    content: String,
61    /// Context lines before the match (if requested)
62    #[serde(skip_serializing_if = "Vec::is_empty", default)]
63    before_context: Vec<ContextLine>,
64    /// Context lines after the match (if requested)
65    #[serde(skip_serializing_if = "Vec::is_empty", default)]
66    after_context: Vec<ContextLine>,
67}
68
69#[derive(Debug, Serialize, Deserialize)]
70struct ContextLine {
71    line_number: usize,
72    content: String,
73}
74
75#[derive(Debug, Serialize, Deserialize)]
76struct GrepResponse {
77    pattern: String,
78    total_matches: usize,
79    matches: Vec<GrepMatch>,
80    truncated: bool,
81}
82
83/// Tool that uses grep-like pattern matching to read specific parts of files.
84///
85/// This tool is designed to help avoid context overload by returning only
86/// the relevant portions of files that match a given pattern, with optional
87/// surrounding context lines.
88pub struct GrepTool {
89    root: PathBuf,
90    max_file_bytes: usize,
91}
92
93impl GrepTool {
94    pub fn new() -> Self {
95        let root = std::env::current_dir().unwrap_or_else(|_| PathBuf::from("."));
96        Self {
97            root,
98            max_file_bytes: DEFAULT_MAX_FILE_BYTES,
99        }
100    }
101
102    pub fn with_root(mut self, root: impl Into<PathBuf>) -> Self {
103        self.root = root.into();
104        self
105    }
106
107    pub fn with_max_file_bytes(mut self, max_file_bytes: usize) -> Self {
108        self.max_file_bytes = max_file_bytes;
109        self
110    }
111
112    fn resolve_path(&self, override_path: &Option<String>) -> PathBuf {
113        override_path
114            .as_ref()
115            .map(PathBuf::from)
116            .unwrap_or_else(|| self.root.clone())
117    }
118
119    fn matches_glob(&self, path: &Path, glob_regex: &Option<Regex>) -> bool {
120        match glob_regex {
121            None => true,
122            Some(regex) => {
123                // Try matching against the full path and just the filename
124                let path_str = path.to_string_lossy();
125                let filename = path.file_name().map(|s| s.to_string_lossy());
126
127                regex.is_match(&path_str) || filename.map(|f| regex.is_match(&f)).unwrap_or(false)
128            }
129        }
130    }
131
132    /// Convert a glob pattern to a regex pattern
133    fn glob_to_regex(glob: &str) -> Result<Regex> {
134        let mut regex = String::with_capacity(glob.len() * 2);
135        regex.push('^');
136
137        let mut chars = glob.chars().peekable();
138        while let Some(c) = chars.next() {
139            match c {
140                '*' => {
141                    if chars.peek() == Some(&'*') {
142                        chars.next(); // consume second *
143                                      // Skip path separator if present after **
144                        if chars.peek() == Some(&'/') {
145                            chars.next();
146                        }
147                        // ** matches any path including separators
148                        regex.push_str(".*");
149                    } else {
150                        // * matches anything except path separator
151                        regex.push_str("[^/]*");
152                    }
153                }
154                '?' => regex.push_str("[^/]"),
155                '.' | '+' | '^' | '$' | '(' | ')' | '{' | '}' | '|' | '\\' => {
156                    regex.push('\\');
157                    regex.push(c);
158                }
159                '[' => {
160                    // Character class - pass through but escape special regex chars inside
161                    regex.push('[');
162                    while let Some(c) = chars.next() {
163                        if c == ']' {
164                            regex.push(']');
165                            break;
166                        }
167                        regex.push(c);
168                    }
169                }
170                _ => regex.push(c),
171            }
172        }
173
174        regex.push('$');
175        Regex::new(&regex).context("Failed to compile glob pattern as regex")
176    }
177
178    fn truncate_line(line: &str) -> String {
179        if line.len() > MAX_LINE_LENGTH {
180            format!("{}...", &line[..MAX_LINE_LENGTH])
181        } else {
182            line.to_string()
183        }
184    }
185
186    fn collect_matches(
187        &self,
188        path: &Path,
189        regex: &regex::Regex,
190        args: &GrepArgs,
191        max_matches: usize,
192        current_count: &mut usize,
193    ) -> Result<Vec<GrepMatch>> {
194        let metadata = fs::metadata(path).context("Failed to read file metadata")?;
195
196        if metadata.len() as usize > self.max_file_bytes {
197            return Ok(Vec::new());
198        }
199
200        let data = fs::read(path).context("Failed to read file")?;
201        let content = match String::from_utf8(data) {
202            Ok(text) => text,
203            Err(_) => return Ok(Vec::new()), // Skip binary files
204        };
205
206        let lines: Vec<&str> = content.lines().collect();
207        let mut matches = Vec::new();
208
209        // Determine context sizes
210        let (before_ctx, after_ctx) = match args.context {
211            Some(c) => (c, c),
212            None => (
213                args.before.unwrap_or(DEFAULT_CONTEXT_LINES),
214                args.after.unwrap_or(DEFAULT_CONTEXT_LINES),
215            ),
216        };
217
218        for (idx, line) in lines.iter().enumerate() {
219            if *current_count >= max_matches {
220                break;
221            }
222
223            if regex.is_match(line) {
224                let line_number = idx + 1;
225
226                // Collect before context
227                let before_context: Vec<ContextLine> = if before_ctx > 0 {
228                    let start = idx.saturating_sub(before_ctx);
229                    (start..idx)
230                        .map(|i| ContextLine {
231                            line_number: i + 1,
232                            content: Self::truncate_line(lines[i]),
233                        })
234                        .collect()
235                } else {
236                    Vec::new()
237                };
238
239                // Collect after context
240                let after_context: Vec<ContextLine> = if after_ctx > 0 {
241                    let end = (idx + 1 + after_ctx).min(lines.len());
242                    ((idx + 1)..end)
243                        .map(|i| ContextLine {
244                            line_number: i + 1,
245                            content: Self::truncate_line(lines[i]),
246                        })
247                        .collect()
248                } else {
249                    Vec::new()
250                };
251
252                matches.push(GrepMatch {
253                    file: path.display().to_string(),
254                    line_number,
255                    content: Self::truncate_line(line),
256                    before_context,
257                    after_context,
258                });
259
260                *current_count += 1;
261            }
262        }
263
264        Ok(matches)
265    }
266}
267
268impl Default for GrepTool {
269    fn default() -> Self {
270        Self::new()
271    }
272}
273
274#[async_trait]
275impl Tool for GrepTool {
276    fn name(&self) -> &str {
277        "grep"
278    }
279
280    fn description(&self) -> &str {
281        "Search for patterns in files using grep-like matching. Returns matching lines with optional context to avoid loading entire files into context."
282    }
283
284    fn parameters(&self) -> Value {
285        serde_json::json!({
286            "type": "object",
287            "properties": {
288                "pattern": {
289                    "type": "string",
290                    "description": "Pattern to search for (regex by default, or literal if regex=false)"
291                },
292                "path": {
293                    "type": "string",
294                    "description": "File or directory to search in (defaults to current workspace)"
295                },
296                "glob": {
297                    "type": "string",
298                    "description": "Glob pattern to filter files (e.g., '*.rs', '**/*.py', 'src/**/*.ts')"
299                },
300                "regex": {
301                    "type": "boolean",
302                    "description": "Interpret pattern as regex (default: true)",
303                    "default": true
304                },
305                "case_insensitive": {
306                    "type": "boolean",
307                    "description": "Case insensitive search (default: false)",
308                    "default": false
309                },
310                "before_context": {
311                    "type": "integer",
312                    "description": "Number of lines to show before each match (like grep -B)"
313                },
314                "after_context": {
315                    "type": "integer",
316                    "description": "Number of lines to show after each match (like grep -A)"
317                },
318                "context": {
319                    "type": "integer",
320                    "description": "Number of lines to show before and after each match (like grep -C, overrides before/after)"
321                },
322                "max_matches": {
323                    "type": "integer",
324                    "description": "Maximum number of matches to return (default: 50, max: 200)"
325                }
326            },
327            "required": ["pattern"]
328        })
329    }
330
331    async fn execute(&self, args: Value) -> Result<ToolResult> {
332        let args: GrepArgs =
333            serde_json::from_value(args).context("Failed to parse grep arguments")?;
334
335        if args.pattern.trim().is_empty() {
336            return Err(anyhow!("grep pattern cannot be empty"));
337        }
338
339        let search_path = self.resolve_path(&args.path);
340        if !search_path.exists() {
341            return Err(anyhow!(
342                "Search path {} does not exist",
343                search_path.display()
344            ));
345        }
346
347        let max_matches = args
348            .max_matches
349            .unwrap_or(DEFAULT_MAX_MATCHES)
350            .clamp(1, HARD_MAX_MATCHES);
351
352        // Build the regex pattern
353        let regex = if args.regex {
354            RegexBuilder::new(&args.pattern)
355                .case_insensitive(args.case_insensitive)
356                .build()
357                .context("Invalid regular expression pattern")?
358        } else {
359            // Escape the pattern for literal matching
360            let escaped = regex::escape(&args.pattern);
361            RegexBuilder::new(&escaped)
362                .case_insensitive(args.case_insensitive)
363                .build()
364                .context("Failed to build literal pattern")?
365        };
366
367        // Parse glob pattern if provided
368        let glob_regex = args
369            .glob
370            .as_ref()
371            .map(|g| Self::glob_to_regex(g))
372            .transpose()?;
373
374        let mut all_matches = Vec::new();
375        let mut match_count = 0;
376
377        if search_path.is_file() {
378            // Search single file
379            if self.matches_glob(&search_path, &glob_regex) {
380                let file_matches = self.collect_matches(
381                    &search_path,
382                    &regex,
383                    &args,
384                    max_matches,
385                    &mut match_count,
386                )?;
387                all_matches.extend(file_matches);
388            }
389        } else {
390            // Walk directory
391            for entry in WalkDir::new(&search_path)
392                .follow_links(false)
393                .into_iter()
394                .filter_map(|e| e.ok())
395            {
396                if match_count >= max_matches {
397                    break;
398                }
399
400                let path = entry.path();
401                if !entry.file_type().is_file() {
402                    continue;
403                }
404
405                // Skip hidden files and common non-text directories
406                let path_str = path.to_string_lossy();
407                if path_str.contains("/.git/")
408                    || path_str.contains("/node_modules/")
409                    || path_str.contains("/target/")
410                    || path_str.contains("/.venv/")
411                    || path_str.contains("/__pycache__/")
412                {
413                    continue;
414                }
415
416                if !self.matches_glob(path, &glob_regex) {
417                    continue;
418                }
419
420                match self.collect_matches(path, &regex, &args, max_matches, &mut match_count) {
421                    Ok(file_matches) => all_matches.extend(file_matches),
422                    Err(_) => continue, // Skip files we can't read
423                }
424            }
425        }
426
427        let truncated = match_count >= max_matches;
428        let response = GrepResponse {
429            pattern: args.pattern,
430            total_matches: all_matches.len(),
431            matches: all_matches,
432            truncated,
433        };
434
435        Ok(ToolResult::success(
436            serde_json::to_string(&response).context("Failed to serialize grep results")?,
437        ))
438    }
439}
440
441#[cfg(test)]
442mod tests {
443    use super::*;
444    use std::fs;
445    use tempfile::tempdir;
446
447    #[tokio::test]
448    async fn test_basic_grep() {
449        let dir = tempdir().unwrap();
450        let file_path = dir.path().join("test.rs");
451        fs::write(
452            &file_path,
453            "fn main() {\n    println!(\"Hello\");\n}\n\nfn other() {\n    println!(\"World\");\n}",
454        )
455        .unwrap();
456
457        let tool = GrepTool::new().with_root(dir.path());
458        let args = serde_json::json!({
459            "pattern": "fn \\w+",
460            "path": dir.path().to_string_lossy()
461        });
462
463        let result = tool.execute(args).await.unwrap();
464        assert!(result.success);
465        let payload: GrepResponse = serde_json::from_str(&result.output).unwrap();
466        assert_eq!(payload.total_matches, 2);
467    }
468
469    #[tokio::test]
470    async fn test_grep_with_context() {
471        let dir = tempdir().unwrap();
472        let file_path = dir.path().join("test.txt");
473        fs::write(&file_path, "line 1\nline 2\nMATCH\nline 4\nline 5").unwrap();
474
475        let tool = GrepTool::new().with_root(dir.path());
476        let args = serde_json::json!({
477            "pattern": "MATCH",
478            "path": file_path.to_string_lossy(),
479            "context": 1,
480            "regex": false
481        });
482
483        let result = tool.execute(args).await.unwrap();
484        assert!(result.success);
485        let payload: GrepResponse = serde_json::from_str(&result.output).unwrap();
486        assert_eq!(payload.total_matches, 1);
487        assert_eq!(payload.matches[0].before_context.len(), 1);
488        assert_eq!(payload.matches[0].after_context.len(), 1);
489        assert_eq!(payload.matches[0].before_context[0].content, "line 2");
490        assert_eq!(payload.matches[0].after_context[0].content, "line 4");
491    }
492
493    #[tokio::test]
494    async fn test_grep_case_insensitive() {
495        let dir = tempdir().unwrap();
496        let file_path = dir.path().join("test.txt");
497        fs::write(&file_path, "Hello World\nhello world\nHELLO WORLD").unwrap();
498
499        let tool = GrepTool::new().with_root(dir.path());
500        let args = serde_json::json!({
501            "pattern": "hello",
502            "path": file_path.to_string_lossy(),
503            "case_insensitive": true,
504            "regex": false
505        });
506
507        let result = tool.execute(args).await.unwrap();
508        assert!(result.success);
509        let payload: GrepResponse = serde_json::from_str(&result.output).unwrap();
510        assert_eq!(payload.total_matches, 3);
511    }
512
513    #[tokio::test]
514    async fn test_grep_with_glob() {
515        let dir = tempdir().unwrap();
516        fs::write(dir.path().join("test.rs"), "fn main()").unwrap();
517        fs::write(dir.path().join("test.py"), "def main()").unwrap();
518        fs::write(dir.path().join("test.txt"), "main function").unwrap();
519
520        let tool = GrepTool::new().with_root(dir.path());
521        let args = serde_json::json!({
522            "pattern": "main",
523            "path": dir.path().to_string_lossy(),
524            "glob": "*.rs",
525            "regex": false
526        });
527
528        let result = tool.execute(args).await.unwrap();
529        assert!(result.success);
530        let payload: GrepResponse = serde_json::from_str(&result.output).unwrap();
531        assert_eq!(payload.total_matches, 1);
532        assert!(payload.matches[0].file.ends_with("test.rs"));
533    }
534
535    #[tokio::test]
536    async fn test_grep_max_matches() {
537        let dir = tempdir().unwrap();
538        let file_path = dir.path().join("test.txt");
539        let content = (0..100)
540            .map(|i| format!("line {}", i))
541            .collect::<Vec<_>>()
542            .join("\n");
543        fs::write(&file_path, content).unwrap();
544
545        let tool = GrepTool::new().with_root(dir.path());
546        let args = serde_json::json!({
547            "pattern": "line",
548            "path": file_path.to_string_lossy(),
549            "max_matches": 5,
550            "regex": false
551        });
552
553        let result = tool.execute(args).await.unwrap();
554        assert!(result.success);
555        let payload: GrepResponse = serde_json::from_str(&result.output).unwrap();
556        assert_eq!(payload.total_matches, 5);
557        assert!(payload.truncated);
558    }
559}