1use std::sync::Arc;
7
8use async_trait::async_trait;
9use serde_json::json;
10use tokio::sync::mpsc;
11
12use soul_core::error::SoulResult;
13use soul_core::tool::{Tool, ToolOutput};
14use soul_core::types::ToolDefinition;
15use soul_core::vfs::VirtualFs;
16
17use crate::truncate::{truncate_head, truncate_line, GREP_MAX_LINE_LENGTH, MAX_BYTES};
18
19const MAX_MATCHES: usize = 100;
21
22use super::resolve_path;
23
24pub struct GrepTool {
25 fs: Arc<dyn VirtualFs>,
26 cwd: String,
27}
28
29impl GrepTool {
30 pub fn new(fs: Arc<dyn VirtualFs>, cwd: impl Into<String>) -> Self {
31 Self {
32 fs,
33 cwd: cwd.into(),
34 }
35 }
36}
37
38fn matches_pattern(line: &str, pattern: &str, literal: bool, ignore_case: bool) -> bool {
40 if literal {
41 if ignore_case {
42 line.to_lowercase().contains(&pattern.to_lowercase())
43 } else {
44 line.contains(pattern)
45 }
46 } else {
47 if ignore_case {
50 line.to_lowercase().contains(&pattern.to_lowercase())
51 } else {
52 line.contains(pattern)
53 }
54 }
55}
56
57async fn collect_files(
59 fs: &dyn VirtualFs,
60 dir: &str,
61 files: &mut Vec<String>,
62 glob_filter: Option<&str>,
63) -> SoulResult<()> {
64 let entries = fs.read_dir(dir).await?;
65 for entry in entries {
66 let path = if dir == "/" || dir.is_empty() {
67 format!("/{}", entry.name)
68 } else {
69 format!("{}/{}", dir.trim_end_matches('/'), entry.name)
70 };
71
72 if entry.is_dir {
73 if !entry.name.starts_with('.') {
75 Box::pin(collect_files(fs, &path, files, glob_filter)).await?;
76 }
77 } else if entry.is_file {
78 if let Some(glob) = glob_filter {
79 if matches_glob(&entry.name, glob) {
80 files.push(path);
81 }
82 } else {
83 files.push(path);
84 }
85 }
86 }
87 Ok(())
88}
89
90fn matches_glob(filename: &str, glob: &str) -> bool {
92 if glob.starts_with("*.") {
93 let ext = &glob[1..]; filename.ends_with(ext)
95 } else if glob.contains('*') {
96 let parts: Vec<&str> = glob.split('*').collect();
98 if parts.len() == 2 {
99 filename.starts_with(parts[0]) && filename.ends_with(parts[1])
100 } else {
101 true }
103 } else {
104 filename == glob
105 }
106}
107
108#[cfg_attr(not(target_arch = "wasm32"), async_trait)]
109#[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
110impl Tool for GrepTool {
111 fn name(&self) -> &str {
112 "grep"
113 }
114
115 fn definition(&self) -> ToolDefinition {
116 ToolDefinition {
117 name: "grep".into(),
118 description: "Search file contents for a pattern. Returns matching lines with file paths and line numbers.".into(),
119 input_schema: json!({
120 "type": "object",
121 "properties": {
122 "pattern": {
123 "type": "string",
124 "description": "Search pattern (literal string or regex)"
125 },
126 "path": {
127 "type": "string",
128 "description": "Directory to search in (defaults to working directory)"
129 },
130 "glob": {
131 "type": "string",
132 "description": "Glob pattern to filter files (e.g., '*.rs', '*.ts')"
133 },
134 "ignore_case": {
135 "type": "boolean",
136 "description": "Case-insensitive search"
137 },
138 "literal": {
139 "type": "boolean",
140 "description": "Treat pattern as literal string (no regex)"
141 },
142 "context": {
143 "type": "integer",
144 "description": "Number of context lines before and after each match"
145 },
146 "max_matches": {
147 "type": "integer",
148 "description": "Maximum number of matches to return (default: 100)"
149 }
150 },
151 "required": ["pattern"]
152 }),
153 }
154 }
155
156 async fn execute(
157 &self,
158 _call_id: &str,
159 arguments: serde_json::Value,
160 _partial_tx: Option<mpsc::UnboundedSender<String>>,
161 ) -> SoulResult<ToolOutput> {
162 let pattern = arguments
163 .get("pattern")
164 .and_then(|v| v.as_str())
165 .unwrap_or("");
166
167 if pattern.is_empty() {
168 return Ok(ToolOutput::error("Missing required parameter: pattern"));
169 }
170
171 let search_path = arguments
172 .get("path")
173 .and_then(|v| v.as_str())
174 .map(|p| resolve_path(&self.cwd, p))
175 .unwrap_or_else(|| self.cwd.clone());
176
177 let glob_filter = arguments.get("glob").and_then(|v| v.as_str());
178 let ignore_case = arguments
179 .get("ignore_case")
180 .and_then(|v| v.as_bool())
181 .unwrap_or(false);
182 let literal = arguments
183 .get("literal")
184 .and_then(|v| v.as_bool())
185 .unwrap_or(false);
186 let context_lines = arguments
187 .get("context")
188 .and_then(|v| v.as_u64())
189 .unwrap_or(0) as usize;
190 let max_matches = arguments
191 .get("max_matches")
192 .and_then(|v| v.as_u64())
193 .map(|v| (v as usize).min(MAX_MATCHES))
194 .unwrap_or(MAX_MATCHES);
195
196 let mut files = Vec::new();
198 if let Err(e) = collect_files(self.fs.as_ref(), &search_path, &mut files, glob_filter).await
199 {
200 return Ok(ToolOutput::error(format!(
201 "Failed to enumerate files in {}: {}",
202 search_path, e
203 )));
204 }
205
206 files.sort();
207
208 let mut output = String::new();
209 let mut total_matches = 0;
210 let mut files_with_matches = 0;
211
212 'files: for file_path in &files {
213 let content = match self.fs.read_to_string(file_path).await {
214 Ok(c) => c,
215 Err(_) => continue, };
217
218 let lines: Vec<&str> = content.lines().collect();
219 let mut file_had_match = false;
220
221 for (line_idx, line) in lines.iter().enumerate() {
222 if matches_pattern(line, pattern, literal, ignore_case) {
223 if !file_had_match {
224 if !output.is_empty() {
225 output.push('\n');
226 }
227 files_with_matches += 1;
228 file_had_match = true;
229 }
230
231 let ctx_start = line_idx.saturating_sub(context_lines);
233 for ctx_idx in ctx_start..line_idx {
234 output.push_str(&format!(
235 "{}:{}-{}\n",
236 display_path(file_path, &self.cwd),
237 ctx_idx + 1,
238 truncate_line(lines[ctx_idx], GREP_MAX_LINE_LENGTH)
239 ));
240 }
241
242 output.push_str(&format!(
244 "{}:{}:{}\n",
245 display_path(file_path, &self.cwd),
246 line_idx + 1,
247 truncate_line(line, GREP_MAX_LINE_LENGTH)
248 ));
249
250 let ctx_end = (line_idx + context_lines + 1).min(lines.len());
252 for ctx_idx in (line_idx + 1)..ctx_end {
253 output.push_str(&format!(
254 "{}:{}-{}\n",
255 display_path(file_path, &self.cwd),
256 ctx_idx + 1,
257 truncate_line(lines[ctx_idx], GREP_MAX_LINE_LENGTH)
258 ));
259 }
260
261 total_matches += 1;
262 if total_matches >= max_matches {
263 break 'files;
264 }
265 }
266 }
267 }
268
269 if total_matches == 0 {
270 return Ok(ToolOutput::success(format!(
271 "No matches found for pattern '{}' in {}",
272 pattern,
273 display_path(&search_path, &self.cwd)
274 ))
275 .with_metadata(json!({"matches": 0, "files": 0})));
276 }
277
278 let truncated = truncate_head(&output, total_matches + (total_matches * context_lines * 2), MAX_BYTES);
280
281 let notice = truncated.truncation_notice();
282 let is_truncated = truncated.is_truncated();
283 let mut result = truncated.content;
284 if total_matches >= max_matches {
285 result.push_str(&format!(
286 "\n[Reached max matches limit: {}]",
287 max_matches
288 ));
289 }
290 if let Some(notice) = notice {
291 result.push_str(&format!("\n{}", notice));
292 }
293
294 Ok(ToolOutput::success(result).with_metadata(json!({
295 "matches": total_matches,
296 "files_with_matches": files_with_matches,
297 "truncated": is_truncated,
298 })))
299 }
300}
301
302fn display_path(path: &str, cwd: &str) -> String {
304 let cwd_prefix = format!("{}/", cwd.trim_end_matches('/'));
305 if path.starts_with(&cwd_prefix) {
306 path[cwd_prefix.len()..].to_string()
307 } else {
308 path.to_string()
309 }
310}
311
312#[cfg(test)]
313mod tests {
314 use super::*;
315 use soul_core::vfs::MemoryFs;
316
317 async fn setup() -> (Arc<MemoryFs>, GrepTool) {
318 let fs = Arc::new(MemoryFs::new());
319 let tool = GrepTool::new(fs.clone() as Arc<dyn VirtualFs>, "/project");
320 (fs, tool)
321 }
322
323 #[tokio::test]
324 async fn grep_simple_match() {
325 let (fs, tool) = setup().await;
326 fs.write("/project/file.txt", "hello world\nfoo bar\nhello again")
327 .await
328 .unwrap();
329
330 let result = tool
331 .execute("c1", json!({"pattern": "hello"}), None)
332 .await
333 .unwrap();
334
335 assert!(!result.is_error);
336 assert!(result.content.contains("file.txt:1:hello world"));
337 assert!(result.content.contains("file.txt:3:hello again"));
338 }
339
340 #[tokio::test]
341 async fn grep_case_insensitive() {
342 let (fs, tool) = setup().await;
343 fs.write("/project/file.txt", "Hello World\nhello world")
344 .await
345 .unwrap();
346
347 let result = tool
348 .execute(
349 "c2",
350 json!({"pattern": "HELLO", "ignore_case": true}),
351 None,
352 )
353 .await
354 .unwrap();
355
356 assert!(!result.is_error);
357 assert!(result.metadata["matches"].as_u64().unwrap() == 2);
358 }
359
360 #[tokio::test]
361 async fn grep_with_glob_filter() {
362 let (fs, tool) = setup().await;
363 fs.write("/project/code.rs", "fn main() {}")
364 .await
365 .unwrap();
366 fs.write("/project/readme.md", "fn main() {}")
367 .await
368 .unwrap();
369
370 let result = tool
371 .execute(
372 "c3",
373 json!({"pattern": "fn main", "glob": "*.rs"}),
374 None,
375 )
376 .await
377 .unwrap();
378
379 assert!(!result.is_error);
380 assert!(result.content.contains("code.rs"));
381 assert!(!result.content.contains("readme.md"));
382 }
383
384 #[tokio::test]
385 async fn grep_no_matches() {
386 let (fs, tool) = setup().await;
387 fs.write("/project/file.txt", "nothing here")
388 .await
389 .unwrap();
390
391 let result = tool
392 .execute("c4", json!({"pattern": "missing"}), None)
393 .await
394 .unwrap();
395
396 assert!(!result.is_error);
397 assert!(result.content.contains("No matches"));
398 }
399
400 #[tokio::test]
401 async fn grep_empty_pattern() {
402 let (_fs, tool) = setup().await;
403 let result = tool
404 .execute("c5", json!({"pattern": ""}), None)
405 .await
406 .unwrap();
407 assert!(result.is_error);
408 }
409
410 #[tokio::test]
411 async fn grep_with_context() {
412 let (fs, tool) = setup().await;
413 fs.write("/project/file.txt", "a\nb\nc\nd\ne")
414 .await
415 .unwrap();
416
417 let result = tool
418 .execute(
419 "c6",
420 json!({"pattern": "c", "context": 1}),
421 None,
422 )
423 .await
424 .unwrap();
425
426 assert!(!result.is_error);
427 assert!(result.content.contains("b")); assert!(result.content.contains("d")); }
430
431 #[test]
432 fn glob_matching() {
433 assert!(matches_glob("file.rs", "*.rs"));
434 assert!(!matches_glob("file.ts", "*.rs"));
435 assert!(matches_glob("test.spec.ts", "*.ts"));
436 }
437
438 #[test]
439 fn display_path_relative() {
440 assert_eq!(display_path("/project/src/main.rs", "/project"), "src/main.rs");
441 assert_eq!(display_path("/other/file.txt", "/project"), "/other/file.txt");
442 }
443
444 #[tokio::test]
445 async fn tool_name_and_definition() {
446 let (_fs, tool) = setup().await;
447 assert_eq!(tool.name(), "grep");
448 let def = tool.definition();
449 assert_eq!(def.name, "grep");
450 }
451}