spec_ai_core/tools/builtin/
search.rs1use crate::tools::{Tool, ToolResult};
2use anyhow::{anyhow, Context, Result};
3use async_trait::async_trait;
4use regex::RegexBuilder;
5use serde::{Deserialize, Serialize};
6use serde_json::Value;
7use std::fs;
8use std::path::{Path, PathBuf};
9use walkdir::WalkDir;
10
11const DEFAULT_MAX_RESULTS: usize = 20;
12const HARD_MAX_RESULTS: usize = 100;
13const DEFAULT_CONTEXT_LINES: usize = 2;
14const DEFAULT_MAX_FILE_BYTES: usize = 512 * 1024; #[derive(Debug, Deserialize)]
17struct SearchArgs {
18 query: String,
19 root: Option<String>,
20 #[serde(default)]
21 regex: bool,
22 #[serde(default)]
23 case_sensitive: bool,
24 file_extensions: Option<Vec<String>>,
25 max_results: Option<usize>,
26 context_lines: Option<usize>,
27}
28
29#[derive(Debug, Serialize)]
30struct SearchResultEntry {
31 path: String,
32 line: usize,
33 snippet: String,
34 score: f32,
35}
36
37#[derive(Debug, Serialize)]
38struct SearchResponse {
39 query: String,
40 results: Vec<SearchResultEntry>,
41}
42
43pub struct SearchTool {
45 root: PathBuf,
46 max_file_bytes: usize,
47}
48
49impl SearchTool {
50 pub fn new() -> Self {
51 let root = std::env::current_dir().unwrap_or_else(|_| PathBuf::from("."));
52 Self {
53 root,
54 max_file_bytes: DEFAULT_MAX_FILE_BYTES,
55 }
56 }
57
58 pub fn with_root(mut self, root: impl Into<PathBuf>) -> Self {
59 self.root = root.into();
60 self
61 }
62
63 pub fn with_max_file_bytes(mut self, max_file_bytes: usize) -> Self {
64 self.max_file_bytes = max_file_bytes;
65 self
66 }
67
68 fn resolve_root(&self, override_root: &Option<String>) -> PathBuf {
69 override_root
70 .as_ref()
71 .map(PathBuf::from)
72 .unwrap_or_else(|| self.root.clone())
73 }
74
75 fn filter_extension(&self, path: &Path, allowed: &Option<Vec<String>>) -> bool {
76 match allowed {
77 None => true,
78 Some(list) if list.is_empty() => true,
79 Some(list) => {
80 if let Some(ext) = path.extension().and_then(|s| s.to_str()) {
81 let ext = ext.trim_start_matches('.');
82 list.iter().any(|allowed_ext| {
83 allowed_ext
84 .trim_start_matches('.')
85 .eq_ignore_ascii_case(ext)
86 })
87 } else {
88 false
89 }
90 }
91 }
92 }
93
94 fn literal_match(
95 &self,
96 query: &str,
97 line: &str,
98 case_sensitive: bool,
99 ) -> Option<(usize, usize)> {
100 if case_sensitive {
101 line.find(query).map(|start| (start, start + query.len()))
102 } else {
103 let lower_line = line.to_lowercase();
104 let lower_query = query.to_lowercase();
105 lower_line
106 .find(&lower_query)
107 .map(|start| (start, start + lower_query.len()))
108 }
109 }
110
111 fn build_snippet(lines: &[String], idx: usize, context_lines: usize) -> String {
112 let start = idx.saturating_sub(context_lines);
113 let end = (idx + context_lines).min(lines.len().saturating_sub(1));
114 lines[start..=end].join("\n")
115 }
116}
117
118impl Default for SearchTool {
119 fn default() -> Self {
120 Self::new()
121 }
122}
123
124#[async_trait]
125impl Tool for SearchTool {
126 fn name(&self) -> &str {
127 "search"
128 }
129
130 fn description(&self) -> &str {
131 "Searches local files using literal or regex queries"
132 }
133
134 fn parameters(&self) -> Value {
135 serde_json::json!({
136 "type": "object",
137 "properties": {
138 "query": {
139 "type": "string",
140 "description": "Query string or regex pattern"
141 },
142 "root": {
143 "type": "string",
144 "description": "Directory to search (defaults to current workspace)"
145 },
146 "regex": {
147 "type": "boolean",
148 "description": "Interpret query as regular expression",
149 "default": false
150 },
151 "case_sensitive": {
152 "type": "boolean",
153 "description": "Case sensitive search (default false for literal matches)",
154 "default": false
155 },
156 "file_extensions": {
157 "type": "array",
158 "items": {"type": "string"},
159 "description": "Limit search to specific file extensions"
160 },
161 "max_results": {
162 "type": "integer",
163 "description": "Maximum number of results to return (max 100)"
164 },
165 "context_lines": {
166 "type": "integer",
167 "description": "Number of lines of context around matches",
168 "default": 2
169 }
170 },
171 "required": ["query"]
172 })
173 }
174
175 async fn execute(&self, args: Value) -> Result<ToolResult> {
176 let args: SearchArgs =
177 serde_json::from_value(args).context("Failed to parse search arguments")?;
178
179 if args.query.trim().is_empty() {
180 return Err(anyhow!("search query cannot be empty"));
181 }
182
183 let root = self.resolve_root(&args.root);
184 if !root.exists() {
185 return Err(anyhow!("Search root {} does not exist", root.display()));
186 }
187
188 let max_results = args
189 .max_results
190 .unwrap_or(DEFAULT_MAX_RESULTS)
191 .clamp(1, HARD_MAX_RESULTS);
192 let context_lines = args.context_lines.unwrap_or(DEFAULT_CONTEXT_LINES);
193
194 let regex = if args.regex {
195 Some(
196 RegexBuilder::new(&args.query)
197 .case_insensitive(!args.case_sensitive)
198 .build()
199 .context("Invalid regular expression for search")?,
200 )
201 } else {
202 None
203 };
204
205 let mut results = Vec::new();
206
207 for entry in WalkDir::new(root)
208 .follow_links(false)
209 .into_iter()
210 .filter_map(|e| e.ok())
211 {
212 if results.len() >= max_results {
213 break;
214 }
215
216 let path = entry.path();
217 if !entry.file_type().is_file() {
218 continue;
219 }
220
221 if !self.filter_extension(path, &args.file_extensions) {
222 continue;
223 }
224
225 let metadata = match entry.metadata() {
226 Ok(meta) => meta,
227 Err(_) => continue,
228 };
229
230 if metadata.len() as usize > self.max_file_bytes {
231 continue;
232 }
233
234 let data = match fs::read(path) {
235 Ok(bytes) => bytes,
236 Err(_) => continue,
237 };
238
239 let content = match String::from_utf8(data) {
240 Ok(text) => text,
241 Err(_) => continue,
242 };
243
244 let lines: Vec<String> = content.lines().map(|line| line.to_string()).collect();
245
246 for (idx, line) in lines.iter().enumerate() {
247 if results.len() >= max_results {
248 break;
249 }
250
251 let maybe_span = if let Some(regex) = ®ex {
252 regex.find(line).map(|m| (m.start(), m.end()))
253 } else {
254 self.literal_match(&args.query, line, args.case_sensitive)
255 };
256
257 if maybe_span.is_none() {
258 continue;
259 }
260
261 let snippet = Self::build_snippet(&lines, idx, context_lines);
262 let score = 1.0 / (1.0 + idx as f32);
263
264 results.push(SearchResultEntry {
265 path: path.display().to_string(),
266 line: idx + 1,
267 snippet,
268 score,
269 });
270 }
271 }
272
273 let response = SearchResponse {
274 query: args.query,
275 results,
276 };
277
278 Ok(ToolResult::success(
279 serde_json::to_string(&response).context("Failed to serialize search results")?,
280 ))
281 }
282}
283
284#[cfg(test)]
285mod tests {
286 use super::*;
287 use std::fs;
288 use tempfile::tempdir;
289
290 #[tokio::test]
291 async fn test_literal_search() {
292 let dir = tempdir().unwrap();
293 let file_path = dir.path().join("sample.txt");
294 fs::write(&file_path, "hello search tool\nsecond line\nhello again").unwrap();
295
296 let tool = SearchTool::new().with_root(dir.path());
297 let args = serde_json::json!({
298 "query": "hello",
299 "root": dir.path().to_string_lossy(),
300 "max_results": 5
301 });
302
303 let result = tool.execute(args).await.unwrap();
304 assert!(result.success);
305 let payload: serde_json::Value = serde_json::from_str(&result.output).unwrap();
306 assert!(payload["results"].as_array().unwrap().len() >= 2);
307 }
308
309 #[tokio::test]
310 async fn test_regex_search() {
311 let dir = tempdir().unwrap();
312 let file_path = dir.path().join("module.rs");
313 fs::write(&file_path, "fn test_case() {}\nfn demo_case() {}\n").unwrap();
314
315 let tool = SearchTool::new().with_root(dir.path());
316 let args = serde_json::json!({
317 "query": "fn\\s+test_\\w+",
318 "regex": true,
319 "root": dir.path().to_string_lossy(),
320 "file_extensions": ["rs"]
321 });
322
323 let result = tool.execute(args).await.unwrap();
324 assert!(result.success);
325 let payload: serde_json::Value = serde_json::from_str(&result.output).unwrap();
326 assert_eq!(payload["results"].as_array().unwrap().len(), 1);
327 }
328}