1use rig::completion::ToolDefinition;
6use rig::tool::Tool;
7use serde::{Deserialize, Serialize};
8use serde_json::json;
9use std::fs;
10use std::path::PathBuf;
11use walkdir::WalkDir;
12use regex::Regex;
13
14#[derive(Debug, Deserialize)]
19pub struct SearchCodeArgs {
20 pub pattern: String,
22 pub path: Option<String>,
24 pub extension: Option<String>,
26 pub regex: Option<bool>,
28 pub case_insensitive: Option<bool>,
30 pub max_results: Option<usize>,
32}
33
34#[derive(Debug, thiserror::Error)]
35#[error("Search error: {0}")]
36pub struct SearchCodeError(String);
37
38#[derive(Debug, Clone, Serialize, Deserialize)]
39pub struct SearchCodeTool {
40 project_path: PathBuf,
41}
42
43impl SearchCodeTool {
44 pub fn new(project_path: PathBuf) -> Self {
45 Self { project_path }
46 }
47
48 fn should_skip_dir(name: &str) -> bool {
49 matches!(
50 name,
51 "node_modules"
52 | ".git"
53 | "target"
54 | "__pycache__"
55 | ".venv"
56 | "dist"
57 | "build"
58 | ".next"
59 | ".nuxt"
60 | "vendor"
61 | ".cache"
62 | "coverage"
63 )
64 }
65
66 fn is_text_file(path: &PathBuf) -> bool {
67 let text_extensions = [
68 "rs", "go", "js", "ts", "jsx", "tsx", "py", "java", "kt", "scala",
69 "rb", "php", "cs", "cpp", "c", "h", "hpp", "swift", "dart", "elm",
70 "clj", "hs", "ml", "r", "sh", "bash", "zsh", "ps1", "bat", "cmd",
71 "json", "yaml", "yml", "toml", "xml", "html", "css", "scss", "sass",
72 "less", "md", "txt", "sql", "graphql", "prisma", "env", "dockerfile",
73 "makefile", "cmake", "gradle", "sbt", "ex", "exs", "erl", "hrl",
74 ];
75
76 if let Some(ext) = path.extension().and_then(|e| e.to_str()) {
77 return text_extensions.contains(&ext.to_lowercase().as_str());
78 }
79
80 if let Some(name) = path.file_name().and_then(|n| n.to_str()) {
82 let lower = name.to_lowercase();
83 return matches!(lower.as_str(), "dockerfile" | "makefile" | "rakefile" | "gemfile" | "procfile" | "justfile");
84 }
85
86 false
87 }
88}
89
90#[derive(Debug, Serialize)]
91struct SearchMatch {
92 file: String,
93 line_number: usize,
94 line: String,
95 context_before: Vec<String>,
96 context_after: Vec<String>,
97}
98
99impl Tool for SearchCodeTool {
100 const NAME: &'static str = "search_code";
101
102 type Error = SearchCodeError;
103 type Args = SearchCodeArgs;
104 type Output = String;
105
106 async fn definition(&self, _prompt: String) -> ToolDefinition {
107 ToolDefinition {
108 name: Self::NAME.to_string(),
109 description: "Search for code patterns, function names, variables, or any text across the codebase. Returns matching lines with context. Use this to find where something is defined, used, or imported.".to_string(),
110 parameters: json!({
111 "type": "object",
112 "properties": {
113 "pattern": {
114 "type": "string",
115 "description": "Search pattern - can be a function name, variable, string literal, or regex pattern"
116 },
117 "path": {
118 "type": "string",
119 "description": "Optional subdirectory to search within (e.g., 'src', 'backend/api')"
120 },
121 "extension": {
122 "type": "string",
123 "description": "Filter by file extension (e.g., 'rs', 'ts', 'py'). Omit for all file types."
124 },
125 "regex": {
126 "type": "boolean",
127 "description": "Treat pattern as regex. Default: false (literal string match)"
128 },
129 "case_insensitive": {
130 "type": "boolean",
131 "description": "Case insensitive search. Default: true"
132 },
133 "max_results": {
134 "type": "integer",
135 "description": "Maximum results to return. Default: 50"
136 }
137 },
138 "required": ["pattern"]
139 }),
140 }
141 }
142
143 async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
144 let search_root = if let Some(ref subpath) = args.path {
145 self.project_path.join(subpath)
146 } else {
147 self.project_path.clone()
148 };
149
150 if !search_root.exists() {
151 return Err(SearchCodeError(format!(
152 "Path does not exist: {}",
153 args.path.unwrap_or_default()
154 )));
155 }
156
157 let case_insensitive = args.case_insensitive.unwrap_or(true);
158 let is_regex = args.regex.unwrap_or(false);
159 let max_results = args.max_results.unwrap_or(50);
160
161 let pattern_str = if is_regex {
163 if case_insensitive {
164 format!("(?i){}", args.pattern)
165 } else {
166 args.pattern.clone()
167 }
168 } else {
169 let escaped = regex::escape(&args.pattern);
170 if case_insensitive {
171 format!("(?i){}", escaped)
172 } else {
173 escaped
174 }
175 };
176
177 let regex = Regex::new(&pattern_str)
178 .map_err(|e| SearchCodeError(format!("Invalid pattern: {}", e)))?;
179
180 let mut matches: Vec<SearchMatch> = Vec::new();
181
182 for entry in WalkDir::new(&search_root)
183 .into_iter()
184 .filter_entry(|e| {
185 if e.file_type().is_dir() {
186 if let Some(name) = e.file_name().to_str() {
187 return !Self::should_skip_dir(name);
188 }
189 }
190 true
191 })
192 .filter_map(|e| e.ok())
193 {
194 if matches.len() >= max_results {
195 break;
196 }
197
198 let path = entry.path();
199 if !path.is_file() {
200 continue;
201 }
202
203 if let Some(ref ext_filter) = args.extension {
205 if let Some(ext) = path.extension().and_then(|e| e.to_str()) {
206 if ext.to_lowercase() != ext_filter.to_lowercase() {
207 continue;
208 }
209 } else {
210 continue;
211 }
212 }
213
214 let path_buf = path.to_path_buf();
216 if !Self::is_text_file(&path_buf) {
217 continue;
218 }
219
220 let content = match fs::read_to_string(path) {
222 Ok(c) => c,
223 Err(_) => continue, };
225
226 let lines: Vec<&str> = content.lines().collect();
227 for (line_idx, line) in lines.iter().enumerate() {
228 if matches.len() >= max_results {
229 break;
230 }
231
232 if regex.is_match(line) {
233 let relative_path = path
234 .strip_prefix(&self.project_path)
235 .unwrap_or(path)
236 .to_string_lossy()
237 .to_string();
238
239 let context_before = if line_idx > 0 {
241 vec![lines[line_idx - 1].to_string()]
242 } else {
243 vec![]
244 };
245
246 let context_after = if line_idx + 1 < lines.len() {
247 vec![lines[line_idx + 1].to_string()]
248 } else {
249 vec![]
250 };
251
252 matches.push(SearchMatch {
253 file: relative_path,
254 line_number: line_idx + 1,
255 line: line.to_string(),
256 context_before,
257 context_after,
258 });
259 }
260 }
261 }
262
263 let result = json!({
264 "pattern": args.pattern,
265 "total_matches": matches.len(),
266 "matches": matches,
267 "truncated": matches.len() >= max_results
268 });
269
270 serde_json::to_string_pretty(&result)
271 .map_err(|e| SearchCodeError(format!("Serialization error: {}", e)))
272 }
273}
274
275#[derive(Debug, Deserialize)]
280pub struct FindFilesArgs {
281 pub pattern: String,
283 pub path: Option<String>,
285 pub extension: Option<String>,
287 pub include_dirs: Option<bool>,
289 pub max_results: Option<usize>,
291}
292
293#[derive(Debug, thiserror::Error)]
294#[error("Find files error: {0}")]
295pub struct FindFilesError(String);
296
297#[derive(Debug, Clone, Serialize, Deserialize)]
298pub struct FindFilesTool {
299 project_path: PathBuf,
300}
301
302impl FindFilesTool {
303 pub fn new(project_path: PathBuf) -> Self {
304 Self { project_path }
305 }
306
307 fn matches_pattern(name: &str, pattern: &str) -> bool {
308 let pattern_lower = pattern.to_lowercase();
309 let name_lower = name.to_lowercase();
310
311 if pattern == "*" {
313 return true;
314 }
315
316 if pattern.contains('*') || pattern.contains('?') {
318 let regex_pattern = pattern_lower
319 .replace('.', r"\.")
320 .replace('*', ".*")
321 .replace('?', ".");
322
323 if let Ok(re) = Regex::new(&format!("^{}$", regex_pattern)) {
324 return re.is_match(&name_lower);
325 }
326 }
327
328 name_lower.contains(&pattern_lower)
330 }
331}
332
333#[derive(Debug, Serialize)]
334struct FileInfo {
335 name: String,
336 path: String,
337 file_type: String,
338 size: Option<u64>,
339 extension: Option<String>,
340}
341
342impl Tool for FindFilesTool {
343 const NAME: &'static str = "find_files";
344
345 type Error = FindFilesError;
346 type Args = FindFilesArgs;
347 type Output = String;
348
349 async fn definition(&self, _prompt: String) -> ToolDefinition {
350 ToolDefinition {
351 name: Self::NAME.to_string(),
352 description: "Find files by name pattern. Use wildcards (* for any characters, ? for single character). Great for locating config files, finding all files of a type, or discovering project structure.".to_string(),
353 parameters: json!({
354 "type": "object",
355 "properties": {
356 "pattern": {
357 "type": "string",
358 "description": "File name pattern with optional wildcards. Examples: 'package.json', '*.config.ts', 'Dockerfile*', 'api*.rs'"
359 },
360 "path": {
361 "type": "string",
362 "description": "Subdirectory to search in (e.g., 'src', 'backend')"
363 },
364 "extension": {
365 "type": "string",
366 "description": "Filter by extension (e.g., 'ts', 'rs', 'yaml')"
367 },
368 "include_dirs": {
369 "type": "boolean",
370 "description": "Include directories in results. Default: false"
371 },
372 "max_results": {
373 "type": "integer",
374 "description": "Maximum results. Default: 100"
375 }
376 },
377 "required": ["pattern"]
378 }),
379 }
380 }
381
382 async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
383 let search_root = if let Some(ref subpath) = args.path {
384 self.project_path.join(subpath)
385 } else {
386 self.project_path.clone()
387 };
388
389 if !search_root.exists() {
390 return Err(FindFilesError(format!(
391 "Path does not exist: {}",
392 args.path.unwrap_or_default()
393 )));
394 }
395
396 let include_dirs = args.include_dirs.unwrap_or(false);
397 let max_results = args.max_results.unwrap_or(100);
398 let skip_dirs = [
399 "node_modules", ".git", "target", "__pycache__", ".venv",
400 "dist", "build", ".next", ".nuxt", "vendor", ".cache", "coverage"
401 ];
402
403 let mut results: Vec<FileInfo> = Vec::new();
404
405 for entry in WalkDir::new(&search_root)
406 .into_iter()
407 .filter_entry(|e| {
408 if e.file_type().is_dir() {
409 if let Some(name) = e.file_name().to_str() {
410 return !skip_dirs.contains(&name);
411 }
412 }
413 true
414 })
415 .filter_map(|e| e.ok())
416 {
417 if results.len() >= max_results {
418 break;
419 }
420
421 let path = entry.path();
422 let is_dir = path.is_dir();
423
424 if is_dir && !include_dirs {
426 continue;
427 }
428
429 let file_name = match path.file_name().and_then(|n| n.to_str()) {
430 Some(n) => n,
431 None => continue,
432 };
433
434 if let Some(ref ext_filter) = args.extension {
436 if let Some(ext) = path.extension().and_then(|e| e.to_str()) {
437 if ext.to_lowercase() != ext_filter.to_lowercase() {
438 continue;
439 }
440 } else {
441 continue;
442 }
443 }
444
445 if !Self::matches_pattern(file_name, &args.pattern) {
447 continue;
448 }
449
450 let relative_path = path
451 .strip_prefix(&self.project_path)
452 .unwrap_or(path)
453 .to_string_lossy()
454 .to_string();
455
456 let metadata = path.metadata().ok();
457 let size = if is_dir { None } else { metadata.as_ref().map(|m| m.len()) };
458
459 results.push(FileInfo {
460 name: file_name.to_string(),
461 path: relative_path,
462 file_type: if is_dir { "directory".to_string() } else { "file".to_string() },
463 size,
464 extension: path.extension().and_then(|e| e.to_str()).map(|s| s.to_string()),
465 });
466 }
467
468 let result = json!({
469 "pattern": args.pattern,
470 "total_found": results.len(),
471 "files": results,
472 "truncated": results.len() >= max_results
473 });
474
475 serde_json::to_string_pretty(&result)
476 .map_err(|e| FindFilesError(format!("Serialization error: {}", e)))
477 }
478}