Skip to main content

soul_coder/tools/
grep.rs

1//! Grep tool — search file contents using regex or literal patterns.
2//!
3//! Uses VirtualFs for WASM compatibility. In WASM mode, performs regex search
4//! over all files in the VFS. In native mode, can delegate to ripgrep via VirtualExecutor.
5
6use std::sync::Arc;
7
8use async_trait::async_trait;
9use serde_json::json;
10use tokio::sync::mpsc;
11
12use soul_core::error::SoulResult;
13use soul_core::tool::{Tool, ToolOutput};
14use soul_core::types::ToolDefinition;
15use soul_core::vfs::VirtualFs;
16
17use crate::truncate::{truncate_head, truncate_line, GREP_MAX_LINE_LENGTH, MAX_BYTES};
18
19/// Maximum number of matches returned.
20const MAX_MATCHES: usize = 100;
21
22use super::resolve_path;
23
24pub struct GrepTool {
25    fs: Arc<dyn VirtualFs>,
26    cwd: String,
27}
28
29impl GrepTool {
30    pub fn new(fs: Arc<dyn VirtualFs>, cwd: impl Into<String>) -> Self {
31        Self {
32            fs,
33            cwd: cwd.into(),
34        }
35    }
36}
37
38/// Simple pattern matching (supports literal and basic regex via contains).
39fn matches_pattern(line: &str, pattern: &str, literal: bool, ignore_case: bool) -> bool {
40    if literal {
41        if ignore_case {
42            line.to_lowercase().contains(&pattern.to_lowercase())
43        } else {
44            line.contains(pattern)
45        }
46    } else {
47        // Basic regex-like: treat as literal for WASM (no regex crate dependency)
48        // For full regex, the native implementation delegates to rg
49        if ignore_case {
50            line.to_lowercase().contains(&pattern.to_lowercase())
51        } else {
52            line.contains(pattern)
53        }
54    }
55}
56
57/// Recursively collect all file paths from a VFS directory.
58async fn collect_files(
59    fs: &dyn VirtualFs,
60    dir: &str,
61    files: &mut Vec<String>,
62    glob_filter: Option<&str>,
63) -> SoulResult<()> {
64    let entries = fs.read_dir(dir).await?;
65    for entry in entries {
66        let path = if dir == "/" || dir.is_empty() {
67            format!("/{}", entry.name)
68        } else {
69            format!("{}/{}", dir.trim_end_matches('/'), entry.name)
70        };
71
72        if entry.is_dir {
73            // Skip hidden dirs
74            if !entry.name.starts_with('.') {
75                Box::pin(collect_files(fs, &path, files, glob_filter)).await?;
76            }
77        } else if entry.is_file {
78            if let Some(glob) = glob_filter {
79                if matches_glob(&entry.name, glob) {
80                    files.push(path);
81                }
82            } else {
83                files.push(path);
84            }
85        }
86    }
87    Ok(())
88}
89
90/// Simple glob matching (supports *.ext patterns).
91fn matches_glob(filename: &str, glob: &str) -> bool {
92    if glob.starts_with("*.") {
93        let ext = &glob[1..]; // ".ext"
94        filename.ends_with(ext)
95    } else if glob.contains('*') {
96        // Very basic wildcard
97        let parts: Vec<&str> = glob.split('*').collect();
98        if parts.len() == 2 {
99            filename.starts_with(parts[0]) && filename.ends_with(parts[1])
100        } else {
101            true // No filtering
102        }
103    } else {
104        filename == glob
105    }
106}
107
108#[async_trait]
109impl Tool for GrepTool {
110    fn name(&self) -> &str {
111        "grep"
112    }
113
114    fn definition(&self) -> ToolDefinition {
115        ToolDefinition {
116            name: "grep".into(),
117            description: "Search file contents for a pattern. Returns matching lines with file paths and line numbers.".into(),
118            input_schema: json!({
119                "type": "object",
120                "properties": {
121                    "pattern": {
122                        "type": "string",
123                        "description": "Search pattern (literal string or regex)"
124                    },
125                    "path": {
126                        "type": "string",
127                        "description": "Directory to search in (defaults to working directory)"
128                    },
129                    "glob": {
130                        "type": "string",
131                        "description": "Glob pattern to filter files (e.g., '*.rs', '*.ts')"
132                    },
133                    "ignore_case": {
134                        "type": "boolean",
135                        "description": "Case-insensitive search"
136                    },
137                    "literal": {
138                        "type": "boolean",
139                        "description": "Treat pattern as literal string (no regex)"
140                    },
141                    "context": {
142                        "type": "integer",
143                        "description": "Number of context lines before and after each match"
144                    },
145                    "max_matches": {
146                        "type": "integer",
147                        "description": "Maximum number of matches to return (default: 100)"
148                    }
149                },
150                "required": ["pattern"]
151            }),
152        }
153    }
154
155    async fn execute(
156        &self,
157        _call_id: &str,
158        arguments: serde_json::Value,
159        _partial_tx: Option<mpsc::UnboundedSender<String>>,
160    ) -> SoulResult<ToolOutput> {
161        let pattern = arguments
162            .get("pattern")
163            .and_then(|v| v.as_str())
164            .unwrap_or("");
165
166        if pattern.is_empty() {
167            return Ok(ToolOutput::error("Missing required parameter: pattern"));
168        }
169
170        let search_path = arguments
171            .get("path")
172            .and_then(|v| v.as_str())
173            .map(|p| resolve_path(&self.cwd, p))
174            .unwrap_or_else(|| self.cwd.clone());
175
176        let glob_filter = arguments.get("glob").and_then(|v| v.as_str());
177        let ignore_case = arguments
178            .get("ignore_case")
179            .and_then(|v| v.as_bool())
180            .unwrap_or(false);
181        let literal = arguments
182            .get("literal")
183            .and_then(|v| v.as_bool())
184            .unwrap_or(false);
185        let context_lines = arguments
186            .get("context")
187            .and_then(|v| v.as_u64())
188            .unwrap_or(0) as usize;
189        let max_matches = arguments
190            .get("max_matches")
191            .and_then(|v| v.as_u64())
192            .map(|v| (v as usize).min(MAX_MATCHES))
193            .unwrap_or(MAX_MATCHES);
194
195        // Collect files to search
196        let mut files = Vec::new();
197        if let Err(e) = collect_files(self.fs.as_ref(), &search_path, &mut files, glob_filter).await
198        {
199            return Ok(ToolOutput::error(format!(
200                "Failed to enumerate files in {}: {}",
201                search_path, e
202            )));
203        }
204
205        files.sort();
206
207        let mut output = String::new();
208        let mut total_matches = 0;
209        let mut files_with_matches = 0;
210
211        'files: for file_path in &files {
212            let content = match self.fs.read_to_string(file_path).await {
213                Ok(c) => c,
214                Err(_) => continue, // Skip unreadable files
215            };
216
217            let lines: Vec<&str> = content.lines().collect();
218            let mut file_had_match = false;
219
220            for (line_idx, line) in lines.iter().enumerate() {
221                if matches_pattern(line, pattern, literal, ignore_case) {
222                    if !file_had_match {
223                        if !output.is_empty() {
224                            output.push('\n');
225                        }
226                        files_with_matches += 1;
227                        file_had_match = true;
228                    }
229
230                    // Context before
231                    let ctx_start = line_idx.saturating_sub(context_lines);
232                    for ctx_idx in ctx_start..line_idx {
233                        output.push_str(&format!(
234                            "{}:{}-{}\n",
235                            display_path(file_path, &self.cwd),
236                            ctx_idx + 1,
237                            truncate_line(lines[ctx_idx], GREP_MAX_LINE_LENGTH)
238                        ));
239                    }
240
241                    // Match line
242                    output.push_str(&format!(
243                        "{}:{}:{}\n",
244                        display_path(file_path, &self.cwd),
245                        line_idx + 1,
246                        truncate_line(line, GREP_MAX_LINE_LENGTH)
247                    ));
248
249                    // Context after
250                    let ctx_end = (line_idx + context_lines + 1).min(lines.len());
251                    for ctx_idx in (line_idx + 1)..ctx_end {
252                        output.push_str(&format!(
253                            "{}:{}-{}\n",
254                            display_path(file_path, &self.cwd),
255                            ctx_idx + 1,
256                            truncate_line(lines[ctx_idx], GREP_MAX_LINE_LENGTH)
257                        ));
258                    }
259
260                    total_matches += 1;
261                    if total_matches >= max_matches {
262                        break 'files;
263                    }
264                }
265            }
266        }
267
268        if total_matches == 0 {
269            return Ok(ToolOutput::success(format!(
270                "No matches found for pattern '{}' in {}",
271                pattern,
272                display_path(&search_path, &self.cwd)
273            ))
274            .with_metadata(json!({"matches": 0, "files": 0})));
275        }
276
277        // Apply byte truncation
278        let truncated = truncate_head(&output, total_matches + (total_matches * context_lines * 2), MAX_BYTES);
279
280        let notice = truncated.truncation_notice();
281        let is_truncated = truncated.is_truncated();
282        let mut result = truncated.content;
283        if total_matches >= max_matches {
284            result.push_str(&format!(
285                "\n[Reached max matches limit: {}]",
286                max_matches
287            ));
288        }
289        if let Some(notice) = notice {
290            result.push_str(&format!("\n{}", notice));
291        }
292
293        Ok(ToolOutput::success(result).with_metadata(json!({
294            "matches": total_matches,
295            "files_with_matches": files_with_matches,
296            "truncated": is_truncated,
297        })))
298    }
299}
300
301/// Make paths relative to cwd for display.
302fn display_path(path: &str, cwd: &str) -> String {
303    let cwd_prefix = format!("{}/", cwd.trim_end_matches('/'));
304    if path.starts_with(&cwd_prefix) {
305        path[cwd_prefix.len()..].to_string()
306    } else {
307        path.to_string()
308    }
309}
310
311#[cfg(test)]
312mod tests {
313    use super::*;
314    use soul_core::vfs::MemoryFs;
315
316    async fn setup() -> (Arc<MemoryFs>, GrepTool) {
317        let fs = Arc::new(MemoryFs::new());
318        let tool = GrepTool::new(fs.clone() as Arc<dyn VirtualFs>, "/project");
319        (fs, tool)
320    }
321
322    #[tokio::test]
323    async fn grep_simple_match() {
324        let (fs, tool) = setup().await;
325        fs.write("/project/file.txt", "hello world\nfoo bar\nhello again")
326            .await
327            .unwrap();
328
329        let result = tool
330            .execute("c1", json!({"pattern": "hello"}), None)
331            .await
332            .unwrap();
333
334        assert!(!result.is_error);
335        assert!(result.content.contains("file.txt:1:hello world"));
336        assert!(result.content.contains("file.txt:3:hello again"));
337    }
338
339    #[tokio::test]
340    async fn grep_case_insensitive() {
341        let (fs, tool) = setup().await;
342        fs.write("/project/file.txt", "Hello World\nhello world")
343            .await
344            .unwrap();
345
346        let result = tool
347            .execute(
348                "c2",
349                json!({"pattern": "HELLO", "ignore_case": true}),
350                None,
351            )
352            .await
353            .unwrap();
354
355        assert!(!result.is_error);
356        assert!(result.metadata["matches"].as_u64().unwrap() == 2);
357    }
358
359    #[tokio::test]
360    async fn grep_with_glob_filter() {
361        let (fs, tool) = setup().await;
362        fs.write("/project/code.rs", "fn main() {}")
363            .await
364            .unwrap();
365        fs.write("/project/readme.md", "fn main() {}")
366            .await
367            .unwrap();
368
369        let result = tool
370            .execute(
371                "c3",
372                json!({"pattern": "fn main", "glob": "*.rs"}),
373                None,
374            )
375            .await
376            .unwrap();
377
378        assert!(!result.is_error);
379        assert!(result.content.contains("code.rs"));
380        assert!(!result.content.contains("readme.md"));
381    }
382
383    #[tokio::test]
384    async fn grep_no_matches() {
385        let (fs, tool) = setup().await;
386        fs.write("/project/file.txt", "nothing here")
387            .await
388            .unwrap();
389
390        let result = tool
391            .execute("c4", json!({"pattern": "missing"}), None)
392            .await
393            .unwrap();
394
395        assert!(!result.is_error);
396        assert!(result.content.contains("No matches"));
397    }
398
399    #[tokio::test]
400    async fn grep_empty_pattern() {
401        let (_fs, tool) = setup().await;
402        let result = tool
403            .execute("c5", json!({"pattern": ""}), None)
404            .await
405            .unwrap();
406        assert!(result.is_error);
407    }
408
409    #[tokio::test]
410    async fn grep_with_context() {
411        let (fs, tool) = setup().await;
412        fs.write("/project/file.txt", "a\nb\nc\nd\ne")
413            .await
414            .unwrap();
415
416        let result = tool
417            .execute(
418                "c6",
419                json!({"pattern": "c", "context": 1}),
420                None,
421            )
422            .await
423            .unwrap();
424
425        assert!(!result.is_error);
426        assert!(result.content.contains("b")); // before context
427        assert!(result.content.contains("d")); // after context
428    }
429
430    #[test]
431    fn glob_matching() {
432        assert!(matches_glob("file.rs", "*.rs"));
433        assert!(!matches_glob("file.ts", "*.rs"));
434        assert!(matches_glob("test.spec.ts", "*.ts"));
435    }
436
437    #[test]
438    fn display_path_relative() {
439        assert_eq!(display_path("/project/src/main.rs", "/project"), "src/main.rs");
440        assert_eq!(display_path("/other/file.txt", "/project"), "/other/file.txt");
441    }
442
443    #[tokio::test]
444    async fn tool_name_and_definition() {
445        let (_fs, tool) = setup().await;
446        assert_eq!(tool.name(), "grep");
447        let def = tool.definition();
448        assert_eq!(def.name, "grep");
449    }
450}