Skip to main content

soul_coder/tools/
find.rs

1//! Find tool — search for files by name/glob pattern.
2//!
3//! Uses VirtualFs for WASM compatibility. Recursively walks directories
4//! and matches filenames against glob patterns.
5
6use std::sync::Arc;
7
8use async_trait::async_trait;
9use serde_json::json;
10use tokio::sync::mpsc;
11
12use soul_core::error::SoulResult;
13use soul_core::tool::{Tool, ToolOutput};
14use soul_core::types::ToolDefinition;
15use soul_core::vfs::VirtualFs;
16
17use crate::truncate::{truncate_head, MAX_BYTES};
18
19/// Maximum results returned.
20const MAX_RESULTS: usize = 1000;
21
22pub struct FindTool {
23    fs: Arc<dyn VirtualFs>,
24    cwd: String,
25}
26
27impl FindTool {
28    pub fn new(fs: Arc<dyn VirtualFs>, cwd: impl Into<String>) -> Self {
29        Self {
30            fs,
31            cwd: cwd.into(),
32        }
33    }
34
35    fn resolve_path(&self, path: &str) -> String {
36        if path.starts_with('/') {
37            path.to_string()
38        } else {
39            format!("{}/{}", self.cwd.trim_end_matches('/'), path)
40        }
41    }
42}
43
44/// Match a filename against a glob pattern.
45/// Supports: *.ext, prefix*, *suffix, exact match, **/ (recursive, treated as *)
46fn matches_glob(name: &str, full_path: &str, pattern: &str) -> bool {
47    let pattern = pattern.trim();
48
49    // Handle **/ patterns (recursive) - match against full path
50    if pattern.contains("**/") || pattern.contains("/**") {
51        let simple = pattern.replace("**/", "").replace("/**", "");
52        return matches_simple_glob(name, &simple) || matches_simple_glob(full_path, pattern);
53    }
54
55    // Handle path patterns (containing /)
56    if pattern.contains('/') {
57        return path_matches_glob(full_path, pattern);
58    }
59
60    matches_simple_glob(name, pattern)
61}
62
63fn matches_simple_glob(name: &str, pattern: &str) -> bool {
64    if pattern == "*" {
65        return true;
66    }
67
68    if pattern.starts_with("*.") {
69        let ext = &pattern[1..];
70        return name.ends_with(ext);
71    }
72
73    if pattern.starts_with('*') && pattern.ends_with('*') && pattern.len() > 2 {
74        let middle = &pattern[1..pattern.len() - 1];
75        return name.contains(middle);
76    }
77
78    if pattern.starts_with('*') {
79        let suffix = &pattern[1..];
80        return name.ends_with(suffix);
81    }
82
83    if pattern.ends_with('*') {
84        let prefix = &pattern[..pattern.len() - 1];
85        return name.starts_with(prefix);
86    }
87
88    name == pattern
89}
90
91fn path_matches_glob(path: &str, pattern: &str) -> bool {
92    let path_parts: Vec<&str> = path.split('/').filter(|s| !s.is_empty()).collect();
93    let pattern_parts: Vec<&str> = pattern.split('/').filter(|s| !s.is_empty()).collect();
94
95    if pattern_parts.is_empty() {
96        return true;
97    }
98
99    // Match from the end (most specific part first)
100    let mut pi = pattern_parts.len();
101    let mut qi = path_parts.len();
102
103    while pi > 0 && qi > 0 {
104        pi -= 1;
105        qi -= 1;
106        if pattern_parts[pi] == "**" {
107            return true; // Matches any depth
108        }
109        if !matches_simple_glob(path_parts[qi], pattern_parts[pi]) {
110            return false;
111        }
112    }
113
114    pi == 0
115}
116
117/// Recursively collect matching files.
118async fn find_files(
119    fs: &dyn VirtualFs,
120    dir: &str,
121    pattern: &str,
122    results: &mut Vec<String>,
123    limit: usize,
124) -> SoulResult<()> {
125    if results.len() >= limit {
126        return Ok(());
127    }
128
129    let entries = match fs.read_dir(dir).await {
130        Ok(e) => e,
131        Err(_) => return Ok(()), // Skip unreadable dirs
132    };
133
134    for entry in entries {
135        if results.len() >= limit {
136            break;
137        }
138
139        let path = if dir == "/" || dir.is_empty() {
140            format!("/{}", entry.name)
141        } else {
142            format!("{}/{}", dir.trim_end_matches('/'), entry.name)
143        };
144
145        if entry.is_dir {
146            if !entry.name.starts_with('.') {
147                Box::pin(find_files(fs, &path, pattern, results, limit)).await?;
148            }
149        } else if entry.is_file && matches_glob(&entry.name, &path, pattern) {
150            results.push(path);
151        }
152    }
153
154    Ok(())
155}
156
157#[async_trait]
158impl Tool for FindTool {
159    fn name(&self) -> &str {
160        "find"
161    }
162
163    fn definition(&self) -> ToolDefinition {
164        ToolDefinition {
165            name: "find".into(),
166            description: "Find files matching a glob pattern. Returns matching file paths.".into(),
167            input_schema: json!({
168                "type": "object",
169                "properties": {
170                    "pattern": {
171                        "type": "string",
172                        "description": "Glob pattern to match files (e.g., '*.rs', 'src/**/*.ts', 'Cargo.toml')"
173                    },
174                    "path": {
175                        "type": "string",
176                        "description": "Directory to search in (defaults to working directory)"
177                    },
178                    "limit": {
179                        "type": "integer",
180                        "description": "Maximum number of results (default: 1000)"
181                    }
182                },
183                "required": ["pattern"]
184            }),
185        }
186    }
187
188    async fn execute(
189        &self,
190        _call_id: &str,
191        arguments: serde_json::Value,
192        _partial_tx: Option<mpsc::UnboundedSender<String>>,
193    ) -> SoulResult<ToolOutput> {
194        let pattern = arguments
195            .get("pattern")
196            .and_then(|v| v.as_str())
197            .unwrap_or("");
198
199        if pattern.is_empty() {
200            return Ok(ToolOutput::error("Missing required parameter: pattern"));
201        }
202
203        let search_path = arguments
204            .get("path")
205            .and_then(|v| v.as_str())
206            .map(|p| self.resolve_path(p))
207            .unwrap_or_else(|| self.cwd.clone());
208
209        let limit = arguments
210            .get("limit")
211            .and_then(|v| v.as_u64())
212            .map(|v| (v as usize).min(MAX_RESULTS))
213            .unwrap_or(MAX_RESULTS);
214
215        let mut results = Vec::new();
216        if let Err(e) =
217            find_files(self.fs.as_ref(), &search_path, pattern, &mut results, limit).await
218        {
219            return Ok(ToolOutput::error(format!(
220                "Failed to search {}: {}",
221                search_path, e
222            )));
223        }
224
225        results.sort();
226
227        if results.is_empty() {
228            return Ok(ToolOutput::success(format!(
229                "No files matching '{}' found",
230                pattern
231            ))
232            .with_metadata(json!({"count": 0})));
233        }
234
235        // Make paths relative to cwd
236        let cwd_prefix = format!("{}/", self.cwd.trim_end_matches('/'));
237        let relative: Vec<String> = results
238            .iter()
239            .map(|p| {
240                if p.starts_with(&cwd_prefix) {
241                    p[cwd_prefix.len()..].to_string()
242                } else {
243                    p.clone()
244                }
245            })
246            .collect();
247
248        let output = relative.join("\n");
249        let truncated = truncate_head(&output, results.len(), MAX_BYTES);
250
251        let notice = truncated.truncation_notice();
252        let mut result = truncated.content;
253        if results.len() >= limit {
254            result.push_str(&format!("\n[Reached limit: {} results]", limit));
255        }
256        if let Some(notice) = notice {
257            result.push_str(&format!("\n{}", notice));
258        }
259
260        Ok(ToolOutput::success(result).with_metadata(json!({
261            "count": results.len(),
262            "limit_reached": results.len() >= limit,
263        })))
264    }
265}
266
267#[cfg(test)]
268mod tests {
269    use super::*;
270    use soul_core::vfs::MemoryFs;
271
272    async fn setup() -> (Arc<MemoryFs>, FindTool) {
273        let fs = Arc::new(MemoryFs::new());
274        let tool = FindTool::new(fs.clone() as Arc<dyn VirtualFs>, "/project");
275        (fs, tool)
276    }
277
278    async fn populate(fs: &MemoryFs) {
279        fs.write("/project/src/main.rs", "fn main() {}")
280            .await
281            .unwrap();
282        fs.write("/project/src/lib.rs", "pub mod foo;")
283            .await
284            .unwrap();
285        fs.write("/project/src/utils.ts", "export {}")
286            .await
287            .unwrap();
288        fs.write("/project/Cargo.toml", "[package]").await.unwrap();
289        fs.write("/project/README.md", "# readme").await.unwrap();
290    }
291
292    #[tokio::test]
293    async fn find_by_extension() {
294        let (fs, tool) = setup().await;
295        populate(&*fs).await;
296
297        let result = tool
298            .execute("c1", json!({"pattern": "*.rs"}), None)
299            .await
300            .unwrap();
301
302        assert!(!result.is_error);
303        assert!(result.content.contains("main.rs"));
304        assert!(result.content.contains("lib.rs"));
305        assert!(!result.content.contains("utils.ts"));
306    }
307
308    #[tokio::test]
309    async fn find_exact_name() {
310        let (fs, tool) = setup().await;
311        populate(&*fs).await;
312
313        let result = tool
314            .execute("c2", json!({"pattern": "Cargo.toml"}), None)
315            .await
316            .unwrap();
317
318        assert!(!result.is_error);
319        assert!(result.content.contains("Cargo.toml"));
320        assert_eq!(result.metadata["count"].as_u64().unwrap(), 1);
321    }
322
323    #[tokio::test]
324    async fn find_no_results() {
325        let (fs, tool) = setup().await;
326        populate(&*fs).await;
327
328        let result = tool
329            .execute("c3", json!({"pattern": "*.py"}), None)
330            .await
331            .unwrap();
332
333        assert!(!result.is_error);
334        assert!(result.content.contains("No files"));
335    }
336
337    #[tokio::test]
338    async fn find_with_limit() {
339        let (fs, tool) = setup().await;
340        populate(&*fs).await;
341
342        let result = tool
343            .execute("c4", json!({"pattern": "*", "limit": 2}), None)
344            .await
345            .unwrap();
346
347        assert!(!result.is_error);
348        assert_eq!(result.metadata["count"].as_u64().unwrap(), 2);
349    }
350
351    #[tokio::test]
352    async fn find_empty_pattern() {
353        let (_fs, tool) = setup().await;
354        let result = tool
355            .execute("c5", json!({"pattern": ""}), None)
356            .await
357            .unwrap();
358        assert!(result.is_error);
359    }
360
361    #[test]
362    fn glob_extensions() {
363        assert!(matches_glob("file.rs", "/src/file.rs", "*.rs"));
364        assert!(!matches_glob("file.ts", "/src/file.ts", "*.rs"));
365    }
366
367    #[test]
368    fn glob_prefix() {
369        assert!(matches_glob("Cargo.toml", "/Cargo.toml", "Cargo*"));
370        assert!(!matches_glob("package.json", "/package.json", "Cargo*"));
371    }
372
373    #[test]
374    fn glob_exact() {
375        assert!(matches_glob("Makefile", "/Makefile", "Makefile"));
376        assert!(!matches_glob("makefile", "/makefile", "Makefile"));
377    }
378
379    #[tokio::test]
380    async fn tool_name_and_definition() {
381        let (_fs, tool) = setup().await;
382        assert_eq!(tool.name(), "find");
383        let def = tool.definition();
384        assert_eq!(def.name, "find");
385    }
386}