Skip to main content

sparrow/tools/
search_and_web.rs

1use async_trait::async_trait;
2use serde_json::json;
3
4use super::{Tool, ToolCtx, ToolResult, resolve_workspace_path};
5use crate::event::{Block, RiskLevel};
6
7// ─── Ripgrep search ─────────────────────────────────────────────────────────────
8
9pub struct Search;
10
11#[async_trait]
12impl Tool for Search {
13    fn name(&self) -> &str {
14        "search"
15    }
16    fn description(&self) -> &str {
17        "Search code in the workspace using ripgrep (regex patterns)"
18    }
19    fn schema(&self) -> serde_json::Value {
20        json!({
21            "type": "object",
22            "properties": {
23                "pattern": { "type": "string", "description": "Regex pattern to search for" },
24                "path": { "type": "string", "description": "Directory or file to search (default: workspace root)" },
25                "include": { "type": "string", "description": "File pattern filter (e.g. '*.rs')" },
26                "max_results": { "type": "integer", "description": "Max results (default: 50)" }
27            },
28            "required": ["pattern"]
29        })
30    }
31    fn risk(&self) -> RiskLevel {
32        RiskLevel::ReadOnly
33    }
34    async fn call(&self, args: serde_json::Value, ctx: &ToolCtx) -> anyhow::Result<ToolResult> {
35        let pattern = args["pattern"].as_str().unwrap_or("");
36        let path = args["path"].as_str().unwrap_or(".");
37        let include = args["include"].as_str();
38        let max_results = args["max_results"].as_u64().unwrap_or(50) as usize;
39
40        let search_path = resolve_workspace_path(&ctx.workspace_root, path)?;
41
42        let mut cmd = std::process::Command::new("rg");
43        cmd.arg("--line-number")
44            .arg("--no-heading")
45            .arg("--color=never")
46            .arg("-M")
47            .arg(max_results.to_string())
48            .arg(pattern)
49            .arg(&search_path);
50
51        if let Some(inc) = include {
52            cmd.arg("--glob").arg(inc);
53        }
54
55        match cmd.output() {
56            Ok(output) => {
57                let stdout = String::from_utf8_lossy(&output.stdout).to_string();
58                if stdout.is_empty() {
59                    Ok(ToolResult::text("No matches found."))
60                } else {
61                    Ok(ToolResult::text(stdout))
62                }
63            }
64            Err(e) => {
65                // Fallback: basic string search
66                if e.kind() == std::io::ErrorKind::NotFound {
67                    let mut results = Vec::new();
68                    basic_grep(&search_path, pattern, include, &mut results, 0, max_results)?;
69                    if results.is_empty() {
70                        Ok(ToolResult::text(
71                            "No matches found (rg not installed, used basic search).",
72                        ))
73                    } else {
74                        Ok(ToolResult::text(results.join("\n")))
75                    }
76                } else {
77                    Err(e.into())
78                }
79            }
80        }
81    }
82}
83
84fn basic_grep(
85    dir: &std::path::Path,
86    pattern: &str,
87    include: Option<&str>,
88    results: &mut Vec<String>,
89    depth: usize,
90    max: usize,
91) -> std::io::Result<()> {
92    if depth > 10 || results.len() >= max {
93        return Ok(());
94    }
95    if dir.is_dir() {
96        let entries = std::fs::read_dir(dir)?;
97        for entry in entries.flatten() {
98            let path = entry.path();
99            let name = path.file_name().unwrap_or_default().to_string_lossy();
100            if name.starts_with('.') || name == "target" || name == "node_modules" {
101                continue;
102            }
103            if path.is_dir() {
104                basic_grep(&path, pattern, include, results, depth + 1, max)?;
105            } else if path.is_file() {
106                if let Some(inc) = include {
107                    if !name.contains(inc) && !inc.contains('*') {
108                        continue;
109                    }
110                }
111                if results.len() >= max {
112                    break;
113                }
114                if let Ok(content) = std::fs::read_to_string(&path) {
115                    for (i, line) in content.lines().enumerate() {
116                        if results.len() >= max {
117                            break;
118                        }
119                        if line.to_lowercase().contains(&pattern.to_lowercase()) {
120                            let rel = path.strip_prefix(dir).unwrap_or(&path).display();
121                            results.push(format!("{}:{}: {}", rel, i + 1, line.trim()));
122                        }
123                    }
124                }
125            }
126        }
127    }
128    Ok(())
129}
130
131// ─── Web search ─────────────────────────────────────────────────────────────────
132
133pub struct WebSearch;
134
135#[async_trait]
136impl Tool for WebSearch {
137    fn name(&self) -> &str {
138        "web_search"
139    }
140    fn description(&self) -> &str {
141        "Search the web for information"
142    }
143    fn schema(&self) -> serde_json::Value {
144        json!({
145            "type": "object",
146            "properties": {
147                "query": { "type": "string", "description": "Search query" },
148                "num_results": { "type": "integer", "description": "Number of results (default: 5)" }
149            },
150            "required": ["query"]
151        })
152    }
153    fn risk(&self) -> RiskLevel {
154        RiskLevel::Network
155    }
156    async fn call(&self, args: serde_json::Value, _ctx: &ToolCtx) -> anyhow::Result<ToolResult> {
157        let query = args["query"].as_str().unwrap_or("");
158        let num = args["num_results"].as_u64().unwrap_or(5);
159
160        let client = reqwest::Client::builder()
161            .user_agent("sparrow/0.1")
162            .build()?;
163
164        // Use DuckDuckGo Lite as a free search backend
165        let resp = client
166            .get("https://lite.duckduckgo.com/lite/")
167            .query(&[("q", query)])
168            .send()
169            .await?;
170
171        let html = resp.text().await?;
172
173        // Simple extraction of result snippets
174        let mut results = Vec::new();
175        for line in html.lines() {
176            let trimmed = line.trim();
177            if trimmed.starts_with("<a") && trimmed.contains("class=\"result-link\"") {
178                if let Some(url) = extract_href(trimmed) {
179                    if results.len() < num as usize {
180                        results.push(format!("🔗 {}", url));
181                    }
182                }
183            }
184            if trimmed.starts_with("<td") && trimmed.contains("class=\"result-snippet\"") {
185                let snippet = strip_html(trimmed);
186                if !snippet.is_empty() && results.len() <= num as usize {
187                    results.push(format!("   {}", snippet));
188                }
189            }
190        }
191
192        if results.is_empty() {
193            Ok(ToolResult::text(format!(
194                "No web results for: {}. Try a more specific query.",
195                query
196            )))
197        } else {
198            Ok(ToolResult::text(results.join("\n")))
199        }
200    }
201}
202
203fn extract_href(line: &str) -> Option<String> {
204    let start = line.find("href=\"")? + 6;
205    let end = line[start..].find('"')?;
206    let mut url = line[start..start + end].to_string();
207    // Clean up DuckDuckGo redirect URLs
208    if url.starts_with("//") {
209        url = format!("https:{}", url);
210    }
211    if url.contains("duckduckgo.com/l/?uddg=") {
212        if let Some(real) = url.split("uddg=").nth(1) {
213            if let Ok(decoded) = urlencoding(&real) {
214                url = decoded;
215            }
216        }
217    }
218    Some(url)
219}
220
221/// Reject non-http(s) schemes and any URL whose host resolves to a private,
222/// loopback, link-local, multicast, or unspecified address (SSRF defence).
223/// Also rejects bare IPs in those ranges and the AWS/GCP metadata endpoints.
224pub(crate) fn validate_public_url(url: &str) -> Result<(), &'static str> {
225    let parsed = url::Url::parse(url).map_err(|_| "invalid URL")?;
226    match parsed.scheme() {
227        "http" | "https" => {}
228        _ => return Err("only http(s) is allowed"),
229    }
230    let host = parsed.host_str().ok_or("missing host")?;
231
232    // Block obvious well-known metadata / loopback hostnames.
233    let lc = host.to_ascii_lowercase();
234    if matches!(
235        lc.as_str(),
236        "localhost" | "ip6-localhost" | "ip6-loopback" | "metadata.google.internal" | "metadata"
237    ) || lc.ends_with(".localhost")
238        || lc.ends_with(".local")
239        || lc.ends_with(".internal")
240    {
241        return Err("host points to local/internal network");
242    }
243
244    // If the host parses as an IP literal, check the ranges directly.
245    if let Ok(ip) = host.parse::<std::net::IpAddr>() {
246        if is_blocked_ip(&ip) {
247            return Err("IP belongs to a private/loopback/link-local range");
248        }
249        return Ok(());
250    }
251
252    // Hostname: best-effort DNS check. We can't await here without making the
253    // fn async, so we resolve synchronously via the std API. A single resolution
254    // is cheap and prevents the most common SSRF payloads (`127.0.0.1`-aliasing
255    // domains, hosts file tricks, etc.).
256    let port = parsed.port_or_known_default().unwrap_or(0);
257    if let Ok(addrs) = std::net::ToSocketAddrs::to_socket_addrs(&(host, port)) {
258        for sa in addrs {
259            if is_blocked_ip(&sa.ip()) {
260                return Err("hostname resolves to a private/loopback IP");
261            }
262        }
263    }
264    Ok(())
265}
266
267fn is_blocked_ip(ip: &std::net::IpAddr) -> bool {
268    match ip {
269        std::net::IpAddr::V4(v4) => {
270            v4.is_loopback()
271                || v4.is_private()
272                || v4.is_link_local()
273                || v4.is_broadcast()
274                || v4.is_multicast()
275                || v4.is_unspecified()
276                || v4.octets() == [169, 254, 169, 254] // AWS/GCP/Azure metadata
277                // Carrier-grade NAT 100.64.0.0/10
278                || (v4.octets()[0] == 100 && (v4.octets()[1] & 0xC0) == 0x40)
279        }
280        std::net::IpAddr::V6(v6) => {
281            v6.is_loopback()
282                || v6.is_unspecified()
283                || v6.is_multicast()
284                // fc00::/7 (unique local), fe80::/10 (link-local)
285                || (v6.segments()[0] & 0xfe00) == 0xfc00
286                || (v6.segments()[0] & 0xffc0) == 0xfe80
287                // IPv4-mapped: re-check the embedded v4
288                || v6.to_ipv4_mapped().map(|m| is_blocked_ip(&std::net::IpAddr::V4(m))).unwrap_or(false)
289        }
290    }
291}
292
293fn strip_html(s: &str) -> String {
294    let mut result = String::new();
295    let mut in_tag = false;
296    for c in s.chars() {
297        if c == '<' {
298            in_tag = true;
299        } else if c == '>' {
300            in_tag = false;
301        } else if !in_tag {
302            result.push(c);
303        }
304    }
305    result.trim().to_string()
306}
307
308fn urlencoding(s: &str) -> Result<String, ()> {
309    let mut result = String::new();
310    let chars: Vec<char> = s.chars().collect();
311    let mut i = 0;
312    while i < chars.len() {
313        if chars[i] == '%' && i + 2 < chars.len() {
314            let hex = &s[i + 1..i + 3];
315            if let Ok(byte) = u8::from_str_radix(hex, 16) {
316                result.push(byte as char);
317                i += 3;
318                continue;
319            }
320        }
321        if chars[i] == '+' {
322            result.push(' ');
323        } else {
324            result.push(chars[i]);
325        }
326        i += 1;
327    }
328    Ok(result)
329}
330
331// ─── Web fetch ──────────────────────────────────────────────────────────────────
332
333pub struct WebFetch;
334
335#[async_trait]
336impl Tool for WebFetch {
337    fn name(&self) -> &str {
338        "web_fetch"
339    }
340    fn description(&self) -> &str {
341        "Fetch and read content from a URL"
342    }
343    fn schema(&self) -> serde_json::Value {
344        json!({
345            "type": "object",
346            "properties": {
347                "url": { "type": "string", "description": "URL to fetch" },
348                "format": { "type": "string", "enum": ["text", "markdown", "html"], "description": "Output format (default: text)" }
349            },
350            "required": ["url"]
351        })
352    }
353    fn risk(&self) -> RiskLevel {
354        RiskLevel::Network
355    }
356    async fn call(&self, args: serde_json::Value, _ctx: &ToolCtx) -> anyhow::Result<ToolResult> {
357        let url = args["url"].as_str().unwrap_or("");
358        let format = args["format"].as_str().unwrap_or("text");
359
360        if let Err(why) = validate_public_url(url) {
361            return Ok(ToolResult::error(format!("Refused URL ({}): {}", why, url)));
362        }
363
364        let client = reqwest::Client::builder()
365            .user_agent("sparrow/0.1")
366            .timeout(std::time::Duration::from_secs(30))
367            // Belt-and-suspenders: re-validate after redirects to block private-IP redirect attacks.
368            .redirect(reqwest::redirect::Policy::custom(|attempt| {
369                if validate_public_url(attempt.url().as_str()).is_err() {
370                    attempt.stop()
371                } else if attempt.previous().len() >= 5 {
372                    attempt.stop()
373                } else {
374                    attempt.follow()
375                }
376            }))
377            .build()?;
378
379        let resp = client.get(url).send().await?;
380        let status = resp.status();
381        let content_type = resp
382            .headers()
383            .get("content-type")
384            .and_then(|v| v.to_str().ok())
385            .unwrap_or("unknown")
386            .to_string();
387
388        let bytes = resp.bytes().await?;
389
390        let text = match format {
391            "html" => String::from_utf8_lossy(&bytes).to_string(),
392            _ => {
393                // Simple HTML to text conversion
394                let raw = String::from_utf8_lossy(&bytes).to_string();
395                let stripped = strip_html(&raw);
396                // Truncate very long content
397                if stripped.len() > 50_000 {
398                    format!(
399                        "{}...\n\n[truncated: {} bytes total]",
400                        &stripped[..50_000],
401                        stripped.len()
402                    )
403                } else {
404                    stripped
405                }
406            }
407        };
408
409        Ok(ToolResult::ok(vec![Block::Text(format!(
410            "URL: {}\nStatus: {}\nType: {}\n\n{}",
411            url, status, content_type, text
412        ))]))
413    }
414}