1use std::sync::Arc;
7
8use async_trait::async_trait;
9use serde_json::json;
10use tokio::sync::mpsc;
11
12use soul_core::error::SoulResult;
13use soul_core::tool::{Tool, ToolOutput};
14use soul_core::types::ToolDefinition;
15use soul_core::vfs::VirtualFs;
16
17use crate::truncate::{truncate_head, truncate_line, GREP_MAX_LINE_LENGTH, MAX_BYTES};
18
19const MAX_MATCHES: usize = 100;
21
22use super::resolve_path;
23
24pub struct GrepTool {
25 fs: Arc<dyn VirtualFs>,
26 cwd: String,
27}
28
29impl GrepTool {
30 pub fn new(fs: Arc<dyn VirtualFs>, cwd: impl Into<String>) -> Self {
31 Self {
32 fs,
33 cwd: cwd.into(),
34 }
35 }
36}
37
38fn matches_pattern(line: &str, pattern: &str, literal: bool, ignore_case: bool) -> bool {
40 if literal {
41 if ignore_case {
42 line.to_lowercase().contains(&pattern.to_lowercase())
43 } else {
44 line.contains(pattern)
45 }
46 } else {
47 if ignore_case {
50 line.to_lowercase().contains(&pattern.to_lowercase())
51 } else {
52 line.contains(pattern)
53 }
54 }
55}
56
57async fn collect_files(
59 fs: &dyn VirtualFs,
60 dir: &str,
61 files: &mut Vec<String>,
62 glob_filter: Option<&str>,
63) -> SoulResult<()> {
64 let entries = fs.read_dir(dir).await?;
65 for entry in entries {
66 let path = if dir == "/" || dir.is_empty() {
67 format!("/{}", entry.name)
68 } else {
69 format!("{}/{}", dir.trim_end_matches('/'), entry.name)
70 };
71
72 if entry.is_dir {
73 if !entry.name.starts_with('.') {
75 Box::pin(collect_files(fs, &path, files, glob_filter)).await?;
76 }
77 } else if entry.is_file {
78 if let Some(glob) = glob_filter {
79 if matches_glob(&entry.name, glob) {
80 files.push(path);
81 }
82 } else {
83 files.push(path);
84 }
85 }
86 }
87 Ok(())
88}
89
90fn matches_glob(filename: &str, glob: &str) -> bool {
92 if glob.starts_with("*.") {
93 let ext = &glob[1..]; filename.ends_with(ext)
95 } else if glob.contains('*') {
96 let parts: Vec<&str> = glob.split('*').collect();
98 if parts.len() == 2 {
99 filename.starts_with(parts[0]) && filename.ends_with(parts[1])
100 } else {
101 true }
103 } else {
104 filename == glob
105 }
106}
107
108#[async_trait]
109impl Tool for GrepTool {
110 fn name(&self) -> &str {
111 "grep"
112 }
113
114 fn definition(&self) -> ToolDefinition {
115 ToolDefinition {
116 name: "grep".into(),
117 description: "Search file contents for a pattern. Returns matching lines with file paths and line numbers.".into(),
118 input_schema: json!({
119 "type": "object",
120 "properties": {
121 "pattern": {
122 "type": "string",
123 "description": "Search pattern (literal string or regex)"
124 },
125 "path": {
126 "type": "string",
127 "description": "Directory to search in (defaults to working directory)"
128 },
129 "glob": {
130 "type": "string",
131 "description": "Glob pattern to filter files (e.g., '*.rs', '*.ts')"
132 },
133 "ignore_case": {
134 "type": "boolean",
135 "description": "Case-insensitive search"
136 },
137 "literal": {
138 "type": "boolean",
139 "description": "Treat pattern as literal string (no regex)"
140 },
141 "context": {
142 "type": "integer",
143 "description": "Number of context lines before and after each match"
144 },
145 "max_matches": {
146 "type": "integer",
147 "description": "Maximum number of matches to return (default: 100)"
148 }
149 },
150 "required": ["pattern"]
151 }),
152 }
153 }
154
155 async fn execute(
156 &self,
157 _call_id: &str,
158 arguments: serde_json::Value,
159 _partial_tx: Option<mpsc::UnboundedSender<String>>,
160 ) -> SoulResult<ToolOutput> {
161 let pattern = arguments
162 .get("pattern")
163 .and_then(|v| v.as_str())
164 .unwrap_or("");
165
166 if pattern.is_empty() {
167 return Ok(ToolOutput::error("Missing required parameter: pattern"));
168 }
169
170 let search_path = arguments
171 .get("path")
172 .and_then(|v| v.as_str())
173 .map(|p| resolve_path(&self.cwd, p))
174 .unwrap_or_else(|| self.cwd.clone());
175
176 let glob_filter = arguments.get("glob").and_then(|v| v.as_str());
177 let ignore_case = arguments
178 .get("ignore_case")
179 .and_then(|v| v.as_bool())
180 .unwrap_or(false);
181 let literal = arguments
182 .get("literal")
183 .and_then(|v| v.as_bool())
184 .unwrap_or(false);
185 let context_lines = arguments
186 .get("context")
187 .and_then(|v| v.as_u64())
188 .unwrap_or(0) as usize;
189 let max_matches = arguments
190 .get("max_matches")
191 .and_then(|v| v.as_u64())
192 .map(|v| (v as usize).min(MAX_MATCHES))
193 .unwrap_or(MAX_MATCHES);
194
195 let mut files = Vec::new();
197 if let Err(e) = collect_files(self.fs.as_ref(), &search_path, &mut files, glob_filter).await
198 {
199 return Ok(ToolOutput::error(format!(
200 "Failed to enumerate files in {}: {}",
201 search_path, e
202 )));
203 }
204
205 files.sort();
206
207 let mut output = String::new();
208 let mut total_matches = 0;
209 let mut files_with_matches = 0;
210
211 'files: for file_path in &files {
212 let content = match self.fs.read_to_string(file_path).await {
213 Ok(c) => c,
214 Err(_) => continue, };
216
217 let lines: Vec<&str> = content.lines().collect();
218 let mut file_had_match = false;
219
220 for (line_idx, line) in lines.iter().enumerate() {
221 if matches_pattern(line, pattern, literal, ignore_case) {
222 if !file_had_match {
223 if !output.is_empty() {
224 output.push('\n');
225 }
226 files_with_matches += 1;
227 file_had_match = true;
228 }
229
230 let ctx_start = line_idx.saturating_sub(context_lines);
232 for ctx_idx in ctx_start..line_idx {
233 output.push_str(&format!(
234 "{}:{}-{}\n",
235 display_path(file_path, &self.cwd),
236 ctx_idx + 1,
237 truncate_line(lines[ctx_idx], GREP_MAX_LINE_LENGTH)
238 ));
239 }
240
241 output.push_str(&format!(
243 "{}:{}:{}\n",
244 display_path(file_path, &self.cwd),
245 line_idx + 1,
246 truncate_line(line, GREP_MAX_LINE_LENGTH)
247 ));
248
249 let ctx_end = (line_idx + context_lines + 1).min(lines.len());
251 for ctx_idx in (line_idx + 1)..ctx_end {
252 output.push_str(&format!(
253 "{}:{}-{}\n",
254 display_path(file_path, &self.cwd),
255 ctx_idx + 1,
256 truncate_line(lines[ctx_idx], GREP_MAX_LINE_LENGTH)
257 ));
258 }
259
260 total_matches += 1;
261 if total_matches >= max_matches {
262 break 'files;
263 }
264 }
265 }
266 }
267
268 if total_matches == 0 {
269 return Ok(ToolOutput::success(format!(
270 "No matches found for pattern '{}' in {}",
271 pattern,
272 display_path(&search_path, &self.cwd)
273 ))
274 .with_metadata(json!({"matches": 0, "files": 0})));
275 }
276
277 let truncated = truncate_head(&output, total_matches + (total_matches * context_lines * 2), MAX_BYTES);
279
280 let notice = truncated.truncation_notice();
281 let is_truncated = truncated.is_truncated();
282 let mut result = truncated.content;
283 if total_matches >= max_matches {
284 result.push_str(&format!(
285 "\n[Reached max matches limit: {}]",
286 max_matches
287 ));
288 }
289 if let Some(notice) = notice {
290 result.push_str(&format!("\n{}", notice));
291 }
292
293 Ok(ToolOutput::success(result).with_metadata(json!({
294 "matches": total_matches,
295 "files_with_matches": files_with_matches,
296 "truncated": is_truncated,
297 })))
298 }
299}
300
301fn display_path(path: &str, cwd: &str) -> String {
303 let cwd_prefix = format!("{}/", cwd.trim_end_matches('/'));
304 if path.starts_with(&cwd_prefix) {
305 path[cwd_prefix.len()..].to_string()
306 } else {
307 path.to_string()
308 }
309}
310
311#[cfg(test)]
312mod tests {
313 use super::*;
314 use soul_core::vfs::MemoryFs;
315
316 async fn setup() -> (Arc<MemoryFs>, GrepTool) {
317 let fs = Arc::new(MemoryFs::new());
318 let tool = GrepTool::new(fs.clone() as Arc<dyn VirtualFs>, "/project");
319 (fs, tool)
320 }
321
322 #[tokio::test]
323 async fn grep_simple_match() {
324 let (fs, tool) = setup().await;
325 fs.write("/project/file.txt", "hello world\nfoo bar\nhello again")
326 .await
327 .unwrap();
328
329 let result = tool
330 .execute("c1", json!({"pattern": "hello"}), None)
331 .await
332 .unwrap();
333
334 assert!(!result.is_error);
335 assert!(result.content.contains("file.txt:1:hello world"));
336 assert!(result.content.contains("file.txt:3:hello again"));
337 }
338
339 #[tokio::test]
340 async fn grep_case_insensitive() {
341 let (fs, tool) = setup().await;
342 fs.write("/project/file.txt", "Hello World\nhello world")
343 .await
344 .unwrap();
345
346 let result = tool
347 .execute(
348 "c2",
349 json!({"pattern": "HELLO", "ignore_case": true}),
350 None,
351 )
352 .await
353 .unwrap();
354
355 assert!(!result.is_error);
356 assert!(result.metadata["matches"].as_u64().unwrap() == 2);
357 }
358
359 #[tokio::test]
360 async fn grep_with_glob_filter() {
361 let (fs, tool) = setup().await;
362 fs.write("/project/code.rs", "fn main() {}")
363 .await
364 .unwrap();
365 fs.write("/project/readme.md", "fn main() {}")
366 .await
367 .unwrap();
368
369 let result = tool
370 .execute(
371 "c3",
372 json!({"pattern": "fn main", "glob": "*.rs"}),
373 None,
374 )
375 .await
376 .unwrap();
377
378 assert!(!result.is_error);
379 assert!(result.content.contains("code.rs"));
380 assert!(!result.content.contains("readme.md"));
381 }
382
383 #[tokio::test]
384 async fn grep_no_matches() {
385 let (fs, tool) = setup().await;
386 fs.write("/project/file.txt", "nothing here")
387 .await
388 .unwrap();
389
390 let result = tool
391 .execute("c4", json!({"pattern": "missing"}), None)
392 .await
393 .unwrap();
394
395 assert!(!result.is_error);
396 assert!(result.content.contains("No matches"));
397 }
398
399 #[tokio::test]
400 async fn grep_empty_pattern() {
401 let (_fs, tool) = setup().await;
402 let result = tool
403 .execute("c5", json!({"pattern": ""}), None)
404 .await
405 .unwrap();
406 assert!(result.is_error);
407 }
408
409 #[tokio::test]
410 async fn grep_with_context() {
411 let (fs, tool) = setup().await;
412 fs.write("/project/file.txt", "a\nb\nc\nd\ne")
413 .await
414 .unwrap();
415
416 let result = tool
417 .execute(
418 "c6",
419 json!({"pattern": "c", "context": 1}),
420 None,
421 )
422 .await
423 .unwrap();
424
425 assert!(!result.is_error);
426 assert!(result.content.contains("b")); assert!(result.content.contains("d")); }
429
430 #[test]
431 fn glob_matching() {
432 assert!(matches_glob("file.rs", "*.rs"));
433 assert!(!matches_glob("file.ts", "*.rs"));
434 assert!(matches_glob("test.spec.ts", "*.ts"));
435 }
436
437 #[test]
438 fn display_path_relative() {
439 assert_eq!(display_path("/project/src/main.rs", "/project"), "src/main.rs");
440 assert_eq!(display_path("/other/file.txt", "/project"), "/other/file.txt");
441 }
442
443 #[tokio::test]
444 async fn tool_name_and_definition() {
445 let (_fs, tool) = setup().await;
446 assert_eq!(tool.name(), "grep");
447 let def = tool.definition();
448 assert_eq!(def.name, "grep");
449 }
450}