Skip to main content

zeph_tools/
scrape.rs

1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4use std::net::{IpAddr, SocketAddr};
5use std::sync::Arc;
6use std::time::{Duration, Instant};
7
8use schemars::JsonSchema;
9use serde::Deserialize;
10use url::Url;
11
12use crate::audit::{AuditEntry, AuditLogger, AuditResult, chrono_now};
13use crate::config::ScrapeConfig;
14use crate::executor::{ToolCall, ToolError, ToolExecutor, ToolOutput, deserialize_params};
15use crate::net::is_private_ip;
16
17#[derive(Debug, Deserialize, JsonSchema)]
18struct FetchParams {
19    /// HTTPS URL to fetch
20    url: String,
21}
22
23#[derive(Debug, Deserialize, JsonSchema)]
24struct ScrapeInstruction {
25    /// HTTPS URL to scrape
26    url: String,
27    /// CSS selector
28    select: String,
29    /// Extract mode: text, html, or attr:<name>
30    #[serde(default = "default_extract")]
31    extract: String,
32    /// Max results to return
33    limit: Option<usize>,
34}
35
36fn default_extract() -> String {
37    "text".into()
38}
39
40#[derive(Debug)]
41enum ExtractMode {
42    Text,
43    Html,
44    Attr(String),
45}
46
47impl ExtractMode {
48    fn parse(s: &str) -> Self {
49        match s {
50            "text" => Self::Text,
51            "html" => Self::Html,
52            attr if attr.starts_with("attr:") => {
53                Self::Attr(attr.strip_prefix("attr:").unwrap_or(attr).to_owned())
54            }
55            _ => Self::Text,
56        }
57    }
58}
59
60/// Extracts data from web pages via CSS selectors.
61///
62/// Detects ` ```scrape ` blocks in LLM responses containing JSON instructions,
63/// fetches the URL, and parses HTML with `scrape-core`.
64#[derive(Debug)]
65pub struct WebScrapeExecutor {
66    timeout: Duration,
67    max_body_bytes: usize,
68    audit_logger: Option<Arc<AuditLogger>>,
69}
70
71impl WebScrapeExecutor {
72    #[must_use]
73    pub fn new(config: &ScrapeConfig) -> Self {
74        Self {
75            timeout: Duration::from_secs(config.timeout),
76            max_body_bytes: config.max_body_bytes,
77            audit_logger: None,
78        }
79    }
80
81    #[must_use]
82    pub fn with_audit(mut self, logger: Arc<AuditLogger>) -> Self {
83        self.audit_logger = Some(logger);
84        self
85    }
86
87    fn build_client(&self, host: &str, addrs: &[SocketAddr]) -> reqwest::Client {
88        let mut builder = reqwest::Client::builder()
89            .timeout(self.timeout)
90            .redirect(reqwest::redirect::Policy::none());
91        builder = builder.resolve_to_addrs(host, addrs);
92        builder.build().unwrap_or_default()
93    }
94}
95
96impl ToolExecutor for WebScrapeExecutor {
97    fn tool_definitions(&self) -> Vec<crate::registry::ToolDef> {
98        use crate::registry::{InvocationHint, ToolDef};
99        vec![
100            ToolDef {
101                id: "web_scrape".into(),
102                description: "Extract structured data from a web page using CSS selectors.\n\nParameters: url (string, required) - HTTPS URL; select (string, required) - CSS selector; extract (string, optional) - \"text\", \"html\", or \"attr:<name>\"; limit (integer, optional) - max results\nReturns: extracted text/HTML/attribute values, one per line\nErrors: InvalidParams if URL is not HTTPS or selector is empty; Timeout after configured seconds; connection/DNS failures\nExample: {\"url\": \"https://example.com\", \"select\": \"h1\", \"extract\": \"text\"}".into(),
103                schema: schemars::schema_for!(ScrapeInstruction),
104                invocation: InvocationHint::FencedBlock("scrape"),
105            },
106            ToolDef {
107                id: "fetch".into(),
108                description: "Fetch a URL and return the response body as plain text.\n\nParameters: url (string, required) - HTTPS URL to fetch\nReturns: response body as UTF-8 text, truncated if exceeding max body size\nErrors: InvalidParams if URL is not HTTPS; Timeout; SSRF-blocked private IPs; connection failures\nExample: {\"url\": \"https://api.example.com/data.json\"}".into(),
109                schema: schemars::schema_for!(FetchParams),
110                invocation: InvocationHint::ToolCall,
111            },
112        ]
113    }
114
115    async fn execute(&self, response: &str) -> Result<Option<ToolOutput>, ToolError> {
116        let blocks = extract_scrape_blocks(response);
117        if blocks.is_empty() {
118            return Ok(None);
119        }
120
121        let mut outputs = Vec::with_capacity(blocks.len());
122        #[allow(clippy::cast_possible_truncation)]
123        let blocks_executed = blocks.len() as u32;
124
125        for block in &blocks {
126            let instruction: ScrapeInstruction = serde_json::from_str(block).map_err(|e| {
127                ToolError::Execution(std::io::Error::new(
128                    std::io::ErrorKind::InvalidData,
129                    e.to_string(),
130                ))
131            })?;
132            let start = Instant::now();
133            let scrape_result = self.scrape_instruction(&instruction).await;
134            #[allow(clippy::cast_possible_truncation)]
135            let duration_ms = start.elapsed().as_millis() as u64;
136            match scrape_result {
137                Ok(output) => {
138                    self.log_audit(
139                        "web_scrape",
140                        &instruction.url,
141                        AuditResult::Success,
142                        duration_ms,
143                    )
144                    .await;
145                    outputs.push(output);
146                }
147                Err(e) => {
148                    let audit_result = tool_error_to_audit_result(&e);
149                    self.log_audit("web_scrape", &instruction.url, audit_result, duration_ms)
150                        .await;
151                    return Err(e);
152                }
153            }
154        }
155
156        Ok(Some(ToolOutput {
157            tool_name: "web-scrape".to_owned(),
158            summary: outputs.join("\n\n"),
159            blocks_executed,
160            filter_stats: None,
161            diff: None,
162            streamed: false,
163            terminal_id: None,
164            locations: None,
165            raw_response: None,
166        }))
167    }
168
169    async fn execute_tool_call(&self, call: &ToolCall) -> Result<Option<ToolOutput>, ToolError> {
170        match call.tool_id.as_str() {
171            "web_scrape" => {
172                let instruction: ScrapeInstruction = deserialize_params(&call.params)?;
173                let start = Instant::now();
174                let result = self.scrape_instruction(&instruction).await;
175                #[allow(clippy::cast_possible_truncation)]
176                let duration_ms = start.elapsed().as_millis() as u64;
177                match result {
178                    Ok(output) => {
179                        self.log_audit(
180                            "web_scrape",
181                            &instruction.url,
182                            AuditResult::Success,
183                            duration_ms,
184                        )
185                        .await;
186                        Ok(Some(ToolOutput {
187                            tool_name: "web-scrape".to_owned(),
188                            summary: output,
189                            blocks_executed: 1,
190                            filter_stats: None,
191                            diff: None,
192                            streamed: false,
193                            terminal_id: None,
194                            locations: None,
195                            raw_response: None,
196                        }))
197                    }
198                    Err(e) => {
199                        let audit_result = tool_error_to_audit_result(&e);
200                        self.log_audit("web_scrape", &instruction.url, audit_result, duration_ms)
201                            .await;
202                        Err(e)
203                    }
204                }
205            }
206            "fetch" => {
207                let p: FetchParams = deserialize_params(&call.params)?;
208                let start = Instant::now();
209                let result = self.handle_fetch(&p).await;
210                #[allow(clippy::cast_possible_truncation)]
211                let duration_ms = start.elapsed().as_millis() as u64;
212                match result {
213                    Ok(output) => {
214                        self.log_audit("fetch", &p.url, AuditResult::Success, duration_ms)
215                            .await;
216                        Ok(Some(ToolOutput {
217                            tool_name: "fetch".to_owned(),
218                            summary: output,
219                            blocks_executed: 1,
220                            filter_stats: None,
221                            diff: None,
222                            streamed: false,
223                            terminal_id: None,
224                            locations: None,
225                            raw_response: None,
226                        }))
227                    }
228                    Err(e) => {
229                        let audit_result = tool_error_to_audit_result(&e);
230                        self.log_audit("fetch", &p.url, audit_result, duration_ms)
231                            .await;
232                        Err(e)
233                    }
234                }
235            }
236            _ => Ok(None),
237        }
238    }
239
240    fn is_tool_retryable(&self, tool_id: &str) -> bool {
241        matches!(tool_id, "web_scrape" | "fetch")
242    }
243}
244
245fn tool_error_to_audit_result(e: &ToolError) -> AuditResult {
246    match e {
247        ToolError::Blocked { command } => AuditResult::Blocked {
248            reason: command.clone(),
249        },
250        ToolError::Timeout { .. } => AuditResult::Timeout,
251        _ => AuditResult::Error {
252            message: e.to_string(),
253        },
254    }
255}
256
257impl WebScrapeExecutor {
258    async fn log_audit(&self, tool: &str, command: &str, result: AuditResult, duration_ms: u64) {
259        if let Some(ref logger) = self.audit_logger {
260            let entry = AuditEntry {
261                timestamp: chrono_now(),
262                tool: tool.into(),
263                command: command.into(),
264                result,
265                duration_ms,
266            };
267            logger.log(&entry).await;
268        }
269    }
270
271    async fn handle_fetch(&self, params: &FetchParams) -> Result<String, ToolError> {
272        let parsed = validate_url(&params.url)?;
273        let (host, addrs) = resolve_and_validate(&parsed).await?;
274        self.fetch_html(&params.url, &host, &addrs).await
275    }
276
277    async fn scrape_instruction(
278        &self,
279        instruction: &ScrapeInstruction,
280    ) -> Result<String, ToolError> {
281        let parsed = validate_url(&instruction.url)?;
282        let (host, addrs) = resolve_and_validate(&parsed).await?;
283        let html = self.fetch_html(&instruction.url, &host, &addrs).await?;
284        let selector = instruction.select.clone();
285        let extract = ExtractMode::parse(&instruction.extract);
286        let limit = instruction.limit.unwrap_or(10);
287        tokio::task::spawn_blocking(move || parse_and_extract(&html, &selector, &extract, limit))
288            .await
289            .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())))?
290    }
291
292    /// Fetches the HTML at `url`, manually following up to 3 redirects.
293    ///
294    /// Each redirect target is validated with `validate_url` and `resolve_and_validate`
295    /// before following, preventing SSRF via redirect chains.
296    ///
297    /// # Errors
298    ///
299    /// Returns `ToolError::Blocked` if any redirect target resolves to a private IP.
300    /// Returns `ToolError::Execution` on HTTP errors, too-large bodies, or too many redirects.
301    async fn fetch_html(
302        &self,
303        url: &str,
304        host: &str,
305        addrs: &[SocketAddr],
306    ) -> Result<String, ToolError> {
307        const MAX_REDIRECTS: usize = 3;
308
309        let mut current_url = url.to_owned();
310        let mut current_host = host.to_owned();
311        let mut current_addrs = addrs.to_vec();
312
313        for hop in 0..=MAX_REDIRECTS {
314            // Build a per-hop client pinned to the current hop's validated addresses.
315            let client = self.build_client(&current_host, &current_addrs);
316            let resp = client
317                .get(&current_url)
318                .send()
319                .await
320                .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())))?;
321
322            let status = resp.status();
323
324            if status.is_redirection() {
325                if hop == MAX_REDIRECTS {
326                    return Err(ToolError::Execution(std::io::Error::other(
327                        "too many redirects",
328                    )));
329                }
330
331                let location = resp
332                    .headers()
333                    .get(reqwest::header::LOCATION)
334                    .and_then(|v| v.to_str().ok())
335                    .ok_or_else(|| {
336                        ToolError::Execution(std::io::Error::other("redirect with no Location"))
337                    })?;
338
339                // Resolve relative redirect URLs against the current URL.
340                let base = Url::parse(&current_url)
341                    .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())))?;
342                let next_url = base
343                    .join(location)
344                    .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())))?;
345
346                let validated = validate_url(next_url.as_str())?;
347                let (next_host, next_addrs) = resolve_and_validate(&validated).await?;
348
349                current_url = next_url.to_string();
350                current_host = next_host;
351                current_addrs = next_addrs;
352                continue;
353            }
354
355            if !status.is_success() {
356                return Err(ToolError::Execution(std::io::Error::other(format!(
357                    "HTTP {status}",
358                ))));
359            }
360
361            let bytes = resp
362                .bytes()
363                .await
364                .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())))?;
365
366            if bytes.len() > self.max_body_bytes {
367                return Err(ToolError::Execution(std::io::Error::other(format!(
368                    "response too large: {} bytes (max: {})",
369                    bytes.len(),
370                    self.max_body_bytes,
371                ))));
372            }
373
374            return String::from_utf8(bytes.to_vec())
375                .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())));
376        }
377
378        Err(ToolError::Execution(std::io::Error::other(
379            "too many redirects",
380        )))
381    }
382}
383
384fn extract_scrape_blocks(text: &str) -> Vec<&str> {
385    crate::executor::extract_fenced_blocks(text, "scrape")
386}
387
388fn validate_url(raw: &str) -> Result<Url, ToolError> {
389    let parsed = Url::parse(raw).map_err(|_| ToolError::Blocked {
390        command: format!("invalid URL: {raw}"),
391    })?;
392
393    if parsed.scheme() != "https" {
394        return Err(ToolError::Blocked {
395            command: format!("scheme not allowed: {}", parsed.scheme()),
396        });
397    }
398
399    if let Some(host) = parsed.host()
400        && is_private_host(&host)
401    {
402        return Err(ToolError::Blocked {
403            command: format!(
404                "private/local host blocked: {}",
405                parsed.host_str().unwrap_or("")
406            ),
407        });
408    }
409
410    Ok(parsed)
411}
412
413fn is_private_host(host: &url::Host<&str>) -> bool {
414    match host {
415        url::Host::Domain(d) => {
416            // Exact match or subdomain of localhost (e.g. foo.localhost)
417            // and .internal/.local TLDs used in cloud/k8s environments.
418            #[allow(clippy::case_sensitive_file_extension_comparisons)]
419            {
420                *d == "localhost"
421                    || d.ends_with(".localhost")
422                    || d.ends_with(".internal")
423                    || d.ends_with(".local")
424            }
425        }
426        url::Host::Ipv4(v4) => is_private_ip(IpAddr::V4(*v4)),
427        url::Host::Ipv6(v6) => is_private_ip(IpAddr::V6(*v6)),
428    }
429}
430
431/// Resolves DNS for the URL host, validates all resolved IPs against private ranges,
432/// and returns the hostname and validated socket addresses.
433///
434/// Returning the addresses allows the caller to pin the HTTP client to these exact
435/// addresses, eliminating TOCTOU between DNS validation and the actual connection.
436async fn resolve_and_validate(url: &Url) -> Result<(String, Vec<SocketAddr>), ToolError> {
437    let Some(host) = url.host_str() else {
438        return Ok((String::new(), vec![]));
439    };
440    let port = url.port_or_known_default().unwrap_or(443);
441    let addrs: Vec<SocketAddr> = tokio::net::lookup_host(format!("{host}:{port}"))
442        .await
443        .map_err(|e| ToolError::Blocked {
444            command: format!("DNS resolution failed: {e}"),
445        })?
446        .collect();
447    for addr in &addrs {
448        if is_private_ip(addr.ip()) {
449            return Err(ToolError::Blocked {
450                command: format!("SSRF protection: private IP {} for host {host}", addr.ip()),
451            });
452        }
453    }
454    Ok((host.to_owned(), addrs))
455}
456
457fn parse_and_extract(
458    html: &str,
459    selector: &str,
460    extract: &ExtractMode,
461    limit: usize,
462) -> Result<String, ToolError> {
463    let soup = scrape_core::Soup::parse(html);
464
465    let tags = soup.find_all(selector).map_err(|e| {
466        ToolError::Execution(std::io::Error::new(
467            std::io::ErrorKind::InvalidData,
468            format!("invalid selector: {e}"),
469        ))
470    })?;
471
472    let mut results = Vec::new();
473
474    for tag in tags.into_iter().take(limit) {
475        let value = match extract {
476            ExtractMode::Text => tag.text(),
477            ExtractMode::Html => tag.inner_html(),
478            ExtractMode::Attr(name) => tag.get(name).unwrap_or_default().to_owned(),
479        };
480        if !value.trim().is_empty() {
481            results.push(value.trim().to_owned());
482        }
483    }
484
485    if results.is_empty() {
486        Ok(format!("No results for selector: {selector}"))
487    } else {
488        Ok(results.join("\n"))
489    }
490}
491
492#[cfg(test)]
493mod tests {
494    use super::*;
495
496    // --- extract_scrape_blocks ---
497
498    #[test]
499    fn extract_single_block() {
500        let text =
501            "Here:\n```scrape\n{\"url\":\"https://example.com\",\"select\":\"h1\"}\n```\nDone.";
502        let blocks = extract_scrape_blocks(text);
503        assert_eq!(blocks.len(), 1);
504        assert!(blocks[0].contains("example.com"));
505    }
506
507    #[test]
508    fn extract_multiple_blocks() {
509        let text = "```scrape\n{\"url\":\"https://a.com\",\"select\":\"h1\"}\n```\ntext\n```scrape\n{\"url\":\"https://b.com\",\"select\":\"p\"}\n```";
510        let blocks = extract_scrape_blocks(text);
511        assert_eq!(blocks.len(), 2);
512    }
513
514    #[test]
515    fn no_blocks_returns_empty() {
516        let blocks = extract_scrape_blocks("plain text, no code blocks");
517        assert!(blocks.is_empty());
518    }
519
520    #[test]
521    fn unclosed_block_ignored() {
522        let blocks = extract_scrape_blocks("```scrape\n{\"url\":\"https://x.com\"}");
523        assert!(blocks.is_empty());
524    }
525
526    #[test]
527    fn non_scrape_block_ignored() {
528        let text =
529            "```bash\necho hi\n```\n```scrape\n{\"url\":\"https://x.com\",\"select\":\"h1\"}\n```";
530        let blocks = extract_scrape_blocks(text);
531        assert_eq!(blocks.len(), 1);
532        assert!(blocks[0].contains("x.com"));
533    }
534
535    #[test]
536    fn multiline_json_block() {
537        let text =
538            "```scrape\n{\n  \"url\": \"https://example.com\",\n  \"select\": \"h1\"\n}\n```";
539        let blocks = extract_scrape_blocks(text);
540        assert_eq!(blocks.len(), 1);
541        let instr: ScrapeInstruction = serde_json::from_str(blocks[0]).unwrap();
542        assert_eq!(instr.url, "https://example.com");
543    }
544
545    // --- ScrapeInstruction parsing ---
546
547    #[test]
548    fn parse_valid_instruction() {
549        let json = r#"{"url":"https://example.com","select":"h1","extract":"text","limit":5}"#;
550        let instr: ScrapeInstruction = serde_json::from_str(json).unwrap();
551        assert_eq!(instr.url, "https://example.com");
552        assert_eq!(instr.select, "h1");
553        assert_eq!(instr.extract, "text");
554        assert_eq!(instr.limit, Some(5));
555    }
556
557    #[test]
558    fn parse_minimal_instruction() {
559        let json = r#"{"url":"https://example.com","select":"p"}"#;
560        let instr: ScrapeInstruction = serde_json::from_str(json).unwrap();
561        assert_eq!(instr.extract, "text");
562        assert!(instr.limit.is_none());
563    }
564
565    #[test]
566    fn parse_attr_extract() {
567        let json = r#"{"url":"https://example.com","select":"a","extract":"attr:href"}"#;
568        let instr: ScrapeInstruction = serde_json::from_str(json).unwrap();
569        assert_eq!(instr.extract, "attr:href");
570    }
571
572    #[test]
573    fn parse_invalid_json_errors() {
574        let result = serde_json::from_str::<ScrapeInstruction>("not json");
575        assert!(result.is_err());
576    }
577
578    // --- ExtractMode ---
579
580    #[test]
581    fn extract_mode_text() {
582        assert!(matches!(ExtractMode::parse("text"), ExtractMode::Text));
583    }
584
585    #[test]
586    fn extract_mode_html() {
587        assert!(matches!(ExtractMode::parse("html"), ExtractMode::Html));
588    }
589
590    #[test]
591    fn extract_mode_attr() {
592        let mode = ExtractMode::parse("attr:href");
593        assert!(matches!(mode, ExtractMode::Attr(ref s) if s == "href"));
594    }
595
596    #[test]
597    fn extract_mode_unknown_defaults_to_text() {
598        assert!(matches!(ExtractMode::parse("unknown"), ExtractMode::Text));
599    }
600
601    // --- validate_url ---
602
603    #[test]
604    fn valid_https_url() {
605        assert!(validate_url("https://example.com").is_ok());
606    }
607
608    #[test]
609    fn http_rejected() {
610        let err = validate_url("http://example.com").unwrap_err();
611        assert!(matches!(err, ToolError::Blocked { .. }));
612    }
613
614    #[test]
615    fn ftp_rejected() {
616        let err = validate_url("ftp://files.example.com").unwrap_err();
617        assert!(matches!(err, ToolError::Blocked { .. }));
618    }
619
620    #[test]
621    fn file_rejected() {
622        let err = validate_url("file:///etc/passwd").unwrap_err();
623        assert!(matches!(err, ToolError::Blocked { .. }));
624    }
625
626    #[test]
627    fn invalid_url_rejected() {
628        let err = validate_url("not a url").unwrap_err();
629        assert!(matches!(err, ToolError::Blocked { .. }));
630    }
631
632    #[test]
633    fn localhost_blocked() {
634        let err = validate_url("https://localhost/path").unwrap_err();
635        assert!(matches!(err, ToolError::Blocked { .. }));
636    }
637
638    #[test]
639    fn loopback_ip_blocked() {
640        let err = validate_url("https://127.0.0.1/path").unwrap_err();
641        assert!(matches!(err, ToolError::Blocked { .. }));
642    }
643
644    #[test]
645    fn private_10_blocked() {
646        let err = validate_url("https://10.0.0.1/api").unwrap_err();
647        assert!(matches!(err, ToolError::Blocked { .. }));
648    }
649
650    #[test]
651    fn private_172_blocked() {
652        let err = validate_url("https://172.16.0.1/api").unwrap_err();
653        assert!(matches!(err, ToolError::Blocked { .. }));
654    }
655
656    #[test]
657    fn private_192_blocked() {
658        let err = validate_url("https://192.168.1.1/api").unwrap_err();
659        assert!(matches!(err, ToolError::Blocked { .. }));
660    }
661
662    #[test]
663    fn ipv6_loopback_blocked() {
664        let err = validate_url("https://[::1]/path").unwrap_err();
665        assert!(matches!(err, ToolError::Blocked { .. }));
666    }
667
668    #[test]
669    fn public_ip_allowed() {
670        assert!(validate_url("https://93.184.216.34/page").is_ok());
671    }
672
673    // --- parse_and_extract ---
674
675    #[test]
676    fn extract_text_from_html() {
677        let html = "<html><body><h1>Hello World</h1><p>Content</p></body></html>";
678        let result = parse_and_extract(html, "h1", &ExtractMode::Text, 10).unwrap();
679        assert_eq!(result, "Hello World");
680    }
681
682    #[test]
683    fn extract_multiple_elements() {
684        let html = "<ul><li>A</li><li>B</li><li>C</li></ul>";
685        let result = parse_and_extract(html, "li", &ExtractMode::Text, 10).unwrap();
686        assert_eq!(result, "A\nB\nC");
687    }
688
689    #[test]
690    fn extract_with_limit() {
691        let html = "<ul><li>A</li><li>B</li><li>C</li></ul>";
692        let result = parse_and_extract(html, "li", &ExtractMode::Text, 2).unwrap();
693        assert_eq!(result, "A\nB");
694    }
695
696    #[test]
697    fn extract_attr_href() {
698        let html = r#"<a href="https://example.com">Link</a>"#;
699        let result =
700            parse_and_extract(html, "a", &ExtractMode::Attr("href".to_owned()), 10).unwrap();
701        assert_eq!(result, "https://example.com");
702    }
703
704    #[test]
705    fn extract_inner_html() {
706        let html = "<div><span>inner</span></div>";
707        let result = parse_and_extract(html, "div", &ExtractMode::Html, 10).unwrap();
708        assert!(result.contains("<span>inner</span>"));
709    }
710
711    #[test]
712    fn no_matches_returns_message() {
713        let html = "<html><body><p>text</p></body></html>";
714        let result = parse_and_extract(html, "h1", &ExtractMode::Text, 10).unwrap();
715        assert!(result.starts_with("No results for selector:"));
716    }
717
718    #[test]
719    fn empty_text_skipped() {
720        let html = "<ul><li>  </li><li>A</li></ul>";
721        let result = parse_and_extract(html, "li", &ExtractMode::Text, 10).unwrap();
722        assert_eq!(result, "A");
723    }
724
725    #[test]
726    fn invalid_selector_errors() {
727        let html = "<html><body></body></html>";
728        let result = parse_and_extract(html, "[[[invalid", &ExtractMode::Text, 10);
729        assert!(result.is_err());
730    }
731
732    #[test]
733    fn empty_html_returns_no_results() {
734        let result = parse_and_extract("", "h1", &ExtractMode::Text, 10).unwrap();
735        assert!(result.starts_with("No results for selector:"));
736    }
737
738    #[test]
739    fn nested_selector() {
740        let html = "<div><span>inner</span></div><span>outer</span>";
741        let result = parse_and_extract(html, "div > span", &ExtractMode::Text, 10).unwrap();
742        assert_eq!(result, "inner");
743    }
744
745    #[test]
746    fn attr_missing_returns_empty() {
747        let html = r"<a>No href</a>";
748        let result =
749            parse_and_extract(html, "a", &ExtractMode::Attr("href".to_owned()), 10).unwrap();
750        assert!(result.starts_with("No results for selector:"));
751    }
752
753    #[test]
754    fn extract_html_mode() {
755        let html = "<div><b>bold</b> text</div>";
756        let result = parse_and_extract(html, "div", &ExtractMode::Html, 10).unwrap();
757        assert!(result.contains("<b>bold</b>"));
758    }
759
760    #[test]
761    fn limit_zero_returns_no_results() {
762        let html = "<ul><li>A</li><li>B</li></ul>";
763        let result = parse_and_extract(html, "li", &ExtractMode::Text, 0).unwrap();
764        assert!(result.starts_with("No results for selector:"));
765    }
766
767    // --- validate_url edge cases ---
768
769    #[test]
770    fn url_with_port_allowed() {
771        assert!(validate_url("https://example.com:8443/path").is_ok());
772    }
773
774    #[test]
775    fn link_local_ip_blocked() {
776        let err = validate_url("https://169.254.1.1/path").unwrap_err();
777        assert!(matches!(err, ToolError::Blocked { .. }));
778    }
779
780    #[test]
781    fn url_no_scheme_rejected() {
782        let err = validate_url("example.com/path").unwrap_err();
783        assert!(matches!(err, ToolError::Blocked { .. }));
784    }
785
786    #[test]
787    fn unspecified_ipv4_blocked() {
788        let err = validate_url("https://0.0.0.0/path").unwrap_err();
789        assert!(matches!(err, ToolError::Blocked { .. }));
790    }
791
792    #[test]
793    fn broadcast_ipv4_blocked() {
794        let err = validate_url("https://255.255.255.255/path").unwrap_err();
795        assert!(matches!(err, ToolError::Blocked { .. }));
796    }
797
798    #[test]
799    fn ipv6_link_local_blocked() {
800        let err = validate_url("https://[fe80::1]/path").unwrap_err();
801        assert!(matches!(err, ToolError::Blocked { .. }));
802    }
803
804    #[test]
805    fn ipv6_unique_local_blocked() {
806        let err = validate_url("https://[fd12::1]/path").unwrap_err();
807        assert!(matches!(err, ToolError::Blocked { .. }));
808    }
809
810    #[test]
811    fn ipv4_mapped_ipv6_loopback_blocked() {
812        let err = validate_url("https://[::ffff:127.0.0.1]/path").unwrap_err();
813        assert!(matches!(err, ToolError::Blocked { .. }));
814    }
815
816    #[test]
817    fn ipv4_mapped_ipv6_private_blocked() {
818        let err = validate_url("https://[::ffff:10.0.0.1]/path").unwrap_err();
819        assert!(matches!(err, ToolError::Blocked { .. }));
820    }
821
822    // --- WebScrapeExecutor (no-network) ---
823
824    #[tokio::test]
825    async fn executor_no_blocks_returns_none() {
826        let config = ScrapeConfig::default();
827        let executor = WebScrapeExecutor::new(&config);
828        let result = executor.execute("plain text").await;
829        assert!(result.unwrap().is_none());
830    }
831
832    #[tokio::test]
833    async fn executor_invalid_json_errors() {
834        let config = ScrapeConfig::default();
835        let executor = WebScrapeExecutor::new(&config);
836        let response = "```scrape\nnot json\n```";
837        let result = executor.execute(response).await;
838        assert!(matches!(result, Err(ToolError::Execution(_))));
839    }
840
841    #[tokio::test]
842    async fn executor_blocked_url_errors() {
843        let config = ScrapeConfig::default();
844        let executor = WebScrapeExecutor::new(&config);
845        let response = "```scrape\n{\"url\":\"http://example.com\",\"select\":\"h1\"}\n```";
846        let result = executor.execute(response).await;
847        assert!(matches!(result, Err(ToolError::Blocked { .. })));
848    }
849
850    #[tokio::test]
851    async fn executor_private_ip_blocked() {
852        let config = ScrapeConfig::default();
853        let executor = WebScrapeExecutor::new(&config);
854        let response = "```scrape\n{\"url\":\"https://192.168.1.1/api\",\"select\":\"h1\"}\n```";
855        let result = executor.execute(response).await;
856        assert!(matches!(result, Err(ToolError::Blocked { .. })));
857    }
858
859    #[tokio::test]
860    async fn executor_unreachable_host_returns_error() {
861        let config = ScrapeConfig {
862            timeout: 1,
863            max_body_bytes: 1_048_576,
864        };
865        let executor = WebScrapeExecutor::new(&config);
866        let response = "```scrape\n{\"url\":\"https://192.0.2.1:1/page\",\"select\":\"h1\"}\n```";
867        let result = executor.execute(response).await;
868        assert!(matches!(result, Err(ToolError::Execution(_))));
869    }
870
871    #[tokio::test]
872    async fn executor_localhost_url_blocked() {
873        let config = ScrapeConfig::default();
874        let executor = WebScrapeExecutor::new(&config);
875        let response = "```scrape\n{\"url\":\"https://localhost:9999/api\",\"select\":\"h1\"}\n```";
876        let result = executor.execute(response).await;
877        assert!(matches!(result, Err(ToolError::Blocked { .. })));
878    }
879
880    #[tokio::test]
881    async fn executor_empty_text_returns_none() {
882        let config = ScrapeConfig::default();
883        let executor = WebScrapeExecutor::new(&config);
884        let result = executor.execute("").await;
885        assert!(result.unwrap().is_none());
886    }
887
888    #[tokio::test]
889    async fn executor_multiple_blocks_first_blocked() {
890        let config = ScrapeConfig::default();
891        let executor = WebScrapeExecutor::new(&config);
892        let response = "```scrape\n{\"url\":\"http://evil.com\",\"select\":\"h1\"}\n```\n\
893             ```scrape\n{\"url\":\"https://ok.com\",\"select\":\"h1\"}\n```";
894        let result = executor.execute(response).await;
895        assert!(result.is_err());
896    }
897
898    #[test]
899    fn validate_url_empty_string() {
900        let err = validate_url("").unwrap_err();
901        assert!(matches!(err, ToolError::Blocked { .. }));
902    }
903
904    #[test]
905    fn validate_url_javascript_scheme_blocked() {
906        let err = validate_url("javascript:alert(1)").unwrap_err();
907        assert!(matches!(err, ToolError::Blocked { .. }));
908    }
909
910    #[test]
911    fn validate_url_data_scheme_blocked() {
912        let err = validate_url("data:text/html,<h1>hi</h1>").unwrap_err();
913        assert!(matches!(err, ToolError::Blocked { .. }));
914    }
915
916    #[test]
917    fn is_private_host_public_domain_is_false() {
918        let host: url::Host<&str> = url::Host::Domain("example.com");
919        assert!(!is_private_host(&host));
920    }
921
922    #[test]
923    fn is_private_host_localhost_is_true() {
924        let host: url::Host<&str> = url::Host::Domain("localhost");
925        assert!(is_private_host(&host));
926    }
927
928    #[test]
929    fn is_private_host_ipv6_unspecified_is_true() {
930        let host = url::Host::Ipv6(std::net::Ipv6Addr::UNSPECIFIED);
931        assert!(is_private_host(&host));
932    }
933
934    #[test]
935    fn is_private_host_public_ipv6_is_false() {
936        let host = url::Host::Ipv6("2001:db8::1".parse().unwrap());
937        assert!(!is_private_host(&host));
938    }
939
940    // --- fetch_html redirect logic: wiremock HTTP server tests ---
941    //
942    // These tests use a local wiremock server to exercise the redirect-following logic
943    // in `fetch_html` without requiring an external HTTPS connection. The server binds to
944    // 127.0.0.1, and tests call `fetch_html` directly (bypassing `validate_url`) to avoid
945    // the SSRF guard that would otherwise block loopback connections.
946
947    /// Helper: returns executor + (`server_url`, `server_addr`) from a running wiremock mock server.
948    /// The server address is passed to `fetch_html` via `resolve_to_addrs` so the client
949    /// connects to the mock instead of doing a real DNS lookup.
950    async fn mock_server_executor() -> (WebScrapeExecutor, wiremock::MockServer) {
951        let server = wiremock::MockServer::start().await;
952        let executor = WebScrapeExecutor {
953            timeout: Duration::from_secs(5),
954            max_body_bytes: 1_048_576,
955            audit_logger: None,
956        };
957        (executor, server)
958    }
959
960    /// Parses the mock server's URI into (`host_str`, `socket_addr`) for use with `build_client`.
961    fn server_host_and_addr(server: &wiremock::MockServer) -> (String, Vec<std::net::SocketAddr>) {
962        let uri = server.uri();
963        let url = Url::parse(&uri).unwrap();
964        let host = url.host_str().unwrap_or("127.0.0.1").to_owned();
965        let port = url.port().unwrap_or(80);
966        let addr: std::net::SocketAddr = format!("{host}:{port}").parse().unwrap();
967        (host, vec![addr])
968    }
969
970    /// Test-only redirect follower that mimics `fetch_html`'s loop but skips `validate_url` /
971    /// `resolve_and_validate`. This lets us exercise the redirect-counting and
972    /// missing-Location logic against a plain HTTP wiremock server.
973    async fn follow_redirects_raw(
974        executor: &WebScrapeExecutor,
975        start_url: &str,
976        host: &str,
977        addrs: &[std::net::SocketAddr],
978    ) -> Result<String, ToolError> {
979        const MAX_REDIRECTS: usize = 3;
980        let mut current_url = start_url.to_owned();
981        let mut current_host = host.to_owned();
982        let mut current_addrs = addrs.to_vec();
983
984        for hop in 0..=MAX_REDIRECTS {
985            let client = executor.build_client(&current_host, &current_addrs);
986            let resp = client
987                .get(&current_url)
988                .send()
989                .await
990                .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())))?;
991
992            let status = resp.status();
993
994            if status.is_redirection() {
995                if hop == MAX_REDIRECTS {
996                    return Err(ToolError::Execution(std::io::Error::other(
997                        "too many redirects",
998                    )));
999                }
1000
1001                let location = resp
1002                    .headers()
1003                    .get(reqwest::header::LOCATION)
1004                    .and_then(|v| v.to_str().ok())
1005                    .ok_or_else(|| {
1006                        ToolError::Execution(std::io::Error::other("redirect with no Location"))
1007                    })?;
1008
1009                let base = Url::parse(&current_url)
1010                    .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())))?;
1011                let next_url = base
1012                    .join(location)
1013                    .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())))?;
1014
1015                // Re-use same host/addrs (mock server is always the same endpoint).
1016                current_url = next_url.to_string();
1017                // Preserve host/addrs as-is since the mock server doesn't change.
1018                let _ = &mut current_host;
1019                let _ = &mut current_addrs;
1020                continue;
1021            }
1022
1023            if !status.is_success() {
1024                return Err(ToolError::Execution(std::io::Error::other(format!(
1025                    "HTTP {status}",
1026                ))));
1027            }
1028
1029            let bytes = resp
1030                .bytes()
1031                .await
1032                .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())))?;
1033
1034            if bytes.len() > executor.max_body_bytes {
1035                return Err(ToolError::Execution(std::io::Error::other(format!(
1036                    "response too large: {} bytes (max: {})",
1037                    bytes.len(),
1038                    executor.max_body_bytes,
1039                ))));
1040            }
1041
1042            return String::from_utf8(bytes.to_vec())
1043                .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())));
1044        }
1045
1046        Err(ToolError::Execution(std::io::Error::other(
1047            "too many redirects",
1048        )))
1049    }
1050
1051    #[tokio::test]
1052    async fn fetch_html_success_returns_body() {
1053        use wiremock::matchers::{method, path};
1054        use wiremock::{Mock, ResponseTemplate};
1055
1056        let (executor, server) = mock_server_executor().await;
1057        Mock::given(method("GET"))
1058            .and(path("/page"))
1059            .respond_with(ResponseTemplate::new(200).set_body_string("<h1>OK</h1>"))
1060            .mount(&server)
1061            .await;
1062
1063        let (host, addrs) = server_host_and_addr(&server);
1064        let url = format!("{}/page", server.uri());
1065        let result = executor.fetch_html(&url, &host, &addrs).await;
1066        assert!(result.is_ok(), "expected Ok, got: {result:?}");
1067        assert_eq!(result.unwrap(), "<h1>OK</h1>");
1068    }
1069
1070    #[tokio::test]
1071    async fn fetch_html_non_2xx_returns_error() {
1072        use wiremock::matchers::{method, path};
1073        use wiremock::{Mock, ResponseTemplate};
1074
1075        let (executor, server) = mock_server_executor().await;
1076        Mock::given(method("GET"))
1077            .and(path("/forbidden"))
1078            .respond_with(ResponseTemplate::new(403))
1079            .mount(&server)
1080            .await;
1081
1082        let (host, addrs) = server_host_and_addr(&server);
1083        let url = format!("{}/forbidden", server.uri());
1084        let result = executor.fetch_html(&url, &host, &addrs).await;
1085        assert!(result.is_err());
1086        let msg = result.unwrap_err().to_string();
1087        assert!(msg.contains("403"), "expected 403 in error: {msg}");
1088    }
1089
1090    #[tokio::test]
1091    async fn fetch_html_404_returns_error() {
1092        use wiremock::matchers::{method, path};
1093        use wiremock::{Mock, ResponseTemplate};
1094
1095        let (executor, server) = mock_server_executor().await;
1096        Mock::given(method("GET"))
1097            .and(path("/missing"))
1098            .respond_with(ResponseTemplate::new(404))
1099            .mount(&server)
1100            .await;
1101
1102        let (host, addrs) = server_host_and_addr(&server);
1103        let url = format!("{}/missing", server.uri());
1104        let result = executor.fetch_html(&url, &host, &addrs).await;
1105        assert!(result.is_err());
1106        let msg = result.unwrap_err().to_string();
1107        assert!(msg.contains("404"), "expected 404 in error: {msg}");
1108    }
1109
1110    #[tokio::test]
1111    async fn fetch_html_redirect_no_location_returns_error() {
1112        use wiremock::matchers::{method, path};
1113        use wiremock::{Mock, ResponseTemplate};
1114
1115        let (executor, server) = mock_server_executor().await;
1116        // 302 with no Location header
1117        Mock::given(method("GET"))
1118            .and(path("/redirect-no-loc"))
1119            .respond_with(ResponseTemplate::new(302))
1120            .mount(&server)
1121            .await;
1122
1123        let (host, addrs) = server_host_and_addr(&server);
1124        let url = format!("{}/redirect-no-loc", server.uri());
1125        let result = executor.fetch_html(&url, &host, &addrs).await;
1126        assert!(result.is_err());
1127        let msg = result.unwrap_err().to_string();
1128        assert!(
1129            msg.contains("Location") || msg.contains("location"),
1130            "expected Location-related error: {msg}"
1131        );
1132    }
1133
1134    #[tokio::test]
1135    async fn fetch_html_single_redirect_followed() {
1136        use wiremock::matchers::{method, path};
1137        use wiremock::{Mock, ResponseTemplate};
1138
1139        let (executor, server) = mock_server_executor().await;
1140        let final_url = format!("{}/final", server.uri());
1141
1142        Mock::given(method("GET"))
1143            .and(path("/start"))
1144            .respond_with(ResponseTemplate::new(302).insert_header("location", final_url.as_str()))
1145            .mount(&server)
1146            .await;
1147
1148        Mock::given(method("GET"))
1149            .and(path("/final"))
1150            .respond_with(ResponseTemplate::new(200).set_body_string("<p>final</p>"))
1151            .mount(&server)
1152            .await;
1153
1154        let (host, addrs) = server_host_and_addr(&server);
1155        let url = format!("{}/start", server.uri());
1156        let result = follow_redirects_raw(&executor, &url, &host, &addrs).await;
1157        assert!(result.is_ok(), "single redirect should succeed: {result:?}");
1158        assert_eq!(result.unwrap(), "<p>final</p>");
1159    }
1160
1161    #[tokio::test]
1162    async fn fetch_html_three_redirects_allowed() {
1163        use wiremock::matchers::{method, path};
1164        use wiremock::{Mock, ResponseTemplate};
1165
1166        let (executor, server) = mock_server_executor().await;
1167        let hop2 = format!("{}/hop2", server.uri());
1168        let hop3 = format!("{}/hop3", server.uri());
1169        let final_dest = format!("{}/done", server.uri());
1170
1171        Mock::given(method("GET"))
1172            .and(path("/hop1"))
1173            .respond_with(ResponseTemplate::new(301).insert_header("location", hop2.as_str()))
1174            .mount(&server)
1175            .await;
1176        Mock::given(method("GET"))
1177            .and(path("/hop2"))
1178            .respond_with(ResponseTemplate::new(301).insert_header("location", hop3.as_str()))
1179            .mount(&server)
1180            .await;
1181        Mock::given(method("GET"))
1182            .and(path("/hop3"))
1183            .respond_with(ResponseTemplate::new(301).insert_header("location", final_dest.as_str()))
1184            .mount(&server)
1185            .await;
1186        Mock::given(method("GET"))
1187            .and(path("/done"))
1188            .respond_with(ResponseTemplate::new(200).set_body_string("<p>done</p>"))
1189            .mount(&server)
1190            .await;
1191
1192        let (host, addrs) = server_host_and_addr(&server);
1193        let url = format!("{}/hop1", server.uri());
1194        let result = follow_redirects_raw(&executor, &url, &host, &addrs).await;
1195        assert!(result.is_ok(), "3 redirects should succeed: {result:?}");
1196        assert_eq!(result.unwrap(), "<p>done</p>");
1197    }
1198
1199    #[tokio::test]
1200    async fn fetch_html_four_redirects_rejected() {
1201        use wiremock::matchers::{method, path};
1202        use wiremock::{Mock, ResponseTemplate};
1203
1204        let (executor, server) = mock_server_executor().await;
1205        let hop2 = format!("{}/r2", server.uri());
1206        let hop3 = format!("{}/r3", server.uri());
1207        let hop4 = format!("{}/r4", server.uri());
1208        let hop5 = format!("{}/r5", server.uri());
1209
1210        for (from, to) in [
1211            ("/r1", &hop2),
1212            ("/r2", &hop3),
1213            ("/r3", &hop4),
1214            ("/r4", &hop5),
1215        ] {
1216            Mock::given(method("GET"))
1217                .and(path(from))
1218                .respond_with(ResponseTemplate::new(301).insert_header("location", to.as_str()))
1219                .mount(&server)
1220                .await;
1221        }
1222
1223        let (host, addrs) = server_host_and_addr(&server);
1224        let url = format!("{}/r1", server.uri());
1225        let result = follow_redirects_raw(&executor, &url, &host, &addrs).await;
1226        assert!(result.is_err(), "4 redirects should be rejected");
1227        let msg = result.unwrap_err().to_string();
1228        assert!(
1229            msg.contains("redirect"),
1230            "expected redirect-related error: {msg}"
1231        );
1232    }
1233
1234    #[tokio::test]
1235    async fn fetch_html_body_too_large_returns_error() {
1236        use wiremock::matchers::{method, path};
1237        use wiremock::{Mock, ResponseTemplate};
1238
1239        let small_limit_executor = WebScrapeExecutor {
1240            timeout: Duration::from_secs(5),
1241            max_body_bytes: 10,
1242            audit_logger: None,
1243        };
1244        let server = wiremock::MockServer::start().await;
1245        Mock::given(method("GET"))
1246            .and(path("/big"))
1247            .respond_with(
1248                ResponseTemplate::new(200)
1249                    .set_body_string("this body is definitely longer than ten bytes"),
1250            )
1251            .mount(&server)
1252            .await;
1253
1254        let (host, addrs) = server_host_and_addr(&server);
1255        let url = format!("{}/big", server.uri());
1256        let result = small_limit_executor.fetch_html(&url, &host, &addrs).await;
1257        assert!(result.is_err());
1258        let msg = result.unwrap_err().to_string();
1259        assert!(msg.contains("too large"), "expected too-large error: {msg}");
1260    }
1261
1262    #[test]
1263    fn extract_scrape_blocks_empty_block_content() {
1264        let text = "```scrape\n\n```";
1265        let blocks = extract_scrape_blocks(text);
1266        assert_eq!(blocks.len(), 1);
1267        assert!(blocks[0].is_empty());
1268    }
1269
1270    #[test]
1271    fn extract_scrape_blocks_whitespace_only() {
1272        let text = "```scrape\n   \n```";
1273        let blocks = extract_scrape_blocks(text);
1274        assert_eq!(blocks.len(), 1);
1275    }
1276
1277    #[test]
1278    fn parse_and_extract_multiple_selectors() {
1279        let html = "<div><h1>Title</h1><p>Para</p></div>";
1280        let result = parse_and_extract(html, "h1, p", &ExtractMode::Text, 10).unwrap();
1281        assert!(result.contains("Title"));
1282        assert!(result.contains("Para"));
1283    }
1284
1285    #[test]
1286    fn webscrape_executor_new_with_custom_config() {
1287        let config = ScrapeConfig {
1288            timeout: 60,
1289            max_body_bytes: 512,
1290        };
1291        let executor = WebScrapeExecutor::new(&config);
1292        assert_eq!(executor.max_body_bytes, 512);
1293    }
1294
1295    #[test]
1296    fn webscrape_executor_debug() {
1297        let config = ScrapeConfig::default();
1298        let executor = WebScrapeExecutor::new(&config);
1299        let dbg = format!("{executor:?}");
1300        assert!(dbg.contains("WebScrapeExecutor"));
1301    }
1302
1303    #[test]
1304    fn extract_mode_attr_empty_name() {
1305        let mode = ExtractMode::parse("attr:");
1306        assert!(matches!(mode, ExtractMode::Attr(ref s) if s.is_empty()));
1307    }
1308
1309    #[test]
1310    fn default_extract_returns_text() {
1311        assert_eq!(default_extract(), "text");
1312    }
1313
1314    #[test]
1315    fn scrape_instruction_debug() {
1316        let json = r#"{"url":"https://example.com","select":"h1"}"#;
1317        let instr: ScrapeInstruction = serde_json::from_str(json).unwrap();
1318        let dbg = format!("{instr:?}");
1319        assert!(dbg.contains("ScrapeInstruction"));
1320    }
1321
1322    #[test]
1323    fn extract_mode_debug() {
1324        let mode = ExtractMode::Text;
1325        let dbg = format!("{mode:?}");
1326        assert!(dbg.contains("Text"));
1327    }
1328
1329    // --- fetch_html redirect logic: constant and validation unit tests ---
1330
1331    /// `MAX_REDIRECTS` is 3; the 4th redirect attempt must be rejected.
1332    /// Verify the boundary is correct by inspecting the constant value.
1333    #[test]
1334    fn max_redirects_constant_is_three() {
1335        // fetch_html uses `for hop in 0..=MAX_REDIRECTS` and returns error when hop == MAX_REDIRECTS
1336        // while still in a redirect. That means hops 0,1,2 can redirect; hop 3 triggers the error.
1337        // This test documents the expected limit.
1338        const MAX_REDIRECTS: usize = 3;
1339        assert_eq!(MAX_REDIRECTS, 3, "fetch_html allows exactly 3 redirects");
1340    }
1341
1342    /// Verifies that a Location-less redirect would produce an error string containing the
1343    /// expected message, matching the error path in `fetch_html`.
1344    #[test]
1345    fn redirect_no_location_error_message() {
1346        let err = std::io::Error::other("redirect with no Location");
1347        assert!(err.to_string().contains("redirect with no Location"));
1348    }
1349
1350    /// Verifies that a too-many-redirects condition produces the expected error string.
1351    #[test]
1352    fn too_many_redirects_error_message() {
1353        let err = std::io::Error::other("too many redirects");
1354        assert!(err.to_string().contains("too many redirects"));
1355    }
1356
1357    /// Verifies that a non-2xx HTTP status produces an error message with the status code.
1358    #[test]
1359    fn non_2xx_status_error_format() {
1360        let status = reqwest::StatusCode::FORBIDDEN;
1361        let msg = format!("HTTP {status}");
1362        assert!(msg.contains("403"));
1363    }
1364
1365    /// Verifies that a 404 response status code formats into the expected error message.
1366    #[test]
1367    fn not_found_status_error_format() {
1368        let status = reqwest::StatusCode::NOT_FOUND;
1369        let msg = format!("HTTP {status}");
1370        assert!(msg.contains("404"));
1371    }
1372
1373    /// Verifies relative redirect resolution for same-host paths (simulates Location: /other).
1374    #[test]
1375    fn relative_redirect_same_host_path() {
1376        let base = Url::parse("https://example.com/current").unwrap();
1377        let resolved = base.join("/other").unwrap();
1378        assert_eq!(resolved.as_str(), "https://example.com/other");
1379    }
1380
1381    /// Verifies relative redirect resolution preserves scheme and host.
1382    #[test]
1383    fn relative_redirect_relative_path() {
1384        let base = Url::parse("https://example.com/a/b").unwrap();
1385        let resolved = base.join("c").unwrap();
1386        assert_eq!(resolved.as_str(), "https://example.com/a/c");
1387    }
1388
1389    /// Verifies that an absolute redirect URL overrides base URL completely.
1390    #[test]
1391    fn absolute_redirect_overrides_base() {
1392        let base = Url::parse("https://example.com/page").unwrap();
1393        let resolved = base.join("https://other.com/target").unwrap();
1394        assert_eq!(resolved.as_str(), "https://other.com/target");
1395    }
1396
1397    /// Verifies that a redirect Location of http:// (downgrade) is rejected.
1398    #[test]
1399    fn redirect_http_downgrade_rejected() {
1400        let location = "http://example.com/page";
1401        let base = Url::parse("https://example.com/start").unwrap();
1402        let next = base.join(location).unwrap();
1403        let err = validate_url(next.as_str()).unwrap_err();
1404        assert!(matches!(err, ToolError::Blocked { .. }));
1405    }
1406
1407    /// Verifies that a redirect to a private IP literal is blocked.
1408    #[test]
1409    fn redirect_location_private_ip_blocked() {
1410        let location = "https://192.168.100.1/admin";
1411        let base = Url::parse("https://example.com/start").unwrap();
1412        let next = base.join(location).unwrap();
1413        let err = validate_url(next.as_str()).unwrap_err();
1414        assert!(matches!(err, ToolError::Blocked { .. }));
1415        let ToolError::Blocked { command: cmd } = err else {
1416            panic!("expected Blocked");
1417        };
1418        assert!(
1419            cmd.contains("private") || cmd.contains("scheme"),
1420            "error message should describe the block reason: {cmd}"
1421        );
1422    }
1423
1424    /// Verifies that a redirect to a .internal domain is blocked.
1425    #[test]
1426    fn redirect_location_internal_domain_blocked() {
1427        let location = "https://metadata.internal/latest/meta-data/";
1428        let base = Url::parse("https://example.com/start").unwrap();
1429        let next = base.join(location).unwrap();
1430        let err = validate_url(next.as_str()).unwrap_err();
1431        assert!(matches!(err, ToolError::Blocked { .. }));
1432    }
1433
1434    /// Verifies that a chain of 3 valid public redirects passes `validate_url` at every hop.
1435    #[test]
1436    fn redirect_chain_three_hops_all_public() {
1437        let hops = [
1438            "https://redirect1.example.com/hop1",
1439            "https://redirect2.example.com/hop2",
1440            "https://destination.example.com/final",
1441        ];
1442        for hop in hops {
1443            assert!(validate_url(hop).is_ok(), "expected ok for {hop}");
1444        }
1445    }
1446
1447    // --- SSRF redirect chain defense ---
1448
1449    /// Verifies that a redirect Location pointing to a private IP is rejected by `validate_url`
1450    /// before any connection attempt — simulating the validation step inside `fetch_html`.
1451    #[test]
1452    fn redirect_to_private_ip_rejected_by_validate_url() {
1453        // These would appear as Location headers in a redirect response.
1454        let private_targets = [
1455            "https://127.0.0.1/secret",
1456            "https://10.0.0.1/internal",
1457            "https://192.168.1.1/admin",
1458            "https://172.16.0.1/data",
1459            "https://[::1]/path",
1460            "https://[fe80::1]/path",
1461            "https://localhost/path",
1462            "https://service.internal/api",
1463        ];
1464        for target in private_targets {
1465            let result = validate_url(target);
1466            assert!(result.is_err(), "expected error for {target}");
1467            assert!(
1468                matches!(result.unwrap_err(), ToolError::Blocked { .. }),
1469                "expected Blocked for {target}"
1470            );
1471        }
1472    }
1473
1474    /// Verifies that relative redirect URLs are resolved correctly before validation.
1475    #[test]
1476    fn redirect_relative_url_resolves_correctly() {
1477        let base = Url::parse("https://example.com/page").unwrap();
1478        let relative = "/other";
1479        let resolved = base.join(relative).unwrap();
1480        assert_eq!(resolved.as_str(), "https://example.com/other");
1481    }
1482
1483    /// Verifies that a protocol-relative redirect to http:// is rejected (scheme check).
1484    #[test]
1485    fn redirect_to_http_rejected() {
1486        let err = validate_url("http://example.com/page").unwrap_err();
1487        assert!(matches!(err, ToolError::Blocked { .. }));
1488    }
1489
1490    #[test]
1491    fn ipv4_mapped_ipv6_link_local_blocked() {
1492        let err = validate_url("https://[::ffff:169.254.0.1]/path").unwrap_err();
1493        assert!(matches!(err, ToolError::Blocked { .. }));
1494    }
1495
1496    #[test]
1497    fn ipv4_mapped_ipv6_public_allowed() {
1498        assert!(validate_url("https://[::ffff:93.184.216.34]/path").is_ok());
1499    }
1500
1501    // --- fetch tool ---
1502
1503    #[tokio::test]
1504    async fn fetch_http_scheme_blocked() {
1505        let config = ScrapeConfig::default();
1506        let executor = WebScrapeExecutor::new(&config);
1507        let call = crate::executor::ToolCall {
1508            tool_id: "fetch".to_owned(),
1509            params: {
1510                let mut m = serde_json::Map::new();
1511                m.insert("url".to_owned(), serde_json::json!("http://example.com"));
1512                m
1513            },
1514        };
1515        let result = executor.execute_tool_call(&call).await;
1516        assert!(matches!(result, Err(ToolError::Blocked { .. })));
1517    }
1518
1519    #[tokio::test]
1520    async fn fetch_private_ip_blocked() {
1521        let config = ScrapeConfig::default();
1522        let executor = WebScrapeExecutor::new(&config);
1523        let call = crate::executor::ToolCall {
1524            tool_id: "fetch".to_owned(),
1525            params: {
1526                let mut m = serde_json::Map::new();
1527                m.insert(
1528                    "url".to_owned(),
1529                    serde_json::json!("https://192.168.1.1/secret"),
1530                );
1531                m
1532            },
1533        };
1534        let result = executor.execute_tool_call(&call).await;
1535        assert!(matches!(result, Err(ToolError::Blocked { .. })));
1536    }
1537
1538    #[tokio::test]
1539    async fn fetch_localhost_blocked() {
1540        let config = ScrapeConfig::default();
1541        let executor = WebScrapeExecutor::new(&config);
1542        let call = crate::executor::ToolCall {
1543            tool_id: "fetch".to_owned(),
1544            params: {
1545                let mut m = serde_json::Map::new();
1546                m.insert(
1547                    "url".to_owned(),
1548                    serde_json::json!("https://localhost/page"),
1549                );
1550                m
1551            },
1552        };
1553        let result = executor.execute_tool_call(&call).await;
1554        assert!(matches!(result, Err(ToolError::Blocked { .. })));
1555    }
1556
1557    #[tokio::test]
1558    async fn fetch_unknown_tool_returns_none() {
1559        let config = ScrapeConfig::default();
1560        let executor = WebScrapeExecutor::new(&config);
1561        let call = crate::executor::ToolCall {
1562            tool_id: "unknown_tool".to_owned(),
1563            params: serde_json::Map::new(),
1564        };
1565        let result = executor.execute_tool_call(&call).await;
1566        assert!(result.unwrap().is_none());
1567    }
1568
1569    #[tokio::test]
1570    async fn fetch_returns_body_via_mock() {
1571        use wiremock::matchers::{method, path};
1572        use wiremock::{Mock, ResponseTemplate};
1573
1574        let (executor, server) = mock_server_executor().await;
1575        Mock::given(method("GET"))
1576            .and(path("/content"))
1577            .respond_with(ResponseTemplate::new(200).set_body_string("plain text content"))
1578            .mount(&server)
1579            .await;
1580
1581        let (host, addrs) = server_host_and_addr(&server);
1582        let url = format!("{}/content", server.uri());
1583        let result = executor.fetch_html(&url, &host, &addrs).await;
1584        assert!(result.is_ok());
1585        assert_eq!(result.unwrap(), "plain text content");
1586    }
1587
1588    #[test]
1589    fn tool_definitions_returns_web_scrape_and_fetch() {
1590        let config = ScrapeConfig::default();
1591        let executor = WebScrapeExecutor::new(&config);
1592        let defs = executor.tool_definitions();
1593        assert_eq!(defs.len(), 2);
1594        assert_eq!(defs[0].id, "web_scrape");
1595        assert_eq!(
1596            defs[0].invocation,
1597            crate::registry::InvocationHint::FencedBlock("scrape")
1598        );
1599        assert_eq!(defs[1].id, "fetch");
1600        assert_eq!(
1601            defs[1].invocation,
1602            crate::registry::InvocationHint::ToolCall
1603        );
1604    }
1605
1606    #[test]
1607    fn tool_definitions_schema_has_all_params() {
1608        let config = ScrapeConfig::default();
1609        let executor = WebScrapeExecutor::new(&config);
1610        let defs = executor.tool_definitions();
1611        let obj = defs[0].schema.as_object().unwrap();
1612        let props = obj["properties"].as_object().unwrap();
1613        assert!(props.contains_key("url"));
1614        assert!(props.contains_key("select"));
1615        assert!(props.contains_key("extract"));
1616        assert!(props.contains_key("limit"));
1617        let req = obj["required"].as_array().unwrap();
1618        assert!(req.iter().any(|v| v.as_str() == Some("url")));
1619        assert!(req.iter().any(|v| v.as_str() == Some("select")));
1620        assert!(!req.iter().any(|v| v.as_str() == Some("extract")));
1621    }
1622
1623    // --- is_private_host: new domain checks (AUD-02) ---
1624
1625    #[test]
1626    fn subdomain_localhost_blocked() {
1627        let host: url::Host<&str> = url::Host::Domain("foo.localhost");
1628        assert!(is_private_host(&host));
1629    }
1630
1631    #[test]
1632    fn internal_tld_blocked() {
1633        let host: url::Host<&str> = url::Host::Domain("service.internal");
1634        assert!(is_private_host(&host));
1635    }
1636
1637    #[test]
1638    fn local_tld_blocked() {
1639        let host: url::Host<&str> = url::Host::Domain("printer.local");
1640        assert!(is_private_host(&host));
1641    }
1642
1643    #[test]
1644    fn public_domain_not_blocked() {
1645        let host: url::Host<&str> = url::Host::Domain("example.com");
1646        assert!(!is_private_host(&host));
1647    }
1648
1649    // --- resolve_and_validate: private IP rejection ---
1650
1651    #[tokio::test]
1652    async fn resolve_loopback_rejected() {
1653        // 127.0.0.1 resolves directly (literal IP in DNS query)
1654        let url = url::Url::parse("https://127.0.0.1/path").unwrap();
1655        // validate_url catches this before resolve_and_validate, but test directly
1656        let result = resolve_and_validate(&url).await;
1657        assert!(
1658            result.is_err(),
1659            "loopback IP must be rejected by resolve_and_validate"
1660        );
1661        let err = result.unwrap_err();
1662        assert!(matches!(err, crate::executor::ToolError::Blocked { .. }));
1663    }
1664
1665    #[tokio::test]
1666    async fn resolve_private_10_rejected() {
1667        let url = url::Url::parse("https://10.0.0.1/path").unwrap();
1668        let result = resolve_and_validate(&url).await;
1669        assert!(result.is_err());
1670        assert!(matches!(
1671            result.unwrap_err(),
1672            crate::executor::ToolError::Blocked { .. }
1673        ));
1674    }
1675
1676    #[tokio::test]
1677    async fn resolve_private_192_rejected() {
1678        let url = url::Url::parse("https://192.168.1.1/path").unwrap();
1679        let result = resolve_and_validate(&url).await;
1680        assert!(result.is_err());
1681        assert!(matches!(
1682            result.unwrap_err(),
1683            crate::executor::ToolError::Blocked { .. }
1684        ));
1685    }
1686
1687    #[tokio::test]
1688    async fn resolve_ipv6_loopback_rejected() {
1689        let url = url::Url::parse("https://[::1]/path").unwrap();
1690        let result = resolve_and_validate(&url).await;
1691        assert!(result.is_err());
1692        assert!(matches!(
1693            result.unwrap_err(),
1694            crate::executor::ToolError::Blocked { .. }
1695        ));
1696    }
1697
1698    #[tokio::test]
1699    async fn resolve_no_host_returns_ok() {
1700        // URL without a resolvable host — should pass through
1701        let url = url::Url::parse("https://example.com/path").unwrap();
1702        // We can't do a live DNS test, but we can verify a URL with no host
1703        let url_no_host = url::Url::parse("data:text/plain,hello").unwrap();
1704        // data: URLs have no host; resolve_and_validate should return Ok with empty addrs
1705        let result = resolve_and_validate(&url_no_host).await;
1706        assert!(result.is_ok());
1707        let (host, addrs) = result.unwrap();
1708        assert!(host.is_empty());
1709        assert!(addrs.is_empty());
1710        drop(url);
1711        drop(url_no_host);
1712    }
1713
1714    // --- audit logging ---
1715
1716    /// Helper: build an `AuditLogger` writing to a temp file, and return the logger + path.
1717    async fn make_file_audit_logger(
1718        dir: &tempfile::TempDir,
1719    ) -> (
1720        std::sync::Arc<crate::audit::AuditLogger>,
1721        std::path::PathBuf,
1722    ) {
1723        use crate::audit::AuditLogger;
1724        use crate::config::AuditConfig;
1725        let path = dir.path().join("audit.log");
1726        let config = AuditConfig {
1727            enabled: true,
1728            destination: path.display().to_string(),
1729        };
1730        let logger = std::sync::Arc::new(AuditLogger::from_config(&config).await.unwrap());
1731        (logger, path)
1732    }
1733
1734    #[tokio::test]
1735    async fn with_audit_sets_logger() {
1736        let config = ScrapeConfig::default();
1737        let executor = WebScrapeExecutor::new(&config);
1738        assert!(executor.audit_logger.is_none());
1739
1740        let dir = tempfile::tempdir().unwrap();
1741        let (logger, _path) = make_file_audit_logger(&dir).await;
1742        let executor = executor.with_audit(logger);
1743        assert!(executor.audit_logger.is_some());
1744    }
1745
1746    #[test]
1747    fn tool_error_to_audit_result_blocked_maps_correctly() {
1748        let err = ToolError::Blocked {
1749            command: "scheme not allowed: http".into(),
1750        };
1751        let result = tool_error_to_audit_result(&err);
1752        assert!(
1753            matches!(result, AuditResult::Blocked { reason } if reason == "scheme not allowed: http")
1754        );
1755    }
1756
1757    #[test]
1758    fn tool_error_to_audit_result_timeout_maps_correctly() {
1759        let err = ToolError::Timeout { timeout_secs: 15 };
1760        let result = tool_error_to_audit_result(&err);
1761        assert!(matches!(result, AuditResult::Timeout));
1762    }
1763
1764    #[test]
1765    fn tool_error_to_audit_result_execution_error_maps_correctly() {
1766        let err = ToolError::Execution(std::io::Error::other("connection refused"));
1767        let result = tool_error_to_audit_result(&err);
1768        assert!(
1769            matches!(result, AuditResult::Error { message } if message.contains("connection refused"))
1770        );
1771    }
1772
1773    #[tokio::test]
1774    async fn fetch_audit_blocked_url_logged() {
1775        let dir = tempfile::tempdir().unwrap();
1776        let (logger, log_path) = make_file_audit_logger(&dir).await;
1777
1778        let config = ScrapeConfig::default();
1779        let executor = WebScrapeExecutor::new(&config).with_audit(logger);
1780
1781        let call = crate::executor::ToolCall {
1782            tool_id: "fetch".to_owned(),
1783            params: {
1784                let mut m = serde_json::Map::new();
1785                m.insert("url".to_owned(), serde_json::json!("http://example.com"));
1786                m
1787            },
1788        };
1789        let result = executor.execute_tool_call(&call).await;
1790        assert!(matches!(result, Err(ToolError::Blocked { .. })));
1791
1792        let content = tokio::fs::read_to_string(&log_path).await.unwrap();
1793        assert!(
1794            content.contains("\"tool\":\"fetch\""),
1795            "expected tool=fetch in audit: {content}"
1796        );
1797        assert!(
1798            content.contains("\"type\":\"blocked\""),
1799            "expected type=blocked in audit: {content}"
1800        );
1801        assert!(
1802            content.contains("http://example.com"),
1803            "expected URL in audit command field: {content}"
1804        );
1805    }
1806
1807    #[tokio::test]
1808    async fn log_audit_success_writes_to_file() {
1809        let dir = tempfile::tempdir().unwrap();
1810        let (logger, log_path) = make_file_audit_logger(&dir).await;
1811
1812        let config = ScrapeConfig::default();
1813        let executor = WebScrapeExecutor::new(&config).with_audit(logger);
1814
1815        executor
1816            .log_audit(
1817                "fetch",
1818                "https://example.com/page",
1819                AuditResult::Success,
1820                42,
1821            )
1822            .await;
1823
1824        let content = tokio::fs::read_to_string(&log_path).await.unwrap();
1825        assert!(
1826            content.contains("\"tool\":\"fetch\""),
1827            "expected tool=fetch in audit: {content}"
1828        );
1829        assert!(
1830            content.contains("\"type\":\"success\""),
1831            "expected type=success in audit: {content}"
1832        );
1833        assert!(
1834            content.contains("\"command\":\"https://example.com/page\""),
1835            "expected command URL in audit: {content}"
1836        );
1837        assert!(
1838            content.contains("\"duration_ms\":42"),
1839            "expected duration_ms in audit: {content}"
1840        );
1841    }
1842
1843    #[tokio::test]
1844    async fn log_audit_blocked_writes_to_file() {
1845        let dir = tempfile::tempdir().unwrap();
1846        let (logger, log_path) = make_file_audit_logger(&dir).await;
1847
1848        let config = ScrapeConfig::default();
1849        let executor = WebScrapeExecutor::new(&config).with_audit(logger);
1850
1851        executor
1852            .log_audit(
1853                "web_scrape",
1854                "http://evil.com/page",
1855                AuditResult::Blocked {
1856                    reason: "scheme not allowed: http".into(),
1857                },
1858                0,
1859            )
1860            .await;
1861
1862        let content = tokio::fs::read_to_string(&log_path).await.unwrap();
1863        assert!(
1864            content.contains("\"tool\":\"web_scrape\""),
1865            "expected tool=web_scrape in audit: {content}"
1866        );
1867        assert!(
1868            content.contains("\"type\":\"blocked\""),
1869            "expected type=blocked in audit: {content}"
1870        );
1871        assert!(
1872            content.contains("scheme not allowed"),
1873            "expected block reason in audit: {content}"
1874        );
1875    }
1876
1877    #[tokio::test]
1878    async fn web_scrape_audit_blocked_url_logged() {
1879        let dir = tempfile::tempdir().unwrap();
1880        let (logger, log_path) = make_file_audit_logger(&dir).await;
1881
1882        let config = ScrapeConfig::default();
1883        let executor = WebScrapeExecutor::new(&config).with_audit(logger);
1884
1885        let call = crate::executor::ToolCall {
1886            tool_id: "web_scrape".to_owned(),
1887            params: {
1888                let mut m = serde_json::Map::new();
1889                m.insert("url".to_owned(), serde_json::json!("http://example.com"));
1890                m.insert("select".to_owned(), serde_json::json!("h1"));
1891                m
1892            },
1893        };
1894        let result = executor.execute_tool_call(&call).await;
1895        assert!(matches!(result, Err(ToolError::Blocked { .. })));
1896
1897        let content = tokio::fs::read_to_string(&log_path).await.unwrap();
1898        assert!(
1899            content.contains("\"tool\":\"web_scrape\""),
1900            "expected tool=web_scrape in audit: {content}"
1901        );
1902        assert!(
1903            content.contains("\"type\":\"blocked\""),
1904            "expected type=blocked in audit: {content}"
1905        );
1906    }
1907
1908    #[tokio::test]
1909    async fn no_audit_logger_does_not_panic_on_blocked_fetch() {
1910        let config = ScrapeConfig::default();
1911        let executor = WebScrapeExecutor::new(&config);
1912        assert!(executor.audit_logger.is_none());
1913
1914        let call = crate::executor::ToolCall {
1915            tool_id: "fetch".to_owned(),
1916            params: {
1917                let mut m = serde_json::Map::new();
1918                m.insert("url".to_owned(), serde_json::json!("http://example.com"));
1919                m
1920            },
1921        };
1922        // Must not panic even without an audit logger
1923        let result = executor.execute_tool_call(&call).await;
1924        assert!(matches!(result, Err(ToolError::Blocked { .. })));
1925    }
1926
1927    // CR-10: fetch end-to-end via execute_tool_call -> handle_fetch -> fetch_html
1928    #[tokio::test]
1929    async fn fetch_execute_tool_call_end_to_end() {
1930        use wiremock::matchers::{method, path};
1931        use wiremock::{Mock, ResponseTemplate};
1932
1933        let (executor, server) = mock_server_executor().await;
1934        Mock::given(method("GET"))
1935            .and(path("/e2e"))
1936            .respond_with(ResponseTemplate::new(200).set_body_string("<h1>end-to-end</h1>"))
1937            .mount(&server)
1938            .await;
1939
1940        let (host, addrs) = server_host_and_addr(&server);
1941        // Call fetch_html directly (bypassing SSRF guard for loopback mock server)
1942        let result = executor
1943            .fetch_html(&format!("{}/e2e", server.uri()), &host, &addrs)
1944            .await;
1945        assert!(result.is_ok());
1946        assert!(result.unwrap().contains("end-to-end"));
1947    }
1948}