1use std::sync::Arc;
7
8use async_trait::async_trait;
9use serde_json::json;
10use tokio::sync::mpsc;
11
12use soul_core::error::SoulResult;
13use soul_core::tool::{Tool, ToolOutput};
14use soul_core::types::ToolDefinition;
15use soul_core::vfs::VirtualFs;
16
17use crate::truncate::{truncate_head, MAX_BYTES};
18
19const MAX_RESULTS: usize = 1000;
21
22pub struct FindTool {
23 fs: Arc<dyn VirtualFs>,
24 cwd: String,
25}
26
27impl FindTool {
28 pub fn new(fs: Arc<dyn VirtualFs>, cwd: impl Into<String>) -> Self {
29 Self {
30 fs,
31 cwd: cwd.into(),
32 }
33 }
34
35 fn resolve_path(&self, path: &str) -> String {
36 if path.starts_with('/') {
37 path.to_string()
38 } else {
39 format!("{}/{}", self.cwd.trim_end_matches('/'), path)
40 }
41 }
42}
43
44fn matches_glob(name: &str, full_path: &str, pattern: &str) -> bool {
47 let pattern = pattern.trim();
48
49 if pattern.contains("**/") || pattern.contains("/**") {
51 let simple = pattern.replace("**/", "").replace("/**", "");
52 return matches_simple_glob(name, &simple) || matches_simple_glob(full_path, pattern);
53 }
54
55 if pattern.contains('/') {
57 return path_matches_glob(full_path, pattern);
58 }
59
60 matches_simple_glob(name, pattern)
61}
62
63fn matches_simple_glob(name: &str, pattern: &str) -> bool {
64 if pattern == "*" {
65 return true;
66 }
67
68 if pattern.starts_with("*.") {
69 let ext = &pattern[1..];
70 return name.ends_with(ext);
71 }
72
73 if pattern.starts_with('*') && pattern.ends_with('*') && pattern.len() > 2 {
74 let middle = &pattern[1..pattern.len() - 1];
75 return name.contains(middle);
76 }
77
78 if pattern.starts_with('*') {
79 let suffix = &pattern[1..];
80 return name.ends_with(suffix);
81 }
82
83 if pattern.ends_with('*') {
84 let prefix = &pattern[..pattern.len() - 1];
85 return name.starts_with(prefix);
86 }
87
88 name == pattern
89}
90
91fn path_matches_glob(path: &str, pattern: &str) -> bool {
92 let path_parts: Vec<&str> = path.split('/').filter(|s| !s.is_empty()).collect();
93 let pattern_parts: Vec<&str> = pattern.split('/').filter(|s| !s.is_empty()).collect();
94
95 if pattern_parts.is_empty() {
96 return true;
97 }
98
99 let mut pi = pattern_parts.len();
101 let mut qi = path_parts.len();
102
103 while pi > 0 && qi > 0 {
104 pi -= 1;
105 qi -= 1;
106 if pattern_parts[pi] == "**" {
107 return true; }
109 if !matches_simple_glob(path_parts[qi], pattern_parts[pi]) {
110 return false;
111 }
112 }
113
114 pi == 0
115}
116
117async fn find_files(
119 fs: &dyn VirtualFs,
120 dir: &str,
121 pattern: &str,
122 results: &mut Vec<String>,
123 limit: usize,
124) -> SoulResult<()> {
125 if results.len() >= limit {
126 return Ok(());
127 }
128
129 let entries = match fs.read_dir(dir).await {
130 Ok(e) => e,
131 Err(_) => return Ok(()), };
133
134 for entry in entries {
135 if results.len() >= limit {
136 break;
137 }
138
139 let path = if dir == "/" || dir.is_empty() {
140 format!("/{}", entry.name)
141 } else {
142 format!("{}/{}", dir.trim_end_matches('/'), entry.name)
143 };
144
145 if entry.is_dir {
146 if !entry.name.starts_with('.') {
147 Box::pin(find_files(fs, &path, pattern, results, limit)).await?;
148 }
149 } else if entry.is_file && matches_glob(&entry.name, &path, pattern) {
150 results.push(path);
151 }
152 }
153
154 Ok(())
155}
156
157#[async_trait]
158impl Tool for FindTool {
159 fn name(&self) -> &str {
160 "find"
161 }
162
163 fn definition(&self) -> ToolDefinition {
164 ToolDefinition {
165 name: "find".into(),
166 description: "Find files matching a glob pattern. Returns matching file paths.".into(),
167 input_schema: json!({
168 "type": "object",
169 "properties": {
170 "pattern": {
171 "type": "string",
172 "description": "Glob pattern to match files (e.g., '*.rs', 'src/**/*.ts', 'Cargo.toml')"
173 },
174 "path": {
175 "type": "string",
176 "description": "Directory to search in (defaults to working directory)"
177 },
178 "limit": {
179 "type": "integer",
180 "description": "Maximum number of results (default: 1000)"
181 }
182 },
183 "required": ["pattern"]
184 }),
185 }
186 }
187
188 async fn execute(
189 &self,
190 _call_id: &str,
191 arguments: serde_json::Value,
192 _partial_tx: Option<mpsc::UnboundedSender<String>>,
193 ) -> SoulResult<ToolOutput> {
194 let pattern = arguments
195 .get("pattern")
196 .and_then(|v| v.as_str())
197 .unwrap_or("");
198
199 if pattern.is_empty() {
200 return Ok(ToolOutput::error("Missing required parameter: pattern"));
201 }
202
203 let search_path = arguments
204 .get("path")
205 .and_then(|v| v.as_str())
206 .map(|p| self.resolve_path(p))
207 .unwrap_or_else(|| self.cwd.clone());
208
209 let limit = arguments
210 .get("limit")
211 .and_then(|v| v.as_u64())
212 .map(|v| (v as usize).min(MAX_RESULTS))
213 .unwrap_or(MAX_RESULTS);
214
215 let mut results = Vec::new();
216 if let Err(e) =
217 find_files(self.fs.as_ref(), &search_path, pattern, &mut results, limit).await
218 {
219 return Ok(ToolOutput::error(format!(
220 "Failed to search {}: {}",
221 search_path, e
222 )));
223 }
224
225 results.sort();
226
227 if results.is_empty() {
228 return Ok(ToolOutput::success(format!(
229 "No files matching '{}' found",
230 pattern
231 ))
232 .with_metadata(json!({"count": 0})));
233 }
234
235 let cwd_prefix = format!("{}/", self.cwd.trim_end_matches('/'));
237 let relative: Vec<String> = results
238 .iter()
239 .map(|p| {
240 if p.starts_with(&cwd_prefix) {
241 p[cwd_prefix.len()..].to_string()
242 } else {
243 p.clone()
244 }
245 })
246 .collect();
247
248 let output = relative.join("\n");
249 let truncated = truncate_head(&output, results.len(), MAX_BYTES);
250
251 let notice = truncated.truncation_notice();
252 let mut result = truncated.content;
253 if results.len() >= limit {
254 result.push_str(&format!("\n[Reached limit: {} results]", limit));
255 }
256 if let Some(notice) = notice {
257 result.push_str(&format!("\n{}", notice));
258 }
259
260 Ok(ToolOutput::success(result).with_metadata(json!({
261 "count": results.len(),
262 "limit_reached": results.len() >= limit,
263 })))
264 }
265}
266
267#[cfg(test)]
268mod tests {
269 use super::*;
270 use soul_core::vfs::MemoryFs;
271
272 async fn setup() -> (Arc<MemoryFs>, FindTool) {
273 let fs = Arc::new(MemoryFs::new());
274 let tool = FindTool::new(fs.clone() as Arc<dyn VirtualFs>, "/project");
275 (fs, tool)
276 }
277
278 async fn populate(fs: &MemoryFs) {
279 fs.write("/project/src/main.rs", "fn main() {}")
280 .await
281 .unwrap();
282 fs.write("/project/src/lib.rs", "pub mod foo;")
283 .await
284 .unwrap();
285 fs.write("/project/src/utils.ts", "export {}")
286 .await
287 .unwrap();
288 fs.write("/project/Cargo.toml", "[package]").await.unwrap();
289 fs.write("/project/README.md", "# readme").await.unwrap();
290 }
291
292 #[tokio::test]
293 async fn find_by_extension() {
294 let (fs, tool) = setup().await;
295 populate(&*fs).await;
296
297 let result = tool
298 .execute("c1", json!({"pattern": "*.rs"}), None)
299 .await
300 .unwrap();
301
302 assert!(!result.is_error);
303 assert!(result.content.contains("main.rs"));
304 assert!(result.content.contains("lib.rs"));
305 assert!(!result.content.contains("utils.ts"));
306 }
307
308 #[tokio::test]
309 async fn find_exact_name() {
310 let (fs, tool) = setup().await;
311 populate(&*fs).await;
312
313 let result = tool
314 .execute("c2", json!({"pattern": "Cargo.toml"}), None)
315 .await
316 .unwrap();
317
318 assert!(!result.is_error);
319 assert!(result.content.contains("Cargo.toml"));
320 assert_eq!(result.metadata["count"].as_u64().unwrap(), 1);
321 }
322
323 #[tokio::test]
324 async fn find_no_results() {
325 let (fs, tool) = setup().await;
326 populate(&*fs).await;
327
328 let result = tool
329 .execute("c3", json!({"pattern": "*.py"}), None)
330 .await
331 .unwrap();
332
333 assert!(!result.is_error);
334 assert!(result.content.contains("No files"));
335 }
336
337 #[tokio::test]
338 async fn find_with_limit() {
339 let (fs, tool) = setup().await;
340 populate(&*fs).await;
341
342 let result = tool
343 .execute("c4", json!({"pattern": "*", "limit": 2}), None)
344 .await
345 .unwrap();
346
347 assert!(!result.is_error);
348 assert_eq!(result.metadata["count"].as_u64().unwrap(), 2);
349 }
350
351 #[tokio::test]
352 async fn find_empty_pattern() {
353 let (_fs, tool) = setup().await;
354 let result = tool
355 .execute("c5", json!({"pattern": ""}), None)
356 .await
357 .unwrap();
358 assert!(result.is_error);
359 }
360
361 #[test]
362 fn glob_extensions() {
363 assert!(matches_glob("file.rs", "/src/file.rs", "*.rs"));
364 assert!(!matches_glob("file.ts", "/src/file.ts", "*.rs"));
365 }
366
367 #[test]
368 fn glob_prefix() {
369 assert!(matches_glob("Cargo.toml", "/Cargo.toml", "Cargo*"));
370 assert!(!matches_glob("package.json", "/package.json", "Cargo*"));
371 }
372
373 #[test]
374 fn glob_exact() {
375 assert!(matches_glob("Makefile", "/Makefile", "Makefile"));
376 assert!(!matches_glob("makefile", "/makefile", "Makefile"));
377 }
378
379 #[tokio::test]
380 async fn tool_name_and_definition() {
381 let (_fs, tool) = setup().await;
382 assert_eq!(tool.name(), "find");
383 let def = tool.definition();
384 assert_eq!(def.name, "find");
385 }
386}