Skip to main content

zeph_tools/
scrape.rs

1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4//! Web scraping executor with SSRF protection and domain policy enforcement.
5//!
6//! Exposes two tools to the LLM:
7//!
8//! - **`web_scrape`** — fetches a URL and extracts elements matching a CSS selector.
9//! - **`fetch`** — fetches a URL and returns the raw response body as UTF-8 text.
10//!
11//! Both tools enforce:
12//!
13//! - HTTPS-only URLs (HTTP and other schemes are rejected).
14//! - DNS resolution followed by a private-IP check to prevent SSRF.
15//! - Optional domain allowlist and denylist from [`ScrapeConfig`].
16//! - Configurable timeout and maximum response body size.
17//! - Redirect following is disabled to prevent open-redirect SSRF bypasses.
18
19use std::net::{IpAddr, SocketAddr};
20use std::sync::Arc;
21use std::time::{Duration, Instant};
22
23use schemars::JsonSchema;
24use serde::Deserialize;
25use url::Url;
26
27use zeph_common::ToolName;
28
29use crate::audit::{AuditEntry, AuditLogger, AuditResult, chrono_now};
30use crate::config::ScrapeConfig;
31use crate::executor::{
32    ClaimSource, ToolCall, ToolError, ToolExecutor, ToolOutput, deserialize_params,
33};
34use crate::net::is_private_ip;
35
36#[derive(Debug, Deserialize, JsonSchema)]
37struct FetchParams {
38    /// HTTPS URL to fetch
39    url: String,
40}
41
42#[derive(Debug, Deserialize, JsonSchema)]
43struct ScrapeInstruction {
44    /// HTTPS URL to scrape
45    url: String,
46    /// CSS selector
47    select: String,
48    /// Extract mode: text, html, or attr:<name>
49    #[serde(default = "default_extract")]
50    extract: String,
51    /// Max results to return
52    limit: Option<usize>,
53}
54
55fn default_extract() -> String {
56    "text".into()
57}
58
59#[derive(Debug)]
60enum ExtractMode {
61    Text,
62    Html,
63    Attr(String),
64}
65
66impl ExtractMode {
67    fn parse(s: &str) -> Self {
68        match s {
69            "text" => Self::Text,
70            "html" => Self::Html,
71            attr if attr.starts_with("attr:") => {
72                Self::Attr(attr.strip_prefix("attr:").unwrap_or(attr).to_owned())
73            }
74            _ => Self::Text,
75        }
76    }
77}
78
79/// Extracts data from web pages via CSS selectors.
80///
81/// Handles two invocation paths:
82///
83/// 1. **Legacy fenced blocks** — detects ` ```scrape ` blocks in the LLM response, each
84///    containing a JSON scrape instruction object. Dispatched via [`ToolExecutor::execute`].
85/// 2. **Structured tool calls** — dispatched via [`ToolExecutor::execute_tool_call`] for
86///    tool IDs `"web_scrape"` and `"fetch"`.
87///
88/// # Security
89///
90/// - Only HTTPS URLs are accepted. HTTP and other schemes return [`ToolError::InvalidParams`].
91/// - DNS is resolved synchronously and each resolved address is checked against
92///   [`is_private_ip`]. Private addresses are rejected to prevent SSRF.
93/// - HTTP redirects are disabled (`Policy::none()`) to prevent open-redirect bypasses.
94/// - Domain allowlists and denylists from config are enforced before DNS resolution.
95///
96/// # Example
97///
98/// ```rust,no_run
99/// use zeph_tools::{WebScrapeExecutor, ToolExecutor, ToolCall, config::ScrapeConfig};
100/// use zeph_common::ToolName;
101///
102/// # async fn example() {
103/// let executor = WebScrapeExecutor::new(&ScrapeConfig::default());
104///
105/// let call = ToolCall {
106///     tool_id: ToolName::new("fetch"),
107///     params: {
108///         let mut m = serde_json::Map::new();
109///         m.insert("url".to_owned(), serde_json::json!("https://example.com"));
110///         m
111///     },
112///     caller_id: None,
113/// };
114/// let _ = executor.execute_tool_call(&call).await;
115/// # }
116/// ```
117#[derive(Debug)]
118pub struct WebScrapeExecutor {
119    timeout: Duration,
120    max_body_bytes: usize,
121    allowed_domains: Vec<String>,
122    denied_domains: Vec<String>,
123    audit_logger: Option<Arc<AuditLogger>>,
124}
125
126impl WebScrapeExecutor {
127    /// Create a new `WebScrapeExecutor` from configuration.
128    ///
129    /// No network connections are made at construction time.
130    #[must_use]
131    pub fn new(config: &ScrapeConfig) -> Self {
132        Self {
133            timeout: Duration::from_secs(config.timeout),
134            max_body_bytes: config.max_body_bytes,
135            allowed_domains: config.allowed_domains.clone(),
136            denied_domains: config.denied_domains.clone(),
137            audit_logger: None,
138        }
139    }
140
141    /// Attach an audit logger. Each tool invocation will emit an [`AuditEntry`].
142    #[must_use]
143    pub fn with_audit(mut self, logger: Arc<AuditLogger>) -> Self {
144        self.audit_logger = Some(logger);
145        self
146    }
147
148    fn build_client(&self, host: &str, addrs: &[SocketAddr]) -> reqwest::Client {
149        let mut builder = reqwest::Client::builder()
150            .timeout(self.timeout)
151            .redirect(reqwest::redirect::Policy::none());
152        builder = builder.resolve_to_addrs(host, addrs);
153        builder.build().unwrap_or_default()
154    }
155}
156
157impl ToolExecutor for WebScrapeExecutor {
158    fn tool_definitions(&self) -> Vec<crate::registry::ToolDef> {
159        use crate::registry::{InvocationHint, ToolDef};
160        vec![
161            ToolDef {
162                id: "web_scrape".into(),
163                description: "Extract structured data from a web page using CSS selectors.\n\nONLY call this tool when the user has explicitly provided a URL in their message, or when a prior tool call returned a URL to retrieve. NEVER construct, guess, or infer a URL from entity names, brand knowledge, or domain patterns.\n\nParameters: url (string, required) - HTTPS URL; select (string, required) - CSS selector; extract (string, optional) - \"text\", \"html\", or \"attr:<name>\"; limit (integer, optional) - max results\nReturns: extracted text/HTML/attribute values, one per line\nErrors: InvalidParams if URL is not HTTPS or selector is empty; Timeout after configured seconds; connection/DNS failures".into(),
164                schema: schemars::schema_for!(ScrapeInstruction),
165                invocation: InvocationHint::FencedBlock("scrape"),
166            },
167            ToolDef {
168                id: "fetch".into(),
169                description: "Fetch a URL and return the response body as plain text.\n\nONLY call this tool when the user has explicitly provided a URL in their message, or when a prior tool call returned a URL to retrieve. NEVER construct, guess, or infer a URL from entity names, brand knowledge, or domain patterns. If no URL is present in the conversation, do not call this tool.\n\nParameters: url (string, required) - HTTPS URL to fetch\nReturns: response body as UTF-8 text, truncated if exceeding max body size\nErrors: InvalidParams if URL is not HTTPS; Timeout; SSRF-blocked private IPs; connection failures".into(),
170                schema: schemars::schema_for!(FetchParams),
171                invocation: InvocationHint::ToolCall,
172            },
173        ]
174    }
175
176    async fn execute(&self, response: &str) -> Result<Option<ToolOutput>, ToolError> {
177        let blocks = extract_scrape_blocks(response);
178        if blocks.is_empty() {
179            return Ok(None);
180        }
181
182        let mut outputs = Vec::with_capacity(blocks.len());
183        #[allow(clippy::cast_possible_truncation)]
184        let blocks_executed = blocks.len() as u32;
185
186        for block in &blocks {
187            let instruction: ScrapeInstruction = serde_json::from_str(block).map_err(|e| {
188                ToolError::Execution(std::io::Error::new(
189                    std::io::ErrorKind::InvalidData,
190                    e.to_string(),
191                ))
192            })?;
193            let start = Instant::now();
194            let scrape_result = self.scrape_instruction(&instruction).await;
195            #[allow(clippy::cast_possible_truncation)]
196            let duration_ms = start.elapsed().as_millis() as u64;
197            match scrape_result {
198                Ok(output) => {
199                    self.log_audit(
200                        "web_scrape",
201                        &instruction.url,
202                        AuditResult::Success,
203                        duration_ms,
204                        None,
205                    )
206                    .await;
207                    outputs.push(output);
208                }
209                Err(e) => {
210                    let audit_result = tool_error_to_audit_result(&e);
211                    self.log_audit(
212                        "web_scrape",
213                        &instruction.url,
214                        audit_result,
215                        duration_ms,
216                        Some(&e),
217                    )
218                    .await;
219                    return Err(e);
220                }
221            }
222        }
223
224        Ok(Some(ToolOutput {
225            tool_name: ToolName::new("web-scrape"),
226            summary: outputs.join("\n\n"),
227            blocks_executed,
228            filter_stats: None,
229            diff: None,
230            streamed: false,
231            terminal_id: None,
232            locations: None,
233            raw_response: None,
234            claim_source: Some(ClaimSource::WebScrape),
235        }))
236    }
237
238    #[cfg_attr(
239        feature = "profiling",
240        tracing::instrument(name = "tool.web_scrape", skip_all)
241    )]
242    async fn execute_tool_call(&self, call: &ToolCall) -> Result<Option<ToolOutput>, ToolError> {
243        match call.tool_id.as_str() {
244            "web_scrape" => {
245                let instruction: ScrapeInstruction = deserialize_params(&call.params)?;
246                let start = Instant::now();
247                let result = self.scrape_instruction(&instruction).await;
248                #[allow(clippy::cast_possible_truncation)]
249                let duration_ms = start.elapsed().as_millis() as u64;
250                match result {
251                    Ok(output) => {
252                        self.log_audit(
253                            "web_scrape",
254                            &instruction.url,
255                            AuditResult::Success,
256                            duration_ms,
257                            None,
258                        )
259                        .await;
260                        Ok(Some(ToolOutput {
261                            tool_name: ToolName::new("web-scrape"),
262                            summary: output,
263                            blocks_executed: 1,
264                            filter_stats: None,
265                            diff: None,
266                            streamed: false,
267                            terminal_id: None,
268                            locations: None,
269                            raw_response: None,
270                            claim_source: Some(ClaimSource::WebScrape),
271                        }))
272                    }
273                    Err(e) => {
274                        let audit_result = tool_error_to_audit_result(&e);
275                        self.log_audit(
276                            "web_scrape",
277                            &instruction.url,
278                            audit_result,
279                            duration_ms,
280                            Some(&e),
281                        )
282                        .await;
283                        Err(e)
284                    }
285                }
286            }
287            "fetch" => {
288                let p: FetchParams = deserialize_params(&call.params)?;
289                let start = Instant::now();
290                let result = self.handle_fetch(&p).await;
291                #[allow(clippy::cast_possible_truncation)]
292                let duration_ms = start.elapsed().as_millis() as u64;
293                match result {
294                    Ok(output) => {
295                        self.log_audit("fetch", &p.url, AuditResult::Success, duration_ms, None)
296                            .await;
297                        Ok(Some(ToolOutput {
298                            tool_name: ToolName::new("fetch"),
299                            summary: output,
300                            blocks_executed: 1,
301                            filter_stats: None,
302                            diff: None,
303                            streamed: false,
304                            terminal_id: None,
305                            locations: None,
306                            raw_response: None,
307                            claim_source: Some(ClaimSource::WebScrape),
308                        }))
309                    }
310                    Err(e) => {
311                        let audit_result = tool_error_to_audit_result(&e);
312                        self.log_audit("fetch", &p.url, audit_result, duration_ms, Some(&e))
313                            .await;
314                        Err(e)
315                    }
316                }
317            }
318            _ => Ok(None),
319        }
320    }
321
322    fn is_tool_retryable(&self, tool_id: &str) -> bool {
323        matches!(tool_id, "web_scrape" | "fetch")
324    }
325}
326
327fn tool_error_to_audit_result(e: &ToolError) -> AuditResult {
328    match e {
329        ToolError::Blocked { command } => AuditResult::Blocked {
330            reason: command.clone(),
331        },
332        ToolError::Timeout { .. } => AuditResult::Timeout,
333        _ => AuditResult::Error {
334            message: e.to_string(),
335        },
336    }
337}
338
339impl WebScrapeExecutor {
340    async fn log_audit(
341        &self,
342        tool: &str,
343        command: &str,
344        result: AuditResult,
345        duration_ms: u64,
346        error: Option<&ToolError>,
347    ) {
348        if let Some(ref logger) = self.audit_logger {
349            let (error_category, error_domain, error_phase) =
350                error.map_or((None, None, None), |e| {
351                    let cat = e.category();
352                    (
353                        Some(cat.label().to_owned()),
354                        Some(cat.domain().label().to_owned()),
355                        Some(cat.phase().label().to_owned()),
356                    )
357                });
358            let entry = AuditEntry {
359                timestamp: chrono_now(),
360                tool: tool.into(),
361                command: command.into(),
362                result,
363                duration_ms,
364                error_category,
365                error_domain,
366                error_phase,
367                claim_source: Some(ClaimSource::WebScrape),
368                mcp_server_id: None,
369                injection_flagged: false,
370                embedding_anomalous: false,
371                cross_boundary_mcp_to_acp: false,
372                adversarial_policy_decision: None,
373                exit_code: None,
374                truncated: false,
375                caller_id: None,
376                policy_match: None,
377            };
378            logger.log(&entry).await;
379        }
380    }
381
382    async fn handle_fetch(&self, params: &FetchParams) -> Result<String, ToolError> {
383        let parsed = validate_url(&params.url)?;
384        check_domain_policy(
385            parsed.host_str().unwrap_or(""),
386            &self.allowed_domains,
387            &self.denied_domains,
388        )?;
389        let (host, addrs) = resolve_and_validate(&parsed).await?;
390        self.fetch_html(&params.url, &host, &addrs).await
391    }
392
393    async fn scrape_instruction(
394        &self,
395        instruction: &ScrapeInstruction,
396    ) -> Result<String, ToolError> {
397        let parsed = validate_url(&instruction.url)?;
398        check_domain_policy(
399            parsed.host_str().unwrap_or(""),
400            &self.allowed_domains,
401            &self.denied_domains,
402        )?;
403        let (host, addrs) = resolve_and_validate(&parsed).await?;
404        let html = self.fetch_html(&instruction.url, &host, &addrs).await?;
405        let selector = instruction.select.clone();
406        let extract = ExtractMode::parse(&instruction.extract);
407        let limit = instruction.limit.unwrap_or(10);
408        tokio::task::spawn_blocking(move || parse_and_extract(&html, &selector, &extract, limit))
409            .await
410            .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())))?
411    }
412
413    /// Fetches the HTML at `url`, manually following up to 3 redirects.
414    ///
415    /// Each redirect target is validated with `validate_url` and `resolve_and_validate`
416    /// before following, preventing SSRF via redirect chains.
417    ///
418    /// # Errors
419    ///
420    /// Returns `ToolError::Blocked` if any redirect target resolves to a private IP.
421    /// Returns `ToolError::Execution` on HTTP errors, too-large bodies, or too many redirects.
422    async fn fetch_html(
423        &self,
424        url: &str,
425        host: &str,
426        addrs: &[SocketAddr],
427    ) -> Result<String, ToolError> {
428        const MAX_REDIRECTS: usize = 3;
429
430        let mut current_url = url.to_owned();
431        let mut current_host = host.to_owned();
432        let mut current_addrs = addrs.to_vec();
433
434        for hop in 0..=MAX_REDIRECTS {
435            // Build a per-hop client pinned to the current hop's validated addresses.
436            let client = self.build_client(&current_host, &current_addrs);
437            let resp = client
438                .get(&current_url)
439                .send()
440                .await
441                .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())))?;
442
443            let status = resp.status();
444
445            if status.is_redirection() {
446                if hop == MAX_REDIRECTS {
447                    return Err(ToolError::Execution(std::io::Error::other(
448                        "too many redirects",
449                    )));
450                }
451
452                let location = resp
453                    .headers()
454                    .get(reqwest::header::LOCATION)
455                    .and_then(|v| v.to_str().ok())
456                    .ok_or_else(|| {
457                        ToolError::Execution(std::io::Error::other("redirect with no Location"))
458                    })?;
459
460                // Resolve relative redirect URLs against the current URL.
461                let base = Url::parse(&current_url)
462                    .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())))?;
463                let next_url = base
464                    .join(location)
465                    .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())))?;
466
467                let validated = validate_url(next_url.as_str())?;
468                let (next_host, next_addrs) = resolve_and_validate(&validated).await?;
469
470                current_url = next_url.to_string();
471                current_host = next_host;
472                current_addrs = next_addrs;
473                continue;
474            }
475
476            if !status.is_success() {
477                return Err(ToolError::Http {
478                    status: status.as_u16(),
479                    message: status.canonical_reason().unwrap_or("unknown").to_owned(),
480                });
481            }
482
483            let bytes = resp
484                .bytes()
485                .await
486                .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())))?;
487
488            if bytes.len() > self.max_body_bytes {
489                return Err(ToolError::Execution(std::io::Error::other(format!(
490                    "response too large: {} bytes (max: {})",
491                    bytes.len(),
492                    self.max_body_bytes,
493                ))));
494            }
495
496            return String::from_utf8(bytes.to_vec())
497                .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())));
498        }
499
500        Err(ToolError::Execution(std::io::Error::other(
501            "too many redirects",
502        )))
503    }
504}
505
506fn extract_scrape_blocks(text: &str) -> Vec<&str> {
507    crate::executor::extract_fenced_blocks(text, "scrape")
508}
509
510/// Check host against the domain allowlist/denylist from `ScrapeConfig`.
511///
512/// Logic:
513/// 1. If `denied_domains` matches the host → block.
514/// 2. If `allowed_domains` is non-empty:
515///    a. IP address hosts are always rejected (no pattern can match a bare IP).
516///    b. Hosts not matching any entry → block.
517/// 3. Otherwise → allow.
518///
519/// Wildcard prefix matching: `*.example.com` matches `sub.example.com` but NOT `example.com`.
520/// Multiple wildcards are not supported; patterns with more than one `*` are treated as exact.
521fn check_domain_policy(
522    host: &str,
523    allowed_domains: &[String],
524    denied_domains: &[String],
525) -> Result<(), ToolError> {
526    if denied_domains.iter().any(|p| domain_matches(p, host)) {
527        return Err(ToolError::Blocked {
528            command: format!("domain blocked by denylist: {host}"),
529        });
530    }
531    if !allowed_domains.is_empty() {
532        // Bare IP addresses cannot match any domain pattern — reject when allowlist is active.
533        let is_ip = host.parse::<std::net::IpAddr>().is_ok()
534            || (host.starts_with('[') && host.ends_with(']'));
535        if is_ip {
536            return Err(ToolError::Blocked {
537                command: format!(
538                    "bare IP address not allowed when domain allowlist is active: {host}"
539                ),
540            });
541        }
542        if !allowed_domains.iter().any(|p| domain_matches(p, host)) {
543            return Err(ToolError::Blocked {
544                command: format!("domain not in allowlist: {host}"),
545            });
546        }
547    }
548    Ok(())
549}
550
551/// Match a domain pattern against a hostname.
552///
553/// Pattern `*.example.com` matches `sub.example.com` but not `example.com` or
554/// `sub.sub.example.com`. Exact patterns match only themselves.
555/// Patterns with multiple `*` characters are not supported and are treated as exact strings.
556fn domain_matches(pattern: &str, host: &str) -> bool {
557    if pattern.starts_with("*.") {
558        // Allow only a single subdomain level: `*.example.com` → `<label>.example.com`
559        let suffix = &pattern[1..]; // ".example.com"
560        if let Some(remainder) = host.strip_suffix(suffix) {
561            // remainder must be a single DNS label (no dots)
562            !remainder.is_empty() && !remainder.contains('.')
563        } else {
564            false
565        }
566    } else {
567        pattern == host
568    }
569}
570
571fn validate_url(raw: &str) -> Result<Url, ToolError> {
572    let parsed = Url::parse(raw).map_err(|_| ToolError::Blocked {
573        command: format!("invalid URL: {raw}"),
574    })?;
575
576    if parsed.scheme() != "https" {
577        return Err(ToolError::Blocked {
578            command: format!("scheme not allowed: {}", parsed.scheme()),
579        });
580    }
581
582    if let Some(host) = parsed.host()
583        && is_private_host(&host)
584    {
585        return Err(ToolError::Blocked {
586            command: format!(
587                "private/local host blocked: {}",
588                parsed.host_str().unwrap_or("")
589            ),
590        });
591    }
592
593    Ok(parsed)
594}
595
596fn is_private_host(host: &url::Host<&str>) -> bool {
597    match host {
598        url::Host::Domain(d) => {
599            // Exact match or subdomain of localhost (e.g. foo.localhost)
600            // and .internal/.local TLDs used in cloud/k8s environments.
601            #[allow(clippy::case_sensitive_file_extension_comparisons)]
602            {
603                *d == "localhost"
604                    || d.ends_with(".localhost")
605                    || d.ends_with(".internal")
606                    || d.ends_with(".local")
607            }
608        }
609        url::Host::Ipv4(v4) => is_private_ip(IpAddr::V4(*v4)),
610        url::Host::Ipv6(v6) => is_private_ip(IpAddr::V6(*v6)),
611    }
612}
613
614/// Resolves DNS for the URL host, validates all resolved IPs against private ranges,
615/// and returns the hostname and validated socket addresses.
616///
617/// Returning the addresses allows the caller to pin the HTTP client to these exact
618/// addresses, eliminating TOCTOU between DNS validation and the actual connection.
619async fn resolve_and_validate(url: &Url) -> Result<(String, Vec<SocketAddr>), ToolError> {
620    let Some(host) = url.host_str() else {
621        return Ok((String::new(), vec![]));
622    };
623    let port = url.port_or_known_default().unwrap_or(443);
624    let addrs: Vec<SocketAddr> = tokio::net::lookup_host(format!("{host}:{port}"))
625        .await
626        .map_err(|e| ToolError::Blocked {
627            command: format!("DNS resolution failed: {e}"),
628        })?
629        .collect();
630    for addr in &addrs {
631        if is_private_ip(addr.ip()) {
632            return Err(ToolError::Blocked {
633                command: format!("SSRF protection: private IP {} for host {host}", addr.ip()),
634            });
635        }
636    }
637    Ok((host.to_owned(), addrs))
638}
639
640fn parse_and_extract(
641    html: &str,
642    selector: &str,
643    extract: &ExtractMode,
644    limit: usize,
645) -> Result<String, ToolError> {
646    let soup = scrape_core::Soup::parse(html);
647
648    let tags = soup.find_all(selector).map_err(|e| {
649        ToolError::Execution(std::io::Error::new(
650            std::io::ErrorKind::InvalidData,
651            format!("invalid selector: {e}"),
652        ))
653    })?;
654
655    let mut results = Vec::new();
656
657    for tag in tags.into_iter().take(limit) {
658        let value = match extract {
659            ExtractMode::Text => tag.text(),
660            ExtractMode::Html => tag.inner_html(),
661            ExtractMode::Attr(name) => tag.get(name).unwrap_or_default().to_owned(),
662        };
663        if !value.trim().is_empty() {
664            results.push(value.trim().to_owned());
665        }
666    }
667
668    if results.is_empty() {
669        Ok(format!("No results for selector: {selector}"))
670    } else {
671        Ok(results.join("\n"))
672    }
673}
674
675#[cfg(test)]
676mod tests {
677    use super::*;
678
679    // --- extract_scrape_blocks ---
680
681    #[test]
682    fn extract_single_block() {
683        let text =
684            "Here:\n```scrape\n{\"url\":\"https://example.com\",\"select\":\"h1\"}\n```\nDone.";
685        let blocks = extract_scrape_blocks(text);
686        assert_eq!(blocks.len(), 1);
687        assert!(blocks[0].contains("example.com"));
688    }
689
690    #[test]
691    fn extract_multiple_blocks() {
692        let text = "```scrape\n{\"url\":\"https://a.com\",\"select\":\"h1\"}\n```\ntext\n```scrape\n{\"url\":\"https://b.com\",\"select\":\"p\"}\n```";
693        let blocks = extract_scrape_blocks(text);
694        assert_eq!(blocks.len(), 2);
695    }
696
697    #[test]
698    fn no_blocks_returns_empty() {
699        let blocks = extract_scrape_blocks("plain text, no code blocks");
700        assert!(blocks.is_empty());
701    }
702
703    #[test]
704    fn unclosed_block_ignored() {
705        let blocks = extract_scrape_blocks("```scrape\n{\"url\":\"https://x.com\"}");
706        assert!(blocks.is_empty());
707    }
708
709    #[test]
710    fn non_scrape_block_ignored() {
711        let text =
712            "```bash\necho hi\n```\n```scrape\n{\"url\":\"https://x.com\",\"select\":\"h1\"}\n```";
713        let blocks = extract_scrape_blocks(text);
714        assert_eq!(blocks.len(), 1);
715        assert!(blocks[0].contains("x.com"));
716    }
717
718    #[test]
719    fn multiline_json_block() {
720        let text =
721            "```scrape\n{\n  \"url\": \"https://example.com\",\n  \"select\": \"h1\"\n}\n```";
722        let blocks = extract_scrape_blocks(text);
723        assert_eq!(blocks.len(), 1);
724        let instr: ScrapeInstruction = serde_json::from_str(blocks[0]).unwrap();
725        assert_eq!(instr.url, "https://example.com");
726    }
727
728    // --- ScrapeInstruction parsing ---
729
730    #[test]
731    fn parse_valid_instruction() {
732        let json = r#"{"url":"https://example.com","select":"h1","extract":"text","limit":5}"#;
733        let instr: ScrapeInstruction = serde_json::from_str(json).unwrap();
734        assert_eq!(instr.url, "https://example.com");
735        assert_eq!(instr.select, "h1");
736        assert_eq!(instr.extract, "text");
737        assert_eq!(instr.limit, Some(5));
738    }
739
740    #[test]
741    fn parse_minimal_instruction() {
742        let json = r#"{"url":"https://example.com","select":"p"}"#;
743        let instr: ScrapeInstruction = serde_json::from_str(json).unwrap();
744        assert_eq!(instr.extract, "text");
745        assert!(instr.limit.is_none());
746    }
747
748    #[test]
749    fn parse_attr_extract() {
750        let json = r#"{"url":"https://example.com","select":"a","extract":"attr:href"}"#;
751        let instr: ScrapeInstruction = serde_json::from_str(json).unwrap();
752        assert_eq!(instr.extract, "attr:href");
753    }
754
755    #[test]
756    fn parse_invalid_json_errors() {
757        let result = serde_json::from_str::<ScrapeInstruction>("not json");
758        assert!(result.is_err());
759    }
760
761    // --- ExtractMode ---
762
763    #[test]
764    fn extract_mode_text() {
765        assert!(matches!(ExtractMode::parse("text"), ExtractMode::Text));
766    }
767
768    #[test]
769    fn extract_mode_html() {
770        assert!(matches!(ExtractMode::parse("html"), ExtractMode::Html));
771    }
772
773    #[test]
774    fn extract_mode_attr() {
775        let mode = ExtractMode::parse("attr:href");
776        assert!(matches!(mode, ExtractMode::Attr(ref s) if s == "href"));
777    }
778
779    #[test]
780    fn extract_mode_unknown_defaults_to_text() {
781        assert!(matches!(ExtractMode::parse("unknown"), ExtractMode::Text));
782    }
783
784    // --- validate_url ---
785
786    #[test]
787    fn valid_https_url() {
788        assert!(validate_url("https://example.com").is_ok());
789    }
790
791    #[test]
792    fn http_rejected() {
793        let err = validate_url("http://example.com").unwrap_err();
794        assert!(matches!(err, ToolError::Blocked { .. }));
795    }
796
797    #[test]
798    fn ftp_rejected() {
799        let err = validate_url("ftp://files.example.com").unwrap_err();
800        assert!(matches!(err, ToolError::Blocked { .. }));
801    }
802
803    #[test]
804    fn file_rejected() {
805        let err = validate_url("file:///etc/passwd").unwrap_err();
806        assert!(matches!(err, ToolError::Blocked { .. }));
807    }
808
809    #[test]
810    fn invalid_url_rejected() {
811        let err = validate_url("not a url").unwrap_err();
812        assert!(matches!(err, ToolError::Blocked { .. }));
813    }
814
815    #[test]
816    fn localhost_blocked() {
817        let err = validate_url("https://localhost/path").unwrap_err();
818        assert!(matches!(err, ToolError::Blocked { .. }));
819    }
820
821    #[test]
822    fn loopback_ip_blocked() {
823        let err = validate_url("https://127.0.0.1/path").unwrap_err();
824        assert!(matches!(err, ToolError::Blocked { .. }));
825    }
826
827    #[test]
828    fn private_10_blocked() {
829        let err = validate_url("https://10.0.0.1/api").unwrap_err();
830        assert!(matches!(err, ToolError::Blocked { .. }));
831    }
832
833    #[test]
834    fn private_172_blocked() {
835        let err = validate_url("https://172.16.0.1/api").unwrap_err();
836        assert!(matches!(err, ToolError::Blocked { .. }));
837    }
838
839    #[test]
840    fn private_192_blocked() {
841        let err = validate_url("https://192.168.1.1/api").unwrap_err();
842        assert!(matches!(err, ToolError::Blocked { .. }));
843    }
844
845    #[test]
846    fn ipv6_loopback_blocked() {
847        let err = validate_url("https://[::1]/path").unwrap_err();
848        assert!(matches!(err, ToolError::Blocked { .. }));
849    }
850
851    #[test]
852    fn public_ip_allowed() {
853        assert!(validate_url("https://93.184.216.34/page").is_ok());
854    }
855
856    // --- parse_and_extract ---
857
858    #[test]
859    fn extract_text_from_html() {
860        let html = "<html><body><h1>Hello World</h1><p>Content</p></body></html>";
861        let result = parse_and_extract(html, "h1", &ExtractMode::Text, 10).unwrap();
862        assert_eq!(result, "Hello World");
863    }
864
865    #[test]
866    fn extract_multiple_elements() {
867        let html = "<ul><li>A</li><li>B</li><li>C</li></ul>";
868        let result = parse_and_extract(html, "li", &ExtractMode::Text, 10).unwrap();
869        assert_eq!(result, "A\nB\nC");
870    }
871
872    #[test]
873    fn extract_with_limit() {
874        let html = "<ul><li>A</li><li>B</li><li>C</li></ul>";
875        let result = parse_and_extract(html, "li", &ExtractMode::Text, 2).unwrap();
876        assert_eq!(result, "A\nB");
877    }
878
879    #[test]
880    fn extract_attr_href() {
881        let html = r#"<a href="https://example.com">Link</a>"#;
882        let result =
883            parse_and_extract(html, "a", &ExtractMode::Attr("href".to_owned()), 10).unwrap();
884        assert_eq!(result, "https://example.com");
885    }
886
887    #[test]
888    fn extract_inner_html() {
889        let html = "<div><span>inner</span></div>";
890        let result = parse_and_extract(html, "div", &ExtractMode::Html, 10).unwrap();
891        assert!(result.contains("<span>inner</span>"));
892    }
893
894    #[test]
895    fn no_matches_returns_message() {
896        let html = "<html><body><p>text</p></body></html>";
897        let result = parse_and_extract(html, "h1", &ExtractMode::Text, 10).unwrap();
898        assert!(result.starts_with("No results for selector:"));
899    }
900
901    #[test]
902    fn empty_text_skipped() {
903        let html = "<ul><li>  </li><li>A</li></ul>";
904        let result = parse_and_extract(html, "li", &ExtractMode::Text, 10).unwrap();
905        assert_eq!(result, "A");
906    }
907
908    #[test]
909    fn invalid_selector_errors() {
910        let html = "<html><body></body></html>";
911        let result = parse_and_extract(html, "[[[invalid", &ExtractMode::Text, 10);
912        assert!(result.is_err());
913    }
914
915    #[test]
916    fn empty_html_returns_no_results() {
917        let result = parse_and_extract("", "h1", &ExtractMode::Text, 10).unwrap();
918        assert!(result.starts_with("No results for selector:"));
919    }
920
921    #[test]
922    fn nested_selector() {
923        let html = "<div><span>inner</span></div><span>outer</span>";
924        let result = parse_and_extract(html, "div > span", &ExtractMode::Text, 10).unwrap();
925        assert_eq!(result, "inner");
926    }
927
928    #[test]
929    fn attr_missing_returns_empty() {
930        let html = r"<a>No href</a>";
931        let result =
932            parse_and_extract(html, "a", &ExtractMode::Attr("href".to_owned()), 10).unwrap();
933        assert!(result.starts_with("No results for selector:"));
934    }
935
936    #[test]
937    fn extract_html_mode() {
938        let html = "<div><b>bold</b> text</div>";
939        let result = parse_and_extract(html, "div", &ExtractMode::Html, 10).unwrap();
940        assert!(result.contains("<b>bold</b>"));
941    }
942
943    #[test]
944    fn limit_zero_returns_no_results() {
945        let html = "<ul><li>A</li><li>B</li></ul>";
946        let result = parse_and_extract(html, "li", &ExtractMode::Text, 0).unwrap();
947        assert!(result.starts_with("No results for selector:"));
948    }
949
950    // --- validate_url edge cases ---
951
952    #[test]
953    fn url_with_port_allowed() {
954        assert!(validate_url("https://example.com:8443/path").is_ok());
955    }
956
957    #[test]
958    fn link_local_ip_blocked() {
959        let err = validate_url("https://169.254.1.1/path").unwrap_err();
960        assert!(matches!(err, ToolError::Blocked { .. }));
961    }
962
963    #[test]
964    fn url_no_scheme_rejected() {
965        let err = validate_url("example.com/path").unwrap_err();
966        assert!(matches!(err, ToolError::Blocked { .. }));
967    }
968
969    #[test]
970    fn unspecified_ipv4_blocked() {
971        let err = validate_url("https://0.0.0.0/path").unwrap_err();
972        assert!(matches!(err, ToolError::Blocked { .. }));
973    }
974
975    #[test]
976    fn broadcast_ipv4_blocked() {
977        let err = validate_url("https://255.255.255.255/path").unwrap_err();
978        assert!(matches!(err, ToolError::Blocked { .. }));
979    }
980
981    #[test]
982    fn ipv6_link_local_blocked() {
983        let err = validate_url("https://[fe80::1]/path").unwrap_err();
984        assert!(matches!(err, ToolError::Blocked { .. }));
985    }
986
987    #[test]
988    fn ipv6_unique_local_blocked() {
989        let err = validate_url("https://[fd12::1]/path").unwrap_err();
990        assert!(matches!(err, ToolError::Blocked { .. }));
991    }
992
993    #[test]
994    fn ipv4_mapped_ipv6_loopback_blocked() {
995        let err = validate_url("https://[::ffff:127.0.0.1]/path").unwrap_err();
996        assert!(matches!(err, ToolError::Blocked { .. }));
997    }
998
999    #[test]
1000    fn ipv4_mapped_ipv6_private_blocked() {
1001        let err = validate_url("https://[::ffff:10.0.0.1]/path").unwrap_err();
1002        assert!(matches!(err, ToolError::Blocked { .. }));
1003    }
1004
1005    // --- WebScrapeExecutor (no-network) ---
1006
1007    #[tokio::test]
1008    async fn executor_no_blocks_returns_none() {
1009        let config = ScrapeConfig::default();
1010        let executor = WebScrapeExecutor::new(&config);
1011        let result = executor.execute("plain text").await;
1012        assert!(result.unwrap().is_none());
1013    }
1014
1015    #[tokio::test]
1016    async fn executor_invalid_json_errors() {
1017        let config = ScrapeConfig::default();
1018        let executor = WebScrapeExecutor::new(&config);
1019        let response = "```scrape\nnot json\n```";
1020        let result = executor.execute(response).await;
1021        assert!(matches!(result, Err(ToolError::Execution(_))));
1022    }
1023
1024    #[tokio::test]
1025    async fn executor_blocked_url_errors() {
1026        let config = ScrapeConfig::default();
1027        let executor = WebScrapeExecutor::new(&config);
1028        let response = "```scrape\n{\"url\":\"http://example.com\",\"select\":\"h1\"}\n```";
1029        let result = executor.execute(response).await;
1030        assert!(matches!(result, Err(ToolError::Blocked { .. })));
1031    }
1032
1033    #[tokio::test]
1034    async fn executor_private_ip_blocked() {
1035        let config = ScrapeConfig::default();
1036        let executor = WebScrapeExecutor::new(&config);
1037        let response = "```scrape\n{\"url\":\"https://192.168.1.1/api\",\"select\":\"h1\"}\n```";
1038        let result = executor.execute(response).await;
1039        assert!(matches!(result, Err(ToolError::Blocked { .. })));
1040    }
1041
1042    #[tokio::test]
1043    async fn executor_unreachable_host_returns_error() {
1044        let config = ScrapeConfig {
1045            timeout: 1,
1046            max_body_bytes: 1_048_576,
1047            ..Default::default()
1048        };
1049        let executor = WebScrapeExecutor::new(&config);
1050        let response = "```scrape\n{\"url\":\"https://192.0.2.1:1/page\",\"select\":\"h1\"}\n```";
1051        let result = executor.execute(response).await;
1052        assert!(matches!(result, Err(ToolError::Execution(_))));
1053    }
1054
1055    #[tokio::test]
1056    async fn executor_localhost_url_blocked() {
1057        let config = ScrapeConfig::default();
1058        let executor = WebScrapeExecutor::new(&config);
1059        let response = "```scrape\n{\"url\":\"https://localhost:9999/api\",\"select\":\"h1\"}\n```";
1060        let result = executor.execute(response).await;
1061        assert!(matches!(result, Err(ToolError::Blocked { .. })));
1062    }
1063
1064    #[tokio::test]
1065    async fn executor_empty_text_returns_none() {
1066        let config = ScrapeConfig::default();
1067        let executor = WebScrapeExecutor::new(&config);
1068        let result = executor.execute("").await;
1069        assert!(result.unwrap().is_none());
1070    }
1071
1072    #[tokio::test]
1073    async fn executor_multiple_blocks_first_blocked() {
1074        let config = ScrapeConfig::default();
1075        let executor = WebScrapeExecutor::new(&config);
1076        let response = "```scrape\n{\"url\":\"http://evil.com\",\"select\":\"h1\"}\n```\n\
1077             ```scrape\n{\"url\":\"https://ok.com\",\"select\":\"h1\"}\n```";
1078        let result = executor.execute(response).await;
1079        assert!(result.is_err());
1080    }
1081
1082    #[test]
1083    fn validate_url_empty_string() {
1084        let err = validate_url("").unwrap_err();
1085        assert!(matches!(err, ToolError::Blocked { .. }));
1086    }
1087
1088    #[test]
1089    fn validate_url_javascript_scheme_blocked() {
1090        let err = validate_url("javascript:alert(1)").unwrap_err();
1091        assert!(matches!(err, ToolError::Blocked { .. }));
1092    }
1093
1094    #[test]
1095    fn validate_url_data_scheme_blocked() {
1096        let err = validate_url("data:text/html,<h1>hi</h1>").unwrap_err();
1097        assert!(matches!(err, ToolError::Blocked { .. }));
1098    }
1099
1100    #[test]
1101    fn is_private_host_public_domain_is_false() {
1102        let host: url::Host<&str> = url::Host::Domain("example.com");
1103        assert!(!is_private_host(&host));
1104    }
1105
1106    #[test]
1107    fn is_private_host_localhost_is_true() {
1108        let host: url::Host<&str> = url::Host::Domain("localhost");
1109        assert!(is_private_host(&host));
1110    }
1111
1112    #[test]
1113    fn is_private_host_ipv6_unspecified_is_true() {
1114        let host = url::Host::Ipv6(std::net::Ipv6Addr::UNSPECIFIED);
1115        assert!(is_private_host(&host));
1116    }
1117
1118    #[test]
1119    fn is_private_host_public_ipv6_is_false() {
1120        let host = url::Host::Ipv6("2001:db8::1".parse().unwrap());
1121        assert!(!is_private_host(&host));
1122    }
1123
1124    // --- fetch_html redirect logic: wiremock HTTP server tests ---
1125    //
1126    // These tests use a local wiremock server to exercise the redirect-following logic
1127    // in `fetch_html` without requiring an external HTTPS connection. The server binds to
1128    // 127.0.0.1, and tests call `fetch_html` directly (bypassing `validate_url`) to avoid
1129    // the SSRF guard that would otherwise block loopback connections.
1130
1131    /// Helper: returns executor + (`server_url`, `server_addr`) from a running wiremock mock server.
1132    /// The server address is passed to `fetch_html` via `resolve_to_addrs` so the client
1133    /// connects to the mock instead of doing a real DNS lookup.
1134    async fn mock_server_executor() -> (WebScrapeExecutor, wiremock::MockServer) {
1135        let server = wiremock::MockServer::start().await;
1136        let executor = WebScrapeExecutor {
1137            timeout: Duration::from_secs(5),
1138            max_body_bytes: 1_048_576,
1139            allowed_domains: vec![],
1140            denied_domains: vec![],
1141            audit_logger: None,
1142        };
1143        (executor, server)
1144    }
1145
1146    /// Parses the mock server's URI into (`host_str`, `socket_addr`) for use with `build_client`.
1147    fn server_host_and_addr(server: &wiremock::MockServer) -> (String, Vec<std::net::SocketAddr>) {
1148        let uri = server.uri();
1149        let url = Url::parse(&uri).unwrap();
1150        let host = url.host_str().unwrap_or("127.0.0.1").to_owned();
1151        let port = url.port().unwrap_or(80);
1152        let addr: std::net::SocketAddr = format!("{host}:{port}").parse().unwrap();
1153        (host, vec![addr])
1154    }
1155
1156    /// Test-only redirect follower that mimics `fetch_html`'s loop but skips `validate_url` /
1157    /// `resolve_and_validate`. This lets us exercise the redirect-counting and
1158    /// missing-Location logic against a plain HTTP wiremock server.
1159    async fn follow_redirects_raw(
1160        executor: &WebScrapeExecutor,
1161        start_url: &str,
1162        host: &str,
1163        addrs: &[std::net::SocketAddr],
1164    ) -> Result<String, ToolError> {
1165        const MAX_REDIRECTS: usize = 3;
1166        let mut current_url = start_url.to_owned();
1167        let mut current_host = host.to_owned();
1168        let mut current_addrs = addrs.to_vec();
1169
1170        for hop in 0..=MAX_REDIRECTS {
1171            let client = executor.build_client(&current_host, &current_addrs);
1172            let resp = client
1173                .get(&current_url)
1174                .send()
1175                .await
1176                .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())))?;
1177
1178            let status = resp.status();
1179
1180            if status.is_redirection() {
1181                if hop == MAX_REDIRECTS {
1182                    return Err(ToolError::Execution(std::io::Error::other(
1183                        "too many redirects",
1184                    )));
1185                }
1186
1187                let location = resp
1188                    .headers()
1189                    .get(reqwest::header::LOCATION)
1190                    .and_then(|v| v.to_str().ok())
1191                    .ok_or_else(|| {
1192                        ToolError::Execution(std::io::Error::other("redirect with no Location"))
1193                    })?;
1194
1195                let base = Url::parse(&current_url)
1196                    .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())))?;
1197                let next_url = base
1198                    .join(location)
1199                    .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())))?;
1200
1201                // Re-use same host/addrs (mock server is always the same endpoint).
1202                current_url = next_url.to_string();
1203                // Preserve host/addrs as-is since the mock server doesn't change.
1204                let _ = &mut current_host;
1205                let _ = &mut current_addrs;
1206                continue;
1207            }
1208
1209            if !status.is_success() {
1210                return Err(ToolError::Execution(std::io::Error::other(format!(
1211                    "HTTP {status}",
1212                ))));
1213            }
1214
1215            let bytes = resp
1216                .bytes()
1217                .await
1218                .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())))?;
1219
1220            if bytes.len() > executor.max_body_bytes {
1221                return Err(ToolError::Execution(std::io::Error::other(format!(
1222                    "response too large: {} bytes (max: {})",
1223                    bytes.len(),
1224                    executor.max_body_bytes,
1225                ))));
1226            }
1227
1228            return String::from_utf8(bytes.to_vec())
1229                .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())));
1230        }
1231
1232        Err(ToolError::Execution(std::io::Error::other(
1233            "too many redirects",
1234        )))
1235    }
1236
1237    #[tokio::test]
1238    async fn fetch_html_success_returns_body() {
1239        use wiremock::matchers::{method, path};
1240        use wiremock::{Mock, ResponseTemplate};
1241
1242        let (executor, server) = mock_server_executor().await;
1243        Mock::given(method("GET"))
1244            .and(path("/page"))
1245            .respond_with(ResponseTemplate::new(200).set_body_string("<h1>OK</h1>"))
1246            .mount(&server)
1247            .await;
1248
1249        let (host, addrs) = server_host_and_addr(&server);
1250        let url = format!("{}/page", server.uri());
1251        let result = executor.fetch_html(&url, &host, &addrs).await;
1252        assert!(result.is_ok(), "expected Ok, got: {result:?}");
1253        assert_eq!(result.unwrap(), "<h1>OK</h1>");
1254    }
1255
1256    #[tokio::test]
1257    async fn fetch_html_non_2xx_returns_error() {
1258        use wiremock::matchers::{method, path};
1259        use wiremock::{Mock, ResponseTemplate};
1260
1261        let (executor, server) = mock_server_executor().await;
1262        Mock::given(method("GET"))
1263            .and(path("/forbidden"))
1264            .respond_with(ResponseTemplate::new(403))
1265            .mount(&server)
1266            .await;
1267
1268        let (host, addrs) = server_host_and_addr(&server);
1269        let url = format!("{}/forbidden", server.uri());
1270        let result = executor.fetch_html(&url, &host, &addrs).await;
1271        assert!(result.is_err());
1272        let msg = result.unwrap_err().to_string();
1273        assert!(msg.contains("403"), "expected 403 in error: {msg}");
1274    }
1275
1276    #[tokio::test]
1277    async fn fetch_html_404_returns_error() {
1278        use wiremock::matchers::{method, path};
1279        use wiremock::{Mock, ResponseTemplate};
1280
1281        let (executor, server) = mock_server_executor().await;
1282        Mock::given(method("GET"))
1283            .and(path("/missing"))
1284            .respond_with(ResponseTemplate::new(404))
1285            .mount(&server)
1286            .await;
1287
1288        let (host, addrs) = server_host_and_addr(&server);
1289        let url = format!("{}/missing", server.uri());
1290        let result = executor.fetch_html(&url, &host, &addrs).await;
1291        assert!(result.is_err());
1292        let msg = result.unwrap_err().to_string();
1293        assert!(msg.contains("404"), "expected 404 in error: {msg}");
1294    }
1295
1296    #[tokio::test]
1297    async fn fetch_html_redirect_no_location_returns_error() {
1298        use wiremock::matchers::{method, path};
1299        use wiremock::{Mock, ResponseTemplate};
1300
1301        let (executor, server) = mock_server_executor().await;
1302        // 302 with no Location header
1303        Mock::given(method("GET"))
1304            .and(path("/redirect-no-loc"))
1305            .respond_with(ResponseTemplate::new(302))
1306            .mount(&server)
1307            .await;
1308
1309        let (host, addrs) = server_host_and_addr(&server);
1310        let url = format!("{}/redirect-no-loc", server.uri());
1311        let result = executor.fetch_html(&url, &host, &addrs).await;
1312        assert!(result.is_err());
1313        let msg = result.unwrap_err().to_string();
1314        assert!(
1315            msg.contains("Location") || msg.contains("location"),
1316            "expected Location-related error: {msg}"
1317        );
1318    }
1319
1320    #[tokio::test]
1321    async fn fetch_html_single_redirect_followed() {
1322        use wiremock::matchers::{method, path};
1323        use wiremock::{Mock, ResponseTemplate};
1324
1325        let (executor, server) = mock_server_executor().await;
1326        let final_url = format!("{}/final", server.uri());
1327
1328        Mock::given(method("GET"))
1329            .and(path("/start"))
1330            .respond_with(ResponseTemplate::new(302).insert_header("location", final_url.as_str()))
1331            .mount(&server)
1332            .await;
1333
1334        Mock::given(method("GET"))
1335            .and(path("/final"))
1336            .respond_with(ResponseTemplate::new(200).set_body_string("<p>final</p>"))
1337            .mount(&server)
1338            .await;
1339
1340        let (host, addrs) = server_host_and_addr(&server);
1341        let url = format!("{}/start", server.uri());
1342        let result = follow_redirects_raw(&executor, &url, &host, &addrs).await;
1343        assert!(result.is_ok(), "single redirect should succeed: {result:?}");
1344        assert_eq!(result.unwrap(), "<p>final</p>");
1345    }
1346
1347    #[tokio::test]
1348    async fn fetch_html_three_redirects_allowed() {
1349        use wiremock::matchers::{method, path};
1350        use wiremock::{Mock, ResponseTemplate};
1351
1352        let (executor, server) = mock_server_executor().await;
1353        let hop2 = format!("{}/hop2", server.uri());
1354        let hop3 = format!("{}/hop3", server.uri());
1355        let final_dest = format!("{}/done", server.uri());
1356
1357        Mock::given(method("GET"))
1358            .and(path("/hop1"))
1359            .respond_with(ResponseTemplate::new(301).insert_header("location", hop2.as_str()))
1360            .mount(&server)
1361            .await;
1362        Mock::given(method("GET"))
1363            .and(path("/hop2"))
1364            .respond_with(ResponseTemplate::new(301).insert_header("location", hop3.as_str()))
1365            .mount(&server)
1366            .await;
1367        Mock::given(method("GET"))
1368            .and(path("/hop3"))
1369            .respond_with(ResponseTemplate::new(301).insert_header("location", final_dest.as_str()))
1370            .mount(&server)
1371            .await;
1372        Mock::given(method("GET"))
1373            .and(path("/done"))
1374            .respond_with(ResponseTemplate::new(200).set_body_string("<p>done</p>"))
1375            .mount(&server)
1376            .await;
1377
1378        let (host, addrs) = server_host_and_addr(&server);
1379        let url = format!("{}/hop1", server.uri());
1380        let result = follow_redirects_raw(&executor, &url, &host, &addrs).await;
1381        assert!(result.is_ok(), "3 redirects should succeed: {result:?}");
1382        assert_eq!(result.unwrap(), "<p>done</p>");
1383    }
1384
1385    #[tokio::test]
1386    async fn fetch_html_four_redirects_rejected() {
1387        use wiremock::matchers::{method, path};
1388        use wiremock::{Mock, ResponseTemplate};
1389
1390        let (executor, server) = mock_server_executor().await;
1391        let hop2 = format!("{}/r2", server.uri());
1392        let hop3 = format!("{}/r3", server.uri());
1393        let hop4 = format!("{}/r4", server.uri());
1394        let hop5 = format!("{}/r5", server.uri());
1395
1396        for (from, to) in [
1397            ("/r1", &hop2),
1398            ("/r2", &hop3),
1399            ("/r3", &hop4),
1400            ("/r4", &hop5),
1401        ] {
1402            Mock::given(method("GET"))
1403                .and(path(from))
1404                .respond_with(ResponseTemplate::new(301).insert_header("location", to.as_str()))
1405                .mount(&server)
1406                .await;
1407        }
1408
1409        let (host, addrs) = server_host_and_addr(&server);
1410        let url = format!("{}/r1", server.uri());
1411        let result = follow_redirects_raw(&executor, &url, &host, &addrs).await;
1412        assert!(result.is_err(), "4 redirects should be rejected");
1413        let msg = result.unwrap_err().to_string();
1414        assert!(
1415            msg.contains("redirect"),
1416            "expected redirect-related error: {msg}"
1417        );
1418    }
1419
1420    #[tokio::test]
1421    async fn fetch_html_body_too_large_returns_error() {
1422        use wiremock::matchers::{method, path};
1423        use wiremock::{Mock, ResponseTemplate};
1424
1425        let small_limit_executor = WebScrapeExecutor {
1426            timeout: Duration::from_secs(5),
1427            max_body_bytes: 10,
1428            allowed_domains: vec![],
1429            denied_domains: vec![],
1430            audit_logger: None,
1431        };
1432        let server = wiremock::MockServer::start().await;
1433        Mock::given(method("GET"))
1434            .and(path("/big"))
1435            .respond_with(
1436                ResponseTemplate::new(200)
1437                    .set_body_string("this body is definitely longer than ten bytes"),
1438            )
1439            .mount(&server)
1440            .await;
1441
1442        let (host, addrs) = server_host_and_addr(&server);
1443        let url = format!("{}/big", server.uri());
1444        let result = small_limit_executor.fetch_html(&url, &host, &addrs).await;
1445        assert!(result.is_err());
1446        let msg = result.unwrap_err().to_string();
1447        assert!(msg.contains("too large"), "expected too-large error: {msg}");
1448    }
1449
1450    #[test]
1451    fn extract_scrape_blocks_empty_block_content() {
1452        let text = "```scrape\n\n```";
1453        let blocks = extract_scrape_blocks(text);
1454        assert_eq!(blocks.len(), 1);
1455        assert!(blocks[0].is_empty());
1456    }
1457
1458    #[test]
1459    fn extract_scrape_blocks_whitespace_only() {
1460        let text = "```scrape\n   \n```";
1461        let blocks = extract_scrape_blocks(text);
1462        assert_eq!(blocks.len(), 1);
1463    }
1464
1465    #[test]
1466    fn parse_and_extract_multiple_selectors() {
1467        let html = "<div><h1>Title</h1><p>Para</p></div>";
1468        let result = parse_and_extract(html, "h1, p", &ExtractMode::Text, 10).unwrap();
1469        assert!(result.contains("Title"));
1470        assert!(result.contains("Para"));
1471    }
1472
1473    #[test]
1474    fn webscrape_executor_new_with_custom_config() {
1475        let config = ScrapeConfig {
1476            timeout: 60,
1477            max_body_bytes: 512,
1478            ..Default::default()
1479        };
1480        let executor = WebScrapeExecutor::new(&config);
1481        assert_eq!(executor.max_body_bytes, 512);
1482    }
1483
1484    #[test]
1485    fn webscrape_executor_debug() {
1486        let config = ScrapeConfig::default();
1487        let executor = WebScrapeExecutor::new(&config);
1488        let dbg = format!("{executor:?}");
1489        assert!(dbg.contains("WebScrapeExecutor"));
1490    }
1491
1492    #[test]
1493    fn extract_mode_attr_empty_name() {
1494        let mode = ExtractMode::parse("attr:");
1495        assert!(matches!(mode, ExtractMode::Attr(ref s) if s.is_empty()));
1496    }
1497
1498    #[test]
1499    fn default_extract_returns_text() {
1500        assert_eq!(default_extract(), "text");
1501    }
1502
1503    #[test]
1504    fn scrape_instruction_debug() {
1505        let json = r#"{"url":"https://example.com","select":"h1"}"#;
1506        let instr: ScrapeInstruction = serde_json::from_str(json).unwrap();
1507        let dbg = format!("{instr:?}");
1508        assert!(dbg.contains("ScrapeInstruction"));
1509    }
1510
1511    #[test]
1512    fn extract_mode_debug() {
1513        let mode = ExtractMode::Text;
1514        let dbg = format!("{mode:?}");
1515        assert!(dbg.contains("Text"));
1516    }
1517
1518    // --- fetch_html redirect logic: constant and validation unit tests ---
1519
1520    /// `MAX_REDIRECTS` is 3; the 4th redirect attempt must be rejected.
1521    /// Verify the boundary is correct by inspecting the constant value.
1522    #[test]
1523    fn max_redirects_constant_is_three() {
1524        // fetch_html uses `for hop in 0..=MAX_REDIRECTS` and returns error when hop == MAX_REDIRECTS
1525        // while still in a redirect. That means hops 0,1,2 can redirect; hop 3 triggers the error.
1526        // This test documents the expected limit.
1527        const MAX_REDIRECTS: usize = 3;
1528        assert_eq!(MAX_REDIRECTS, 3, "fetch_html allows exactly 3 redirects");
1529    }
1530
1531    /// Verifies that a Location-less redirect would produce an error string containing the
1532    /// expected message, matching the error path in `fetch_html`.
1533    #[test]
1534    fn redirect_no_location_error_message() {
1535        let err = std::io::Error::other("redirect with no Location");
1536        assert!(err.to_string().contains("redirect with no Location"));
1537    }
1538
1539    /// Verifies that a too-many-redirects condition produces the expected error string.
1540    #[test]
1541    fn too_many_redirects_error_message() {
1542        let err = std::io::Error::other("too many redirects");
1543        assert!(err.to_string().contains("too many redirects"));
1544    }
1545
1546    /// Verifies that a non-2xx HTTP status produces an error message with the status code.
1547    #[test]
1548    fn non_2xx_status_error_format() {
1549        let status = reqwest::StatusCode::FORBIDDEN;
1550        let msg = format!("HTTP {status}");
1551        assert!(msg.contains("403"));
1552    }
1553
1554    /// Verifies that a 404 response status code formats into the expected error message.
1555    #[test]
1556    fn not_found_status_error_format() {
1557        let status = reqwest::StatusCode::NOT_FOUND;
1558        let msg = format!("HTTP {status}");
1559        assert!(msg.contains("404"));
1560    }
1561
1562    /// Verifies relative redirect resolution for same-host paths (simulates Location: /other).
1563    #[test]
1564    fn relative_redirect_same_host_path() {
1565        let base = Url::parse("https://example.com/current").unwrap();
1566        let resolved = base.join("/other").unwrap();
1567        assert_eq!(resolved.as_str(), "https://example.com/other");
1568    }
1569
1570    /// Verifies relative redirect resolution preserves scheme and host.
1571    #[test]
1572    fn relative_redirect_relative_path() {
1573        let base = Url::parse("https://example.com/a/b").unwrap();
1574        let resolved = base.join("c").unwrap();
1575        assert_eq!(resolved.as_str(), "https://example.com/a/c");
1576    }
1577
1578    /// Verifies that an absolute redirect URL overrides base URL completely.
1579    #[test]
1580    fn absolute_redirect_overrides_base() {
1581        let base = Url::parse("https://example.com/page").unwrap();
1582        let resolved = base.join("https://other.com/target").unwrap();
1583        assert_eq!(resolved.as_str(), "https://other.com/target");
1584    }
1585
1586    /// Verifies that a redirect Location of http:// (downgrade) is rejected.
1587    #[test]
1588    fn redirect_http_downgrade_rejected() {
1589        let location = "http://example.com/page";
1590        let base = Url::parse("https://example.com/start").unwrap();
1591        let next = base.join(location).unwrap();
1592        let err = validate_url(next.as_str()).unwrap_err();
1593        assert!(matches!(err, ToolError::Blocked { .. }));
1594    }
1595
1596    /// Verifies that a redirect to a private IP literal is blocked.
1597    #[test]
1598    fn redirect_location_private_ip_blocked() {
1599        let location = "https://192.168.100.1/admin";
1600        let base = Url::parse("https://example.com/start").unwrap();
1601        let next = base.join(location).unwrap();
1602        let err = validate_url(next.as_str()).unwrap_err();
1603        assert!(matches!(err, ToolError::Blocked { .. }));
1604        let ToolError::Blocked { command: cmd } = err else {
1605            panic!("expected Blocked");
1606        };
1607        assert!(
1608            cmd.contains("private") || cmd.contains("scheme"),
1609            "error message should describe the block reason: {cmd}"
1610        );
1611    }
1612
1613    /// Verifies that a redirect to a .internal domain is blocked.
1614    #[test]
1615    fn redirect_location_internal_domain_blocked() {
1616        let location = "https://metadata.internal/latest/meta-data/";
1617        let base = Url::parse("https://example.com/start").unwrap();
1618        let next = base.join(location).unwrap();
1619        let err = validate_url(next.as_str()).unwrap_err();
1620        assert!(matches!(err, ToolError::Blocked { .. }));
1621    }
1622
1623    /// Verifies that a chain of 3 valid public redirects passes `validate_url` at every hop.
1624    #[test]
1625    fn redirect_chain_three_hops_all_public() {
1626        let hops = [
1627            "https://redirect1.example.com/hop1",
1628            "https://redirect2.example.com/hop2",
1629            "https://destination.example.com/final",
1630        ];
1631        for hop in hops {
1632            assert!(validate_url(hop).is_ok(), "expected ok for {hop}");
1633        }
1634    }
1635
1636    // --- SSRF redirect chain defense ---
1637
1638    /// Verifies that a redirect Location pointing to a private IP is rejected by `validate_url`
1639    /// before any connection attempt — simulating the validation step inside `fetch_html`.
1640    #[test]
1641    fn redirect_to_private_ip_rejected_by_validate_url() {
1642        // These would appear as Location headers in a redirect response.
1643        let private_targets = [
1644            "https://127.0.0.1/secret",
1645            "https://10.0.0.1/internal",
1646            "https://192.168.1.1/admin",
1647            "https://172.16.0.1/data",
1648            "https://[::1]/path",
1649            "https://[fe80::1]/path",
1650            "https://localhost/path",
1651            "https://service.internal/api",
1652        ];
1653        for target in private_targets {
1654            let result = validate_url(target);
1655            assert!(result.is_err(), "expected error for {target}");
1656            assert!(
1657                matches!(result.unwrap_err(), ToolError::Blocked { .. }),
1658                "expected Blocked for {target}"
1659            );
1660        }
1661    }
1662
1663    /// Verifies that relative redirect URLs are resolved correctly before validation.
1664    #[test]
1665    fn redirect_relative_url_resolves_correctly() {
1666        let base = Url::parse("https://example.com/page").unwrap();
1667        let relative = "/other";
1668        let resolved = base.join(relative).unwrap();
1669        assert_eq!(resolved.as_str(), "https://example.com/other");
1670    }
1671
1672    /// Verifies that a protocol-relative redirect to http:// is rejected (scheme check).
1673    #[test]
1674    fn redirect_to_http_rejected() {
1675        let err = validate_url("http://example.com/page").unwrap_err();
1676        assert!(matches!(err, ToolError::Blocked { .. }));
1677    }
1678
1679    #[test]
1680    fn ipv4_mapped_ipv6_link_local_blocked() {
1681        let err = validate_url("https://[::ffff:169.254.0.1]/path").unwrap_err();
1682        assert!(matches!(err, ToolError::Blocked { .. }));
1683    }
1684
1685    #[test]
1686    fn ipv4_mapped_ipv6_public_allowed() {
1687        assert!(validate_url("https://[::ffff:93.184.216.34]/path").is_ok());
1688    }
1689
1690    // --- fetch tool ---
1691
1692    #[tokio::test]
1693    async fn fetch_http_scheme_blocked() {
1694        let config = ScrapeConfig::default();
1695        let executor = WebScrapeExecutor::new(&config);
1696        let call = crate::executor::ToolCall {
1697            tool_id: ToolName::new("fetch"),
1698            params: {
1699                let mut m = serde_json::Map::new();
1700                m.insert("url".to_owned(), serde_json::json!("http://example.com"));
1701                m
1702            },
1703            caller_id: None,
1704        };
1705        let result = executor.execute_tool_call(&call).await;
1706        assert!(matches!(result, Err(ToolError::Blocked { .. })));
1707    }
1708
1709    #[tokio::test]
1710    async fn fetch_private_ip_blocked() {
1711        let config = ScrapeConfig::default();
1712        let executor = WebScrapeExecutor::new(&config);
1713        let call = crate::executor::ToolCall {
1714            tool_id: ToolName::new("fetch"),
1715            params: {
1716                let mut m = serde_json::Map::new();
1717                m.insert(
1718                    "url".to_owned(),
1719                    serde_json::json!("https://192.168.1.1/secret"),
1720                );
1721                m
1722            },
1723            caller_id: None,
1724        };
1725        let result = executor.execute_tool_call(&call).await;
1726        assert!(matches!(result, Err(ToolError::Blocked { .. })));
1727    }
1728
1729    #[tokio::test]
1730    async fn fetch_localhost_blocked() {
1731        let config = ScrapeConfig::default();
1732        let executor = WebScrapeExecutor::new(&config);
1733        let call = crate::executor::ToolCall {
1734            tool_id: ToolName::new("fetch"),
1735            params: {
1736                let mut m = serde_json::Map::new();
1737                m.insert(
1738                    "url".to_owned(),
1739                    serde_json::json!("https://localhost/page"),
1740                );
1741                m
1742            },
1743            caller_id: None,
1744        };
1745        let result = executor.execute_tool_call(&call).await;
1746        assert!(matches!(result, Err(ToolError::Blocked { .. })));
1747    }
1748
1749    #[tokio::test]
1750    async fn fetch_unknown_tool_returns_none() {
1751        let config = ScrapeConfig::default();
1752        let executor = WebScrapeExecutor::new(&config);
1753        let call = crate::executor::ToolCall {
1754            tool_id: ToolName::new("unknown_tool"),
1755            params: serde_json::Map::new(),
1756            caller_id: None,
1757        };
1758        let result = executor.execute_tool_call(&call).await;
1759        assert!(result.unwrap().is_none());
1760    }
1761
1762    #[tokio::test]
1763    async fn fetch_returns_body_via_mock() {
1764        use wiremock::matchers::{method, path};
1765        use wiremock::{Mock, ResponseTemplate};
1766
1767        let (executor, server) = mock_server_executor().await;
1768        Mock::given(method("GET"))
1769            .and(path("/content"))
1770            .respond_with(ResponseTemplate::new(200).set_body_string("plain text content"))
1771            .mount(&server)
1772            .await;
1773
1774        let (host, addrs) = server_host_and_addr(&server);
1775        let url = format!("{}/content", server.uri());
1776        let result = executor.fetch_html(&url, &host, &addrs).await;
1777        assert!(result.is_ok());
1778        assert_eq!(result.unwrap(), "plain text content");
1779    }
1780
1781    #[test]
1782    fn tool_definitions_returns_web_scrape_and_fetch() {
1783        let config = ScrapeConfig::default();
1784        let executor = WebScrapeExecutor::new(&config);
1785        let defs = executor.tool_definitions();
1786        assert_eq!(defs.len(), 2);
1787        assert_eq!(defs[0].id, "web_scrape");
1788        assert_eq!(
1789            defs[0].invocation,
1790            crate::registry::InvocationHint::FencedBlock("scrape")
1791        );
1792        assert_eq!(defs[1].id, "fetch");
1793        assert_eq!(
1794            defs[1].invocation,
1795            crate::registry::InvocationHint::ToolCall
1796        );
1797    }
1798
1799    #[test]
1800    fn tool_definitions_schema_has_all_params() {
1801        let config = ScrapeConfig::default();
1802        let executor = WebScrapeExecutor::new(&config);
1803        let defs = executor.tool_definitions();
1804        let obj = defs[0].schema.as_object().unwrap();
1805        let props = obj["properties"].as_object().unwrap();
1806        assert!(props.contains_key("url"));
1807        assert!(props.contains_key("select"));
1808        assert!(props.contains_key("extract"));
1809        assert!(props.contains_key("limit"));
1810        let req = obj["required"].as_array().unwrap();
1811        assert!(req.iter().any(|v| v.as_str() == Some("url")));
1812        assert!(req.iter().any(|v| v.as_str() == Some("select")));
1813        assert!(!req.iter().any(|v| v.as_str() == Some("extract")));
1814    }
1815
1816    // --- is_private_host: new domain checks (AUD-02) ---
1817
1818    #[test]
1819    fn subdomain_localhost_blocked() {
1820        let host: url::Host<&str> = url::Host::Domain("foo.localhost");
1821        assert!(is_private_host(&host));
1822    }
1823
1824    #[test]
1825    fn internal_tld_blocked() {
1826        let host: url::Host<&str> = url::Host::Domain("service.internal");
1827        assert!(is_private_host(&host));
1828    }
1829
1830    #[test]
1831    fn local_tld_blocked() {
1832        let host: url::Host<&str> = url::Host::Domain("printer.local");
1833        assert!(is_private_host(&host));
1834    }
1835
1836    #[test]
1837    fn public_domain_not_blocked() {
1838        let host: url::Host<&str> = url::Host::Domain("example.com");
1839        assert!(!is_private_host(&host));
1840    }
1841
1842    // --- resolve_and_validate: private IP rejection ---
1843
1844    #[tokio::test]
1845    async fn resolve_loopback_rejected() {
1846        // 127.0.0.1 resolves directly (literal IP in DNS query)
1847        let url = url::Url::parse("https://127.0.0.1/path").unwrap();
1848        // validate_url catches this before resolve_and_validate, but test directly
1849        let result = resolve_and_validate(&url).await;
1850        assert!(
1851            result.is_err(),
1852            "loopback IP must be rejected by resolve_and_validate"
1853        );
1854        let err = result.unwrap_err();
1855        assert!(matches!(err, crate::executor::ToolError::Blocked { .. }));
1856    }
1857
1858    #[tokio::test]
1859    async fn resolve_private_10_rejected() {
1860        let url = url::Url::parse("https://10.0.0.1/path").unwrap();
1861        let result = resolve_and_validate(&url).await;
1862        assert!(result.is_err());
1863        assert!(matches!(
1864            result.unwrap_err(),
1865            crate::executor::ToolError::Blocked { .. }
1866        ));
1867    }
1868
1869    #[tokio::test]
1870    async fn resolve_private_192_rejected() {
1871        let url = url::Url::parse("https://192.168.1.1/path").unwrap();
1872        let result = resolve_and_validate(&url).await;
1873        assert!(result.is_err());
1874        assert!(matches!(
1875            result.unwrap_err(),
1876            crate::executor::ToolError::Blocked { .. }
1877        ));
1878    }
1879
1880    #[tokio::test]
1881    async fn resolve_ipv6_loopback_rejected() {
1882        let url = url::Url::parse("https://[::1]/path").unwrap();
1883        let result = resolve_and_validate(&url).await;
1884        assert!(result.is_err());
1885        assert!(matches!(
1886            result.unwrap_err(),
1887            crate::executor::ToolError::Blocked { .. }
1888        ));
1889    }
1890
1891    #[tokio::test]
1892    async fn resolve_no_host_returns_ok() {
1893        // URL without a resolvable host — should pass through
1894        let url = url::Url::parse("https://example.com/path").unwrap();
1895        // We can't do a live DNS test, but we can verify a URL with no host
1896        let url_no_host = url::Url::parse("data:text/plain,hello").unwrap();
1897        // data: URLs have no host; resolve_and_validate should return Ok with empty addrs
1898        let result = resolve_and_validate(&url_no_host).await;
1899        assert!(result.is_ok());
1900        let (host, addrs) = result.unwrap();
1901        assert!(host.is_empty());
1902        assert!(addrs.is_empty());
1903        drop(url);
1904        drop(url_no_host);
1905    }
1906
1907    // --- audit logging ---
1908
1909    /// Helper: build an `AuditLogger` writing to a temp file, and return the logger + path.
1910    async fn make_file_audit_logger(
1911        dir: &tempfile::TempDir,
1912    ) -> (
1913        std::sync::Arc<crate::audit::AuditLogger>,
1914        std::path::PathBuf,
1915    ) {
1916        use crate::audit::AuditLogger;
1917        use crate::config::AuditConfig;
1918        let path = dir.path().join("audit.log");
1919        let config = AuditConfig {
1920            enabled: true,
1921            destination: path.display().to_string(),
1922            ..Default::default()
1923        };
1924        let logger = std::sync::Arc::new(AuditLogger::from_config(&config, false).await.unwrap());
1925        (logger, path)
1926    }
1927
1928    #[tokio::test]
1929    async fn with_audit_sets_logger() {
1930        let config = ScrapeConfig::default();
1931        let executor = WebScrapeExecutor::new(&config);
1932        assert!(executor.audit_logger.is_none());
1933
1934        let dir = tempfile::tempdir().unwrap();
1935        let (logger, _path) = make_file_audit_logger(&dir).await;
1936        let executor = executor.with_audit(logger);
1937        assert!(executor.audit_logger.is_some());
1938    }
1939
1940    #[test]
1941    fn tool_error_to_audit_result_blocked_maps_correctly() {
1942        let err = ToolError::Blocked {
1943            command: "scheme not allowed: http".into(),
1944        };
1945        let result = tool_error_to_audit_result(&err);
1946        assert!(
1947            matches!(result, AuditResult::Blocked { reason } if reason == "scheme not allowed: http")
1948        );
1949    }
1950
1951    #[test]
1952    fn tool_error_to_audit_result_timeout_maps_correctly() {
1953        let err = ToolError::Timeout { timeout_secs: 15 };
1954        let result = tool_error_to_audit_result(&err);
1955        assert!(matches!(result, AuditResult::Timeout));
1956    }
1957
1958    #[test]
1959    fn tool_error_to_audit_result_execution_error_maps_correctly() {
1960        let err = ToolError::Execution(std::io::Error::other("connection refused"));
1961        let result = tool_error_to_audit_result(&err);
1962        assert!(
1963            matches!(result, AuditResult::Error { message } if message.contains("connection refused"))
1964        );
1965    }
1966
1967    #[tokio::test]
1968    async fn fetch_audit_blocked_url_logged() {
1969        let dir = tempfile::tempdir().unwrap();
1970        let (logger, log_path) = make_file_audit_logger(&dir).await;
1971
1972        let config = ScrapeConfig::default();
1973        let executor = WebScrapeExecutor::new(&config).with_audit(logger);
1974
1975        let call = crate::executor::ToolCall {
1976            tool_id: ToolName::new("fetch"),
1977            params: {
1978                let mut m = serde_json::Map::new();
1979                m.insert("url".to_owned(), serde_json::json!("http://example.com"));
1980                m
1981            },
1982            caller_id: None,
1983        };
1984        let result = executor.execute_tool_call(&call).await;
1985        assert!(matches!(result, Err(ToolError::Blocked { .. })));
1986
1987        let content = tokio::fs::read_to_string(&log_path).await.unwrap();
1988        assert!(
1989            content.contains("\"tool\":\"fetch\""),
1990            "expected tool=fetch in audit: {content}"
1991        );
1992        assert!(
1993            content.contains("\"type\":\"blocked\""),
1994            "expected type=blocked in audit: {content}"
1995        );
1996        assert!(
1997            content.contains("http://example.com"),
1998            "expected URL in audit command field: {content}"
1999        );
2000    }
2001
2002    #[tokio::test]
2003    async fn log_audit_success_writes_to_file() {
2004        let dir = tempfile::tempdir().unwrap();
2005        let (logger, log_path) = make_file_audit_logger(&dir).await;
2006
2007        let config = ScrapeConfig::default();
2008        let executor = WebScrapeExecutor::new(&config).with_audit(logger);
2009
2010        executor
2011            .log_audit(
2012                "fetch",
2013                "https://example.com/page",
2014                AuditResult::Success,
2015                42,
2016                None,
2017            )
2018            .await;
2019
2020        let content = tokio::fs::read_to_string(&log_path).await.unwrap();
2021        assert!(
2022            content.contains("\"tool\":\"fetch\""),
2023            "expected tool=fetch in audit: {content}"
2024        );
2025        assert!(
2026            content.contains("\"type\":\"success\""),
2027            "expected type=success in audit: {content}"
2028        );
2029        assert!(
2030            content.contains("\"command\":\"https://example.com/page\""),
2031            "expected command URL in audit: {content}"
2032        );
2033        assert!(
2034            content.contains("\"duration_ms\":42"),
2035            "expected duration_ms in audit: {content}"
2036        );
2037    }
2038
2039    #[tokio::test]
2040    async fn log_audit_blocked_writes_to_file() {
2041        let dir = tempfile::tempdir().unwrap();
2042        let (logger, log_path) = make_file_audit_logger(&dir).await;
2043
2044        let config = ScrapeConfig::default();
2045        let executor = WebScrapeExecutor::new(&config).with_audit(logger);
2046
2047        executor
2048            .log_audit(
2049                "web_scrape",
2050                "http://evil.com/page",
2051                AuditResult::Blocked {
2052                    reason: "scheme not allowed: http".into(),
2053                },
2054                0,
2055                None,
2056            )
2057            .await;
2058
2059        let content = tokio::fs::read_to_string(&log_path).await.unwrap();
2060        assert!(
2061            content.contains("\"tool\":\"web_scrape\""),
2062            "expected tool=web_scrape in audit: {content}"
2063        );
2064        assert!(
2065            content.contains("\"type\":\"blocked\""),
2066            "expected type=blocked in audit: {content}"
2067        );
2068        assert!(
2069            content.contains("scheme not allowed"),
2070            "expected block reason in audit: {content}"
2071        );
2072    }
2073
2074    #[tokio::test]
2075    async fn web_scrape_audit_blocked_url_logged() {
2076        let dir = tempfile::tempdir().unwrap();
2077        let (logger, log_path) = make_file_audit_logger(&dir).await;
2078
2079        let config = ScrapeConfig::default();
2080        let executor = WebScrapeExecutor::new(&config).with_audit(logger);
2081
2082        let call = crate::executor::ToolCall {
2083            tool_id: ToolName::new("web_scrape"),
2084            params: {
2085                let mut m = serde_json::Map::new();
2086                m.insert("url".to_owned(), serde_json::json!("http://example.com"));
2087                m.insert("select".to_owned(), serde_json::json!("h1"));
2088                m
2089            },
2090            caller_id: None,
2091        };
2092        let result = executor.execute_tool_call(&call).await;
2093        assert!(matches!(result, Err(ToolError::Blocked { .. })));
2094
2095        let content = tokio::fs::read_to_string(&log_path).await.unwrap();
2096        assert!(
2097            content.contains("\"tool\":\"web_scrape\""),
2098            "expected tool=web_scrape in audit: {content}"
2099        );
2100        assert!(
2101            content.contains("\"type\":\"blocked\""),
2102            "expected type=blocked in audit: {content}"
2103        );
2104    }
2105
2106    #[tokio::test]
2107    async fn no_audit_logger_does_not_panic_on_blocked_fetch() {
2108        let config = ScrapeConfig::default();
2109        let executor = WebScrapeExecutor::new(&config);
2110        assert!(executor.audit_logger.is_none());
2111
2112        let call = crate::executor::ToolCall {
2113            tool_id: ToolName::new("fetch"),
2114            params: {
2115                let mut m = serde_json::Map::new();
2116                m.insert("url".to_owned(), serde_json::json!("http://example.com"));
2117                m
2118            },
2119            caller_id: None,
2120        };
2121        // Must not panic even without an audit logger
2122        let result = executor.execute_tool_call(&call).await;
2123        assert!(matches!(result, Err(ToolError::Blocked { .. })));
2124    }
2125
2126    // CR-10: fetch end-to-end via execute_tool_call -> handle_fetch -> fetch_html
2127    #[tokio::test]
2128    async fn fetch_execute_tool_call_end_to_end() {
2129        use wiremock::matchers::{method, path};
2130        use wiremock::{Mock, ResponseTemplate};
2131
2132        let (executor, server) = mock_server_executor().await;
2133        Mock::given(method("GET"))
2134            .and(path("/e2e"))
2135            .respond_with(ResponseTemplate::new(200).set_body_string("<h1>end-to-end</h1>"))
2136            .mount(&server)
2137            .await;
2138
2139        let (host, addrs) = server_host_and_addr(&server);
2140        // Call fetch_html directly (bypassing SSRF guard for loopback mock server)
2141        let result = executor
2142            .fetch_html(&format!("{}/e2e", server.uri()), &host, &addrs)
2143            .await;
2144        assert!(result.is_ok());
2145        assert!(result.unwrap().contains("end-to-end"));
2146    }
2147
2148    // --- domain_matches ---
2149
2150    #[test]
2151    fn domain_matches_exact() {
2152        assert!(domain_matches("example.com", "example.com"));
2153        assert!(!domain_matches("example.com", "other.com"));
2154        assert!(!domain_matches("example.com", "sub.example.com"));
2155    }
2156
2157    #[test]
2158    fn domain_matches_wildcard_single_subdomain() {
2159        assert!(domain_matches("*.example.com", "sub.example.com"));
2160        assert!(!domain_matches("*.example.com", "example.com"));
2161        assert!(!domain_matches("*.example.com", "sub.sub.example.com"));
2162    }
2163
2164    #[test]
2165    fn domain_matches_wildcard_does_not_match_empty_label() {
2166        // Pattern "*.example.com" requires a non-empty label before ".example.com"
2167        assert!(!domain_matches("*.example.com", ".example.com"));
2168    }
2169
2170    #[test]
2171    fn domain_matches_multi_wildcard_treated_as_exact() {
2172        // Multiple wildcards are unsupported — treated as literal pattern
2173        assert!(!domain_matches("*.*.example.com", "a.b.example.com"));
2174    }
2175
2176    // --- check_domain_policy ---
2177
2178    #[test]
2179    fn check_domain_policy_empty_lists_allow_all() {
2180        assert!(check_domain_policy("example.com", &[], &[]).is_ok());
2181        assert!(check_domain_policy("evil.com", &[], &[]).is_ok());
2182    }
2183
2184    #[test]
2185    fn check_domain_policy_denylist_blocks() {
2186        let denied = vec!["evil.com".to_string()];
2187        let err = check_domain_policy("evil.com", &[], &denied).unwrap_err();
2188        assert!(matches!(err, ToolError::Blocked { .. }));
2189    }
2190
2191    #[test]
2192    fn check_domain_policy_denylist_does_not_block_other_domains() {
2193        let denied = vec!["evil.com".to_string()];
2194        assert!(check_domain_policy("good.com", &[], &denied).is_ok());
2195    }
2196
2197    #[test]
2198    fn check_domain_policy_allowlist_permits_matching() {
2199        let allowed = vec!["docs.rs".to_string(), "*.rust-lang.org".to_string()];
2200        assert!(check_domain_policy("docs.rs", &allowed, &[]).is_ok());
2201        assert!(check_domain_policy("blog.rust-lang.org", &allowed, &[]).is_ok());
2202    }
2203
2204    #[test]
2205    fn check_domain_policy_allowlist_blocks_unknown() {
2206        let allowed = vec!["docs.rs".to_string()];
2207        let err = check_domain_policy("other.com", &allowed, &[]).unwrap_err();
2208        assert!(matches!(err, ToolError::Blocked { .. }));
2209    }
2210
2211    #[test]
2212    fn check_domain_policy_deny_overrides_allow() {
2213        let allowed = vec!["example.com".to_string()];
2214        let denied = vec!["example.com".to_string()];
2215        let err = check_domain_policy("example.com", &allowed, &denied).unwrap_err();
2216        assert!(matches!(err, ToolError::Blocked { .. }));
2217    }
2218
2219    #[test]
2220    fn check_domain_policy_wildcard_in_denylist() {
2221        let denied = vec!["*.evil.com".to_string()];
2222        let err = check_domain_policy("sub.evil.com", &[], &denied).unwrap_err();
2223        assert!(matches!(err, ToolError::Blocked { .. }));
2224        // parent domain not blocked
2225        assert!(check_domain_policy("evil.com", &[], &denied).is_ok());
2226    }
2227}