sparrow/tools/
search_and_web.rs1use async_trait::async_trait;
2use serde_json::json;
3
4use super::{Tool, ToolCtx, ToolResult, resolve_workspace_path};
5use crate::event::{Block, RiskLevel};
6
7pub struct Search;
10
11#[async_trait]
12impl Tool for Search {
13 fn name(&self) -> &str {
14 "search"
15 }
16 fn description(&self) -> &str {
17 "Search code in the workspace using ripgrep (regex patterns)"
18 }
19 fn schema(&self) -> serde_json::Value {
20 json!({
21 "type": "object",
22 "properties": {
23 "pattern": { "type": "string", "description": "Regex pattern to search for" },
24 "path": { "type": "string", "description": "Directory or file to search (default: workspace root)" },
25 "include": { "type": "string", "description": "File pattern filter (e.g. '*.rs')" },
26 "max_results": { "type": "integer", "description": "Max results (default: 50)" }
27 },
28 "required": ["pattern"]
29 })
30 }
31 fn risk(&self) -> RiskLevel {
32 RiskLevel::ReadOnly
33 }
34 async fn call(&self, args: serde_json::Value, ctx: &ToolCtx) -> anyhow::Result<ToolResult> {
35 let pattern = args["pattern"].as_str().unwrap_or("");
36 let path = args["path"].as_str().unwrap_or(".");
37 let include = args["include"].as_str();
38 let max_results = args["max_results"].as_u64().unwrap_or(50) as usize;
39
40 let search_path = resolve_workspace_path(&ctx.workspace_root, path)?;
41
42 let mut cmd = std::process::Command::new("rg");
43 cmd.arg("--line-number")
44 .arg("--no-heading")
45 .arg("--color=never")
46 .arg("-M")
47 .arg(max_results.to_string())
48 .arg(pattern)
49 .arg(&search_path);
50
51 if let Some(inc) = include {
52 cmd.arg("--glob").arg(inc);
53 }
54
55 match cmd.output() {
56 Ok(output) => {
57 let stdout = String::from_utf8_lossy(&output.stdout).to_string();
58 if stdout.is_empty() {
59 Ok(ToolResult::text("No matches found."))
60 } else {
61 Ok(ToolResult::text(stdout))
62 }
63 }
64 Err(e) => {
65 if e.kind() == std::io::ErrorKind::NotFound {
67 let mut results = Vec::new();
68 basic_grep(&search_path, pattern, include, &mut results, 0, max_results)?;
69 if results.is_empty() {
70 Ok(ToolResult::text(
71 "No matches found (rg not installed, used basic search).",
72 ))
73 } else {
74 Ok(ToolResult::text(results.join("\n")))
75 }
76 } else {
77 Err(e.into())
78 }
79 }
80 }
81 }
82}
83
84fn basic_grep(
85 dir: &std::path::Path,
86 pattern: &str,
87 include: Option<&str>,
88 results: &mut Vec<String>,
89 depth: usize,
90 max: usize,
91) -> std::io::Result<()> {
92 if depth > 10 || results.len() >= max {
93 return Ok(());
94 }
95 if dir.is_dir() {
96 let entries = std::fs::read_dir(dir)?;
97 for entry in entries.flatten() {
98 let path = entry.path();
99 let name = path.file_name().unwrap_or_default().to_string_lossy();
100 if name.starts_with('.') || name == "target" || name == "node_modules" {
101 continue;
102 }
103 if path.is_dir() {
104 basic_grep(&path, pattern, include, results, depth + 1, max)?;
105 } else if path.is_file() {
106 if let Some(inc) = include {
107 if !name.contains(inc) && !inc.contains('*') {
108 continue;
109 }
110 }
111 if results.len() >= max {
112 break;
113 }
114 if let Ok(content) = std::fs::read_to_string(&path) {
115 for (i, line) in content.lines().enumerate() {
116 if results.len() >= max {
117 break;
118 }
119 if line.to_lowercase().contains(&pattern.to_lowercase()) {
120 let rel = path.strip_prefix(dir).unwrap_or(&path).display();
121 results.push(format!("{}:{}: {}", rel, i + 1, line.trim()));
122 }
123 }
124 }
125 }
126 }
127 }
128 Ok(())
129}
130
131pub struct WebSearch;
134
135#[async_trait]
136impl Tool for WebSearch {
137 fn name(&self) -> &str {
138 "web_search"
139 }
140 fn description(&self) -> &str {
141 "Search the web for information"
142 }
143 fn schema(&self) -> serde_json::Value {
144 json!({
145 "type": "object",
146 "properties": {
147 "query": { "type": "string", "description": "Search query" },
148 "num_results": { "type": "integer", "description": "Number of results (default: 5)" }
149 },
150 "required": ["query"]
151 })
152 }
153 fn risk(&self) -> RiskLevel {
154 RiskLevel::Network
155 }
156 async fn call(&self, args: serde_json::Value, _ctx: &ToolCtx) -> anyhow::Result<ToolResult> {
157 let query = args["query"].as_str().unwrap_or("");
158 let num = args["num_results"].as_u64().unwrap_or(5);
159
160 let client = reqwest::Client::builder()
161 .user_agent("sparrow/0.1")
162 .build()?;
163
164 let resp = client
166 .get("https://lite.duckduckgo.com/lite/")
167 .query(&[("q", query)])
168 .send()
169 .await?;
170
171 let html = resp.text().await?;
172
173 let mut results = Vec::new();
175 for line in html.lines() {
176 let trimmed = line.trim();
177 if trimmed.starts_with("<a") && trimmed.contains("class=\"result-link\"") {
178 if let Some(url) = extract_href(trimmed) {
179 if results.len() < num as usize {
180 results.push(format!("🔗 {}", url));
181 }
182 }
183 }
184 if trimmed.starts_with("<td") && trimmed.contains("class=\"result-snippet\"") {
185 let snippet = strip_html(trimmed);
186 if !snippet.is_empty() && results.len() <= num as usize {
187 results.push(format!(" {}", snippet));
188 }
189 }
190 }
191
192 if results.is_empty() {
193 Ok(ToolResult::text(format!(
194 "No web results for: {}. Try a more specific query.",
195 query
196 )))
197 } else {
198 Ok(ToolResult::text(results.join("\n")))
199 }
200 }
201}
202
203fn extract_href(line: &str) -> Option<String> {
204 let start = line.find("href=\"")? + 6;
205 let end = line[start..].find('"')?;
206 let mut url = line[start..start + end].to_string();
207 if url.starts_with("//") {
209 url = format!("https:{}", url);
210 }
211 if url.contains("duckduckgo.com/l/?uddg=") {
212 if let Some(real) = url.split("uddg=").nth(1) {
213 if let Ok(decoded) = urlencoding(&real) {
214 url = decoded;
215 }
216 }
217 }
218 Some(url)
219}
220
221pub(crate) fn validate_public_url(url: &str) -> Result<(), &'static str> {
225 let parsed = url::Url::parse(url).map_err(|_| "invalid URL")?;
226 match parsed.scheme() {
227 "http" | "https" => {}
228 _ => return Err("only http(s) is allowed"),
229 }
230 let host = parsed.host_str().ok_or("missing host")?;
231
232 let lc = host.to_ascii_lowercase();
234 if matches!(
235 lc.as_str(),
236 "localhost" | "ip6-localhost" | "ip6-loopback" | "metadata.google.internal" | "metadata"
237 ) || lc.ends_with(".localhost")
238 || lc.ends_with(".local")
239 || lc.ends_with(".internal")
240 {
241 return Err("host points to local/internal network");
242 }
243
244 if let Ok(ip) = host.parse::<std::net::IpAddr>() {
246 if is_blocked_ip(&ip) {
247 return Err("IP belongs to a private/loopback/link-local range");
248 }
249 return Ok(());
250 }
251
252 let port = parsed.port_or_known_default().unwrap_or(0);
257 if let Ok(addrs) = std::net::ToSocketAddrs::to_socket_addrs(&(host, port)) {
258 for sa in addrs {
259 if is_blocked_ip(&sa.ip()) {
260 return Err("hostname resolves to a private/loopback IP");
261 }
262 }
263 }
264 Ok(())
265}
266
267fn is_blocked_ip(ip: &std::net::IpAddr) -> bool {
268 match ip {
269 std::net::IpAddr::V4(v4) => {
270 v4.is_loopback()
271 || v4.is_private()
272 || v4.is_link_local()
273 || v4.is_broadcast()
274 || v4.is_multicast()
275 || v4.is_unspecified()
276 || v4.octets() == [169, 254, 169, 254] || (v4.octets()[0] == 100 && (v4.octets()[1] & 0xC0) == 0x40)
279 }
280 std::net::IpAddr::V6(v6) => {
281 v6.is_loopback()
282 || v6.is_unspecified()
283 || v6.is_multicast()
284 || (v6.segments()[0] & 0xfe00) == 0xfc00
286 || (v6.segments()[0] & 0xffc0) == 0xfe80
287 || v6.to_ipv4_mapped().map(|m| is_blocked_ip(&std::net::IpAddr::V4(m))).unwrap_or(false)
289 }
290 }
291}
292
293fn strip_html(s: &str) -> String {
294 let mut result = String::new();
295 let mut in_tag = false;
296 for c in s.chars() {
297 if c == '<' {
298 in_tag = true;
299 } else if c == '>' {
300 in_tag = false;
301 } else if !in_tag {
302 result.push(c);
303 }
304 }
305 result.trim().to_string()
306}
307
308fn urlencoding(s: &str) -> Result<String, ()> {
309 let mut result = String::new();
310 let chars: Vec<char> = s.chars().collect();
311 let mut i = 0;
312 while i < chars.len() {
313 if chars[i] == '%' && i + 2 < chars.len() {
314 let hex = &s[i + 1..i + 3];
315 if let Ok(byte) = u8::from_str_radix(hex, 16) {
316 result.push(byte as char);
317 i += 3;
318 continue;
319 }
320 }
321 if chars[i] == '+' {
322 result.push(' ');
323 } else {
324 result.push(chars[i]);
325 }
326 i += 1;
327 }
328 Ok(result)
329}
330
331pub struct WebFetch;
334
335#[async_trait]
336impl Tool for WebFetch {
337 fn name(&self) -> &str {
338 "web_fetch"
339 }
340 fn description(&self) -> &str {
341 "Fetch and read content from a URL"
342 }
343 fn schema(&self) -> serde_json::Value {
344 json!({
345 "type": "object",
346 "properties": {
347 "url": { "type": "string", "description": "URL to fetch" },
348 "format": { "type": "string", "enum": ["text", "markdown", "html"], "description": "Output format (default: text)" }
349 },
350 "required": ["url"]
351 })
352 }
353 fn risk(&self) -> RiskLevel {
354 RiskLevel::Network
355 }
356 async fn call(&self, args: serde_json::Value, _ctx: &ToolCtx) -> anyhow::Result<ToolResult> {
357 let url = args["url"].as_str().unwrap_or("");
358 let format = args["format"].as_str().unwrap_or("text");
359
360 if let Err(why) = validate_public_url(url) {
361 return Ok(ToolResult::error(format!("Refused URL ({}): {}", why, url)));
362 }
363
364 let client = reqwest::Client::builder()
365 .user_agent("sparrow/0.1")
366 .timeout(std::time::Duration::from_secs(30))
367 .redirect(reqwest::redirect::Policy::custom(|attempt| {
369 if validate_public_url(attempt.url().as_str()).is_err() {
370 attempt.stop()
371 } else if attempt.previous().len() >= 5 {
372 attempt.stop()
373 } else {
374 attempt.follow()
375 }
376 }))
377 .build()?;
378
379 let resp = client.get(url).send().await?;
380 let status = resp.status();
381 let content_type = resp
382 .headers()
383 .get("content-type")
384 .and_then(|v| v.to_str().ok())
385 .unwrap_or("unknown")
386 .to_string();
387
388 let bytes = resp.bytes().await?;
389
390 let text = match format {
391 "html" => String::from_utf8_lossy(&bytes).to_string(),
392 _ => {
393 let raw = String::from_utf8_lossy(&bytes).to_string();
395 let stripped = strip_html(&raw);
396 if stripped.len() > 50_000 {
398 format!(
399 "{}...\n\n[truncated: {} bytes total]",
400 &stripped[..50_000],
401 stripped.len()
402 )
403 } else {
404 stripped
405 }
406 }
407 };
408
409 Ok(ToolResult::ok(vec![Block::Text(format!(
410 "URL: {}\nStatus: {}\nType: {}\n\n{}",
411 url, status, content_type, text
412 ))]))
413 }
414}