spec_ai_core/tools/builtin/
search.rs

1use crate::tools::{Tool, ToolResult};
2use anyhow::{anyhow, Context, Result};
3use async_trait::async_trait;
4use regex::RegexBuilder;
5use serde::{Deserialize, Serialize};
6use serde_json::Value;
7use std::fs;
8use std::path::{Path, PathBuf};
9use walkdir::WalkDir;
10
11const DEFAULT_MAX_RESULTS: usize = 20;
12const HARD_MAX_RESULTS: usize = 100;
13const DEFAULT_CONTEXT_LINES: usize = 2;
14const DEFAULT_MAX_FILE_BYTES: usize = 512 * 1024; // 512 KiB
15
16#[derive(Debug, Deserialize)]
17struct SearchArgs {
18    query: String,
19    root: Option<String>,
20    #[serde(default)]
21    regex: bool,
22    #[serde(default)]
23    case_sensitive: bool,
24    file_extensions: Option<Vec<String>>,
25    max_results: Option<usize>,
26    context_lines: Option<usize>,
27}
28
29#[derive(Debug, Serialize)]
30struct SearchResultEntry {
31    path: String,
32    line: usize,
33    snippet: String,
34    score: f32,
35}
36
37#[derive(Debug, Serialize)]
38struct SearchResponse {
39    query: String,
40    results: Vec<SearchResultEntry>,
41}
42
43/// Tool that searches local files for literal or regex matches
44pub struct SearchTool {
45    root: PathBuf,
46    max_file_bytes: usize,
47}
48
49impl SearchTool {
50    pub fn new() -> Self {
51        let root = std::env::current_dir().unwrap_or_else(|_| PathBuf::from("."));
52        Self {
53            root,
54            max_file_bytes: DEFAULT_MAX_FILE_BYTES,
55        }
56    }
57
58    pub fn with_root(mut self, root: impl Into<PathBuf>) -> Self {
59        self.root = root.into();
60        self
61    }
62
63    pub fn with_max_file_bytes(mut self, max_file_bytes: usize) -> Self {
64        self.max_file_bytes = max_file_bytes;
65        self
66    }
67
68    fn resolve_root(&self, override_root: &Option<String>) -> PathBuf {
69        override_root
70            .as_ref()
71            .map(PathBuf::from)
72            .unwrap_or_else(|| self.root.clone())
73    }
74
75    fn filter_extension(&self, path: &Path, allowed: &Option<Vec<String>>) -> bool {
76        match allowed {
77            None => true,
78            Some(list) if list.is_empty() => true,
79            Some(list) => {
80                if let Some(ext) = path.extension().and_then(|s| s.to_str()) {
81                    let ext = ext.trim_start_matches('.');
82                    list.iter().any(|allowed_ext| {
83                        allowed_ext
84                            .trim_start_matches('.')
85                            .eq_ignore_ascii_case(ext)
86                    })
87                } else {
88                    false
89                }
90            }
91        }
92    }
93
94    fn literal_match(
95        &self,
96        query: &str,
97        line: &str,
98        case_sensitive: bool,
99    ) -> Option<(usize, usize)> {
100        if case_sensitive {
101            line.find(query).map(|start| (start, start + query.len()))
102        } else {
103            let lower_line = line.to_lowercase();
104            let lower_query = query.to_lowercase();
105            lower_line
106                .find(&lower_query)
107                .map(|start| (start, start + lower_query.len()))
108        }
109    }
110
111    fn build_snippet(lines: &[String], idx: usize, context_lines: usize) -> String {
112        let start = idx.saturating_sub(context_lines);
113        let end = (idx + context_lines).min(lines.len().saturating_sub(1));
114        lines[start..=end].join("\n")
115    }
116}
117
118impl Default for SearchTool {
119    fn default() -> Self {
120        Self::new()
121    }
122}
123
124#[async_trait]
125impl Tool for SearchTool {
126    fn name(&self) -> &str {
127        "search"
128    }
129
130    fn description(&self) -> &str {
131        "Searches local files using literal or regex queries"
132    }
133
134    fn parameters(&self) -> Value {
135        serde_json::json!({
136            "type": "object",
137            "properties": {
138                "query": {
139                    "type": "string",
140                    "description": "Query string or regex pattern"
141                },
142                "root": {
143                    "type": "string",
144                    "description": "Directory to search (defaults to current workspace)"
145                },
146                "regex": {
147                    "type": "boolean",
148                    "description": "Interpret query as regular expression",
149                    "default": false
150                },
151                "case_sensitive": {
152                    "type": "boolean",
153                    "description": "Case sensitive search (default false for literal matches)",
154                    "default": false
155                },
156                "file_extensions": {
157                    "type": "array",
158                    "items": {"type": "string"},
159                    "description": "Limit search to specific file extensions"
160                },
161                "max_results": {
162                    "type": "integer",
163                    "description": "Maximum number of results to return (max 100)"
164                },
165                "context_lines": {
166                    "type": "integer",
167                    "description": "Number of lines of context around matches",
168                    "default": 2
169                }
170            },
171            "required": ["query"]
172        })
173    }
174
175    async fn execute(&self, args: Value) -> Result<ToolResult> {
176        let args: SearchArgs =
177            serde_json::from_value(args).context("Failed to parse search arguments")?;
178
179        if args.query.trim().is_empty() {
180            return Err(anyhow!("search query cannot be empty"));
181        }
182
183        let root = self.resolve_root(&args.root);
184        if !root.exists() {
185            return Err(anyhow!("Search root {} does not exist", root.display()));
186        }
187
188        let max_results = args
189            .max_results
190            .unwrap_or(DEFAULT_MAX_RESULTS)
191            .clamp(1, HARD_MAX_RESULTS);
192        let context_lines = args.context_lines.unwrap_or(DEFAULT_CONTEXT_LINES);
193
194        let regex = if args.regex {
195            Some(
196                RegexBuilder::new(&args.query)
197                    .case_insensitive(!args.case_sensitive)
198                    .build()
199                    .context("Invalid regular expression for search")?,
200            )
201        } else {
202            None
203        };
204
205        let mut results = Vec::new();
206
207        for entry in WalkDir::new(root)
208            .follow_links(false)
209            .into_iter()
210            .filter_map(|e| e.ok())
211        {
212            if results.len() >= max_results {
213                break;
214            }
215
216            let path = entry.path();
217            if !entry.file_type().is_file() {
218                continue;
219            }
220
221            if !self.filter_extension(path, &args.file_extensions) {
222                continue;
223            }
224
225            let metadata = match entry.metadata() {
226                Ok(meta) => meta,
227                Err(_) => continue,
228            };
229
230            if metadata.len() as usize > self.max_file_bytes {
231                continue;
232            }
233
234            let data = match fs::read(path) {
235                Ok(bytes) => bytes,
236                Err(_) => continue,
237            };
238
239            let content = match String::from_utf8(data) {
240                Ok(text) => text,
241                Err(_) => continue,
242            };
243
244            let lines: Vec<String> = content.lines().map(|line| line.to_string()).collect();
245
246            for (idx, line) in lines.iter().enumerate() {
247                if results.len() >= max_results {
248                    break;
249                }
250
251                let maybe_span = if let Some(regex) = &regex {
252                    regex.find(line).map(|m| (m.start(), m.end()))
253                } else {
254                    self.literal_match(&args.query, line, args.case_sensitive)
255                };
256
257                if maybe_span.is_none() {
258                    continue;
259                }
260
261                let snippet = Self::build_snippet(&lines, idx, context_lines);
262                let score = 1.0 / (1.0 + idx as f32);
263
264                results.push(SearchResultEntry {
265                    path: path.display().to_string(),
266                    line: idx + 1,
267                    snippet,
268                    score,
269                });
270            }
271        }
272
273        let response = SearchResponse {
274            query: args.query,
275            results,
276        };
277
278        Ok(ToolResult::success(
279            serde_json::to_string(&response).context("Failed to serialize search results")?,
280        ))
281    }
282}
283
284#[cfg(test)]
285mod tests {
286    use super::*;
287    use std::fs;
288    use tempfile::tempdir;
289
290    #[tokio::test]
291    async fn test_literal_search() {
292        let dir = tempdir().unwrap();
293        let file_path = dir.path().join("sample.txt");
294        fs::write(&file_path, "hello search tool\nsecond line\nhello again").unwrap();
295
296        let tool = SearchTool::new().with_root(dir.path());
297        let args = serde_json::json!({
298            "query": "hello",
299            "root": dir.path().to_string_lossy(),
300            "max_results": 5
301        });
302
303        let result = tool.execute(args).await.unwrap();
304        assert!(result.success);
305        let payload: serde_json::Value = serde_json::from_str(&result.output).unwrap();
306        assert!(payload["results"].as_array().unwrap().len() >= 2);
307    }
308
309    #[tokio::test]
310    async fn test_regex_search() {
311        let dir = tempdir().unwrap();
312        let file_path = dir.path().join("module.rs");
313        fs::write(&file_path, "fn test_case() {}\nfn demo_case() {}\n").unwrap();
314
315        let tool = SearchTool::new().with_root(dir.path());
316        let args = serde_json::json!({
317            "query": "fn\\s+test_\\w+",
318            "regex": true,
319            "root": dir.path().to_string_lossy(),
320            "file_extensions": ["rs"]
321        });
322
323        let result = tool.execute(args).await.unwrap();
324        assert!(result.success);
325        let payload: serde_json::Value = serde_json::from_str(&result.output).unwrap();
326        assert_eq!(payload["results"].as_array().unwrap().len(), 1);
327    }
328}