Skip to main content

vtcode_core/tools/web_fetch/
mod.rs

1//! WebFetch tool for fetching and analyzing web content using AI
2//!
3//! Supports both restricted (blocklist) and whitelist (allowlist) modes
4//! with dynamic configuration loading from vtcode.toml
5
6use super::traits::Tool;
7use crate::config::constants::tools;
8use crate::tools::error_helpers::with_path_context;
9use anyhow::{Context, Result, anyhow, bail};
10use async_trait::async_trait;
11use hashbrown::HashSet;
12use reqwest::header::{ACCEPT, HeaderMap, HeaderValue, USER_AGENT};
13use serde::Deserialize;
14use serde_json::{Value, json};
15use std::fs;
16use std::net::IpAddr;
17use std::path::Path;
18use url::Url;
19
20pub mod domains;
21pub use domains::{BUILTIN_BLOCKED_DOMAINS, BUILTIN_BLOCKED_PATTERNS, MALICIOUS_PATTERNS};
22
23const DEFAULT_TIMEOUT_SECS: u64 = 30;
24const MAX_CONTENT_SIZE: usize = 500_000; // 500KB max content size
25
26#[derive(Debug, Deserialize)]
27struct WebFetchArgs {
28    url: String,
29    prompt: String,
30    #[serde(default)]
31    max_bytes: Option<usize>,
32    #[serde(default)]
33    timeout_secs: Option<u64>,
34}
35
36/// WebFetch tool that fetches URL content and processes it with AI
37#[derive(Clone)]
38pub struct WebFetchTool {
39    /// Security mode: "restricted" (blocklist) or "whitelist" (allowlist)
40    pub mode: String,
41    /// Additional blocked domains (merged with builtin)
42    pub blocked_domains: HashSet<String>,
43    /// Additional blocked patterns (merged with builtin)
44    pub blocked_patterns: Vec<String>,
45    /// Allowed domains (for exemptions in restricted mode or primary list in whitelist mode)
46    pub allowed_domains: HashSet<String>,
47    /// Strict HTTPS-only mode
48    pub strict_https_only: bool,
49}
50
51impl WebFetchTool {
52    pub fn new() -> Self {
53        Self {
54            mode: "restricted".to_string(),
55            blocked_domains: HashSet::new(),
56            blocked_patterns: Vec::new(),
57            allowed_domains: HashSet::new(),
58            strict_https_only: true,
59        }
60    }
61
62    /// Create a WebFetchTool with custom configuration
63    pub fn with_config(
64        mode: String,
65        blocked_domains: Vec<String>,
66        blocked_patterns: Vec<String>,
67        allowed_domains: Vec<String>,
68        strict_https_only: bool,
69    ) -> Self {
70        Self {
71            mode,
72            blocked_domains: blocked_domains.into_iter().collect(),
73            blocked_patterns,
74            allowed_domains: allowed_domains.into_iter().collect(),
75            strict_https_only,
76        }
77    }
78
79    async fn fetch_url_content(
80        &self,
81        url: &str,
82        max_bytes: usize,
83        timeout_secs: u64,
84    ) -> Result<String> {
85        // Validate URL
86        self.validate_url(url)?;
87
88        let default_headers = Self::default_headers();
89
90        let client = reqwest::Client::builder()
91            .default_headers(default_headers)
92            .timeout(std::time::Duration::from_secs(timeout_secs))
93            .build()?;
94
95        let response = client.get(url).send().await?;
96
97        if !response.status().is_success() {
98            return Err(anyhow!(
99                "HTTP request failed with status: {}",
100                response.status()
101            ));
102        }
103
104        let content_type = response
105            .headers()
106            .get("content-type")
107            .and_then(|h| h.to_str().ok())
108            .unwrap_or("")
109            .to_string();
110
111        // Validate content type
112        self.validate_content_type(&content_type)?;
113
114        // Limit response body to max_bytes
115        let bytes = response.bytes().await?;
116        if bytes.len() > max_bytes {
117            return Err(anyhow!(
118                "Response body size {} bytes exceeds maximum allowed size of {} bytes",
119                bytes.len(),
120                max_bytes
121            ));
122        }
123
124        String::from_utf8(bytes.to_vec()).context("Response body is not valid UTF-8")
125    }
126
127    fn validate_url(&self, url: &str) -> Result<()> {
128        // HTTPS enforcement (can be disabled only for testing)
129        if self.strict_https_only && !url.starts_with("https://") {
130            return Err(anyhow!("Only HTTPS URLs are allowed for security"));
131        }
132
133        // Parse the URL to extract the real host, which correctly separates
134        // userinfo credentials (http://evil@127.0.0.1/) from the host.
135        let domain = extract_domain(url)
136            .map_err(|e| anyhow!("Failed to parse URL for security validation: {e}"))?;
137
138        // Reject private, loopback, link-local, and reserved IPs.
139        if is_private_host(&domain) {
140            return Err(anyhow!("Access to local/private networks is blocked"));
141        }
142
143        // Block .local and .internal TLDs (mDNS / split-DNS).
144        let domain_lower = domain.to_ascii_lowercase();
145        if domain_lower.ends_with(".local") || domain_lower.ends_with(".internal") {
146            return Err(anyhow!("Access to local/private networks is blocked"));
147        }
148
149        let url_lower = url.to_lowercase();
150
151        // Apply security policy based on mode
152        match self.mode.as_str() {
153            "whitelist" => self.validate_whitelist_mode(&url_lower)?,
154            "restricted" => self.validate_restricted_mode(&url_lower)?,
155            _ => return Err(anyhow!("Unknown web_fetch security mode: {}", self.mode)),
156        }
157
158        Ok(())
159    }
160
161    fn validate_whitelist_mode(&self, url: &str) -> Result<()> {
162        // In whitelist mode, only explicitly allowed domains are permitted
163        let domain = extract_domain(url)?;
164
165        if self.allowed_domains.is_empty() {
166            return Err(anyhow!(
167                "Whitelist mode enabled but no domains are whitelisted. Configure allowed_domains in web_fetch settings."
168            ));
169        }
170
171        // Check if domain matches any whitelisted domain or pattern
172        for allowed in &self.allowed_domains {
173            if domain_matches_allowed(&domain, allowed) {
174                return Ok(());
175            }
176        }
177
178        Err(anyhow!(
179            "Domain '{}' is not in the whitelist. Only explicitly allowed domains are permitted in whitelist mode.",
180            domain
181        ))
182    }
183
184    fn validate_restricted_mode(&self, url: &str) -> Result<()> {
185        // In restricted mode, use a blocklist of known dangerous/sensitive domains
186        let url_lower = url.to_lowercase();
187
188        // Check against allowed exemptions first (exemptions override blocklist)
189        let domain = extract_domain(url)?;
190        for allowed in &self.allowed_domains {
191            if domain_matches_allowed(&domain, allowed) {
192                return Ok(());
193            }
194        }
195
196        // Check for malicious and sensitive URL patterns
197        self.validate_url_safety(&url_lower)?;
198
199        Ok(())
200    }
201
202    fn validate_url_safety(&self, url: &str) -> Result<()> {
203        // Combine built-in and custom blocked domains
204        let mut all_blocked_domains = BUILTIN_BLOCKED_DOMAINS.to_vec();
205        all_blocked_domains.extend(self.blocked_domains.iter().map(|s| s.as_str()));
206
207        // Combine built-in and custom blocked patterns
208        let mut all_blocked_patterns = BUILTIN_BLOCKED_PATTERNS.to_vec();
209        all_blocked_patterns.extend(self.blocked_patterns.iter().map(|s| s.as_str()));
210
211        // Check blocked domains
212        for domain in &all_blocked_domains {
213            if url.contains(domain) {
214                return Err(anyhow!(
215                    "Access to sensitive domain '{}' is blocked for privacy and security reasons",
216                    domain
217                ));
218            }
219        }
220
221        // Check for sensitive patterns in URL
222        for pattern in &all_blocked_patterns {
223            if url.contains(pattern) {
224                return Err(anyhow!(
225                    "URL contains sensitive pattern '{}'. Fetching URLs with credentials or sensitive data is blocked",
226                    pattern
227                ));
228            }
229        }
230
231        // Check for common malware/phishing indicators
232        self.check_malicious_indicators(url)?;
233
234        Ok(())
235    }
236
237    fn check_malicious_indicators(&self, url: &str) -> Result<()> {
238        for pattern in MALICIOUS_PATTERNS {
239            if url.contains(pattern) {
240                return Err(anyhow!(
241                    "URL contains potentially malicious pattern. Access blocked for safety"
242                ));
243            }
244        }
245
246        Ok(())
247    }
248
249    /// Expand ~ to home directory
250    fn expand_home_path(path: &str) -> String {
251        if path.starts_with("~/")
252            && let Ok(home) = std::env::var("HOME")
253        {
254            return path.replace("~/", &format!("{}/", home));
255        }
256        path.to_string()
257    }
258
259    /// Load blocklist from external JSON file
260    #[expect(dead_code)]
261    async fn load_dynamic_blocklist(&self, path: &str) -> Result<(Vec<String>, Vec<String>)> {
262        let expanded_path = Self::expand_home_path(path);
263        if !Path::new(&expanded_path).exists() {
264            return Ok((Vec::new(), Vec::new()));
265        }
266
267        let content = with_path_context(
268            fs::read_to_string(&expanded_path),
269            "read blocklist from",
270            path,
271        )?;
272
273        #[derive(Deserialize)]
274        struct BlocklistFile {
275            blocked_domains: Option<Vec<String>>,
276            blocked_patterns: Option<Vec<String>>,
277        }
278
279        let blocklist: BlocklistFile = with_path_context(
280            serde_json::from_str(&content),
281            "parse blocklist JSON from",
282            path,
283        )?;
284
285        Ok((
286            blocklist.blocked_domains.unwrap_or_default(),
287            blocklist.blocked_patterns.unwrap_or_default(),
288        ))
289    }
290
291    /// Load whitelist from external JSON file
292    #[expect(dead_code)]
293    async fn load_dynamic_whitelist(&self, path: &str) -> Result<Vec<String>> {
294        let expanded_path = Self::expand_home_path(path);
295        if !Path::new(&expanded_path).exists() {
296            return Ok(Vec::new());
297        }
298
299        let content = with_path_context(
300            fs::read_to_string(&expanded_path),
301            "read whitelist from",
302            path,
303        )?;
304
305        #[derive(Deserialize)]
306        struct WhitelistFile {
307            allowed_domains: Option<Vec<String>>,
308        }
309
310        let whitelist: WhitelistFile = with_path_context(
311            serde_json::from_str(&content),
312            "parse whitelist JSON from",
313            path,
314        )?;
315
316        Ok(whitelist.allowed_domains.unwrap_or_default())
317    }
318
319    fn validate_content_type(&self, content_type: &str) -> Result<()> {
320        if content_type.is_empty() {
321            return Ok(());
322        }
323
324        let allowed_types = [
325            "text/html",
326            "text/plain",
327            "text/markdown",
328            "application/json",
329            "application/xml",
330            "text/xml",
331            "application/javascript",
332            "text/css",
333            "text/javascript",
334            "application/xhtml+xml",
335        ];
336
337        let content_type_lower = content_type.to_lowercase();
338        if allowed_types
339            .iter()
340            .any(|&t| content_type_lower.contains(t))
341        {
342            Ok(())
343        } else {
344            Err(anyhow!(
345                "Content type '{}' is not supported. Only text-based content types are allowed.",
346                content_type
347            ))
348        }
349    }
350
351    async fn run(&self, raw_args: Value) -> Result<Value> {
352        let args: WebFetchArgs = serde_json::from_value(raw_args)
353            .context("Invalid arguments for web_fetch tool. Provide 'url' and 'prompt'.")?;
354
355        let max_bytes = args.max_bytes.unwrap_or(MAX_CONTENT_SIZE);
356        let timeout_secs = args.timeout_secs.unwrap_or(DEFAULT_TIMEOUT_SECS);
357
358        // Fetch the URL content with detailed error handling
359        let content = match self
360            .fetch_url_content(&args.url, max_bytes, timeout_secs)
361            .await
362        {
363            Ok(content) => content,
364            Err(e) => {
365                // Structured, readable error. The agent can surface this or fall back to other tools.
366                return Ok(json!({
367                    "error": format!("web_fetch: failed to fetch URL '{}': {}", args.url, e),
368                    "url": args.url,
369                    "max_bytes": max_bytes,
370                    "timeout_secs": timeout_secs
371                }));
372            }
373        };
374
375        let content_length = content.len();
376
377        if content_length == 0 {
378            return Ok(json!({
379                "error": format!(
380                    "web_fetch: no content fetched from '{}'. The URL may be unreachable, returned empty content, or used an unsupported content-type.",
381                    args.url
382                ),
383                "url": args.url
384            }));
385        }
386
387        // Truncate preview for UI; keep full content available for reasoning.
388        let preview_limit = 8000;
389        let (preview, truncated, overflow_info) = if content_length > preview_limit {
390            let truncated_content =
391                vtcode_commons::formatting::truncate_byte_budget(&content, preview_limit, "...");
392            let overflow = format!("[+{} more characters]", content_length - preview_limit);
393            (truncated_content, true, Some(overflow))
394        } else {
395            (content.clone(), false, None)
396        };
397
398        // Canonical response shape:
399        // - `content`: full fetched body
400        // - `preview`: truncated snippet for display
401        // - `prompt`: what the user/model wants to know
402        // - `next_action_hint`: explicit instruction so the agent continues the loop correctly
403        let mut response = json!({
404            "url": args.url,
405            "prompt": args.prompt,
406            "content": content,
407            "preview": preview,
408            "content_length": content_length,
409            "truncated": truncated,
410            "next_action_hint": "Analyze `content` using `prompt` and answer the user in natural language based on the fetched page."
411        });
412
413        // Add overflow indicator if content was truncated
414        if let Some(overflow) = overflow_info {
415            response["overflow"] = json!(overflow);
416        }
417
418        Ok(response)
419    }
420}
421
422impl WebFetchTool {
423    /// Returns default headers used by the WebFetch client. This keeps Accept set to
424    /// prefer 'text/markdown' so documentation sites can provide token-efficient markdown
425    /// content as a preference.
426    fn default_headers() -> HeaderMap {
427        let mut headers = HeaderMap::new();
428        headers.insert(ACCEPT, HeaderValue::from_static("text/markdown, */*"));
429        headers.insert(
430            USER_AGENT,
431            HeaderValue::from_static("VT Code/1.0 (compatible; web-fetch tool)"),
432        );
433        headers
434    }
435}
436
437/// Helper function to extract domain from URL
438///
439/// Uses proper URL parsing to correctly handle:
440/// - User credentials in URLs (`http://user@host/`) — the host is properly
441///   separated from the userinfo, preventing SSRF bypass
442/// - Port numbers, paths, and query strings
443fn extract_domain(url: &str) -> Result<String> {
444    let parsed = Url::parse(url).with_context(|| format!("Failed to parse URL: {url}"))?;
445    let host = parsed
446        .host_str()
447        .ok_or_else(|| anyhow!("URL has no host: {url}"))?;
448    if host.is_empty() {
449        bail!("URL has empty host: {url}");
450    }
451    Ok(host.to_string())
452}
453
454/// Returns `true` when `host` is a private, loopback, or link-local IP address.
455fn is_private_host(host: &str) -> bool {
456    // Try IPv4 / IPv6 parsing first.
457    if let Ok(ip) = host.parse::<IpAddr>() {
458        return match ip {
459            IpAddr::V4(v4) => {
460                let octets = v4.octets();
461                // 127.0.0.0/8 — loopback (is_loopback only matches 127.0.0.1)
462                octets[0] == 127
463                    // 10.0.0.0/8 — class A private
464                    || octets[0] == 10
465                    // 172.16.0.0/12 — class B private
466                    || (octets[0] == 172 && (octets[1] & 0xf0) == 16)
467                    // 192.168.0.0/16 — class C private
468                    || (octets[0] == 192 && octets[1] == 168)
469                    // 169.254.0.0/16 — link-local
470                    || (octets[0] == 169 && octets[1] == 254)
471                    // 0.0.0.0/8 — "this network"
472                    || octets[0] == 0
473            }
474            IpAddr::V6(v6) => {
475                let segments = v6.segments();
476                v6.is_loopback()
477                    || v6.is_unspecified()
478                    // fc00::/7 — unique local unicast
479                    || (segments[0] & 0xfe00) == 0xfc00
480                    // fe80::/10 — link-local unicast
481                    || (segments[0] & 0xffc0) == 0xfe80
482            }
483        };
484    }
485
486    // DNS names like "localhost" that will resolve to loopback.
487    if host.eq_ignore_ascii_case("localhost") || host.eq_ignore_ascii_case("localhost.localdomain")
488    {
489        return true;
490    }
491
492    false
493}
494
495fn domain_matches_allowed(domain: &str, allowed: &str) -> bool {
496    let normalized_domain = domain.trim_end_matches('.').to_ascii_lowercase();
497    let normalized_allowed = allowed
498        .trim_start_matches('.')
499        .trim_end_matches('.')
500        .to_ascii_lowercase();
501
502    normalized_domain == normalized_allowed
503        || normalized_domain.ends_with(&format!(".{normalized_allowed}"))
504}
505
506impl Default for WebFetchTool {
507    fn default() -> Self {
508        Self::new()
509    }
510}
511
512#[async_trait]
513impl Tool for WebFetchTool {
514    async fn execute(&self, mut args: Value) -> Result<Value> {
515        // Backwards-compatible argument normalization:
516        // - If called with only { "url": "..." } (no prompt), interpret as:
517        //   "Fetch this URL and return a concise natural language summary."
518        //
519        // This ensures:
520        // - Simple "fetch https://..." style calls are handled natively by VT Code.
521        // - We do not force upstream agents or MCP tools to construct a full prompt.
522        // - MCP tools like `get_current_time` remain unaffected (they are separate).
523        if let Some(obj) = args.as_object_mut() {
524            let has_url = obj.get("url").is_some_and(Value::is_string);
525            let has_prompt = obj.get("prompt").is_some_and(Value::is_string);
526
527            if has_url && !has_prompt {
528                obj.insert(
529                    "prompt".to_string(),
530                    json!("Briefly summarize what this page is and what it represents. Focus on the owner/profile, primary purpose, and any notable repositories or projects."),
531                );
532            }
533        }
534
535        self.run(args).await
536    }
537
538    fn name(&self) -> &str {
539        tools::WEB_FETCH
540    }
541
542    fn description(&self) -> &str {
543        "Fetches content from a specified URL and returns an analyzed summary. Accepts: { url: string, prompt?: string, max_bytes?: number, timeout_secs?: number }. If 'prompt' is omitted, VT Code uses a safe default summary prompt so that simple 'fetch https://…' requests are handled by this built-in tool instead of delegating to external MCP tools."
544    }
545}
546
547#[cfg(test)]
548mod tests {
549    use super::*;
550    use serde_json::{Value, json};
551
552    async fn execute_json(tool: &WebFetchTool, args: Value) -> Value {
553        tool.execute(args)
554            .await
555            .expect("web_fetch should return structured JSON output")
556    }
557
558    fn error_text(result: &Value) -> Option<&str> {
559        result.get("error").and_then(Value::as_str)
560    }
561
562    #[tokio::test]
563    async fn rejects_non_https_urls() {
564        let tool = WebFetchTool::new();
565        let result = execute_json(
566            &tool,
567            json!({
568                "url": "http://example.com",
569                "prompt": "Extract the main content"
570            }),
571        )
572        .await;
573        let error = error_text(&result).unwrap_or("");
574        assert!(error.contains("Only HTTPS URLs are allowed"));
575    }
576
577    #[tokio::test]
578    async fn allows_http_when_https_disabled() {
579        let tool = WebFetchTool::with_config(
580            "restricted".to_string(),
581            Vec::new(),
582            Vec::new(),
583            Vec::new(),
584            false, // strict_https_only = false
585        );
586        let result = execute_json(
587            &tool,
588            json!({
589                "url": "http://example.com",
590                "prompt": "Extract the main content"
591            }),
592        )
593        .await;
594        if let Some(error) = error_text(&result) {
595            assert!(!error.contains("Only HTTPS URLs are allowed"));
596        }
597    }
598
599    #[tokio::test]
600    async fn rejects_localhost_urls() {
601        let tool = WebFetchTool::new();
602        let result = execute_json(
603            &tool,
604            json!({
605                "url": "https://localhost:8080",
606                "prompt": "Extract the main content"
607            }),
608        )
609        .await;
610        let error = error_text(&result).unwrap_or("");
611        assert!(error.contains("local/private networks"));
612    }
613
614    #[tokio::test]
615    async fn requires_both_url_and_prompt() {
616        let tool = WebFetchTool::new();
617        let result = execute_json(
618            &tool,
619            json!({
620                "url": "http://example.com"
621            }),
622        )
623        .await;
624        // Prompt should be auto-filled; URL then fails HTTPS policy.
625        let error = error_text(&result).unwrap_or("");
626        assert!(error.contains("Only HTTPS URLs are allowed"));
627    }
628
629    #[tokio::test]
630    async fn rejects_sensitive_banking_domains() {
631        let tool = WebFetchTool::new();
632        let result = execute_json(
633            &tool,
634            json!({
635                "url": "https://paypal.com/login",
636                "prompt": "Extract the main content"
637            }),
638        )
639        .await;
640        let error = error_text(&result).unwrap_or("");
641        assert!(error.contains("blocked for privacy and security reasons"));
642    }
643
644    #[tokio::test]
645    async fn rejects_sensitive_auth_domains() {
646        let tool = WebFetchTool::new();
647        let result = execute_json(
648            &tool,
649            json!({
650                "url": "https://accounts.google.com",
651                "prompt": "Extract the main content"
652            }),
653        )
654        .await;
655        let error = error_text(&result).unwrap_or("");
656        assert!(error.contains("blocked for privacy and security reasons"));
657    }
658
659    #[tokio::test]
660    async fn rejects_urls_with_credentials() {
661        let tool = WebFetchTool::new();
662        let result = execute_json(
663            &tool,
664            json!({
665                "url": "https://example.com?password=secret123",
666                "prompt": "Extract the main content"
667            }),
668        )
669        .await;
670        let error = error_text(&result).unwrap_or("");
671        assert!(error.contains("sensitive pattern"));
672    }
673
674    #[tokio::test]
675    async fn rejects_urls_with_api_keys() {
676        let tool = WebFetchTool::new();
677        let result = execute_json(
678            &tool,
679            json!({
680                "url": "https://api.example.com?api_key=sk_live_123456",
681                "prompt": "Extract the main content"
682            }),
683        )
684        .await;
685        let error = error_text(&result).unwrap_or("");
686        assert!(error.contains("sensitive pattern"));
687    }
688
689    #[tokio::test]
690    async fn rejects_urls_with_tokens() {
691        let tool = WebFetchTool::new();
692        let result = execute_json(
693            &tool,
694            json!({
695                "url": "https://example.com?token=xyz123",
696                "prompt": "Extract the main content"
697            }),
698        )
699        .await;
700        let error = error_text(&result).unwrap_or("");
701        assert!(error.contains("sensitive pattern"));
702    }
703
704    #[tokio::test]
705    async fn rejects_malicious_url_patterns() {
706        let tool = WebFetchTool::new();
707        let result = execute_json(
708            &tool,
709            json!({
710                "url": "https://example.com/malware.exe\"",
711                "prompt": "Extract the main content"
712            }),
713        )
714        .await;
715        let error = error_text(&result).unwrap_or("");
716        assert!(error.contains("potentially malicious pattern"));
717    }
718
719    #[tokio::test]
720    async fn rejects_typosquatting_domains() {
721        let tool = WebFetchTool::new();
722        let result = execute_json(
723            &tool,
724            json!({
725                "url": "https://g00gle.com",
726                "prompt": "Extract the main content"
727            }),
728        )
729        .await;
730        let error = error_text(&result).unwrap_or("");
731        assert!(error.contains("potentially malicious pattern"));
732    }
733
734    #[tokio::test]
735    async fn rejects_url_shorteners() {
736        let tool = WebFetchTool::new();
737        let result = execute_json(
738            &tool,
739            json!({
740                "url": "https://bit.ly/xyz123",
741                "prompt": "Extract the main content"
742            }),
743        )
744        .await;
745        let error = error_text(&result).unwrap_or("");
746        assert!(error.contains("potentially malicious pattern"));
747    }
748
749    #[tokio::test]
750    async fn whitelist_mode_requires_allowed_domains() {
751        let tool = WebFetchTool::with_config(
752            "whitelist".to_string(),
753            Vec::new(),
754            Vec::new(),
755            Vec::new(), // No allowed domains
756            true,
757        );
758        let result = execute_json(
759            &tool,
760            json!({
761                "url": "https://example.com",
762                "prompt": "Extract the main content"
763            }),
764        )
765        .await;
766        let error = error_text(&result).unwrap_or("").to_string();
767        assert!(error.contains("whitelist") || error.contains("whitelisted"));
768    }
769
770    #[tokio::test]
771    async fn whitelist_mode_allows_whitelisted_domains() {
772        let tool = WebFetchTool::with_config(
773            "whitelist".to_string(),
774            Vec::new(),
775            Vec::new(),
776            vec!["example.com".to_string()], // Only example.com allowed
777            true,
778        );
779        let result = execute_json(
780            &tool,
781            json!({
782                "url": "https://example.com/path",
783                "prompt": "Extract the main content"
784            }),
785        )
786        .await;
787        if let Some(error) = error_text(&result) {
788            assert!(!error.contains("not in the whitelist"));
789        }
790    }
791
792    #[tokio::test]
793    async fn whitelist_mode_rejects_non_whitelisted_domains() {
794        let tool = WebFetchTool::with_config(
795            "whitelist".to_string(),
796            Vec::new(),
797            Vec::new(),
798            vec!["allowed.com".to_string()],
799            true,
800        );
801        let result = execute_json(
802            &tool,
803            json!({
804                "url": "https://notallowed.com",
805                "prompt": "Extract the main content"
806            }),
807        )
808        .await;
809        let error = error_text(&result).unwrap_or("").to_string();
810        assert!(error.contains("not in the whitelist"));
811    }
812
813    #[tokio::test]
814    async fn restricted_mode_allows_exemptions() {
815        let tool = WebFetchTool::with_config(
816            "restricted".to_string(),
817            Vec::new(),
818            Vec::new(),
819            vec!["paypal.com".to_string()], // Exempt from blocklist
820            true,
821        );
822        let result = execute_json(
823            &tool,
824            json!({
825                "url": "https://paypal.com/login",
826                "prompt": "Extract the main content"
827            }),
828        )
829        .await;
830        if let Some(error) = error_text(&result) {
831            assert!(!error.contains("blocked for privacy"));
832        }
833    }
834
835    #[tokio::test]
836    async fn custom_blocked_domains_work() {
837        let tool = WebFetchTool::with_config(
838            "restricted".to_string(),
839            vec!["custom-blocked.com".to_string()], // Custom blocked domain
840            Vec::new(),
841            Vec::new(),
842            true,
843        );
844        let result = execute_json(
845            &tool,
846            json!({
847                "url": "https://custom-blocked.com/page",
848                "prompt": "Extract the main content"
849            }),
850        )
851        .await;
852        let error = error_text(&result).unwrap_or("").to_string();
853        assert!(error.contains("blocked for privacy and security reasons"));
854    }
855
856    #[tokio::test]
857    async fn custom_blocked_patterns_work() {
858        let tool = WebFetchTool::with_config(
859            "restricted".to_string(),
860            Vec::new(),
861            vec!["custom_secret=".to_string()], // Custom pattern
862            Vec::new(),
863            true,
864        );
865        let result = execute_json(
866            &tool,
867            json!({
868                "url": "https://example.com?custom_secret=abc123",
869                "prompt": "Extract the main content"
870            }),
871        )
872        .await;
873        let error = error_text(&result).unwrap_or("").to_string();
874        assert!(error.contains("sensitive pattern"));
875    }
876
877    #[test]
878    fn default_headers_contain_text_markdown_accept() {
879        let headers = WebFetchTool::default_headers();
880        assert!(headers.contains_key(ACCEPT));
881        let val = headers.get(ACCEPT).unwrap().to_str().unwrap();
882        assert!(val.contains("text/markdown"));
883    }
884}