steer_tools/tools/
grep.rs

1use glob;
2use grep_regex::RegexMatcherBuilder;
3use grep_searcher::sinks::UTF8;
4use grep_searcher::{BinaryDetection, SearcherBuilder};
5use ignore::WalkBuilder;
6use schemars::JsonSchema;
7use serde::{Deserialize, Serialize};
8use std::path::Path;
9use steer_macros::tool;
10use tokio::task;
11
12use crate::{
13    ExecutionContext, ToolError,
14    result::{GrepResult, SearchResult},
15};
16
17/// Match from grep search
18pub type GrepMatch = crate::result::SearchMatch;
19
20#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
21pub struct GrepParams {
22    /// The search pattern (regex or literal string). If invalid regex, searches for literal text
23    pub pattern: String,
24    /// Optional glob pattern to filter files by name (e.g., "*.rs", "*.{ts,tsx}")
25    pub include: Option<String>,
26    /// Optional directory to search in (defaults to current working directory)
27    pub path: Option<String>,
28}
29
30tool! {
31    GrepTool {
32        params: GrepParams,
33        output: GrepResult,
34        variant: Search,
35        description: r#"Fast content search built on ripgrep for blazing performance at any scale.
36- Searches using regular expressions or literal strings
37- Supports regex syntax like "log.*Error", "function\\s+\\w+", etc.
38- If the pattern isn't valid regex, it automatically searches for the literal text
39- Filter files by name pattern with include parameter (e.g., "*.js", "*.{ts,tsx}")
40- Automatically respects .gitignore files
41- Returns matches as "filepath:line_number: line_content""#,
42        name: "grep",
43        require_approval: false
44    }
45
46    async fn run(
47        _tool: &GrepTool,
48        params: GrepParams,
49        context: &ExecutionContext,
50    ) -> Result<GrepResult, ToolError> {
51        if context.is_cancelled() {
52            return Err(ToolError::Cancelled(GREP_TOOL_NAME.to_string()));
53        }
54
55        let search_path = params.path.as_deref().unwrap_or(".");
56        let base_path = if Path::new(search_path).is_absolute() {
57            Path::new(search_path).to_path_buf()
58        } else {
59            context.working_directory.join(search_path)
60        };
61
62        // Run the blocking search operation in a separate task
63        let pattern = params.pattern.clone();
64        let include = params.include.clone();
65        let cancellation_token = context.cancellation_token.clone();
66
67        let result = task::spawn_blocking(move || {
68            grep_search_internal(&pattern, include.as_deref(), &base_path, &cancellation_token)
69        }).await;
70
71        match result {
72            Ok(search_result) => search_result.map_err(|e| ToolError::execution(GREP_TOOL_NAME, e)),
73            Err(e) => Err(ToolError::execution(GREP_TOOL_NAME, format!("Task join error: {e}"))),
74        }
75    }
76}
77
78fn grep_search_internal(
79    pattern: &str,
80    include: Option<&str>,
81    base_path: &Path,
82    cancellation_token: &tokio_util::sync::CancellationToken,
83) -> Result<GrepResult, String> {
84    if !base_path.exists() {
85        return Err(format!("Path does not exist: {}", base_path.display()));
86    }
87
88    // Create matcher - try regex first, fall back to literal if it fails
89    let matcher = match RegexMatcherBuilder::new()
90        .line_terminator(Some(b'\n'))
91        .build(pattern)
92    {
93        Ok(m) => m,
94        Err(_) => {
95            // Fall back to literal search by escaping the pattern
96            let escaped = regex::escape(pattern);
97            RegexMatcherBuilder::new()
98                .line_terminator(Some(b'\n'))
99                .build(&escaped)
100                .map_err(|e| format!("Failed to create matcher: {e}"))?
101        }
102    };
103
104    // Build the searcher with binary detection
105    let mut searcher = SearcherBuilder::new()
106        .binary_detection(BinaryDetection::quit(b'\x00'))
107        .line_number(true)
108        .build();
109
110    // Use ignore crate's WalkBuilder for respecting .gitignore
111    let mut walker = WalkBuilder::new(base_path);
112    walker.hidden(false); // Include hidden files by default
113    walker.git_ignore(true); // Respect .gitignore files
114    walker.git_global(true); // Respect global gitignore
115    walker.git_exclude(true); // Respect .git/info/exclude
116
117    let include_pattern = include
118        .map(|p| glob::Pattern::new(p).map_err(|e| format!("Invalid glob pattern: {e}")))
119        .transpose()?;
120
121    let mut all_matches = Vec::new();
122    let mut files_searched = 0;
123
124    for result in walker.build() {
125        if cancellation_token.is_cancelled() {
126            return Ok(GrepResult(SearchResult {
127                matches: all_matches,
128                total_files_searched: files_searched,
129                search_completed: false,
130            }));
131        }
132
133        let entry = match result {
134            Ok(e) => e,
135            Err(_) => continue,
136        };
137
138        let path = entry.path();
139        if !path.is_file() {
140            continue;
141        }
142
143        // Check include pattern if specified
144        if let Some(ref pattern) = include_pattern {
145            if !path_matches_glob(path, pattern, base_path) {
146                continue;
147            }
148        }
149
150        files_searched += 1;
151
152        // Search the file
153        let mut matches_in_file = Vec::new();
154        let search_result = searcher.search_path(
155            &matcher,
156            path,
157            UTF8(|line_num, line| {
158                // Canonicalize the path for clean output
159                let display_path = match path.canonicalize() {
160                    Ok(canonical) => canonical.display().to_string(),
161                    // If canonicalization fails (e.g., file doesn't exist), fall back to regular display
162                    Err(_) => path.display().to_string(),
163                };
164                matches_in_file.push(GrepMatch {
165                    file_path: display_path,
166                    line_number: line_num as usize,
167                    line_content: line.trim_end().to_string(),
168                    column_range: None,
169                });
170                Ok(true)
171            }),
172        );
173
174        if let Err(e) = search_result {
175            // Skip files that can't be searched (e.g., binary files)
176            if e.kind() == std::io::ErrorKind::InvalidData {
177                continue;
178            }
179        }
180
181        // Add all matches from this file to the overall results
182        all_matches.extend(matches_in_file);
183    }
184
185    // Sort matches by file modification time (newest first)
186    if !all_matches.is_empty() {
187        // Group matches by file
188        let mut file_groups: std::collections::HashMap<String, Vec<GrepMatch>> =
189            std::collections::HashMap::new();
190        for match_item in all_matches {
191            file_groups
192                .entry(match_item.file_path.clone())
193                .or_default()
194                .push(match_item);
195        }
196
197        // Get modification times and sort
198        let mut sorted_files: Vec<(String, std::time::SystemTime)> = Vec::new();
199        for file_path in file_groups.keys() {
200            if cancellation_token.is_cancelled() {
201                return Ok(GrepResult(SearchResult {
202                    matches: Vec::new(),
203                    total_files_searched: files_searched,
204                    search_completed: false,
205                }));
206            }
207
208            let mtime = Path::new(file_path)
209                .metadata()
210                .and_then(|m| m.modified())
211                .unwrap_or(std::time::SystemTime::UNIX_EPOCH);
212            sorted_files.push((file_path.clone(), mtime));
213        }
214        sorted_files.sort_by(|a, b| b.1.cmp(&a.1));
215
216        // Rebuild matches in sorted order
217        let mut sorted_matches = Vec::new();
218        for (file_path, _) in sorted_files {
219            if let Some(file_matches) = file_groups.remove(&file_path) {
220                sorted_matches.extend(file_matches);
221            }
222        }
223        all_matches = sorted_matches;
224    }
225
226    Ok(GrepResult(SearchResult {
227        matches: all_matches,
228        total_files_searched: files_searched,
229        search_completed: true,
230    }))
231}
232
233fn path_matches_glob(path: &Path, pattern: &glob::Pattern, base_path: &Path) -> bool {
234    // Check if the full path matches
235    if pattern.matches_path(path) {
236        return true;
237    }
238
239    // Check if the relative path from base_path matches
240    if let Ok(relative_path) = path.strip_prefix(base_path) {
241        if pattern.matches_path(relative_path) {
242            return true;
243        }
244    }
245
246    // Also check if just the filename matches (for patterns like "*.rs")
247    if let Some(filename) = path.file_name() {
248        if pattern.matches(&filename.to_string_lossy()) {
249            return true;
250        }
251    }
252
253    false
254}
255
256#[cfg(test)]
257mod tests {
258    use super::*;
259    use crate::{ExecutionContext, Tool};
260    use std::fs;
261    use std::path::Path;
262    use tempfile::tempdir;
263    use tokio_util::sync::CancellationToken;
264
265    fn create_test_files(dir: &Path) {
266        fs::write(dir.join("file1.txt"), "hello world\nfind me here").unwrap();
267        fs::write(
268            dir.join("file2.log"),
269            "another file\nwith some logs\nLOG-123: an error",
270        )
271        .unwrap();
272        fs::create_dir(dir.join("subdir")).unwrap();
273        fs::write(
274            dir.join("subdir/file3.txt"),
275            "nested file\nshould also be found",
276        )
277        .unwrap();
278        // A file with non-utf8 content to test robustness
279        fs::write(dir.join("binary.dat"), [0, 159, 146, 150]).unwrap();
280    }
281
282    fn create_test_context(temp_dir: &tempfile::TempDir) -> ExecutionContext {
283        ExecutionContext::new("test-call-id".to_string())
284            .with_working_directory(temp_dir.path().to_path_buf())
285    }
286
287    #[tokio::test]
288    async fn test_grep_simple_match() {
289        let temp_dir = tempdir().unwrap();
290        create_test_files(temp_dir.path());
291        let context = create_test_context(&temp_dir);
292
293        let tool = GrepTool;
294        let params = GrepParams {
295            pattern: "find me".to_string(),
296            include: None,
297            path: None,
298        };
299        let params_json = serde_json::to_value(params).unwrap();
300
301        let result = tool.execute(params_json, &context).await.unwrap();
302
303        assert_eq!(result.0.matches.len(), 1);
304        assert!(result.0.matches[0].file_path.contains("file1.txt"));
305        assert_eq!(result.0.matches[0].line_number, 2);
306        assert_eq!(result.0.matches[0].line_content, "find me here");
307        assert!(result.0.search_completed);
308        assert!(result.0.total_files_searched > 0);
309    }
310
311    #[tokio::test]
312    async fn test_grep_regex_match() {
313        let temp_dir = tempdir().unwrap();
314        create_test_files(temp_dir.path());
315        let context = create_test_context(&temp_dir);
316
317        let tool = GrepTool;
318        let params = GrepParams {
319            pattern: r"LOG-\d+".to_string(),
320            include: None,
321            path: None,
322        };
323        let params_json = serde_json::to_value(params).unwrap();
324
325        let result = tool.execute(params_json, &context).await.unwrap();
326
327        assert_eq!(result.0.matches.len(), 1);
328        assert!(result.0.matches[0].file_path.contains("file2.log"));
329        assert_eq!(result.0.matches[0].line_number, 3);
330        assert_eq!(result.0.matches[0].line_content, "LOG-123: an error");
331        assert!(result.0.search_completed);
332    }
333
334    #[tokio::test]
335    async fn test_grep_no_matches() {
336        let temp_dir = tempdir().unwrap();
337        create_test_files(temp_dir.path());
338        let context = create_test_context(&temp_dir);
339
340        let tool = GrepTool;
341        let params = GrepParams {
342            pattern: "non-existent pattern".to_string(),
343            include: None,
344            path: None,
345        };
346        let params_json = serde_json::to_value(params).unwrap();
347
348        let result = tool.execute(params_json, &context).await.unwrap();
349
350        assert_eq!(result.0.matches.len(), 0);
351        assert!(result.0.search_completed);
352        assert!(result.0.total_files_searched > 0);
353    }
354
355    #[tokio::test]
356    async fn test_grep_with_path() {
357        let temp_dir = tempdir().unwrap();
358        create_test_files(temp_dir.path());
359        let context = create_test_context(&temp_dir);
360
361        let tool = GrepTool;
362        let params = GrepParams {
363            pattern: "nested".to_string(),
364            include: None,
365            path: Some("subdir".to_string()),
366        };
367        let params_json = serde_json::to_value(params).unwrap();
368
369        let result = tool.execute(params_json, &context).await.unwrap();
370
371        assert_eq!(result.0.matches.len(), 1);
372        assert!(result.0.matches[0].file_path.contains("subdir/file3.txt"));
373        assert_eq!(result.0.matches[0].line_number, 1);
374        assert_eq!(result.0.matches[0].line_content, "nested file");
375        assert!(result.0.search_completed);
376    }
377
378    #[tokio::test]
379    async fn test_grep_with_include() {
380        let temp_dir = tempdir().unwrap();
381        create_test_files(temp_dir.path());
382        let context = create_test_context(&temp_dir);
383
384        let tool = GrepTool;
385        let params = GrepParams {
386            pattern: "file".to_string(),
387            include: Some("*.log".to_string()),
388            path: None,
389        };
390        let params_json = serde_json::to_value(params).unwrap();
391
392        let result = tool.execute(params_json, &context).await.unwrap();
393
394        assert_eq!(result.0.matches.len(), 1);
395        assert!(result.0.matches[0].file_path.contains("file2.log"));
396        assert_eq!(result.0.matches[0].line_number, 1);
397        assert_eq!(result.0.matches[0].line_content, "another file");
398        assert!(result.0.search_completed);
399    }
400
401    #[tokio::test]
402    async fn test_grep_non_existent_path() {
403        let temp_dir = tempdir().unwrap();
404        create_test_files(temp_dir.path());
405        let context = create_test_context(&temp_dir);
406
407        let tool = GrepTool;
408        let params = GrepParams {
409            pattern: "any".to_string(),
410            include: None,
411            path: Some("non-existent-dir".to_string()),
412        };
413        let params_json = serde_json::to_value(params).unwrap();
414
415        let result = tool.execute(params_json, &context).await;
416
417        assert!(matches!(result, Err(ToolError::Execution { .. })));
418        if let Err(ToolError::Execution { message, .. }) = result {
419            assert!(message.contains("Path does not exist"));
420        }
421    }
422
423    #[tokio::test]
424    async fn test_grep_cancellation() {
425        let temp_dir = tempdir().unwrap();
426        create_test_files(temp_dir.path());
427
428        let token = CancellationToken::new();
429        token.cancel(); // Cancel immediately
430
431        let context = ExecutionContext::new("test-call-id".to_string())
432            .with_working_directory(temp_dir.path().to_path_buf())
433            .with_cancellation_token(token);
434
435        let tool = GrepTool;
436        let params = GrepParams {
437            pattern: "hello".to_string(),
438            include: None,
439            path: None,
440        };
441        let params_json = serde_json::to_value(params).unwrap();
442
443        let result = tool.execute(params_json, &context).await;
444
445        assert!(matches!(result, Err(ToolError::Cancelled(_))));
446    }
447
448    #[tokio::test]
449    async fn test_grep_respects_gitignore() {
450        let temp_dir = tempdir().unwrap();
451
452        // Initialize a git repository (required for .gitignore to work)
453        fs::create_dir(temp_dir.path().join(".git")).unwrap();
454
455        // Create test files
456        fs::write(
457            temp_dir.path().join("file1.txt"),
458            "hello world\nfind me here",
459        )
460        .unwrap();
461        fs::write(
462            temp_dir.path().join("ignored.txt"),
463            "this should be ignored\nfind me here",
464        )
465        .unwrap();
466        fs::write(
467            temp_dir.path().join("also_ignored.log"),
468            "another ignored file\nfind me here",
469        )
470        .unwrap();
471
472        // Create .gitignore file
473        fs::write(temp_dir.path().join(".gitignore"), "ignored.txt\n*.log").unwrap();
474
475        let context = create_test_context(&temp_dir);
476
477        let tool = GrepTool;
478        let params = GrepParams {
479            pattern: "find me here".to_string(),
480            include: None,
481            path: None,
482        };
483        let params_json = serde_json::to_value(params).unwrap();
484
485        let result = tool.execute(params_json, &context).await.unwrap();
486
487        // Should find the match in file1.txt but not in ignored files
488        assert_eq!(result.0.matches.len(), 1);
489        assert!(result.0.matches[0].file_path.contains("file1.txt"));
490        assert_eq!(result.0.matches[0].line_number, 2);
491        assert_eq!(result.0.matches[0].line_content, "find me here");
492        assert!(result.0.search_completed);
493    }
494
495    #[tokio::test]
496    async fn test_grep_literal_fallback() {
497        let temp_dir = tempdir().unwrap();
498
499        // Create test files with patterns that would fail as regex
500        fs::write(
501            temp_dir.path().join("code.rs"),
502            "fn main() {\n    format_message(\"hello\");\n    println!(\"world\");\n}",
503        )
504        .unwrap();
505
506        let context = create_test_context(&temp_dir);
507
508        let tool = GrepTool;
509        let params = GrepParams {
510            pattern: "format_message(".to_string(), // This would fail as regex due to unclosed (
511            include: None,
512            path: None,
513        };
514        let params_json = serde_json::to_value(params).unwrap();
515
516        let result = tool.execute(params_json, &context).await.unwrap();
517
518        // Should find the literal match
519        assert_eq!(result.0.matches.len(), 1);
520        assert!(result.0.matches[0].file_path.contains("code.rs"));
521        assert_eq!(result.0.matches[0].line_number, 2);
522        assert!(
523            result.0.matches[0]
524                .line_content
525                .contains("format_message(\"hello\");")
526        );
527        assert!(result.0.search_completed);
528    }
529
530    #[tokio::test]
531    async fn test_grep_relative_path_glob_matching() {
532        let temp_dir = tempdir().unwrap();
533
534        // Create nested directory structure similar to steer project
535        fs::create_dir_all(temp_dir.path().join("steer/src/session")).unwrap();
536        fs::create_dir_all(temp_dir.path().join("steer/src/utils")).unwrap();
537        fs::create_dir_all(temp_dir.path().join("other/src")).unwrap();
538
539        // Create test files
540        fs::write(
541            temp_dir.path().join("steer/src/session/state.rs"),
542            "pub struct SessionConfig {\n    pub field: String,\n}",
543        )
544        .unwrap();
545        fs::write(
546            temp_dir.path().join("steer/src/utils/session.rs"),
547            "use crate::SessionConfig;\nfn test() -> SessionConfig {\n    SessionConfig { field: \"test\".to_string() }\n}",
548        )
549        .unwrap();
550        fs::write(
551            temp_dir.path().join("other/src/main.rs"),
552            "struct SessionConfig;\nfn main() {}",
553        )
554        .unwrap();
555
556        let context = create_test_context(&temp_dir);
557
558        let tool = GrepTool;
559        let params = GrepParams {
560            pattern: "SessionConfig \\{".to_string(),
561            include: Some("steer/src/**/*.rs".to_string()),
562            path: None,
563        };
564        let params_json = serde_json::to_value(params).unwrap();
565
566        let result = tool.execute(params_json, &context).await.unwrap();
567
568        // Should find 3 matches in steer/src files but not in other/src
569        assert_eq!(result.0.matches.len(), 3);
570        assert!(result.0.matches.iter().any(|m| {
571            m.file_path.contains("steer/src/session/state.rs")
572                && m.line_number == 1
573                && m.line_content == "pub struct SessionConfig {"
574        }));
575        assert!(result.0.matches.iter().any(|m| {
576            m.file_path.contains("steer/src/utils/session.rs")
577                && m.line_number == 2
578                && m.line_content == "fn test() -> SessionConfig {"
579        }));
580        assert!(result.0.matches.iter().any(|m| {
581            m.file_path.contains("steer/src/utils/session.rs")
582                && m.line_number == 3
583                && m.line_content
584                    .contains("SessionConfig { field: \"test\".to_string() }")
585        }));
586        // Ensure no matches from other/src
587        assert!(
588            !result
589                .0
590                .matches
591                .iter()
592                .any(|m| m.file_path.contains("other/src"))
593        );
594        assert!(result.0.search_completed);
595    }
596
597    #[tokio::test]
598    async fn test_grep_complex_relative_patterns() {
599        let temp_dir = tempdir().unwrap();
600
601        // Create complex directory structure
602        fs::create_dir_all(temp_dir.path().join("src/api/client")).unwrap();
603        fs::create_dir_all(temp_dir.path().join("src/tools")).unwrap();
604        fs::create_dir_all(temp_dir.path().join("tests/integration")).unwrap();
605
606        // Create test files
607        fs::write(
608            temp_dir.path().join("src/api/client/mod.rs"),
609            "pub mod client;\npub use client::ApiClient;",
610        )
611        .unwrap();
612        fs::write(
613            temp_dir.path().join("src/tools/grep.rs"),
614            "pub struct GrepTool;\nimpl Tool for GrepTool {}",
615        )
616        .unwrap();
617        fs::write(
618            temp_dir.path().join("tests/integration/api_test.rs"),
619            "use crate::api::ApiClient;\n#[test]\nfn test_api() {}",
620        )
621        .unwrap();
622
623        let context = create_test_context(&temp_dir);
624
625        // Test pattern that should match only src/**/*.rs files
626        let tool = GrepTool;
627        let params = GrepParams {
628            pattern: "pub".to_string(),
629            include: Some("src/**/*.rs".to_string()),
630            path: None,
631        };
632        let params_json = serde_json::to_value(params).unwrap();
633
634        let result = tool.execute(params_json, &context).await.unwrap();
635
636        // Should find matches in src/ but not in tests/
637        assert!(result.0.matches.len() >= 3);
638        assert!(
639            result
640                .0
641                .matches
642                .iter()
643                .any(|m| m.file_path.contains("src/api/client/mod.rs")
644                    && m.line_number == 1
645                    && m.line_content == "pub mod client;")
646        );
647        assert!(
648            result
649                .0
650                .matches
651                .iter()
652                .any(|m| m.file_path.contains("src/api/client/mod.rs")
653                    && m.line_number == 2
654                    && m.line_content == "pub use client::ApiClient;")
655        );
656        assert!(
657            result
658                .0
659                .matches
660                .iter()
661                .any(|m| m.file_path.contains("src/tools/grep.rs")
662                    && m.line_number == 1
663                    && m.line_content == "pub struct GrepTool;")
664        );
665        assert!(
666            !result
667                .0
668                .matches
669                .iter()
670                .any(|m| m.file_path.contains("tests/"))
671        );
672        assert!(result.0.search_completed);
673    }
674
675    #[tokio::test]
676    async fn test_grep_canonicalized_paths() {
677        let temp_dir = tempdir().unwrap();
678
679        // Create a test file
680        fs::write(
681            temp_dir.path().join("test.txt"),
682            "line one\nfind this line\nline three",
683        )
684        .unwrap();
685
686        let context = create_test_context(&temp_dir);
687
688        let tool = GrepTool;
689        // Use "." as the path to simulate the issue
690        let params = GrepParams {
691            pattern: "find this".to_string(),
692            include: None,
693            path: Some(".".to_string()),
694        };
695        let params_json = serde_json::to_value(params).unwrap();
696
697        let result = tool.execute(params_json, &context).await.unwrap();
698
699        // The result should contain a canonicalized path without "./"
700        assert_eq!(result.0.matches.len(), 1);
701        assert_eq!(result.0.matches[0].line_number, 2);
702        assert_eq!(result.0.matches[0].line_content, "find this line");
703        // Ensure no "./" appears in the path
704        assert!(!result.0.matches[0].file_path.contains("./"));
705        assert!(result.0.matches[0].file_path.contains("test.txt"));
706        assert!(result.0.search_completed);
707
708        // The path should be absolute and canonical
709        let canonical_path = temp_dir.path().join("test.txt").canonicalize().unwrap();
710        assert_eq!(
711            result.0.matches[0].file_path,
712            canonical_path.display().to_string()
713        );
714    }
715}