Skip to main content

zeph_tools/
scrape.rs

1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4use std::net::{IpAddr, SocketAddr};
5use std::sync::Arc;
6use std::time::{Duration, Instant};
7
8use schemars::JsonSchema;
9use serde::Deserialize;
10use url::Url;
11
12use crate::audit::{AuditEntry, AuditLogger, AuditResult, chrono_now};
13use crate::config::ScrapeConfig;
14use crate::executor::{
15    ClaimSource, ToolCall, ToolError, ToolExecutor, ToolOutput, deserialize_params,
16};
17use crate::net::is_private_ip;
18
19#[derive(Debug, Deserialize, JsonSchema)]
20struct FetchParams {
21    /// HTTPS URL to fetch
22    url: String,
23}
24
25#[derive(Debug, Deserialize, JsonSchema)]
26struct ScrapeInstruction {
27    /// HTTPS URL to scrape
28    url: String,
29    /// CSS selector
30    select: String,
31    /// Extract mode: text, html, or attr:<name>
32    #[serde(default = "default_extract")]
33    extract: String,
34    /// Max results to return
35    limit: Option<usize>,
36}
37
38fn default_extract() -> String {
39    "text".into()
40}
41
42#[derive(Debug)]
43enum ExtractMode {
44    Text,
45    Html,
46    Attr(String),
47}
48
49impl ExtractMode {
50    fn parse(s: &str) -> Self {
51        match s {
52            "text" => Self::Text,
53            "html" => Self::Html,
54            attr if attr.starts_with("attr:") => {
55                Self::Attr(attr.strip_prefix("attr:").unwrap_or(attr).to_owned())
56            }
57            _ => Self::Text,
58        }
59    }
60}
61
62/// Extracts data from web pages via CSS selectors.
63///
64/// Detects ` ```scrape ` blocks in LLM responses containing JSON instructions,
65/// fetches the URL, and parses HTML with `scrape-core`.
66#[derive(Debug)]
67pub struct WebScrapeExecutor {
68    timeout: Duration,
69    max_body_bytes: usize,
70    allowed_domains: Vec<String>,
71    denied_domains: Vec<String>,
72    audit_logger: Option<Arc<AuditLogger>>,
73}
74
75impl WebScrapeExecutor {
76    #[must_use]
77    pub fn new(config: &ScrapeConfig) -> Self {
78        Self {
79            timeout: Duration::from_secs(config.timeout),
80            max_body_bytes: config.max_body_bytes,
81            allowed_domains: config.allowed_domains.clone(),
82            denied_domains: config.denied_domains.clone(),
83            audit_logger: None,
84        }
85    }
86
87    #[must_use]
88    pub fn with_audit(mut self, logger: Arc<AuditLogger>) -> Self {
89        self.audit_logger = Some(logger);
90        self
91    }
92
93    fn build_client(&self, host: &str, addrs: &[SocketAddr]) -> reqwest::Client {
94        let mut builder = reqwest::Client::builder()
95            .timeout(self.timeout)
96            .redirect(reqwest::redirect::Policy::none());
97        builder = builder.resolve_to_addrs(host, addrs);
98        builder.build().unwrap_or_default()
99    }
100}
101
102impl ToolExecutor for WebScrapeExecutor {
103    fn tool_definitions(&self) -> Vec<crate::registry::ToolDef> {
104        use crate::registry::{InvocationHint, ToolDef};
105        vec![
106            ToolDef {
107                id: "web_scrape".into(),
108                description: "Extract structured data from a web page using CSS selectors.\n\nONLY call this tool when the user has explicitly provided a URL in their message, or when a prior tool call returned a URL to retrieve. NEVER construct, guess, or infer a URL from entity names, brand knowledge, or domain patterns.\n\nParameters: url (string, required) - HTTPS URL; select (string, required) - CSS selector; extract (string, optional) - \"text\", \"html\", or \"attr:<name>\"; limit (integer, optional) - max results\nReturns: extracted text/HTML/attribute values, one per line\nErrors: InvalidParams if URL is not HTTPS or selector is empty; Timeout after configured seconds; connection/DNS failures".into(),
109                schema: schemars::schema_for!(ScrapeInstruction),
110                invocation: InvocationHint::FencedBlock("scrape"),
111            },
112            ToolDef {
113                id: "fetch".into(),
114                description: "Fetch a URL and return the response body as plain text.\n\nONLY call this tool when the user has explicitly provided a URL in their message, or when a prior tool call returned a URL to retrieve. NEVER construct, guess, or infer a URL from entity names, brand knowledge, or domain patterns. If no URL is present in the conversation, do not call this tool.\n\nParameters: url (string, required) - HTTPS URL to fetch\nReturns: response body as UTF-8 text, truncated if exceeding max body size\nErrors: InvalidParams if URL is not HTTPS; Timeout; SSRF-blocked private IPs; connection failures".into(),
115                schema: schemars::schema_for!(FetchParams),
116                invocation: InvocationHint::ToolCall,
117            },
118        ]
119    }
120
121    async fn execute(&self, response: &str) -> Result<Option<ToolOutput>, ToolError> {
122        let blocks = extract_scrape_blocks(response);
123        if blocks.is_empty() {
124            return Ok(None);
125        }
126
127        let mut outputs = Vec::with_capacity(blocks.len());
128        #[allow(clippy::cast_possible_truncation)]
129        let blocks_executed = blocks.len() as u32;
130
131        for block in &blocks {
132            let instruction: ScrapeInstruction = serde_json::from_str(block).map_err(|e| {
133                ToolError::Execution(std::io::Error::new(
134                    std::io::ErrorKind::InvalidData,
135                    e.to_string(),
136                ))
137            })?;
138            let start = Instant::now();
139            let scrape_result = self.scrape_instruction(&instruction).await;
140            #[allow(clippy::cast_possible_truncation)]
141            let duration_ms = start.elapsed().as_millis() as u64;
142            match scrape_result {
143                Ok(output) => {
144                    self.log_audit(
145                        "web_scrape",
146                        &instruction.url,
147                        AuditResult::Success,
148                        duration_ms,
149                        None,
150                    )
151                    .await;
152                    outputs.push(output);
153                }
154                Err(e) => {
155                    let audit_result = tool_error_to_audit_result(&e);
156                    self.log_audit(
157                        "web_scrape",
158                        &instruction.url,
159                        audit_result,
160                        duration_ms,
161                        Some(&e),
162                    )
163                    .await;
164                    return Err(e);
165                }
166            }
167        }
168
169        Ok(Some(ToolOutput {
170            tool_name: "web-scrape".to_owned(),
171            summary: outputs.join("\n\n"),
172            blocks_executed,
173            filter_stats: None,
174            diff: None,
175            streamed: false,
176            terminal_id: None,
177            locations: None,
178            raw_response: None,
179            claim_source: Some(ClaimSource::WebScrape),
180        }))
181    }
182
183    async fn execute_tool_call(&self, call: &ToolCall) -> Result<Option<ToolOutput>, ToolError> {
184        match call.tool_id.as_str() {
185            "web_scrape" => {
186                let instruction: ScrapeInstruction = deserialize_params(&call.params)?;
187                let start = Instant::now();
188                let result = self.scrape_instruction(&instruction).await;
189                #[allow(clippy::cast_possible_truncation)]
190                let duration_ms = start.elapsed().as_millis() as u64;
191                match result {
192                    Ok(output) => {
193                        self.log_audit(
194                            "web_scrape",
195                            &instruction.url,
196                            AuditResult::Success,
197                            duration_ms,
198                            None,
199                        )
200                        .await;
201                        Ok(Some(ToolOutput {
202                            tool_name: "web-scrape".to_owned(),
203                            summary: output,
204                            blocks_executed: 1,
205                            filter_stats: None,
206                            diff: None,
207                            streamed: false,
208                            terminal_id: None,
209                            locations: None,
210                            raw_response: None,
211                            claim_source: Some(ClaimSource::WebScrape),
212                        }))
213                    }
214                    Err(e) => {
215                        let audit_result = tool_error_to_audit_result(&e);
216                        self.log_audit(
217                            "web_scrape",
218                            &instruction.url,
219                            audit_result,
220                            duration_ms,
221                            Some(&e),
222                        )
223                        .await;
224                        Err(e)
225                    }
226                }
227            }
228            "fetch" => {
229                let p: FetchParams = deserialize_params(&call.params)?;
230                let start = Instant::now();
231                let result = self.handle_fetch(&p).await;
232                #[allow(clippy::cast_possible_truncation)]
233                let duration_ms = start.elapsed().as_millis() as u64;
234                match result {
235                    Ok(output) => {
236                        self.log_audit("fetch", &p.url, AuditResult::Success, duration_ms, None)
237                            .await;
238                        Ok(Some(ToolOutput {
239                            tool_name: "fetch".to_owned(),
240                            summary: output,
241                            blocks_executed: 1,
242                            filter_stats: None,
243                            diff: None,
244                            streamed: false,
245                            terminal_id: None,
246                            locations: None,
247                            raw_response: None,
248                            claim_source: Some(ClaimSource::WebScrape),
249                        }))
250                    }
251                    Err(e) => {
252                        let audit_result = tool_error_to_audit_result(&e);
253                        self.log_audit("fetch", &p.url, audit_result, duration_ms, Some(&e))
254                            .await;
255                        Err(e)
256                    }
257                }
258            }
259            _ => Ok(None),
260        }
261    }
262
263    fn is_tool_retryable(&self, tool_id: &str) -> bool {
264        matches!(tool_id, "web_scrape" | "fetch")
265    }
266}
267
268fn tool_error_to_audit_result(e: &ToolError) -> AuditResult {
269    match e {
270        ToolError::Blocked { command } => AuditResult::Blocked {
271            reason: command.clone(),
272        },
273        ToolError::Timeout { .. } => AuditResult::Timeout,
274        _ => AuditResult::Error {
275            message: e.to_string(),
276        },
277    }
278}
279
280impl WebScrapeExecutor {
281    async fn log_audit(
282        &self,
283        tool: &str,
284        command: &str,
285        result: AuditResult,
286        duration_ms: u64,
287        error: Option<&ToolError>,
288    ) {
289        if let Some(ref logger) = self.audit_logger {
290            let (error_category, error_domain, error_phase) =
291                error.map_or((None, None, None), |e| {
292                    let cat = e.category();
293                    (
294                        Some(cat.label().to_owned()),
295                        Some(cat.domain().label().to_owned()),
296                        Some(cat.phase().label().to_owned()),
297                    )
298                });
299            let entry = AuditEntry {
300                timestamp: chrono_now(),
301                tool: tool.into(),
302                command: command.into(),
303                result,
304                duration_ms,
305                error_category,
306                error_domain,
307                error_phase,
308                claim_source: Some(ClaimSource::WebScrape),
309                mcp_server_id: None,
310                injection_flagged: false,
311                embedding_anomalous: false,
312                cross_boundary_mcp_to_acp: false,
313                adversarial_policy_decision: None,
314                exit_code: None,
315                truncated: false,
316            };
317            logger.log(&entry).await;
318        }
319    }
320
321    async fn handle_fetch(&self, params: &FetchParams) -> Result<String, ToolError> {
322        let parsed = validate_url(&params.url)?;
323        check_domain_policy(
324            parsed.host_str().unwrap_or(""),
325            &self.allowed_domains,
326            &self.denied_domains,
327        )?;
328        let (host, addrs) = resolve_and_validate(&parsed).await?;
329        self.fetch_html(&params.url, &host, &addrs).await
330    }
331
332    async fn scrape_instruction(
333        &self,
334        instruction: &ScrapeInstruction,
335    ) -> Result<String, ToolError> {
336        let parsed = validate_url(&instruction.url)?;
337        check_domain_policy(
338            parsed.host_str().unwrap_or(""),
339            &self.allowed_domains,
340            &self.denied_domains,
341        )?;
342        let (host, addrs) = resolve_and_validate(&parsed).await?;
343        let html = self.fetch_html(&instruction.url, &host, &addrs).await?;
344        let selector = instruction.select.clone();
345        let extract = ExtractMode::parse(&instruction.extract);
346        let limit = instruction.limit.unwrap_or(10);
347        tokio::task::spawn_blocking(move || parse_and_extract(&html, &selector, &extract, limit))
348            .await
349            .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())))?
350    }
351
352    /// Fetches the HTML at `url`, manually following up to 3 redirects.
353    ///
354    /// Each redirect target is validated with `validate_url` and `resolve_and_validate`
355    /// before following, preventing SSRF via redirect chains.
356    ///
357    /// # Errors
358    ///
359    /// Returns `ToolError::Blocked` if any redirect target resolves to a private IP.
360    /// Returns `ToolError::Execution` on HTTP errors, too-large bodies, or too many redirects.
361    async fn fetch_html(
362        &self,
363        url: &str,
364        host: &str,
365        addrs: &[SocketAddr],
366    ) -> Result<String, ToolError> {
367        const MAX_REDIRECTS: usize = 3;
368
369        let mut current_url = url.to_owned();
370        let mut current_host = host.to_owned();
371        let mut current_addrs = addrs.to_vec();
372
373        for hop in 0..=MAX_REDIRECTS {
374            // Build a per-hop client pinned to the current hop's validated addresses.
375            let client = self.build_client(&current_host, &current_addrs);
376            let resp = client
377                .get(&current_url)
378                .send()
379                .await
380                .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())))?;
381
382            let status = resp.status();
383
384            if status.is_redirection() {
385                if hop == MAX_REDIRECTS {
386                    return Err(ToolError::Execution(std::io::Error::other(
387                        "too many redirects",
388                    )));
389                }
390
391                let location = resp
392                    .headers()
393                    .get(reqwest::header::LOCATION)
394                    .and_then(|v| v.to_str().ok())
395                    .ok_or_else(|| {
396                        ToolError::Execution(std::io::Error::other("redirect with no Location"))
397                    })?;
398
399                // Resolve relative redirect URLs against the current URL.
400                let base = Url::parse(&current_url)
401                    .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())))?;
402                let next_url = base
403                    .join(location)
404                    .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())))?;
405
406                let validated = validate_url(next_url.as_str())?;
407                let (next_host, next_addrs) = resolve_and_validate(&validated).await?;
408
409                current_url = next_url.to_string();
410                current_host = next_host;
411                current_addrs = next_addrs;
412                continue;
413            }
414
415            if !status.is_success() {
416                return Err(ToolError::Http {
417                    status: status.as_u16(),
418                    message: status.canonical_reason().unwrap_or("unknown").to_owned(),
419                });
420            }
421
422            let bytes = resp
423                .bytes()
424                .await
425                .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())))?;
426
427            if bytes.len() > self.max_body_bytes {
428                return Err(ToolError::Execution(std::io::Error::other(format!(
429                    "response too large: {} bytes (max: {})",
430                    bytes.len(),
431                    self.max_body_bytes,
432                ))));
433            }
434
435            return String::from_utf8(bytes.to_vec())
436                .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())));
437        }
438
439        Err(ToolError::Execution(std::io::Error::other(
440            "too many redirects",
441        )))
442    }
443}
444
445fn extract_scrape_blocks(text: &str) -> Vec<&str> {
446    crate::executor::extract_fenced_blocks(text, "scrape")
447}
448
449/// Check host against the domain allowlist/denylist from `ScrapeConfig`.
450///
451/// Logic:
452/// 1. If `denied_domains` matches the host → block.
453/// 2. If `allowed_domains` is non-empty:
454///    a. IP address hosts are always rejected (no pattern can match a bare IP).
455///    b. Hosts not matching any entry → block.
456/// 3. Otherwise → allow.
457///
458/// Wildcard prefix matching: `*.example.com` matches `sub.example.com` but NOT `example.com`.
459/// Multiple wildcards are not supported; patterns with more than one `*` are treated as exact.
460fn check_domain_policy(
461    host: &str,
462    allowed_domains: &[String],
463    denied_domains: &[String],
464) -> Result<(), ToolError> {
465    if denied_domains.iter().any(|p| domain_matches(p, host)) {
466        return Err(ToolError::Blocked {
467            command: format!("domain blocked by denylist: {host}"),
468        });
469    }
470    if !allowed_domains.is_empty() {
471        // Bare IP addresses cannot match any domain pattern — reject when allowlist is active.
472        let is_ip = host.parse::<std::net::IpAddr>().is_ok()
473            || (host.starts_with('[') && host.ends_with(']'));
474        if is_ip {
475            return Err(ToolError::Blocked {
476                command: format!(
477                    "bare IP address not allowed when domain allowlist is active: {host}"
478                ),
479            });
480        }
481        if !allowed_domains.iter().any(|p| domain_matches(p, host)) {
482            return Err(ToolError::Blocked {
483                command: format!("domain not in allowlist: {host}"),
484            });
485        }
486    }
487    Ok(())
488}
489
490/// Match a domain pattern against a hostname.
491///
492/// Pattern `*.example.com` matches `sub.example.com` but not `example.com` or
493/// `sub.sub.example.com`. Exact patterns match only themselves.
494/// Patterns with multiple `*` characters are not supported and are treated as exact strings.
495fn domain_matches(pattern: &str, host: &str) -> bool {
496    if pattern.starts_with("*.") {
497        // Allow only a single subdomain level: `*.example.com` → `<label>.example.com`
498        let suffix = &pattern[1..]; // ".example.com"
499        if let Some(remainder) = host.strip_suffix(suffix) {
500            // remainder must be a single DNS label (no dots)
501            !remainder.is_empty() && !remainder.contains('.')
502        } else {
503            false
504        }
505    } else {
506        pattern == host
507    }
508}
509
510fn validate_url(raw: &str) -> Result<Url, ToolError> {
511    let parsed = Url::parse(raw).map_err(|_| ToolError::Blocked {
512        command: format!("invalid URL: {raw}"),
513    })?;
514
515    if parsed.scheme() != "https" {
516        return Err(ToolError::Blocked {
517            command: format!("scheme not allowed: {}", parsed.scheme()),
518        });
519    }
520
521    if let Some(host) = parsed.host()
522        && is_private_host(&host)
523    {
524        return Err(ToolError::Blocked {
525            command: format!(
526                "private/local host blocked: {}",
527                parsed.host_str().unwrap_or("")
528            ),
529        });
530    }
531
532    Ok(parsed)
533}
534
535fn is_private_host(host: &url::Host<&str>) -> bool {
536    match host {
537        url::Host::Domain(d) => {
538            // Exact match or subdomain of localhost (e.g. foo.localhost)
539            // and .internal/.local TLDs used in cloud/k8s environments.
540            #[allow(clippy::case_sensitive_file_extension_comparisons)]
541            {
542                *d == "localhost"
543                    || d.ends_with(".localhost")
544                    || d.ends_with(".internal")
545                    || d.ends_with(".local")
546            }
547        }
548        url::Host::Ipv4(v4) => is_private_ip(IpAddr::V4(*v4)),
549        url::Host::Ipv6(v6) => is_private_ip(IpAddr::V6(*v6)),
550    }
551}
552
553/// Resolves DNS for the URL host, validates all resolved IPs against private ranges,
554/// and returns the hostname and validated socket addresses.
555///
556/// Returning the addresses allows the caller to pin the HTTP client to these exact
557/// addresses, eliminating TOCTOU between DNS validation and the actual connection.
558async fn resolve_and_validate(url: &Url) -> Result<(String, Vec<SocketAddr>), ToolError> {
559    let Some(host) = url.host_str() else {
560        return Ok((String::new(), vec![]));
561    };
562    let port = url.port_or_known_default().unwrap_or(443);
563    let addrs: Vec<SocketAddr> = tokio::net::lookup_host(format!("{host}:{port}"))
564        .await
565        .map_err(|e| ToolError::Blocked {
566            command: format!("DNS resolution failed: {e}"),
567        })?
568        .collect();
569    for addr in &addrs {
570        if is_private_ip(addr.ip()) {
571            return Err(ToolError::Blocked {
572                command: format!("SSRF protection: private IP {} for host {host}", addr.ip()),
573            });
574        }
575    }
576    Ok((host.to_owned(), addrs))
577}
578
579fn parse_and_extract(
580    html: &str,
581    selector: &str,
582    extract: &ExtractMode,
583    limit: usize,
584) -> Result<String, ToolError> {
585    let soup = scrape_core::Soup::parse(html);
586
587    let tags = soup.find_all(selector).map_err(|e| {
588        ToolError::Execution(std::io::Error::new(
589            std::io::ErrorKind::InvalidData,
590            format!("invalid selector: {e}"),
591        ))
592    })?;
593
594    let mut results = Vec::new();
595
596    for tag in tags.into_iter().take(limit) {
597        let value = match extract {
598            ExtractMode::Text => tag.text(),
599            ExtractMode::Html => tag.inner_html(),
600            ExtractMode::Attr(name) => tag.get(name).unwrap_or_default().to_owned(),
601        };
602        if !value.trim().is_empty() {
603            results.push(value.trim().to_owned());
604        }
605    }
606
607    if results.is_empty() {
608        Ok(format!("No results for selector: {selector}"))
609    } else {
610        Ok(results.join("\n"))
611    }
612}
613
614#[cfg(test)]
615mod tests {
616    use super::*;
617
618    // --- extract_scrape_blocks ---
619
620    #[test]
621    fn extract_single_block() {
622        let text =
623            "Here:\n```scrape\n{\"url\":\"https://example.com\",\"select\":\"h1\"}\n```\nDone.";
624        let blocks = extract_scrape_blocks(text);
625        assert_eq!(blocks.len(), 1);
626        assert!(blocks[0].contains("example.com"));
627    }
628
629    #[test]
630    fn extract_multiple_blocks() {
631        let text = "```scrape\n{\"url\":\"https://a.com\",\"select\":\"h1\"}\n```\ntext\n```scrape\n{\"url\":\"https://b.com\",\"select\":\"p\"}\n```";
632        let blocks = extract_scrape_blocks(text);
633        assert_eq!(blocks.len(), 2);
634    }
635
636    #[test]
637    fn no_blocks_returns_empty() {
638        let blocks = extract_scrape_blocks("plain text, no code blocks");
639        assert!(blocks.is_empty());
640    }
641
642    #[test]
643    fn unclosed_block_ignored() {
644        let blocks = extract_scrape_blocks("```scrape\n{\"url\":\"https://x.com\"}");
645        assert!(blocks.is_empty());
646    }
647
648    #[test]
649    fn non_scrape_block_ignored() {
650        let text =
651            "```bash\necho hi\n```\n```scrape\n{\"url\":\"https://x.com\",\"select\":\"h1\"}\n```";
652        let blocks = extract_scrape_blocks(text);
653        assert_eq!(blocks.len(), 1);
654        assert!(blocks[0].contains("x.com"));
655    }
656
657    #[test]
658    fn multiline_json_block() {
659        let text =
660            "```scrape\n{\n  \"url\": \"https://example.com\",\n  \"select\": \"h1\"\n}\n```";
661        let blocks = extract_scrape_blocks(text);
662        assert_eq!(blocks.len(), 1);
663        let instr: ScrapeInstruction = serde_json::from_str(blocks[0]).unwrap();
664        assert_eq!(instr.url, "https://example.com");
665    }
666
667    // --- ScrapeInstruction parsing ---
668
669    #[test]
670    fn parse_valid_instruction() {
671        let json = r#"{"url":"https://example.com","select":"h1","extract":"text","limit":5}"#;
672        let instr: ScrapeInstruction = serde_json::from_str(json).unwrap();
673        assert_eq!(instr.url, "https://example.com");
674        assert_eq!(instr.select, "h1");
675        assert_eq!(instr.extract, "text");
676        assert_eq!(instr.limit, Some(5));
677    }
678
679    #[test]
680    fn parse_minimal_instruction() {
681        let json = r#"{"url":"https://example.com","select":"p"}"#;
682        let instr: ScrapeInstruction = serde_json::from_str(json).unwrap();
683        assert_eq!(instr.extract, "text");
684        assert!(instr.limit.is_none());
685    }
686
687    #[test]
688    fn parse_attr_extract() {
689        let json = r#"{"url":"https://example.com","select":"a","extract":"attr:href"}"#;
690        let instr: ScrapeInstruction = serde_json::from_str(json).unwrap();
691        assert_eq!(instr.extract, "attr:href");
692    }
693
694    #[test]
695    fn parse_invalid_json_errors() {
696        let result = serde_json::from_str::<ScrapeInstruction>("not json");
697        assert!(result.is_err());
698    }
699
700    // --- ExtractMode ---
701
702    #[test]
703    fn extract_mode_text() {
704        assert!(matches!(ExtractMode::parse("text"), ExtractMode::Text));
705    }
706
707    #[test]
708    fn extract_mode_html() {
709        assert!(matches!(ExtractMode::parse("html"), ExtractMode::Html));
710    }
711
712    #[test]
713    fn extract_mode_attr() {
714        let mode = ExtractMode::parse("attr:href");
715        assert!(matches!(mode, ExtractMode::Attr(ref s) if s == "href"));
716    }
717
718    #[test]
719    fn extract_mode_unknown_defaults_to_text() {
720        assert!(matches!(ExtractMode::parse("unknown"), ExtractMode::Text));
721    }
722
723    // --- validate_url ---
724
725    #[test]
726    fn valid_https_url() {
727        assert!(validate_url("https://example.com").is_ok());
728    }
729
730    #[test]
731    fn http_rejected() {
732        let err = validate_url("http://example.com").unwrap_err();
733        assert!(matches!(err, ToolError::Blocked { .. }));
734    }
735
736    #[test]
737    fn ftp_rejected() {
738        let err = validate_url("ftp://files.example.com").unwrap_err();
739        assert!(matches!(err, ToolError::Blocked { .. }));
740    }
741
742    #[test]
743    fn file_rejected() {
744        let err = validate_url("file:///etc/passwd").unwrap_err();
745        assert!(matches!(err, ToolError::Blocked { .. }));
746    }
747
748    #[test]
749    fn invalid_url_rejected() {
750        let err = validate_url("not a url").unwrap_err();
751        assert!(matches!(err, ToolError::Blocked { .. }));
752    }
753
754    #[test]
755    fn localhost_blocked() {
756        let err = validate_url("https://localhost/path").unwrap_err();
757        assert!(matches!(err, ToolError::Blocked { .. }));
758    }
759
760    #[test]
761    fn loopback_ip_blocked() {
762        let err = validate_url("https://127.0.0.1/path").unwrap_err();
763        assert!(matches!(err, ToolError::Blocked { .. }));
764    }
765
766    #[test]
767    fn private_10_blocked() {
768        let err = validate_url("https://10.0.0.1/api").unwrap_err();
769        assert!(matches!(err, ToolError::Blocked { .. }));
770    }
771
772    #[test]
773    fn private_172_blocked() {
774        let err = validate_url("https://172.16.0.1/api").unwrap_err();
775        assert!(matches!(err, ToolError::Blocked { .. }));
776    }
777
778    #[test]
779    fn private_192_blocked() {
780        let err = validate_url("https://192.168.1.1/api").unwrap_err();
781        assert!(matches!(err, ToolError::Blocked { .. }));
782    }
783
784    #[test]
785    fn ipv6_loopback_blocked() {
786        let err = validate_url("https://[::1]/path").unwrap_err();
787        assert!(matches!(err, ToolError::Blocked { .. }));
788    }
789
790    #[test]
791    fn public_ip_allowed() {
792        assert!(validate_url("https://93.184.216.34/page").is_ok());
793    }
794
795    // --- parse_and_extract ---
796
797    #[test]
798    fn extract_text_from_html() {
799        let html = "<html><body><h1>Hello World</h1><p>Content</p></body></html>";
800        let result = parse_and_extract(html, "h1", &ExtractMode::Text, 10).unwrap();
801        assert_eq!(result, "Hello World");
802    }
803
804    #[test]
805    fn extract_multiple_elements() {
806        let html = "<ul><li>A</li><li>B</li><li>C</li></ul>";
807        let result = parse_and_extract(html, "li", &ExtractMode::Text, 10).unwrap();
808        assert_eq!(result, "A\nB\nC");
809    }
810
811    #[test]
812    fn extract_with_limit() {
813        let html = "<ul><li>A</li><li>B</li><li>C</li></ul>";
814        let result = parse_and_extract(html, "li", &ExtractMode::Text, 2).unwrap();
815        assert_eq!(result, "A\nB");
816    }
817
818    #[test]
819    fn extract_attr_href() {
820        let html = r#"<a href="https://example.com">Link</a>"#;
821        let result =
822            parse_and_extract(html, "a", &ExtractMode::Attr("href".to_owned()), 10).unwrap();
823        assert_eq!(result, "https://example.com");
824    }
825
826    #[test]
827    fn extract_inner_html() {
828        let html = "<div><span>inner</span></div>";
829        let result = parse_and_extract(html, "div", &ExtractMode::Html, 10).unwrap();
830        assert!(result.contains("<span>inner</span>"));
831    }
832
833    #[test]
834    fn no_matches_returns_message() {
835        let html = "<html><body><p>text</p></body></html>";
836        let result = parse_and_extract(html, "h1", &ExtractMode::Text, 10).unwrap();
837        assert!(result.starts_with("No results for selector:"));
838    }
839
840    #[test]
841    fn empty_text_skipped() {
842        let html = "<ul><li>  </li><li>A</li></ul>";
843        let result = parse_and_extract(html, "li", &ExtractMode::Text, 10).unwrap();
844        assert_eq!(result, "A");
845    }
846
847    #[test]
848    fn invalid_selector_errors() {
849        let html = "<html><body></body></html>";
850        let result = parse_and_extract(html, "[[[invalid", &ExtractMode::Text, 10);
851        assert!(result.is_err());
852    }
853
854    #[test]
855    fn empty_html_returns_no_results() {
856        let result = parse_and_extract("", "h1", &ExtractMode::Text, 10).unwrap();
857        assert!(result.starts_with("No results for selector:"));
858    }
859
860    #[test]
861    fn nested_selector() {
862        let html = "<div><span>inner</span></div><span>outer</span>";
863        let result = parse_and_extract(html, "div > span", &ExtractMode::Text, 10).unwrap();
864        assert_eq!(result, "inner");
865    }
866
867    #[test]
868    fn attr_missing_returns_empty() {
869        let html = r"<a>No href</a>";
870        let result =
871            parse_and_extract(html, "a", &ExtractMode::Attr("href".to_owned()), 10).unwrap();
872        assert!(result.starts_with("No results for selector:"));
873    }
874
875    #[test]
876    fn extract_html_mode() {
877        let html = "<div><b>bold</b> text</div>";
878        let result = parse_and_extract(html, "div", &ExtractMode::Html, 10).unwrap();
879        assert!(result.contains("<b>bold</b>"));
880    }
881
882    #[test]
883    fn limit_zero_returns_no_results() {
884        let html = "<ul><li>A</li><li>B</li></ul>";
885        let result = parse_and_extract(html, "li", &ExtractMode::Text, 0).unwrap();
886        assert!(result.starts_with("No results for selector:"));
887    }
888
889    // --- validate_url edge cases ---
890
891    #[test]
892    fn url_with_port_allowed() {
893        assert!(validate_url("https://example.com:8443/path").is_ok());
894    }
895
896    #[test]
897    fn link_local_ip_blocked() {
898        let err = validate_url("https://169.254.1.1/path").unwrap_err();
899        assert!(matches!(err, ToolError::Blocked { .. }));
900    }
901
902    #[test]
903    fn url_no_scheme_rejected() {
904        let err = validate_url("example.com/path").unwrap_err();
905        assert!(matches!(err, ToolError::Blocked { .. }));
906    }
907
908    #[test]
909    fn unspecified_ipv4_blocked() {
910        let err = validate_url("https://0.0.0.0/path").unwrap_err();
911        assert!(matches!(err, ToolError::Blocked { .. }));
912    }
913
914    #[test]
915    fn broadcast_ipv4_blocked() {
916        let err = validate_url("https://255.255.255.255/path").unwrap_err();
917        assert!(matches!(err, ToolError::Blocked { .. }));
918    }
919
920    #[test]
921    fn ipv6_link_local_blocked() {
922        let err = validate_url("https://[fe80::1]/path").unwrap_err();
923        assert!(matches!(err, ToolError::Blocked { .. }));
924    }
925
926    #[test]
927    fn ipv6_unique_local_blocked() {
928        let err = validate_url("https://[fd12::1]/path").unwrap_err();
929        assert!(matches!(err, ToolError::Blocked { .. }));
930    }
931
932    #[test]
933    fn ipv4_mapped_ipv6_loopback_blocked() {
934        let err = validate_url("https://[::ffff:127.0.0.1]/path").unwrap_err();
935        assert!(matches!(err, ToolError::Blocked { .. }));
936    }
937
938    #[test]
939    fn ipv4_mapped_ipv6_private_blocked() {
940        let err = validate_url("https://[::ffff:10.0.0.1]/path").unwrap_err();
941        assert!(matches!(err, ToolError::Blocked { .. }));
942    }
943
944    // --- WebScrapeExecutor (no-network) ---
945
946    #[tokio::test]
947    async fn executor_no_blocks_returns_none() {
948        let config = ScrapeConfig::default();
949        let executor = WebScrapeExecutor::new(&config);
950        let result = executor.execute("plain text").await;
951        assert!(result.unwrap().is_none());
952    }
953
954    #[tokio::test]
955    async fn executor_invalid_json_errors() {
956        let config = ScrapeConfig::default();
957        let executor = WebScrapeExecutor::new(&config);
958        let response = "```scrape\nnot json\n```";
959        let result = executor.execute(response).await;
960        assert!(matches!(result, Err(ToolError::Execution(_))));
961    }
962
963    #[tokio::test]
964    async fn executor_blocked_url_errors() {
965        let config = ScrapeConfig::default();
966        let executor = WebScrapeExecutor::new(&config);
967        let response = "```scrape\n{\"url\":\"http://example.com\",\"select\":\"h1\"}\n```";
968        let result = executor.execute(response).await;
969        assert!(matches!(result, Err(ToolError::Blocked { .. })));
970    }
971
972    #[tokio::test]
973    async fn executor_private_ip_blocked() {
974        let config = ScrapeConfig::default();
975        let executor = WebScrapeExecutor::new(&config);
976        let response = "```scrape\n{\"url\":\"https://192.168.1.1/api\",\"select\":\"h1\"}\n```";
977        let result = executor.execute(response).await;
978        assert!(matches!(result, Err(ToolError::Blocked { .. })));
979    }
980
981    #[tokio::test]
982    async fn executor_unreachable_host_returns_error() {
983        let config = ScrapeConfig {
984            timeout: 1,
985            max_body_bytes: 1_048_576,
986            ..Default::default()
987        };
988        let executor = WebScrapeExecutor::new(&config);
989        let response = "```scrape\n{\"url\":\"https://192.0.2.1:1/page\",\"select\":\"h1\"}\n```";
990        let result = executor.execute(response).await;
991        assert!(matches!(result, Err(ToolError::Execution(_))));
992    }
993
994    #[tokio::test]
995    async fn executor_localhost_url_blocked() {
996        let config = ScrapeConfig::default();
997        let executor = WebScrapeExecutor::new(&config);
998        let response = "```scrape\n{\"url\":\"https://localhost:9999/api\",\"select\":\"h1\"}\n```";
999        let result = executor.execute(response).await;
1000        assert!(matches!(result, Err(ToolError::Blocked { .. })));
1001    }
1002
1003    #[tokio::test]
1004    async fn executor_empty_text_returns_none() {
1005        let config = ScrapeConfig::default();
1006        let executor = WebScrapeExecutor::new(&config);
1007        let result = executor.execute("").await;
1008        assert!(result.unwrap().is_none());
1009    }
1010
1011    #[tokio::test]
1012    async fn executor_multiple_blocks_first_blocked() {
1013        let config = ScrapeConfig::default();
1014        let executor = WebScrapeExecutor::new(&config);
1015        let response = "```scrape\n{\"url\":\"http://evil.com\",\"select\":\"h1\"}\n```\n\
1016             ```scrape\n{\"url\":\"https://ok.com\",\"select\":\"h1\"}\n```";
1017        let result = executor.execute(response).await;
1018        assert!(result.is_err());
1019    }
1020
1021    #[test]
1022    fn validate_url_empty_string() {
1023        let err = validate_url("").unwrap_err();
1024        assert!(matches!(err, ToolError::Blocked { .. }));
1025    }
1026
1027    #[test]
1028    fn validate_url_javascript_scheme_blocked() {
1029        let err = validate_url("javascript:alert(1)").unwrap_err();
1030        assert!(matches!(err, ToolError::Blocked { .. }));
1031    }
1032
1033    #[test]
1034    fn validate_url_data_scheme_blocked() {
1035        let err = validate_url("data:text/html,<h1>hi</h1>").unwrap_err();
1036        assert!(matches!(err, ToolError::Blocked { .. }));
1037    }
1038
1039    #[test]
1040    fn is_private_host_public_domain_is_false() {
1041        let host: url::Host<&str> = url::Host::Domain("example.com");
1042        assert!(!is_private_host(&host));
1043    }
1044
1045    #[test]
1046    fn is_private_host_localhost_is_true() {
1047        let host: url::Host<&str> = url::Host::Domain("localhost");
1048        assert!(is_private_host(&host));
1049    }
1050
1051    #[test]
1052    fn is_private_host_ipv6_unspecified_is_true() {
1053        let host = url::Host::Ipv6(std::net::Ipv6Addr::UNSPECIFIED);
1054        assert!(is_private_host(&host));
1055    }
1056
1057    #[test]
1058    fn is_private_host_public_ipv6_is_false() {
1059        let host = url::Host::Ipv6("2001:db8::1".parse().unwrap());
1060        assert!(!is_private_host(&host));
1061    }
1062
1063    // --- fetch_html redirect logic: wiremock HTTP server tests ---
1064    //
1065    // These tests use a local wiremock server to exercise the redirect-following logic
1066    // in `fetch_html` without requiring an external HTTPS connection. The server binds to
1067    // 127.0.0.1, and tests call `fetch_html` directly (bypassing `validate_url`) to avoid
1068    // the SSRF guard that would otherwise block loopback connections.
1069
1070    /// Helper: returns executor + (`server_url`, `server_addr`) from a running wiremock mock server.
1071    /// The server address is passed to `fetch_html` via `resolve_to_addrs` so the client
1072    /// connects to the mock instead of doing a real DNS lookup.
1073    async fn mock_server_executor() -> (WebScrapeExecutor, wiremock::MockServer) {
1074        let server = wiremock::MockServer::start().await;
1075        let executor = WebScrapeExecutor {
1076            timeout: Duration::from_secs(5),
1077            max_body_bytes: 1_048_576,
1078            allowed_domains: vec![],
1079            denied_domains: vec![],
1080            audit_logger: None,
1081        };
1082        (executor, server)
1083    }
1084
1085    /// Parses the mock server's URI into (`host_str`, `socket_addr`) for use with `build_client`.
1086    fn server_host_and_addr(server: &wiremock::MockServer) -> (String, Vec<std::net::SocketAddr>) {
1087        let uri = server.uri();
1088        let url = Url::parse(&uri).unwrap();
1089        let host = url.host_str().unwrap_or("127.0.0.1").to_owned();
1090        let port = url.port().unwrap_or(80);
1091        let addr: std::net::SocketAddr = format!("{host}:{port}").parse().unwrap();
1092        (host, vec![addr])
1093    }
1094
1095    /// Test-only redirect follower that mimics `fetch_html`'s loop but skips `validate_url` /
1096    /// `resolve_and_validate`. This lets us exercise the redirect-counting and
1097    /// missing-Location logic against a plain HTTP wiremock server.
1098    async fn follow_redirects_raw(
1099        executor: &WebScrapeExecutor,
1100        start_url: &str,
1101        host: &str,
1102        addrs: &[std::net::SocketAddr],
1103    ) -> Result<String, ToolError> {
1104        const MAX_REDIRECTS: usize = 3;
1105        let mut current_url = start_url.to_owned();
1106        let mut current_host = host.to_owned();
1107        let mut current_addrs = addrs.to_vec();
1108
1109        for hop in 0..=MAX_REDIRECTS {
1110            let client = executor.build_client(&current_host, &current_addrs);
1111            let resp = client
1112                .get(&current_url)
1113                .send()
1114                .await
1115                .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())))?;
1116
1117            let status = resp.status();
1118
1119            if status.is_redirection() {
1120                if hop == MAX_REDIRECTS {
1121                    return Err(ToolError::Execution(std::io::Error::other(
1122                        "too many redirects",
1123                    )));
1124                }
1125
1126                let location = resp
1127                    .headers()
1128                    .get(reqwest::header::LOCATION)
1129                    .and_then(|v| v.to_str().ok())
1130                    .ok_or_else(|| {
1131                        ToolError::Execution(std::io::Error::other("redirect with no Location"))
1132                    })?;
1133
1134                let base = Url::parse(&current_url)
1135                    .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())))?;
1136                let next_url = base
1137                    .join(location)
1138                    .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())))?;
1139
1140                // Re-use same host/addrs (mock server is always the same endpoint).
1141                current_url = next_url.to_string();
1142                // Preserve host/addrs as-is since the mock server doesn't change.
1143                let _ = &mut current_host;
1144                let _ = &mut current_addrs;
1145                continue;
1146            }
1147
1148            if !status.is_success() {
1149                return Err(ToolError::Execution(std::io::Error::other(format!(
1150                    "HTTP {status}",
1151                ))));
1152            }
1153
1154            let bytes = resp
1155                .bytes()
1156                .await
1157                .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())))?;
1158
1159            if bytes.len() > executor.max_body_bytes {
1160                return Err(ToolError::Execution(std::io::Error::other(format!(
1161                    "response too large: {} bytes (max: {})",
1162                    bytes.len(),
1163                    executor.max_body_bytes,
1164                ))));
1165            }
1166
1167            return String::from_utf8(bytes.to_vec())
1168                .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())));
1169        }
1170
1171        Err(ToolError::Execution(std::io::Error::other(
1172            "too many redirects",
1173        )))
1174    }
1175
1176    #[tokio::test]
1177    async fn fetch_html_success_returns_body() {
1178        use wiremock::matchers::{method, path};
1179        use wiremock::{Mock, ResponseTemplate};
1180
1181        let (executor, server) = mock_server_executor().await;
1182        Mock::given(method("GET"))
1183            .and(path("/page"))
1184            .respond_with(ResponseTemplate::new(200).set_body_string("<h1>OK</h1>"))
1185            .mount(&server)
1186            .await;
1187
1188        let (host, addrs) = server_host_and_addr(&server);
1189        let url = format!("{}/page", server.uri());
1190        let result = executor.fetch_html(&url, &host, &addrs).await;
1191        assert!(result.is_ok(), "expected Ok, got: {result:?}");
1192        assert_eq!(result.unwrap(), "<h1>OK</h1>");
1193    }
1194
1195    #[tokio::test]
1196    async fn fetch_html_non_2xx_returns_error() {
1197        use wiremock::matchers::{method, path};
1198        use wiremock::{Mock, ResponseTemplate};
1199
1200        let (executor, server) = mock_server_executor().await;
1201        Mock::given(method("GET"))
1202            .and(path("/forbidden"))
1203            .respond_with(ResponseTemplate::new(403))
1204            .mount(&server)
1205            .await;
1206
1207        let (host, addrs) = server_host_and_addr(&server);
1208        let url = format!("{}/forbidden", server.uri());
1209        let result = executor.fetch_html(&url, &host, &addrs).await;
1210        assert!(result.is_err());
1211        let msg = result.unwrap_err().to_string();
1212        assert!(msg.contains("403"), "expected 403 in error: {msg}");
1213    }
1214
1215    #[tokio::test]
1216    async fn fetch_html_404_returns_error() {
1217        use wiremock::matchers::{method, path};
1218        use wiremock::{Mock, ResponseTemplate};
1219
1220        let (executor, server) = mock_server_executor().await;
1221        Mock::given(method("GET"))
1222            .and(path("/missing"))
1223            .respond_with(ResponseTemplate::new(404))
1224            .mount(&server)
1225            .await;
1226
1227        let (host, addrs) = server_host_and_addr(&server);
1228        let url = format!("{}/missing", server.uri());
1229        let result = executor.fetch_html(&url, &host, &addrs).await;
1230        assert!(result.is_err());
1231        let msg = result.unwrap_err().to_string();
1232        assert!(msg.contains("404"), "expected 404 in error: {msg}");
1233    }
1234
1235    #[tokio::test]
1236    async fn fetch_html_redirect_no_location_returns_error() {
1237        use wiremock::matchers::{method, path};
1238        use wiremock::{Mock, ResponseTemplate};
1239
1240        let (executor, server) = mock_server_executor().await;
1241        // 302 with no Location header
1242        Mock::given(method("GET"))
1243            .and(path("/redirect-no-loc"))
1244            .respond_with(ResponseTemplate::new(302))
1245            .mount(&server)
1246            .await;
1247
1248        let (host, addrs) = server_host_and_addr(&server);
1249        let url = format!("{}/redirect-no-loc", server.uri());
1250        let result = executor.fetch_html(&url, &host, &addrs).await;
1251        assert!(result.is_err());
1252        let msg = result.unwrap_err().to_string();
1253        assert!(
1254            msg.contains("Location") || msg.contains("location"),
1255            "expected Location-related error: {msg}"
1256        );
1257    }
1258
1259    #[tokio::test]
1260    async fn fetch_html_single_redirect_followed() {
1261        use wiremock::matchers::{method, path};
1262        use wiremock::{Mock, ResponseTemplate};
1263
1264        let (executor, server) = mock_server_executor().await;
1265        let final_url = format!("{}/final", server.uri());
1266
1267        Mock::given(method("GET"))
1268            .and(path("/start"))
1269            .respond_with(ResponseTemplate::new(302).insert_header("location", final_url.as_str()))
1270            .mount(&server)
1271            .await;
1272
1273        Mock::given(method("GET"))
1274            .and(path("/final"))
1275            .respond_with(ResponseTemplate::new(200).set_body_string("<p>final</p>"))
1276            .mount(&server)
1277            .await;
1278
1279        let (host, addrs) = server_host_and_addr(&server);
1280        let url = format!("{}/start", server.uri());
1281        let result = follow_redirects_raw(&executor, &url, &host, &addrs).await;
1282        assert!(result.is_ok(), "single redirect should succeed: {result:?}");
1283        assert_eq!(result.unwrap(), "<p>final</p>");
1284    }
1285
1286    #[tokio::test]
1287    async fn fetch_html_three_redirects_allowed() {
1288        use wiremock::matchers::{method, path};
1289        use wiremock::{Mock, ResponseTemplate};
1290
1291        let (executor, server) = mock_server_executor().await;
1292        let hop2 = format!("{}/hop2", server.uri());
1293        let hop3 = format!("{}/hop3", server.uri());
1294        let final_dest = format!("{}/done", server.uri());
1295
1296        Mock::given(method("GET"))
1297            .and(path("/hop1"))
1298            .respond_with(ResponseTemplate::new(301).insert_header("location", hop2.as_str()))
1299            .mount(&server)
1300            .await;
1301        Mock::given(method("GET"))
1302            .and(path("/hop2"))
1303            .respond_with(ResponseTemplate::new(301).insert_header("location", hop3.as_str()))
1304            .mount(&server)
1305            .await;
1306        Mock::given(method("GET"))
1307            .and(path("/hop3"))
1308            .respond_with(ResponseTemplate::new(301).insert_header("location", final_dest.as_str()))
1309            .mount(&server)
1310            .await;
1311        Mock::given(method("GET"))
1312            .and(path("/done"))
1313            .respond_with(ResponseTemplate::new(200).set_body_string("<p>done</p>"))
1314            .mount(&server)
1315            .await;
1316
1317        let (host, addrs) = server_host_and_addr(&server);
1318        let url = format!("{}/hop1", server.uri());
1319        let result = follow_redirects_raw(&executor, &url, &host, &addrs).await;
1320        assert!(result.is_ok(), "3 redirects should succeed: {result:?}");
1321        assert_eq!(result.unwrap(), "<p>done</p>");
1322    }
1323
1324    #[tokio::test]
1325    async fn fetch_html_four_redirects_rejected() {
1326        use wiremock::matchers::{method, path};
1327        use wiremock::{Mock, ResponseTemplate};
1328
1329        let (executor, server) = mock_server_executor().await;
1330        let hop2 = format!("{}/r2", server.uri());
1331        let hop3 = format!("{}/r3", server.uri());
1332        let hop4 = format!("{}/r4", server.uri());
1333        let hop5 = format!("{}/r5", server.uri());
1334
1335        for (from, to) in [
1336            ("/r1", &hop2),
1337            ("/r2", &hop3),
1338            ("/r3", &hop4),
1339            ("/r4", &hop5),
1340        ] {
1341            Mock::given(method("GET"))
1342                .and(path(from))
1343                .respond_with(ResponseTemplate::new(301).insert_header("location", to.as_str()))
1344                .mount(&server)
1345                .await;
1346        }
1347
1348        let (host, addrs) = server_host_and_addr(&server);
1349        let url = format!("{}/r1", server.uri());
1350        let result = follow_redirects_raw(&executor, &url, &host, &addrs).await;
1351        assert!(result.is_err(), "4 redirects should be rejected");
1352        let msg = result.unwrap_err().to_string();
1353        assert!(
1354            msg.contains("redirect"),
1355            "expected redirect-related error: {msg}"
1356        );
1357    }
1358
1359    #[tokio::test]
1360    async fn fetch_html_body_too_large_returns_error() {
1361        use wiremock::matchers::{method, path};
1362        use wiremock::{Mock, ResponseTemplate};
1363
1364        let small_limit_executor = WebScrapeExecutor {
1365            timeout: Duration::from_secs(5),
1366            max_body_bytes: 10,
1367            allowed_domains: vec![],
1368            denied_domains: vec![],
1369            audit_logger: None,
1370        };
1371        let server = wiremock::MockServer::start().await;
1372        Mock::given(method("GET"))
1373            .and(path("/big"))
1374            .respond_with(
1375                ResponseTemplate::new(200)
1376                    .set_body_string("this body is definitely longer than ten bytes"),
1377            )
1378            .mount(&server)
1379            .await;
1380
1381        let (host, addrs) = server_host_and_addr(&server);
1382        let url = format!("{}/big", server.uri());
1383        let result = small_limit_executor.fetch_html(&url, &host, &addrs).await;
1384        assert!(result.is_err());
1385        let msg = result.unwrap_err().to_string();
1386        assert!(msg.contains("too large"), "expected too-large error: {msg}");
1387    }
1388
1389    #[test]
1390    fn extract_scrape_blocks_empty_block_content() {
1391        let text = "```scrape\n\n```";
1392        let blocks = extract_scrape_blocks(text);
1393        assert_eq!(blocks.len(), 1);
1394        assert!(blocks[0].is_empty());
1395    }
1396
1397    #[test]
1398    fn extract_scrape_blocks_whitespace_only() {
1399        let text = "```scrape\n   \n```";
1400        let blocks = extract_scrape_blocks(text);
1401        assert_eq!(blocks.len(), 1);
1402    }
1403
1404    #[test]
1405    fn parse_and_extract_multiple_selectors() {
1406        let html = "<div><h1>Title</h1><p>Para</p></div>";
1407        let result = parse_and_extract(html, "h1, p", &ExtractMode::Text, 10).unwrap();
1408        assert!(result.contains("Title"));
1409        assert!(result.contains("Para"));
1410    }
1411
1412    #[test]
1413    fn webscrape_executor_new_with_custom_config() {
1414        let config = ScrapeConfig {
1415            timeout: 60,
1416            max_body_bytes: 512,
1417            ..Default::default()
1418        };
1419        let executor = WebScrapeExecutor::new(&config);
1420        assert_eq!(executor.max_body_bytes, 512);
1421    }
1422
1423    #[test]
1424    fn webscrape_executor_debug() {
1425        let config = ScrapeConfig::default();
1426        let executor = WebScrapeExecutor::new(&config);
1427        let dbg = format!("{executor:?}");
1428        assert!(dbg.contains("WebScrapeExecutor"));
1429    }
1430
1431    #[test]
1432    fn extract_mode_attr_empty_name() {
1433        let mode = ExtractMode::parse("attr:");
1434        assert!(matches!(mode, ExtractMode::Attr(ref s) if s.is_empty()));
1435    }
1436
1437    #[test]
1438    fn default_extract_returns_text() {
1439        assert_eq!(default_extract(), "text");
1440    }
1441
1442    #[test]
1443    fn scrape_instruction_debug() {
1444        let json = r#"{"url":"https://example.com","select":"h1"}"#;
1445        let instr: ScrapeInstruction = serde_json::from_str(json).unwrap();
1446        let dbg = format!("{instr:?}");
1447        assert!(dbg.contains("ScrapeInstruction"));
1448    }
1449
1450    #[test]
1451    fn extract_mode_debug() {
1452        let mode = ExtractMode::Text;
1453        let dbg = format!("{mode:?}");
1454        assert!(dbg.contains("Text"));
1455    }
1456
1457    // --- fetch_html redirect logic: constant and validation unit tests ---
1458
1459    /// `MAX_REDIRECTS` is 3; the 4th redirect attempt must be rejected.
1460    /// Verify the boundary is correct by inspecting the constant value.
1461    #[test]
1462    fn max_redirects_constant_is_three() {
1463        // fetch_html uses `for hop in 0..=MAX_REDIRECTS` and returns error when hop == MAX_REDIRECTS
1464        // while still in a redirect. That means hops 0,1,2 can redirect; hop 3 triggers the error.
1465        // This test documents the expected limit.
1466        const MAX_REDIRECTS: usize = 3;
1467        assert_eq!(MAX_REDIRECTS, 3, "fetch_html allows exactly 3 redirects");
1468    }
1469
1470    /// Verifies that a Location-less redirect would produce an error string containing the
1471    /// expected message, matching the error path in `fetch_html`.
1472    #[test]
1473    fn redirect_no_location_error_message() {
1474        let err = std::io::Error::other("redirect with no Location");
1475        assert!(err.to_string().contains("redirect with no Location"));
1476    }
1477
1478    /// Verifies that a too-many-redirects condition produces the expected error string.
1479    #[test]
1480    fn too_many_redirects_error_message() {
1481        let err = std::io::Error::other("too many redirects");
1482        assert!(err.to_string().contains("too many redirects"));
1483    }
1484
1485    /// Verifies that a non-2xx HTTP status produces an error message with the status code.
1486    #[test]
1487    fn non_2xx_status_error_format() {
1488        let status = reqwest::StatusCode::FORBIDDEN;
1489        let msg = format!("HTTP {status}");
1490        assert!(msg.contains("403"));
1491    }
1492
1493    /// Verifies that a 404 response status code formats into the expected error message.
1494    #[test]
1495    fn not_found_status_error_format() {
1496        let status = reqwest::StatusCode::NOT_FOUND;
1497        let msg = format!("HTTP {status}");
1498        assert!(msg.contains("404"));
1499    }
1500
1501    /// Verifies relative redirect resolution for same-host paths (simulates Location: /other).
1502    #[test]
1503    fn relative_redirect_same_host_path() {
1504        let base = Url::parse("https://example.com/current").unwrap();
1505        let resolved = base.join("/other").unwrap();
1506        assert_eq!(resolved.as_str(), "https://example.com/other");
1507    }
1508
1509    /// Verifies relative redirect resolution preserves scheme and host.
1510    #[test]
1511    fn relative_redirect_relative_path() {
1512        let base = Url::parse("https://example.com/a/b").unwrap();
1513        let resolved = base.join("c").unwrap();
1514        assert_eq!(resolved.as_str(), "https://example.com/a/c");
1515    }
1516
1517    /// Verifies that an absolute redirect URL overrides base URL completely.
1518    #[test]
1519    fn absolute_redirect_overrides_base() {
1520        let base = Url::parse("https://example.com/page").unwrap();
1521        let resolved = base.join("https://other.com/target").unwrap();
1522        assert_eq!(resolved.as_str(), "https://other.com/target");
1523    }
1524
1525    /// Verifies that a redirect Location of http:// (downgrade) is rejected.
1526    #[test]
1527    fn redirect_http_downgrade_rejected() {
1528        let location = "http://example.com/page";
1529        let base = Url::parse("https://example.com/start").unwrap();
1530        let next = base.join(location).unwrap();
1531        let err = validate_url(next.as_str()).unwrap_err();
1532        assert!(matches!(err, ToolError::Blocked { .. }));
1533    }
1534
1535    /// Verifies that a redirect to a private IP literal is blocked.
1536    #[test]
1537    fn redirect_location_private_ip_blocked() {
1538        let location = "https://192.168.100.1/admin";
1539        let base = Url::parse("https://example.com/start").unwrap();
1540        let next = base.join(location).unwrap();
1541        let err = validate_url(next.as_str()).unwrap_err();
1542        assert!(matches!(err, ToolError::Blocked { .. }));
1543        let ToolError::Blocked { command: cmd } = err else {
1544            panic!("expected Blocked");
1545        };
1546        assert!(
1547            cmd.contains("private") || cmd.contains("scheme"),
1548            "error message should describe the block reason: {cmd}"
1549        );
1550    }
1551
1552    /// Verifies that a redirect to a .internal domain is blocked.
1553    #[test]
1554    fn redirect_location_internal_domain_blocked() {
1555        let location = "https://metadata.internal/latest/meta-data/";
1556        let base = Url::parse("https://example.com/start").unwrap();
1557        let next = base.join(location).unwrap();
1558        let err = validate_url(next.as_str()).unwrap_err();
1559        assert!(matches!(err, ToolError::Blocked { .. }));
1560    }
1561
1562    /// Verifies that a chain of 3 valid public redirects passes `validate_url` at every hop.
1563    #[test]
1564    fn redirect_chain_three_hops_all_public() {
1565        let hops = [
1566            "https://redirect1.example.com/hop1",
1567            "https://redirect2.example.com/hop2",
1568            "https://destination.example.com/final",
1569        ];
1570        for hop in hops {
1571            assert!(validate_url(hop).is_ok(), "expected ok for {hop}");
1572        }
1573    }
1574
1575    // --- SSRF redirect chain defense ---
1576
1577    /// Verifies that a redirect Location pointing to a private IP is rejected by `validate_url`
1578    /// before any connection attempt — simulating the validation step inside `fetch_html`.
1579    #[test]
1580    fn redirect_to_private_ip_rejected_by_validate_url() {
1581        // These would appear as Location headers in a redirect response.
1582        let private_targets = [
1583            "https://127.0.0.1/secret",
1584            "https://10.0.0.1/internal",
1585            "https://192.168.1.1/admin",
1586            "https://172.16.0.1/data",
1587            "https://[::1]/path",
1588            "https://[fe80::1]/path",
1589            "https://localhost/path",
1590            "https://service.internal/api",
1591        ];
1592        for target in private_targets {
1593            let result = validate_url(target);
1594            assert!(result.is_err(), "expected error for {target}");
1595            assert!(
1596                matches!(result.unwrap_err(), ToolError::Blocked { .. }),
1597                "expected Blocked for {target}"
1598            );
1599        }
1600    }
1601
1602    /// Verifies that relative redirect URLs are resolved correctly before validation.
1603    #[test]
1604    fn redirect_relative_url_resolves_correctly() {
1605        let base = Url::parse("https://example.com/page").unwrap();
1606        let relative = "/other";
1607        let resolved = base.join(relative).unwrap();
1608        assert_eq!(resolved.as_str(), "https://example.com/other");
1609    }
1610
1611    /// Verifies that a protocol-relative redirect to http:// is rejected (scheme check).
1612    #[test]
1613    fn redirect_to_http_rejected() {
1614        let err = validate_url("http://example.com/page").unwrap_err();
1615        assert!(matches!(err, ToolError::Blocked { .. }));
1616    }
1617
1618    #[test]
1619    fn ipv4_mapped_ipv6_link_local_blocked() {
1620        let err = validate_url("https://[::ffff:169.254.0.1]/path").unwrap_err();
1621        assert!(matches!(err, ToolError::Blocked { .. }));
1622    }
1623
1624    #[test]
1625    fn ipv4_mapped_ipv6_public_allowed() {
1626        assert!(validate_url("https://[::ffff:93.184.216.34]/path").is_ok());
1627    }
1628
1629    // --- fetch tool ---
1630
1631    #[tokio::test]
1632    async fn fetch_http_scheme_blocked() {
1633        let config = ScrapeConfig::default();
1634        let executor = WebScrapeExecutor::new(&config);
1635        let call = crate::executor::ToolCall {
1636            tool_id: "fetch".to_owned(),
1637            params: {
1638                let mut m = serde_json::Map::new();
1639                m.insert("url".to_owned(), serde_json::json!("http://example.com"));
1640                m
1641            },
1642        };
1643        let result = executor.execute_tool_call(&call).await;
1644        assert!(matches!(result, Err(ToolError::Blocked { .. })));
1645    }
1646
1647    #[tokio::test]
1648    async fn fetch_private_ip_blocked() {
1649        let config = ScrapeConfig::default();
1650        let executor = WebScrapeExecutor::new(&config);
1651        let call = crate::executor::ToolCall {
1652            tool_id: "fetch".to_owned(),
1653            params: {
1654                let mut m = serde_json::Map::new();
1655                m.insert(
1656                    "url".to_owned(),
1657                    serde_json::json!("https://192.168.1.1/secret"),
1658                );
1659                m
1660            },
1661        };
1662        let result = executor.execute_tool_call(&call).await;
1663        assert!(matches!(result, Err(ToolError::Blocked { .. })));
1664    }
1665
1666    #[tokio::test]
1667    async fn fetch_localhost_blocked() {
1668        let config = ScrapeConfig::default();
1669        let executor = WebScrapeExecutor::new(&config);
1670        let call = crate::executor::ToolCall {
1671            tool_id: "fetch".to_owned(),
1672            params: {
1673                let mut m = serde_json::Map::new();
1674                m.insert(
1675                    "url".to_owned(),
1676                    serde_json::json!("https://localhost/page"),
1677                );
1678                m
1679            },
1680        };
1681        let result = executor.execute_tool_call(&call).await;
1682        assert!(matches!(result, Err(ToolError::Blocked { .. })));
1683    }
1684
1685    #[tokio::test]
1686    async fn fetch_unknown_tool_returns_none() {
1687        let config = ScrapeConfig::default();
1688        let executor = WebScrapeExecutor::new(&config);
1689        let call = crate::executor::ToolCall {
1690            tool_id: "unknown_tool".to_owned(),
1691            params: serde_json::Map::new(),
1692        };
1693        let result = executor.execute_tool_call(&call).await;
1694        assert!(result.unwrap().is_none());
1695    }
1696
1697    #[tokio::test]
1698    async fn fetch_returns_body_via_mock() {
1699        use wiremock::matchers::{method, path};
1700        use wiremock::{Mock, ResponseTemplate};
1701
1702        let (executor, server) = mock_server_executor().await;
1703        Mock::given(method("GET"))
1704            .and(path("/content"))
1705            .respond_with(ResponseTemplate::new(200).set_body_string("plain text content"))
1706            .mount(&server)
1707            .await;
1708
1709        let (host, addrs) = server_host_and_addr(&server);
1710        let url = format!("{}/content", server.uri());
1711        let result = executor.fetch_html(&url, &host, &addrs).await;
1712        assert!(result.is_ok());
1713        assert_eq!(result.unwrap(), "plain text content");
1714    }
1715
1716    #[test]
1717    fn tool_definitions_returns_web_scrape_and_fetch() {
1718        let config = ScrapeConfig::default();
1719        let executor = WebScrapeExecutor::new(&config);
1720        let defs = executor.tool_definitions();
1721        assert_eq!(defs.len(), 2);
1722        assert_eq!(defs[0].id, "web_scrape");
1723        assert_eq!(
1724            defs[0].invocation,
1725            crate::registry::InvocationHint::FencedBlock("scrape")
1726        );
1727        assert_eq!(defs[1].id, "fetch");
1728        assert_eq!(
1729            defs[1].invocation,
1730            crate::registry::InvocationHint::ToolCall
1731        );
1732    }
1733
1734    #[test]
1735    fn tool_definitions_schema_has_all_params() {
1736        let config = ScrapeConfig::default();
1737        let executor = WebScrapeExecutor::new(&config);
1738        let defs = executor.tool_definitions();
1739        let obj = defs[0].schema.as_object().unwrap();
1740        let props = obj["properties"].as_object().unwrap();
1741        assert!(props.contains_key("url"));
1742        assert!(props.contains_key("select"));
1743        assert!(props.contains_key("extract"));
1744        assert!(props.contains_key("limit"));
1745        let req = obj["required"].as_array().unwrap();
1746        assert!(req.iter().any(|v| v.as_str() == Some("url")));
1747        assert!(req.iter().any(|v| v.as_str() == Some("select")));
1748        assert!(!req.iter().any(|v| v.as_str() == Some("extract")));
1749    }
1750
1751    // --- is_private_host: new domain checks (AUD-02) ---
1752
1753    #[test]
1754    fn subdomain_localhost_blocked() {
1755        let host: url::Host<&str> = url::Host::Domain("foo.localhost");
1756        assert!(is_private_host(&host));
1757    }
1758
1759    #[test]
1760    fn internal_tld_blocked() {
1761        let host: url::Host<&str> = url::Host::Domain("service.internal");
1762        assert!(is_private_host(&host));
1763    }
1764
1765    #[test]
1766    fn local_tld_blocked() {
1767        let host: url::Host<&str> = url::Host::Domain("printer.local");
1768        assert!(is_private_host(&host));
1769    }
1770
1771    #[test]
1772    fn public_domain_not_blocked() {
1773        let host: url::Host<&str> = url::Host::Domain("example.com");
1774        assert!(!is_private_host(&host));
1775    }
1776
1777    // --- resolve_and_validate: private IP rejection ---
1778
1779    #[tokio::test]
1780    async fn resolve_loopback_rejected() {
1781        // 127.0.0.1 resolves directly (literal IP in DNS query)
1782        let url = url::Url::parse("https://127.0.0.1/path").unwrap();
1783        // validate_url catches this before resolve_and_validate, but test directly
1784        let result = resolve_and_validate(&url).await;
1785        assert!(
1786            result.is_err(),
1787            "loopback IP must be rejected by resolve_and_validate"
1788        );
1789        let err = result.unwrap_err();
1790        assert!(matches!(err, crate::executor::ToolError::Blocked { .. }));
1791    }
1792
1793    #[tokio::test]
1794    async fn resolve_private_10_rejected() {
1795        let url = url::Url::parse("https://10.0.0.1/path").unwrap();
1796        let result = resolve_and_validate(&url).await;
1797        assert!(result.is_err());
1798        assert!(matches!(
1799            result.unwrap_err(),
1800            crate::executor::ToolError::Blocked { .. }
1801        ));
1802    }
1803
1804    #[tokio::test]
1805    async fn resolve_private_192_rejected() {
1806        let url = url::Url::parse("https://192.168.1.1/path").unwrap();
1807        let result = resolve_and_validate(&url).await;
1808        assert!(result.is_err());
1809        assert!(matches!(
1810            result.unwrap_err(),
1811            crate::executor::ToolError::Blocked { .. }
1812        ));
1813    }
1814
1815    #[tokio::test]
1816    async fn resolve_ipv6_loopback_rejected() {
1817        let url = url::Url::parse("https://[::1]/path").unwrap();
1818        let result = resolve_and_validate(&url).await;
1819        assert!(result.is_err());
1820        assert!(matches!(
1821            result.unwrap_err(),
1822            crate::executor::ToolError::Blocked { .. }
1823        ));
1824    }
1825
1826    #[tokio::test]
1827    async fn resolve_no_host_returns_ok() {
1828        // URL without a resolvable host — should pass through
1829        let url = url::Url::parse("https://example.com/path").unwrap();
1830        // We can't do a live DNS test, but we can verify a URL with no host
1831        let url_no_host = url::Url::parse("data:text/plain,hello").unwrap();
1832        // data: URLs have no host; resolve_and_validate should return Ok with empty addrs
1833        let result = resolve_and_validate(&url_no_host).await;
1834        assert!(result.is_ok());
1835        let (host, addrs) = result.unwrap();
1836        assert!(host.is_empty());
1837        assert!(addrs.is_empty());
1838        drop(url);
1839        drop(url_no_host);
1840    }
1841
1842    // --- audit logging ---
1843
1844    /// Helper: build an `AuditLogger` writing to a temp file, and return the logger + path.
1845    async fn make_file_audit_logger(
1846        dir: &tempfile::TempDir,
1847    ) -> (
1848        std::sync::Arc<crate::audit::AuditLogger>,
1849        std::path::PathBuf,
1850    ) {
1851        use crate::audit::AuditLogger;
1852        use crate::config::AuditConfig;
1853        let path = dir.path().join("audit.log");
1854        let config = AuditConfig {
1855            enabled: true,
1856            destination: path.display().to_string(),
1857            ..Default::default()
1858        };
1859        let logger = std::sync::Arc::new(AuditLogger::from_config(&config).await.unwrap());
1860        (logger, path)
1861    }
1862
1863    #[tokio::test]
1864    async fn with_audit_sets_logger() {
1865        let config = ScrapeConfig::default();
1866        let executor = WebScrapeExecutor::new(&config);
1867        assert!(executor.audit_logger.is_none());
1868
1869        let dir = tempfile::tempdir().unwrap();
1870        let (logger, _path) = make_file_audit_logger(&dir).await;
1871        let executor = executor.with_audit(logger);
1872        assert!(executor.audit_logger.is_some());
1873    }
1874
1875    #[test]
1876    fn tool_error_to_audit_result_blocked_maps_correctly() {
1877        let err = ToolError::Blocked {
1878            command: "scheme not allowed: http".into(),
1879        };
1880        let result = tool_error_to_audit_result(&err);
1881        assert!(
1882            matches!(result, AuditResult::Blocked { reason } if reason == "scheme not allowed: http")
1883        );
1884    }
1885
1886    #[test]
1887    fn tool_error_to_audit_result_timeout_maps_correctly() {
1888        let err = ToolError::Timeout { timeout_secs: 15 };
1889        let result = tool_error_to_audit_result(&err);
1890        assert!(matches!(result, AuditResult::Timeout));
1891    }
1892
1893    #[test]
1894    fn tool_error_to_audit_result_execution_error_maps_correctly() {
1895        let err = ToolError::Execution(std::io::Error::other("connection refused"));
1896        let result = tool_error_to_audit_result(&err);
1897        assert!(
1898            matches!(result, AuditResult::Error { message } if message.contains("connection refused"))
1899        );
1900    }
1901
1902    #[tokio::test]
1903    async fn fetch_audit_blocked_url_logged() {
1904        let dir = tempfile::tempdir().unwrap();
1905        let (logger, log_path) = make_file_audit_logger(&dir).await;
1906
1907        let config = ScrapeConfig::default();
1908        let executor = WebScrapeExecutor::new(&config).with_audit(logger);
1909
1910        let call = crate::executor::ToolCall {
1911            tool_id: "fetch".to_owned(),
1912            params: {
1913                let mut m = serde_json::Map::new();
1914                m.insert("url".to_owned(), serde_json::json!("http://example.com"));
1915                m
1916            },
1917        };
1918        let result = executor.execute_tool_call(&call).await;
1919        assert!(matches!(result, Err(ToolError::Blocked { .. })));
1920
1921        let content = tokio::fs::read_to_string(&log_path).await.unwrap();
1922        assert!(
1923            content.contains("\"tool\":\"fetch\""),
1924            "expected tool=fetch in audit: {content}"
1925        );
1926        assert!(
1927            content.contains("\"type\":\"blocked\""),
1928            "expected type=blocked in audit: {content}"
1929        );
1930        assert!(
1931            content.contains("http://example.com"),
1932            "expected URL in audit command field: {content}"
1933        );
1934    }
1935
1936    #[tokio::test]
1937    async fn log_audit_success_writes_to_file() {
1938        let dir = tempfile::tempdir().unwrap();
1939        let (logger, log_path) = make_file_audit_logger(&dir).await;
1940
1941        let config = ScrapeConfig::default();
1942        let executor = WebScrapeExecutor::new(&config).with_audit(logger);
1943
1944        executor
1945            .log_audit(
1946                "fetch",
1947                "https://example.com/page",
1948                AuditResult::Success,
1949                42,
1950                None,
1951            )
1952            .await;
1953
1954        let content = tokio::fs::read_to_string(&log_path).await.unwrap();
1955        assert!(
1956            content.contains("\"tool\":\"fetch\""),
1957            "expected tool=fetch in audit: {content}"
1958        );
1959        assert!(
1960            content.contains("\"type\":\"success\""),
1961            "expected type=success in audit: {content}"
1962        );
1963        assert!(
1964            content.contains("\"command\":\"https://example.com/page\""),
1965            "expected command URL in audit: {content}"
1966        );
1967        assert!(
1968            content.contains("\"duration_ms\":42"),
1969            "expected duration_ms in audit: {content}"
1970        );
1971    }
1972
1973    #[tokio::test]
1974    async fn log_audit_blocked_writes_to_file() {
1975        let dir = tempfile::tempdir().unwrap();
1976        let (logger, log_path) = make_file_audit_logger(&dir).await;
1977
1978        let config = ScrapeConfig::default();
1979        let executor = WebScrapeExecutor::new(&config).with_audit(logger);
1980
1981        executor
1982            .log_audit(
1983                "web_scrape",
1984                "http://evil.com/page",
1985                AuditResult::Blocked {
1986                    reason: "scheme not allowed: http".into(),
1987                },
1988                0,
1989                None,
1990            )
1991            .await;
1992
1993        let content = tokio::fs::read_to_string(&log_path).await.unwrap();
1994        assert!(
1995            content.contains("\"tool\":\"web_scrape\""),
1996            "expected tool=web_scrape in audit: {content}"
1997        );
1998        assert!(
1999            content.contains("\"type\":\"blocked\""),
2000            "expected type=blocked in audit: {content}"
2001        );
2002        assert!(
2003            content.contains("scheme not allowed"),
2004            "expected block reason in audit: {content}"
2005        );
2006    }
2007
2008    #[tokio::test]
2009    async fn web_scrape_audit_blocked_url_logged() {
2010        let dir = tempfile::tempdir().unwrap();
2011        let (logger, log_path) = make_file_audit_logger(&dir).await;
2012
2013        let config = ScrapeConfig::default();
2014        let executor = WebScrapeExecutor::new(&config).with_audit(logger);
2015
2016        let call = crate::executor::ToolCall {
2017            tool_id: "web_scrape".to_owned(),
2018            params: {
2019                let mut m = serde_json::Map::new();
2020                m.insert("url".to_owned(), serde_json::json!("http://example.com"));
2021                m.insert("select".to_owned(), serde_json::json!("h1"));
2022                m
2023            },
2024        };
2025        let result = executor.execute_tool_call(&call).await;
2026        assert!(matches!(result, Err(ToolError::Blocked { .. })));
2027
2028        let content = tokio::fs::read_to_string(&log_path).await.unwrap();
2029        assert!(
2030            content.contains("\"tool\":\"web_scrape\""),
2031            "expected tool=web_scrape in audit: {content}"
2032        );
2033        assert!(
2034            content.contains("\"type\":\"blocked\""),
2035            "expected type=blocked in audit: {content}"
2036        );
2037    }
2038
2039    #[tokio::test]
2040    async fn no_audit_logger_does_not_panic_on_blocked_fetch() {
2041        let config = ScrapeConfig::default();
2042        let executor = WebScrapeExecutor::new(&config);
2043        assert!(executor.audit_logger.is_none());
2044
2045        let call = crate::executor::ToolCall {
2046            tool_id: "fetch".to_owned(),
2047            params: {
2048                let mut m = serde_json::Map::new();
2049                m.insert("url".to_owned(), serde_json::json!("http://example.com"));
2050                m
2051            },
2052        };
2053        // Must not panic even without an audit logger
2054        let result = executor.execute_tool_call(&call).await;
2055        assert!(matches!(result, Err(ToolError::Blocked { .. })));
2056    }
2057
2058    // CR-10: fetch end-to-end via execute_tool_call -> handle_fetch -> fetch_html
2059    #[tokio::test]
2060    async fn fetch_execute_tool_call_end_to_end() {
2061        use wiremock::matchers::{method, path};
2062        use wiremock::{Mock, ResponseTemplate};
2063
2064        let (executor, server) = mock_server_executor().await;
2065        Mock::given(method("GET"))
2066            .and(path("/e2e"))
2067            .respond_with(ResponseTemplate::new(200).set_body_string("<h1>end-to-end</h1>"))
2068            .mount(&server)
2069            .await;
2070
2071        let (host, addrs) = server_host_and_addr(&server);
2072        // Call fetch_html directly (bypassing SSRF guard for loopback mock server)
2073        let result = executor
2074            .fetch_html(&format!("{}/e2e", server.uri()), &host, &addrs)
2075            .await;
2076        assert!(result.is_ok());
2077        assert!(result.unwrap().contains("end-to-end"));
2078    }
2079
2080    // --- domain_matches ---
2081
2082    #[test]
2083    fn domain_matches_exact() {
2084        assert!(domain_matches("example.com", "example.com"));
2085        assert!(!domain_matches("example.com", "other.com"));
2086        assert!(!domain_matches("example.com", "sub.example.com"));
2087    }
2088
2089    #[test]
2090    fn domain_matches_wildcard_single_subdomain() {
2091        assert!(domain_matches("*.example.com", "sub.example.com"));
2092        assert!(!domain_matches("*.example.com", "example.com"));
2093        assert!(!domain_matches("*.example.com", "sub.sub.example.com"));
2094    }
2095
2096    #[test]
2097    fn domain_matches_wildcard_does_not_match_empty_label() {
2098        // Pattern "*.example.com" requires a non-empty label before ".example.com"
2099        assert!(!domain_matches("*.example.com", ".example.com"));
2100    }
2101
2102    #[test]
2103    fn domain_matches_multi_wildcard_treated_as_exact() {
2104        // Multiple wildcards are unsupported — treated as literal pattern
2105        assert!(!domain_matches("*.*.example.com", "a.b.example.com"));
2106    }
2107
2108    // --- check_domain_policy ---
2109
2110    #[test]
2111    fn check_domain_policy_empty_lists_allow_all() {
2112        assert!(check_domain_policy("example.com", &[], &[]).is_ok());
2113        assert!(check_domain_policy("evil.com", &[], &[]).is_ok());
2114    }
2115
2116    #[test]
2117    fn check_domain_policy_denylist_blocks() {
2118        let denied = vec!["evil.com".to_string()];
2119        let err = check_domain_policy("evil.com", &[], &denied).unwrap_err();
2120        assert!(matches!(err, ToolError::Blocked { .. }));
2121    }
2122
2123    #[test]
2124    fn check_domain_policy_denylist_does_not_block_other_domains() {
2125        let denied = vec!["evil.com".to_string()];
2126        assert!(check_domain_policy("good.com", &[], &denied).is_ok());
2127    }
2128
2129    #[test]
2130    fn check_domain_policy_allowlist_permits_matching() {
2131        let allowed = vec!["docs.rs".to_string(), "*.rust-lang.org".to_string()];
2132        assert!(check_domain_policy("docs.rs", &allowed, &[]).is_ok());
2133        assert!(check_domain_policy("blog.rust-lang.org", &allowed, &[]).is_ok());
2134    }
2135
2136    #[test]
2137    fn check_domain_policy_allowlist_blocks_unknown() {
2138        let allowed = vec!["docs.rs".to_string()];
2139        let err = check_domain_policy("other.com", &allowed, &[]).unwrap_err();
2140        assert!(matches!(err, ToolError::Blocked { .. }));
2141    }
2142
2143    #[test]
2144    fn check_domain_policy_deny_overrides_allow() {
2145        let allowed = vec!["example.com".to_string()];
2146        let denied = vec!["example.com".to_string()];
2147        let err = check_domain_policy("example.com", &allowed, &denied).unwrap_err();
2148        assert!(matches!(err, ToolError::Blocked { .. }));
2149    }
2150
2151    #[test]
2152    fn check_domain_policy_wildcard_in_denylist() {
2153        let denied = vec!["*.evil.com".to_string()];
2154        let err = check_domain_policy("sub.evil.com", &[], &denied).unwrap_err();
2155        assert!(matches!(err, ToolError::Blocked { .. }));
2156        // parent domain not blocked
2157        assert!(check_domain_policy("evil.com", &[], &denied).is_ok());
2158    }
2159}