Skip to main content

soul_coder/tools/
find.rs

1//! Find tool — search for files by name/glob pattern.
2//!
3//! Uses VirtualFs for WASM compatibility. Recursively walks directories
4//! and matches filenames against glob patterns.
5
6use std::sync::Arc;
7
8use async_trait::async_trait;
9use serde_json::json;
10use tokio::sync::mpsc;
11
12use soul_core::error::SoulResult;
13use soul_core::tool::{Tool, ToolOutput};
14use soul_core::types::ToolDefinition;
15use soul_core::vfs::VirtualFs;
16
17use crate::truncate::{truncate_head, MAX_BYTES};
18
19/// Maximum results returned.
20const MAX_RESULTS: usize = 1000;
21
22use super::resolve_path;
23
24pub struct FindTool {
25    fs: Arc<dyn VirtualFs>,
26    cwd: String,
27}
28
29impl FindTool {
30    pub fn new(fs: Arc<dyn VirtualFs>, cwd: impl Into<String>) -> Self {
31        Self {
32            fs,
33            cwd: cwd.into(),
34        }
35    }
36}
37
38/// Match a filename against a glob pattern.
39/// Supports: *.ext, prefix*, *suffix, exact match, **/ (recursive, treated as *)
40fn matches_glob(name: &str, full_path: &str, pattern: &str) -> bool {
41    let pattern = pattern.trim();
42
43    // Handle **/ patterns (recursive) - match against full path
44    if pattern.contains("**/") || pattern.contains("/**") {
45        let simple = pattern.replace("**/", "").replace("/**", "");
46        return matches_simple_glob(name, &simple) || matches_simple_glob(full_path, pattern);
47    }
48
49    // Handle path patterns (containing /)
50    if pattern.contains('/') {
51        return path_matches_glob(full_path, pattern);
52    }
53
54    matches_simple_glob(name, pattern)
55}
56
57fn matches_simple_glob(name: &str, pattern: &str) -> bool {
58    if pattern == "*" {
59        return true;
60    }
61
62    if pattern.starts_with("*.") {
63        let ext = &pattern[1..];
64        return name.ends_with(ext);
65    }
66
67    if pattern.starts_with('*') && pattern.ends_with('*') && pattern.len() > 2 {
68        let middle = &pattern[1..pattern.len() - 1];
69        return name.contains(middle);
70    }
71
72    if pattern.starts_with('*') {
73        let suffix = &pattern[1..];
74        return name.ends_with(suffix);
75    }
76
77    if pattern.ends_with('*') {
78        let prefix = &pattern[..pattern.len() - 1];
79        return name.starts_with(prefix);
80    }
81
82    name == pattern
83}
84
85fn path_matches_glob(path: &str, pattern: &str) -> bool {
86    let path_parts: Vec<&str> = path.split('/').filter(|s| !s.is_empty()).collect();
87    let pattern_parts: Vec<&str> = pattern.split('/').filter(|s| !s.is_empty()).collect();
88
89    if pattern_parts.is_empty() {
90        return true;
91    }
92
93    // Match from the end (most specific part first)
94    let mut pi = pattern_parts.len();
95    let mut qi = path_parts.len();
96
97    while pi > 0 && qi > 0 {
98        pi -= 1;
99        qi -= 1;
100        if pattern_parts[pi] == "**" {
101            return true; // Matches any depth
102        }
103        if !matches_simple_glob(path_parts[qi], pattern_parts[pi]) {
104            return false;
105        }
106    }
107
108    pi == 0
109}
110
111/// Recursively collect matching files.
112async fn find_files(
113    fs: &dyn VirtualFs,
114    dir: &str,
115    pattern: &str,
116    results: &mut Vec<String>,
117    limit: usize,
118) -> SoulResult<()> {
119    if results.len() >= limit {
120        return Ok(());
121    }
122
123    let entries = match fs.read_dir(dir).await {
124        Ok(e) => e,
125        Err(_) => return Ok(()), // Skip unreadable dirs
126    };
127
128    for entry in entries {
129        if results.len() >= limit {
130            break;
131        }
132
133        let path = if dir == "/" || dir.is_empty() {
134            format!("/{}", entry.name)
135        } else {
136            format!("{}/{}", dir.trim_end_matches('/'), entry.name)
137        };
138
139        if entry.is_dir {
140            if !entry.name.starts_with('.') {
141                Box::pin(find_files(fs, &path, pattern, results, limit)).await?;
142            }
143        } else if entry.is_file && matches_glob(&entry.name, &path, pattern) {
144            results.push(path);
145        }
146    }
147
148    Ok(())
149}
150
151#[cfg_attr(not(target_arch = "wasm32"), async_trait)]
152#[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
153impl Tool for FindTool {
154    fn name(&self) -> &str {
155        "find"
156    }
157
158    fn definition(&self) -> ToolDefinition {
159        ToolDefinition {
160            name: "find".into(),
161            description: "Find files matching a glob pattern. Returns matching file paths.".into(),
162            input_schema: json!({
163                "type": "object",
164                "properties": {
165                    "pattern": {
166                        "type": "string",
167                        "description": "Glob pattern to match files (e.g., '*.rs', 'src/**/*.ts', 'Cargo.toml')"
168                    },
169                    "path": {
170                        "type": "string",
171                        "description": "Directory to search in (defaults to working directory)"
172                    },
173                    "limit": {
174                        "type": "integer",
175                        "description": "Maximum number of results (default: 1000)"
176                    }
177                },
178                "required": ["pattern"]
179            }),
180        }
181    }
182
183    async fn execute(
184        &self,
185        _call_id: &str,
186        arguments: serde_json::Value,
187        _partial_tx: Option<mpsc::UnboundedSender<String>>,
188    ) -> SoulResult<ToolOutput> {
189        let pattern = arguments
190            .get("pattern")
191            .and_then(|v| v.as_str())
192            .unwrap_or("");
193
194        if pattern.is_empty() {
195            return Ok(ToolOutput::error("Missing required parameter: pattern"));
196        }
197
198        let search_path = arguments
199            .get("path")
200            .and_then(|v| v.as_str())
201            .map(|p| resolve_path(&self.cwd, p))
202            .unwrap_or_else(|| self.cwd.clone());
203
204        let limit = arguments
205            .get("limit")
206            .and_then(|v| v.as_u64())
207            .map(|v| (v as usize).min(MAX_RESULTS))
208            .unwrap_or(MAX_RESULTS);
209
210        let mut results = Vec::new();
211        if let Err(e) =
212            find_files(self.fs.as_ref(), &search_path, pattern, &mut results, limit).await
213        {
214            return Ok(ToolOutput::error(format!(
215                "Failed to search {}: {}",
216                search_path, e
217            )));
218        }
219
220        results.sort();
221
222        if results.is_empty() {
223            return Ok(ToolOutput::success(format!(
224                "No files matching '{}' found",
225                pattern
226            ))
227            .with_metadata(json!({"count": 0})));
228        }
229
230        // Make paths relative to cwd
231        let cwd_prefix = format!("{}/", self.cwd.trim_end_matches('/'));
232        let relative: Vec<String> = results
233            .iter()
234            .map(|p| {
235                if p.starts_with(&cwd_prefix) {
236                    p[cwd_prefix.len()..].to_string()
237                } else {
238                    p.clone()
239                }
240            })
241            .collect();
242
243        let output = relative.join("\n");
244        let truncated = truncate_head(&output, results.len(), MAX_BYTES);
245
246        let notice = truncated.truncation_notice();
247        let mut result = truncated.content;
248        if results.len() >= limit {
249            result.push_str(&format!("\n[Reached limit: {} results]", limit));
250        }
251        if let Some(notice) = notice {
252            result.push_str(&format!("\n{}", notice));
253        }
254
255        Ok(ToolOutput::success(result).with_metadata(json!({
256            "count": results.len(),
257            "limit_reached": results.len() >= limit,
258        })))
259    }
260}
261
262#[cfg(test)]
263mod tests {
264    use super::*;
265    use soul_core::vfs::MemoryFs;
266
267    async fn setup() -> (Arc<MemoryFs>, FindTool) {
268        let fs = Arc::new(MemoryFs::new());
269        let tool = FindTool::new(fs.clone() as Arc<dyn VirtualFs>, "/project");
270        (fs, tool)
271    }
272
273    async fn populate(fs: &MemoryFs) {
274        fs.write("/project/src/main.rs", "fn main() {}")
275            .await
276            .unwrap();
277        fs.write("/project/src/lib.rs", "pub mod foo;")
278            .await
279            .unwrap();
280        fs.write("/project/src/utils.ts", "export {}")
281            .await
282            .unwrap();
283        fs.write("/project/Cargo.toml", "[package]").await.unwrap();
284        fs.write("/project/README.md", "# readme").await.unwrap();
285    }
286
287    #[tokio::test]
288    async fn find_by_extension() {
289        let (fs, tool) = setup().await;
290        populate(&*fs).await;
291
292        let result = tool
293            .execute("c1", json!({"pattern": "*.rs"}), None)
294            .await
295            .unwrap();
296
297        assert!(!result.is_error);
298        assert!(result.content.contains("main.rs"));
299        assert!(result.content.contains("lib.rs"));
300        assert!(!result.content.contains("utils.ts"));
301    }
302
303    #[tokio::test]
304    async fn find_exact_name() {
305        let (fs, tool) = setup().await;
306        populate(&*fs).await;
307
308        let result = tool
309            .execute("c2", json!({"pattern": "Cargo.toml"}), None)
310            .await
311            .unwrap();
312
313        assert!(!result.is_error);
314        assert!(result.content.contains("Cargo.toml"));
315        assert_eq!(result.metadata["count"].as_u64().unwrap(), 1);
316    }
317
318    #[tokio::test]
319    async fn find_no_results() {
320        let (fs, tool) = setup().await;
321        populate(&*fs).await;
322
323        let result = tool
324            .execute("c3", json!({"pattern": "*.py"}), None)
325            .await
326            .unwrap();
327
328        assert!(!result.is_error);
329        assert!(result.content.contains("No files"));
330    }
331
332    #[tokio::test]
333    async fn find_with_limit() {
334        let (fs, tool) = setup().await;
335        populate(&*fs).await;
336
337        let result = tool
338            .execute("c4", json!({"pattern": "*", "limit": 2}), None)
339            .await
340            .unwrap();
341
342        assert!(!result.is_error);
343        assert_eq!(result.metadata["count"].as_u64().unwrap(), 2);
344    }
345
346    #[tokio::test]
347    async fn find_empty_pattern() {
348        let (_fs, tool) = setup().await;
349        let result = tool
350            .execute("c5", json!({"pattern": ""}), None)
351            .await
352            .unwrap();
353        assert!(result.is_error);
354    }
355
356    #[test]
357    fn glob_extensions() {
358        assert!(matches_glob("file.rs", "/src/file.rs", "*.rs"));
359        assert!(!matches_glob("file.ts", "/src/file.ts", "*.rs"));
360    }
361
362    #[test]
363    fn glob_prefix() {
364        assert!(matches_glob("Cargo.toml", "/Cargo.toml", "Cargo*"));
365        assert!(!matches_glob("package.json", "/package.json", "Cargo*"));
366    }
367
368    #[test]
369    fn glob_exact() {
370        assert!(matches_glob("Makefile", "/Makefile", "Makefile"));
371        assert!(!matches_glob("makefile", "/makefile", "Makefile"));
372    }
373
374    #[tokio::test]
375    async fn tool_name_and_definition() {
376        let (_fs, tool) = setup().await;
377        assert_eq!(tool.name(), "find");
378        let def = tool.definition();
379        assert_eq!(def.name, "find");
380    }
381}