Skip to main content

zeph_tools/
scrape.rs

1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4use std::net::{IpAddr, SocketAddr};
5use std::sync::Arc;
6use std::time::{Duration, Instant};
7
8use schemars::JsonSchema;
9use serde::Deserialize;
10use url::Url;
11
12use crate::audit::{AuditEntry, AuditLogger, AuditResult, chrono_now};
13use crate::config::ScrapeConfig;
14use crate::executor::{ToolCall, ToolError, ToolExecutor, ToolOutput, deserialize_params};
15use crate::net::is_private_ip;
16
17#[derive(Debug, Deserialize, JsonSchema)]
18struct FetchParams {
19    /// HTTPS URL to fetch
20    url: String,
21}
22
23#[derive(Debug, Deserialize, JsonSchema)]
24struct ScrapeInstruction {
25    /// HTTPS URL to scrape
26    url: String,
27    /// CSS selector
28    select: String,
29    /// Extract mode: text, html, or attr:<name>
30    #[serde(default = "default_extract")]
31    extract: String,
32    /// Max results to return
33    limit: Option<usize>,
34}
35
36fn default_extract() -> String {
37    "text".into()
38}
39
40#[derive(Debug)]
41enum ExtractMode {
42    Text,
43    Html,
44    Attr(String),
45}
46
47impl ExtractMode {
48    fn parse(s: &str) -> Self {
49        match s {
50            "text" => Self::Text,
51            "html" => Self::Html,
52            attr if attr.starts_with("attr:") => {
53                Self::Attr(attr.strip_prefix("attr:").unwrap_or(attr).to_owned())
54            }
55            _ => Self::Text,
56        }
57    }
58}
59
60/// Extracts data from web pages via CSS selectors.
61///
62/// Detects ` ```scrape ` blocks in LLM responses containing JSON instructions,
63/// fetches the URL, and parses HTML with `scrape-core`.
64#[derive(Debug)]
65pub struct WebScrapeExecutor {
66    timeout: Duration,
67    max_body_bytes: usize,
68    audit_logger: Option<Arc<AuditLogger>>,
69}
70
71impl WebScrapeExecutor {
72    #[must_use]
73    pub fn new(config: &ScrapeConfig) -> Self {
74        Self {
75            timeout: Duration::from_secs(config.timeout),
76            max_body_bytes: config.max_body_bytes,
77            audit_logger: None,
78        }
79    }
80
81    #[must_use]
82    pub fn with_audit(mut self, logger: Arc<AuditLogger>) -> Self {
83        self.audit_logger = Some(logger);
84        self
85    }
86
87    fn build_client(&self, host: &str, addrs: &[SocketAddr]) -> reqwest::Client {
88        let mut builder = reqwest::Client::builder()
89            .timeout(self.timeout)
90            .redirect(reqwest::redirect::Policy::none());
91        builder = builder.resolve_to_addrs(host, addrs);
92        builder.build().unwrap_or_default()
93    }
94}
95
96impl ToolExecutor for WebScrapeExecutor {
97    fn tool_definitions(&self) -> Vec<crate::registry::ToolDef> {
98        use crate::registry::{InvocationHint, ToolDef};
99        vec![
100            ToolDef {
101                id: "web_scrape".into(),
102                description: "Extract structured data from a web page using CSS selectors.\n\nONLY call this tool when the user has explicitly provided a URL in their message, or when a prior tool call returned a URL to retrieve. NEVER construct, guess, or infer a URL from entity names, brand knowledge, or domain patterns.\n\nParameters: url (string, required) - HTTPS URL; select (string, required) - CSS selector; extract (string, optional) - \"text\", \"html\", or \"attr:<name>\"; limit (integer, optional) - max results\nReturns: extracted text/HTML/attribute values, one per line\nErrors: InvalidParams if URL is not HTTPS or selector is empty; Timeout after configured seconds; connection/DNS failures".into(),
103                schema: schemars::schema_for!(ScrapeInstruction),
104                invocation: InvocationHint::FencedBlock("scrape"),
105            },
106            ToolDef {
107                id: "fetch".into(),
108                description: "Fetch a URL and return the response body as plain text.\n\nONLY call this tool when the user has explicitly provided a URL in their message, or when a prior tool call returned a URL to retrieve. NEVER construct, guess, or infer a URL from entity names, brand knowledge, or domain patterns. If no URL is present in the conversation, do not call this tool.\n\nParameters: url (string, required) - HTTPS URL to fetch\nReturns: response body as UTF-8 text, truncated if exceeding max body size\nErrors: InvalidParams if URL is not HTTPS; Timeout; SSRF-blocked private IPs; connection failures".into(),
109                schema: schemars::schema_for!(FetchParams),
110                invocation: InvocationHint::ToolCall,
111            },
112        ]
113    }
114
115    async fn execute(&self, response: &str) -> Result<Option<ToolOutput>, ToolError> {
116        let blocks = extract_scrape_blocks(response);
117        if blocks.is_empty() {
118            return Ok(None);
119        }
120
121        let mut outputs = Vec::with_capacity(blocks.len());
122        #[allow(clippy::cast_possible_truncation)]
123        let blocks_executed = blocks.len() as u32;
124
125        for block in &blocks {
126            let instruction: ScrapeInstruction = serde_json::from_str(block).map_err(|e| {
127                ToolError::Execution(std::io::Error::new(
128                    std::io::ErrorKind::InvalidData,
129                    e.to_string(),
130                ))
131            })?;
132            let start = Instant::now();
133            let scrape_result = self.scrape_instruction(&instruction).await;
134            #[allow(clippy::cast_possible_truncation)]
135            let duration_ms = start.elapsed().as_millis() as u64;
136            match scrape_result {
137                Ok(output) => {
138                    self.log_audit(
139                        "web_scrape",
140                        &instruction.url,
141                        AuditResult::Success,
142                        duration_ms,
143                    )
144                    .await;
145                    outputs.push(output);
146                }
147                Err(e) => {
148                    let audit_result = tool_error_to_audit_result(&e);
149                    self.log_audit("web_scrape", &instruction.url, audit_result, duration_ms)
150                        .await;
151                    return Err(e);
152                }
153            }
154        }
155
156        Ok(Some(ToolOutput {
157            tool_name: "web-scrape".to_owned(),
158            summary: outputs.join("\n\n"),
159            blocks_executed,
160            filter_stats: None,
161            diff: None,
162            streamed: false,
163            terminal_id: None,
164            locations: None,
165            raw_response: None,
166        }))
167    }
168
169    async fn execute_tool_call(&self, call: &ToolCall) -> Result<Option<ToolOutput>, ToolError> {
170        match call.tool_id.as_str() {
171            "web_scrape" => {
172                let instruction: ScrapeInstruction = deserialize_params(&call.params)?;
173                let start = Instant::now();
174                let result = self.scrape_instruction(&instruction).await;
175                #[allow(clippy::cast_possible_truncation)]
176                let duration_ms = start.elapsed().as_millis() as u64;
177                match result {
178                    Ok(output) => {
179                        self.log_audit(
180                            "web_scrape",
181                            &instruction.url,
182                            AuditResult::Success,
183                            duration_ms,
184                        )
185                        .await;
186                        Ok(Some(ToolOutput {
187                            tool_name: "web-scrape".to_owned(),
188                            summary: output,
189                            blocks_executed: 1,
190                            filter_stats: None,
191                            diff: None,
192                            streamed: false,
193                            terminal_id: None,
194                            locations: None,
195                            raw_response: None,
196                        }))
197                    }
198                    Err(e) => {
199                        let audit_result = tool_error_to_audit_result(&e);
200                        self.log_audit("web_scrape", &instruction.url, audit_result, duration_ms)
201                            .await;
202                        Err(e)
203                    }
204                }
205            }
206            "fetch" => {
207                let p: FetchParams = deserialize_params(&call.params)?;
208                let start = Instant::now();
209                let result = self.handle_fetch(&p).await;
210                #[allow(clippy::cast_possible_truncation)]
211                let duration_ms = start.elapsed().as_millis() as u64;
212                match result {
213                    Ok(output) => {
214                        self.log_audit("fetch", &p.url, AuditResult::Success, duration_ms)
215                            .await;
216                        Ok(Some(ToolOutput {
217                            tool_name: "fetch".to_owned(),
218                            summary: output,
219                            blocks_executed: 1,
220                            filter_stats: None,
221                            diff: None,
222                            streamed: false,
223                            terminal_id: None,
224                            locations: None,
225                            raw_response: None,
226                        }))
227                    }
228                    Err(e) => {
229                        let audit_result = tool_error_to_audit_result(&e);
230                        self.log_audit("fetch", &p.url, audit_result, duration_ms)
231                            .await;
232                        Err(e)
233                    }
234                }
235            }
236            _ => Ok(None),
237        }
238    }
239
240    fn is_tool_retryable(&self, tool_id: &str) -> bool {
241        matches!(tool_id, "web_scrape" | "fetch")
242    }
243}
244
245fn tool_error_to_audit_result(e: &ToolError) -> AuditResult {
246    match e {
247        ToolError::Blocked { command } => AuditResult::Blocked {
248            reason: command.clone(),
249        },
250        ToolError::Timeout { .. } => AuditResult::Timeout,
251        _ => AuditResult::Error {
252            message: e.to_string(),
253        },
254    }
255}
256
257impl WebScrapeExecutor {
258    async fn log_audit(&self, tool: &str, command: &str, result: AuditResult, duration_ms: u64) {
259        if let Some(ref logger) = self.audit_logger {
260            let entry = AuditEntry {
261                timestamp: chrono_now(),
262                tool: tool.into(),
263                command: command.into(),
264                result,
265                duration_ms,
266                error_category: None,
267            };
268            logger.log(&entry).await;
269        }
270    }
271
272    async fn handle_fetch(&self, params: &FetchParams) -> Result<String, ToolError> {
273        let parsed = validate_url(&params.url)?;
274        let (host, addrs) = resolve_and_validate(&parsed).await?;
275        self.fetch_html(&params.url, &host, &addrs).await
276    }
277
278    async fn scrape_instruction(
279        &self,
280        instruction: &ScrapeInstruction,
281    ) -> Result<String, ToolError> {
282        let parsed = validate_url(&instruction.url)?;
283        let (host, addrs) = resolve_and_validate(&parsed).await?;
284        let html = self.fetch_html(&instruction.url, &host, &addrs).await?;
285        let selector = instruction.select.clone();
286        let extract = ExtractMode::parse(&instruction.extract);
287        let limit = instruction.limit.unwrap_or(10);
288        tokio::task::spawn_blocking(move || parse_and_extract(&html, &selector, &extract, limit))
289            .await
290            .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())))?
291    }
292
293    /// Fetches the HTML at `url`, manually following up to 3 redirects.
294    ///
295    /// Each redirect target is validated with `validate_url` and `resolve_and_validate`
296    /// before following, preventing SSRF via redirect chains.
297    ///
298    /// # Errors
299    ///
300    /// Returns `ToolError::Blocked` if any redirect target resolves to a private IP.
301    /// Returns `ToolError::Execution` on HTTP errors, too-large bodies, or too many redirects.
302    async fn fetch_html(
303        &self,
304        url: &str,
305        host: &str,
306        addrs: &[SocketAddr],
307    ) -> Result<String, ToolError> {
308        const MAX_REDIRECTS: usize = 3;
309
310        let mut current_url = url.to_owned();
311        let mut current_host = host.to_owned();
312        let mut current_addrs = addrs.to_vec();
313
314        for hop in 0..=MAX_REDIRECTS {
315            // Build a per-hop client pinned to the current hop's validated addresses.
316            let client = self.build_client(&current_host, &current_addrs);
317            let resp = client
318                .get(&current_url)
319                .send()
320                .await
321                .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())))?;
322
323            let status = resp.status();
324
325            if status.is_redirection() {
326                if hop == MAX_REDIRECTS {
327                    return Err(ToolError::Execution(std::io::Error::other(
328                        "too many redirects",
329                    )));
330                }
331
332                let location = resp
333                    .headers()
334                    .get(reqwest::header::LOCATION)
335                    .and_then(|v| v.to_str().ok())
336                    .ok_or_else(|| {
337                        ToolError::Execution(std::io::Error::other("redirect with no Location"))
338                    })?;
339
340                // Resolve relative redirect URLs against the current URL.
341                let base = Url::parse(&current_url)
342                    .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())))?;
343                let next_url = base
344                    .join(location)
345                    .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())))?;
346
347                let validated = validate_url(next_url.as_str())?;
348                let (next_host, next_addrs) = resolve_and_validate(&validated).await?;
349
350                current_url = next_url.to_string();
351                current_host = next_host;
352                current_addrs = next_addrs;
353                continue;
354            }
355
356            if !status.is_success() {
357                return Err(ToolError::Http {
358                    status: status.as_u16(),
359                    message: status.canonical_reason().unwrap_or("unknown").to_owned(),
360                });
361            }
362
363            let bytes = resp
364                .bytes()
365                .await
366                .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())))?;
367
368            if bytes.len() > self.max_body_bytes {
369                return Err(ToolError::Execution(std::io::Error::other(format!(
370                    "response too large: {} bytes (max: {})",
371                    bytes.len(),
372                    self.max_body_bytes,
373                ))));
374            }
375
376            return String::from_utf8(bytes.to_vec())
377                .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())));
378        }
379
380        Err(ToolError::Execution(std::io::Error::other(
381            "too many redirects",
382        )))
383    }
384}
385
386fn extract_scrape_blocks(text: &str) -> Vec<&str> {
387    crate::executor::extract_fenced_blocks(text, "scrape")
388}
389
390fn validate_url(raw: &str) -> Result<Url, ToolError> {
391    let parsed = Url::parse(raw).map_err(|_| ToolError::Blocked {
392        command: format!("invalid URL: {raw}"),
393    })?;
394
395    if parsed.scheme() != "https" {
396        return Err(ToolError::Blocked {
397            command: format!("scheme not allowed: {}", parsed.scheme()),
398        });
399    }
400
401    if let Some(host) = parsed.host()
402        && is_private_host(&host)
403    {
404        return Err(ToolError::Blocked {
405            command: format!(
406                "private/local host blocked: {}",
407                parsed.host_str().unwrap_or("")
408            ),
409        });
410    }
411
412    Ok(parsed)
413}
414
415fn is_private_host(host: &url::Host<&str>) -> bool {
416    match host {
417        url::Host::Domain(d) => {
418            // Exact match or subdomain of localhost (e.g. foo.localhost)
419            // and .internal/.local TLDs used in cloud/k8s environments.
420            #[allow(clippy::case_sensitive_file_extension_comparisons)]
421            {
422                *d == "localhost"
423                    || d.ends_with(".localhost")
424                    || d.ends_with(".internal")
425                    || d.ends_with(".local")
426            }
427        }
428        url::Host::Ipv4(v4) => is_private_ip(IpAddr::V4(*v4)),
429        url::Host::Ipv6(v6) => is_private_ip(IpAddr::V6(*v6)),
430    }
431}
432
433/// Resolves DNS for the URL host, validates all resolved IPs against private ranges,
434/// and returns the hostname and validated socket addresses.
435///
436/// Returning the addresses allows the caller to pin the HTTP client to these exact
437/// addresses, eliminating TOCTOU between DNS validation and the actual connection.
438async fn resolve_and_validate(url: &Url) -> Result<(String, Vec<SocketAddr>), ToolError> {
439    let Some(host) = url.host_str() else {
440        return Ok((String::new(), vec![]));
441    };
442    let port = url.port_or_known_default().unwrap_or(443);
443    let addrs: Vec<SocketAddr> = tokio::net::lookup_host(format!("{host}:{port}"))
444        .await
445        .map_err(|e| ToolError::Blocked {
446            command: format!("DNS resolution failed: {e}"),
447        })?
448        .collect();
449    for addr in &addrs {
450        if is_private_ip(addr.ip()) {
451            return Err(ToolError::Blocked {
452                command: format!("SSRF protection: private IP {} for host {host}", addr.ip()),
453            });
454        }
455    }
456    Ok((host.to_owned(), addrs))
457}
458
459fn parse_and_extract(
460    html: &str,
461    selector: &str,
462    extract: &ExtractMode,
463    limit: usize,
464) -> Result<String, ToolError> {
465    let soup = scrape_core::Soup::parse(html);
466
467    let tags = soup.find_all(selector).map_err(|e| {
468        ToolError::Execution(std::io::Error::new(
469            std::io::ErrorKind::InvalidData,
470            format!("invalid selector: {e}"),
471        ))
472    })?;
473
474    let mut results = Vec::new();
475
476    for tag in tags.into_iter().take(limit) {
477        let value = match extract {
478            ExtractMode::Text => tag.text(),
479            ExtractMode::Html => tag.inner_html(),
480            ExtractMode::Attr(name) => tag.get(name).unwrap_or_default().to_owned(),
481        };
482        if !value.trim().is_empty() {
483            results.push(value.trim().to_owned());
484        }
485    }
486
487    if results.is_empty() {
488        Ok(format!("No results for selector: {selector}"))
489    } else {
490        Ok(results.join("\n"))
491    }
492}
493
494#[cfg(test)]
495mod tests {
496    use super::*;
497
498    // --- extract_scrape_blocks ---
499
500    #[test]
501    fn extract_single_block() {
502        let text =
503            "Here:\n```scrape\n{\"url\":\"https://example.com\",\"select\":\"h1\"}\n```\nDone.";
504        let blocks = extract_scrape_blocks(text);
505        assert_eq!(blocks.len(), 1);
506        assert!(blocks[0].contains("example.com"));
507    }
508
509    #[test]
510    fn extract_multiple_blocks() {
511        let text = "```scrape\n{\"url\":\"https://a.com\",\"select\":\"h1\"}\n```\ntext\n```scrape\n{\"url\":\"https://b.com\",\"select\":\"p\"}\n```";
512        let blocks = extract_scrape_blocks(text);
513        assert_eq!(blocks.len(), 2);
514    }
515
516    #[test]
517    fn no_blocks_returns_empty() {
518        let blocks = extract_scrape_blocks("plain text, no code blocks");
519        assert!(blocks.is_empty());
520    }
521
522    #[test]
523    fn unclosed_block_ignored() {
524        let blocks = extract_scrape_blocks("```scrape\n{\"url\":\"https://x.com\"}");
525        assert!(blocks.is_empty());
526    }
527
528    #[test]
529    fn non_scrape_block_ignored() {
530        let text =
531            "```bash\necho hi\n```\n```scrape\n{\"url\":\"https://x.com\",\"select\":\"h1\"}\n```";
532        let blocks = extract_scrape_blocks(text);
533        assert_eq!(blocks.len(), 1);
534        assert!(blocks[0].contains("x.com"));
535    }
536
537    #[test]
538    fn multiline_json_block() {
539        let text =
540            "```scrape\n{\n  \"url\": \"https://example.com\",\n  \"select\": \"h1\"\n}\n```";
541        let blocks = extract_scrape_blocks(text);
542        assert_eq!(blocks.len(), 1);
543        let instr: ScrapeInstruction = serde_json::from_str(blocks[0]).unwrap();
544        assert_eq!(instr.url, "https://example.com");
545    }
546
547    // --- ScrapeInstruction parsing ---
548
549    #[test]
550    fn parse_valid_instruction() {
551        let json = r#"{"url":"https://example.com","select":"h1","extract":"text","limit":5}"#;
552        let instr: ScrapeInstruction = serde_json::from_str(json).unwrap();
553        assert_eq!(instr.url, "https://example.com");
554        assert_eq!(instr.select, "h1");
555        assert_eq!(instr.extract, "text");
556        assert_eq!(instr.limit, Some(5));
557    }
558
559    #[test]
560    fn parse_minimal_instruction() {
561        let json = r#"{"url":"https://example.com","select":"p"}"#;
562        let instr: ScrapeInstruction = serde_json::from_str(json).unwrap();
563        assert_eq!(instr.extract, "text");
564        assert!(instr.limit.is_none());
565    }
566
567    #[test]
568    fn parse_attr_extract() {
569        let json = r#"{"url":"https://example.com","select":"a","extract":"attr:href"}"#;
570        let instr: ScrapeInstruction = serde_json::from_str(json).unwrap();
571        assert_eq!(instr.extract, "attr:href");
572    }
573
574    #[test]
575    fn parse_invalid_json_errors() {
576        let result = serde_json::from_str::<ScrapeInstruction>("not json");
577        assert!(result.is_err());
578    }
579
580    // --- ExtractMode ---
581
582    #[test]
583    fn extract_mode_text() {
584        assert!(matches!(ExtractMode::parse("text"), ExtractMode::Text));
585    }
586
587    #[test]
588    fn extract_mode_html() {
589        assert!(matches!(ExtractMode::parse("html"), ExtractMode::Html));
590    }
591
592    #[test]
593    fn extract_mode_attr() {
594        let mode = ExtractMode::parse("attr:href");
595        assert!(matches!(mode, ExtractMode::Attr(ref s) if s == "href"));
596    }
597
598    #[test]
599    fn extract_mode_unknown_defaults_to_text() {
600        assert!(matches!(ExtractMode::parse("unknown"), ExtractMode::Text));
601    }
602
603    // --- validate_url ---
604
605    #[test]
606    fn valid_https_url() {
607        assert!(validate_url("https://example.com").is_ok());
608    }
609
610    #[test]
611    fn http_rejected() {
612        let err = validate_url("http://example.com").unwrap_err();
613        assert!(matches!(err, ToolError::Blocked { .. }));
614    }
615
616    #[test]
617    fn ftp_rejected() {
618        let err = validate_url("ftp://files.example.com").unwrap_err();
619        assert!(matches!(err, ToolError::Blocked { .. }));
620    }
621
622    #[test]
623    fn file_rejected() {
624        let err = validate_url("file:///etc/passwd").unwrap_err();
625        assert!(matches!(err, ToolError::Blocked { .. }));
626    }
627
628    #[test]
629    fn invalid_url_rejected() {
630        let err = validate_url("not a url").unwrap_err();
631        assert!(matches!(err, ToolError::Blocked { .. }));
632    }
633
634    #[test]
635    fn localhost_blocked() {
636        let err = validate_url("https://localhost/path").unwrap_err();
637        assert!(matches!(err, ToolError::Blocked { .. }));
638    }
639
640    #[test]
641    fn loopback_ip_blocked() {
642        let err = validate_url("https://127.0.0.1/path").unwrap_err();
643        assert!(matches!(err, ToolError::Blocked { .. }));
644    }
645
646    #[test]
647    fn private_10_blocked() {
648        let err = validate_url("https://10.0.0.1/api").unwrap_err();
649        assert!(matches!(err, ToolError::Blocked { .. }));
650    }
651
652    #[test]
653    fn private_172_blocked() {
654        let err = validate_url("https://172.16.0.1/api").unwrap_err();
655        assert!(matches!(err, ToolError::Blocked { .. }));
656    }
657
658    #[test]
659    fn private_192_blocked() {
660        let err = validate_url("https://192.168.1.1/api").unwrap_err();
661        assert!(matches!(err, ToolError::Blocked { .. }));
662    }
663
664    #[test]
665    fn ipv6_loopback_blocked() {
666        let err = validate_url("https://[::1]/path").unwrap_err();
667        assert!(matches!(err, ToolError::Blocked { .. }));
668    }
669
670    #[test]
671    fn public_ip_allowed() {
672        assert!(validate_url("https://93.184.216.34/page").is_ok());
673    }
674
675    // --- parse_and_extract ---
676
677    #[test]
678    fn extract_text_from_html() {
679        let html = "<html><body><h1>Hello World</h1><p>Content</p></body></html>";
680        let result = parse_and_extract(html, "h1", &ExtractMode::Text, 10).unwrap();
681        assert_eq!(result, "Hello World");
682    }
683
684    #[test]
685    fn extract_multiple_elements() {
686        let html = "<ul><li>A</li><li>B</li><li>C</li></ul>";
687        let result = parse_and_extract(html, "li", &ExtractMode::Text, 10).unwrap();
688        assert_eq!(result, "A\nB\nC");
689    }
690
691    #[test]
692    fn extract_with_limit() {
693        let html = "<ul><li>A</li><li>B</li><li>C</li></ul>";
694        let result = parse_and_extract(html, "li", &ExtractMode::Text, 2).unwrap();
695        assert_eq!(result, "A\nB");
696    }
697
698    #[test]
699    fn extract_attr_href() {
700        let html = r#"<a href="https://example.com">Link</a>"#;
701        let result =
702            parse_and_extract(html, "a", &ExtractMode::Attr("href".to_owned()), 10).unwrap();
703        assert_eq!(result, "https://example.com");
704    }
705
706    #[test]
707    fn extract_inner_html() {
708        let html = "<div><span>inner</span></div>";
709        let result = parse_and_extract(html, "div", &ExtractMode::Html, 10).unwrap();
710        assert!(result.contains("<span>inner</span>"));
711    }
712
713    #[test]
714    fn no_matches_returns_message() {
715        let html = "<html><body><p>text</p></body></html>";
716        let result = parse_and_extract(html, "h1", &ExtractMode::Text, 10).unwrap();
717        assert!(result.starts_with("No results for selector:"));
718    }
719
720    #[test]
721    fn empty_text_skipped() {
722        let html = "<ul><li>  </li><li>A</li></ul>";
723        let result = parse_and_extract(html, "li", &ExtractMode::Text, 10).unwrap();
724        assert_eq!(result, "A");
725    }
726
727    #[test]
728    fn invalid_selector_errors() {
729        let html = "<html><body></body></html>";
730        let result = parse_and_extract(html, "[[[invalid", &ExtractMode::Text, 10);
731        assert!(result.is_err());
732    }
733
734    #[test]
735    fn empty_html_returns_no_results() {
736        let result = parse_and_extract("", "h1", &ExtractMode::Text, 10).unwrap();
737        assert!(result.starts_with("No results for selector:"));
738    }
739
740    #[test]
741    fn nested_selector() {
742        let html = "<div><span>inner</span></div><span>outer</span>";
743        let result = parse_and_extract(html, "div > span", &ExtractMode::Text, 10).unwrap();
744        assert_eq!(result, "inner");
745    }
746
747    #[test]
748    fn attr_missing_returns_empty() {
749        let html = r"<a>No href</a>";
750        let result =
751            parse_and_extract(html, "a", &ExtractMode::Attr("href".to_owned()), 10).unwrap();
752        assert!(result.starts_with("No results for selector:"));
753    }
754
755    #[test]
756    fn extract_html_mode() {
757        let html = "<div><b>bold</b> text</div>";
758        let result = parse_and_extract(html, "div", &ExtractMode::Html, 10).unwrap();
759        assert!(result.contains("<b>bold</b>"));
760    }
761
762    #[test]
763    fn limit_zero_returns_no_results() {
764        let html = "<ul><li>A</li><li>B</li></ul>";
765        let result = parse_and_extract(html, "li", &ExtractMode::Text, 0).unwrap();
766        assert!(result.starts_with("No results for selector:"));
767    }
768
769    // --- validate_url edge cases ---
770
771    #[test]
772    fn url_with_port_allowed() {
773        assert!(validate_url("https://example.com:8443/path").is_ok());
774    }
775
776    #[test]
777    fn link_local_ip_blocked() {
778        let err = validate_url("https://169.254.1.1/path").unwrap_err();
779        assert!(matches!(err, ToolError::Blocked { .. }));
780    }
781
782    #[test]
783    fn url_no_scheme_rejected() {
784        let err = validate_url("example.com/path").unwrap_err();
785        assert!(matches!(err, ToolError::Blocked { .. }));
786    }
787
788    #[test]
789    fn unspecified_ipv4_blocked() {
790        let err = validate_url("https://0.0.0.0/path").unwrap_err();
791        assert!(matches!(err, ToolError::Blocked { .. }));
792    }
793
794    #[test]
795    fn broadcast_ipv4_blocked() {
796        let err = validate_url("https://255.255.255.255/path").unwrap_err();
797        assert!(matches!(err, ToolError::Blocked { .. }));
798    }
799
800    #[test]
801    fn ipv6_link_local_blocked() {
802        let err = validate_url("https://[fe80::1]/path").unwrap_err();
803        assert!(matches!(err, ToolError::Blocked { .. }));
804    }
805
806    #[test]
807    fn ipv6_unique_local_blocked() {
808        let err = validate_url("https://[fd12::1]/path").unwrap_err();
809        assert!(matches!(err, ToolError::Blocked { .. }));
810    }
811
812    #[test]
813    fn ipv4_mapped_ipv6_loopback_blocked() {
814        let err = validate_url("https://[::ffff:127.0.0.1]/path").unwrap_err();
815        assert!(matches!(err, ToolError::Blocked { .. }));
816    }
817
818    #[test]
819    fn ipv4_mapped_ipv6_private_blocked() {
820        let err = validate_url("https://[::ffff:10.0.0.1]/path").unwrap_err();
821        assert!(matches!(err, ToolError::Blocked { .. }));
822    }
823
824    // --- WebScrapeExecutor (no-network) ---
825
826    #[tokio::test]
827    async fn executor_no_blocks_returns_none() {
828        let config = ScrapeConfig::default();
829        let executor = WebScrapeExecutor::new(&config);
830        let result = executor.execute("plain text").await;
831        assert!(result.unwrap().is_none());
832    }
833
834    #[tokio::test]
835    async fn executor_invalid_json_errors() {
836        let config = ScrapeConfig::default();
837        let executor = WebScrapeExecutor::new(&config);
838        let response = "```scrape\nnot json\n```";
839        let result = executor.execute(response).await;
840        assert!(matches!(result, Err(ToolError::Execution(_))));
841    }
842
843    #[tokio::test]
844    async fn executor_blocked_url_errors() {
845        let config = ScrapeConfig::default();
846        let executor = WebScrapeExecutor::new(&config);
847        let response = "```scrape\n{\"url\":\"http://example.com\",\"select\":\"h1\"}\n```";
848        let result = executor.execute(response).await;
849        assert!(matches!(result, Err(ToolError::Blocked { .. })));
850    }
851
852    #[tokio::test]
853    async fn executor_private_ip_blocked() {
854        let config = ScrapeConfig::default();
855        let executor = WebScrapeExecutor::new(&config);
856        let response = "```scrape\n{\"url\":\"https://192.168.1.1/api\",\"select\":\"h1\"}\n```";
857        let result = executor.execute(response).await;
858        assert!(matches!(result, Err(ToolError::Blocked { .. })));
859    }
860
861    #[tokio::test]
862    async fn executor_unreachable_host_returns_error() {
863        let config = ScrapeConfig {
864            timeout: 1,
865            max_body_bytes: 1_048_576,
866        };
867        let executor = WebScrapeExecutor::new(&config);
868        let response = "```scrape\n{\"url\":\"https://192.0.2.1:1/page\",\"select\":\"h1\"}\n```";
869        let result = executor.execute(response).await;
870        assert!(matches!(result, Err(ToolError::Execution(_))));
871    }
872
873    #[tokio::test]
874    async fn executor_localhost_url_blocked() {
875        let config = ScrapeConfig::default();
876        let executor = WebScrapeExecutor::new(&config);
877        let response = "```scrape\n{\"url\":\"https://localhost:9999/api\",\"select\":\"h1\"}\n```";
878        let result = executor.execute(response).await;
879        assert!(matches!(result, Err(ToolError::Blocked { .. })));
880    }
881
882    #[tokio::test]
883    async fn executor_empty_text_returns_none() {
884        let config = ScrapeConfig::default();
885        let executor = WebScrapeExecutor::new(&config);
886        let result = executor.execute("").await;
887        assert!(result.unwrap().is_none());
888    }
889
890    #[tokio::test]
891    async fn executor_multiple_blocks_first_blocked() {
892        let config = ScrapeConfig::default();
893        let executor = WebScrapeExecutor::new(&config);
894        let response = "```scrape\n{\"url\":\"http://evil.com\",\"select\":\"h1\"}\n```\n\
895             ```scrape\n{\"url\":\"https://ok.com\",\"select\":\"h1\"}\n```";
896        let result = executor.execute(response).await;
897        assert!(result.is_err());
898    }
899
900    #[test]
901    fn validate_url_empty_string() {
902        let err = validate_url("").unwrap_err();
903        assert!(matches!(err, ToolError::Blocked { .. }));
904    }
905
906    #[test]
907    fn validate_url_javascript_scheme_blocked() {
908        let err = validate_url("javascript:alert(1)").unwrap_err();
909        assert!(matches!(err, ToolError::Blocked { .. }));
910    }
911
912    #[test]
913    fn validate_url_data_scheme_blocked() {
914        let err = validate_url("data:text/html,<h1>hi</h1>").unwrap_err();
915        assert!(matches!(err, ToolError::Blocked { .. }));
916    }
917
918    #[test]
919    fn is_private_host_public_domain_is_false() {
920        let host: url::Host<&str> = url::Host::Domain("example.com");
921        assert!(!is_private_host(&host));
922    }
923
924    #[test]
925    fn is_private_host_localhost_is_true() {
926        let host: url::Host<&str> = url::Host::Domain("localhost");
927        assert!(is_private_host(&host));
928    }
929
930    #[test]
931    fn is_private_host_ipv6_unspecified_is_true() {
932        let host = url::Host::Ipv6(std::net::Ipv6Addr::UNSPECIFIED);
933        assert!(is_private_host(&host));
934    }
935
936    #[test]
937    fn is_private_host_public_ipv6_is_false() {
938        let host = url::Host::Ipv6("2001:db8::1".parse().unwrap());
939        assert!(!is_private_host(&host));
940    }
941
942    // --- fetch_html redirect logic: wiremock HTTP server tests ---
943    //
944    // These tests use a local wiremock server to exercise the redirect-following logic
945    // in `fetch_html` without requiring an external HTTPS connection. The server binds to
946    // 127.0.0.1, and tests call `fetch_html` directly (bypassing `validate_url`) to avoid
947    // the SSRF guard that would otherwise block loopback connections.
948
949    /// Helper: returns executor + (`server_url`, `server_addr`) from a running wiremock mock server.
950    /// The server address is passed to `fetch_html` via `resolve_to_addrs` so the client
951    /// connects to the mock instead of doing a real DNS lookup.
952    async fn mock_server_executor() -> (WebScrapeExecutor, wiremock::MockServer) {
953        let server = wiremock::MockServer::start().await;
954        let executor = WebScrapeExecutor {
955            timeout: Duration::from_secs(5),
956            max_body_bytes: 1_048_576,
957            audit_logger: None,
958        };
959        (executor, server)
960    }
961
962    /// Parses the mock server's URI into (`host_str`, `socket_addr`) for use with `build_client`.
963    fn server_host_and_addr(server: &wiremock::MockServer) -> (String, Vec<std::net::SocketAddr>) {
964        let uri = server.uri();
965        let url = Url::parse(&uri).unwrap();
966        let host = url.host_str().unwrap_or("127.0.0.1").to_owned();
967        let port = url.port().unwrap_or(80);
968        let addr: std::net::SocketAddr = format!("{host}:{port}").parse().unwrap();
969        (host, vec![addr])
970    }
971
972    /// Test-only redirect follower that mimics `fetch_html`'s loop but skips `validate_url` /
973    /// `resolve_and_validate`. This lets us exercise the redirect-counting and
974    /// missing-Location logic against a plain HTTP wiremock server.
975    async fn follow_redirects_raw(
976        executor: &WebScrapeExecutor,
977        start_url: &str,
978        host: &str,
979        addrs: &[std::net::SocketAddr],
980    ) -> Result<String, ToolError> {
981        const MAX_REDIRECTS: usize = 3;
982        let mut current_url = start_url.to_owned();
983        let mut current_host = host.to_owned();
984        let mut current_addrs = addrs.to_vec();
985
986        for hop in 0..=MAX_REDIRECTS {
987            let client = executor.build_client(&current_host, &current_addrs);
988            let resp = client
989                .get(&current_url)
990                .send()
991                .await
992                .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())))?;
993
994            let status = resp.status();
995
996            if status.is_redirection() {
997                if hop == MAX_REDIRECTS {
998                    return Err(ToolError::Execution(std::io::Error::other(
999                        "too many redirects",
1000                    )));
1001                }
1002
1003                let location = resp
1004                    .headers()
1005                    .get(reqwest::header::LOCATION)
1006                    .and_then(|v| v.to_str().ok())
1007                    .ok_or_else(|| {
1008                        ToolError::Execution(std::io::Error::other("redirect with no Location"))
1009                    })?;
1010
1011                let base = Url::parse(&current_url)
1012                    .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())))?;
1013                let next_url = base
1014                    .join(location)
1015                    .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())))?;
1016
1017                // Re-use same host/addrs (mock server is always the same endpoint).
1018                current_url = next_url.to_string();
1019                // Preserve host/addrs as-is since the mock server doesn't change.
1020                let _ = &mut current_host;
1021                let _ = &mut current_addrs;
1022                continue;
1023            }
1024
1025            if !status.is_success() {
1026                return Err(ToolError::Execution(std::io::Error::other(format!(
1027                    "HTTP {status}",
1028                ))));
1029            }
1030
1031            let bytes = resp
1032                .bytes()
1033                .await
1034                .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())))?;
1035
1036            if bytes.len() > executor.max_body_bytes {
1037                return Err(ToolError::Execution(std::io::Error::other(format!(
1038                    "response too large: {} bytes (max: {})",
1039                    bytes.len(),
1040                    executor.max_body_bytes,
1041                ))));
1042            }
1043
1044            return String::from_utf8(bytes.to_vec())
1045                .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())));
1046        }
1047
1048        Err(ToolError::Execution(std::io::Error::other(
1049            "too many redirects",
1050        )))
1051    }
1052
1053    #[tokio::test]
1054    async fn fetch_html_success_returns_body() {
1055        use wiremock::matchers::{method, path};
1056        use wiremock::{Mock, ResponseTemplate};
1057
1058        let (executor, server) = mock_server_executor().await;
1059        Mock::given(method("GET"))
1060            .and(path("/page"))
1061            .respond_with(ResponseTemplate::new(200).set_body_string("<h1>OK</h1>"))
1062            .mount(&server)
1063            .await;
1064
1065        let (host, addrs) = server_host_and_addr(&server);
1066        let url = format!("{}/page", server.uri());
1067        let result = executor.fetch_html(&url, &host, &addrs).await;
1068        assert!(result.is_ok(), "expected Ok, got: {result:?}");
1069        assert_eq!(result.unwrap(), "<h1>OK</h1>");
1070    }
1071
1072    #[tokio::test]
1073    async fn fetch_html_non_2xx_returns_error() {
1074        use wiremock::matchers::{method, path};
1075        use wiremock::{Mock, ResponseTemplate};
1076
1077        let (executor, server) = mock_server_executor().await;
1078        Mock::given(method("GET"))
1079            .and(path("/forbidden"))
1080            .respond_with(ResponseTemplate::new(403))
1081            .mount(&server)
1082            .await;
1083
1084        let (host, addrs) = server_host_and_addr(&server);
1085        let url = format!("{}/forbidden", server.uri());
1086        let result = executor.fetch_html(&url, &host, &addrs).await;
1087        assert!(result.is_err());
1088        let msg = result.unwrap_err().to_string();
1089        assert!(msg.contains("403"), "expected 403 in error: {msg}");
1090    }
1091
1092    #[tokio::test]
1093    async fn fetch_html_404_returns_error() {
1094        use wiremock::matchers::{method, path};
1095        use wiremock::{Mock, ResponseTemplate};
1096
1097        let (executor, server) = mock_server_executor().await;
1098        Mock::given(method("GET"))
1099            .and(path("/missing"))
1100            .respond_with(ResponseTemplate::new(404))
1101            .mount(&server)
1102            .await;
1103
1104        let (host, addrs) = server_host_and_addr(&server);
1105        let url = format!("{}/missing", server.uri());
1106        let result = executor.fetch_html(&url, &host, &addrs).await;
1107        assert!(result.is_err());
1108        let msg = result.unwrap_err().to_string();
1109        assert!(msg.contains("404"), "expected 404 in error: {msg}");
1110    }
1111
1112    #[tokio::test]
1113    async fn fetch_html_redirect_no_location_returns_error() {
1114        use wiremock::matchers::{method, path};
1115        use wiremock::{Mock, ResponseTemplate};
1116
1117        let (executor, server) = mock_server_executor().await;
1118        // 302 with no Location header
1119        Mock::given(method("GET"))
1120            .and(path("/redirect-no-loc"))
1121            .respond_with(ResponseTemplate::new(302))
1122            .mount(&server)
1123            .await;
1124
1125        let (host, addrs) = server_host_and_addr(&server);
1126        let url = format!("{}/redirect-no-loc", server.uri());
1127        let result = executor.fetch_html(&url, &host, &addrs).await;
1128        assert!(result.is_err());
1129        let msg = result.unwrap_err().to_string();
1130        assert!(
1131            msg.contains("Location") || msg.contains("location"),
1132            "expected Location-related error: {msg}"
1133        );
1134    }
1135
1136    #[tokio::test]
1137    async fn fetch_html_single_redirect_followed() {
1138        use wiremock::matchers::{method, path};
1139        use wiremock::{Mock, ResponseTemplate};
1140
1141        let (executor, server) = mock_server_executor().await;
1142        let final_url = format!("{}/final", server.uri());
1143
1144        Mock::given(method("GET"))
1145            .and(path("/start"))
1146            .respond_with(ResponseTemplate::new(302).insert_header("location", final_url.as_str()))
1147            .mount(&server)
1148            .await;
1149
1150        Mock::given(method("GET"))
1151            .and(path("/final"))
1152            .respond_with(ResponseTemplate::new(200).set_body_string("<p>final</p>"))
1153            .mount(&server)
1154            .await;
1155
1156        let (host, addrs) = server_host_and_addr(&server);
1157        let url = format!("{}/start", server.uri());
1158        let result = follow_redirects_raw(&executor, &url, &host, &addrs).await;
1159        assert!(result.is_ok(), "single redirect should succeed: {result:?}");
1160        assert_eq!(result.unwrap(), "<p>final</p>");
1161    }
1162
1163    #[tokio::test]
1164    async fn fetch_html_three_redirects_allowed() {
1165        use wiremock::matchers::{method, path};
1166        use wiremock::{Mock, ResponseTemplate};
1167
1168        let (executor, server) = mock_server_executor().await;
1169        let hop2 = format!("{}/hop2", server.uri());
1170        let hop3 = format!("{}/hop3", server.uri());
1171        let final_dest = format!("{}/done", server.uri());
1172
1173        Mock::given(method("GET"))
1174            .and(path("/hop1"))
1175            .respond_with(ResponseTemplate::new(301).insert_header("location", hop2.as_str()))
1176            .mount(&server)
1177            .await;
1178        Mock::given(method("GET"))
1179            .and(path("/hop2"))
1180            .respond_with(ResponseTemplate::new(301).insert_header("location", hop3.as_str()))
1181            .mount(&server)
1182            .await;
1183        Mock::given(method("GET"))
1184            .and(path("/hop3"))
1185            .respond_with(ResponseTemplate::new(301).insert_header("location", final_dest.as_str()))
1186            .mount(&server)
1187            .await;
1188        Mock::given(method("GET"))
1189            .and(path("/done"))
1190            .respond_with(ResponseTemplate::new(200).set_body_string("<p>done</p>"))
1191            .mount(&server)
1192            .await;
1193
1194        let (host, addrs) = server_host_and_addr(&server);
1195        let url = format!("{}/hop1", server.uri());
1196        let result = follow_redirects_raw(&executor, &url, &host, &addrs).await;
1197        assert!(result.is_ok(), "3 redirects should succeed: {result:?}");
1198        assert_eq!(result.unwrap(), "<p>done</p>");
1199    }
1200
1201    #[tokio::test]
1202    async fn fetch_html_four_redirects_rejected() {
1203        use wiremock::matchers::{method, path};
1204        use wiremock::{Mock, ResponseTemplate};
1205
1206        let (executor, server) = mock_server_executor().await;
1207        let hop2 = format!("{}/r2", server.uri());
1208        let hop3 = format!("{}/r3", server.uri());
1209        let hop4 = format!("{}/r4", server.uri());
1210        let hop5 = format!("{}/r5", server.uri());
1211
1212        for (from, to) in [
1213            ("/r1", &hop2),
1214            ("/r2", &hop3),
1215            ("/r3", &hop4),
1216            ("/r4", &hop5),
1217        ] {
1218            Mock::given(method("GET"))
1219                .and(path(from))
1220                .respond_with(ResponseTemplate::new(301).insert_header("location", to.as_str()))
1221                .mount(&server)
1222                .await;
1223        }
1224
1225        let (host, addrs) = server_host_and_addr(&server);
1226        let url = format!("{}/r1", server.uri());
1227        let result = follow_redirects_raw(&executor, &url, &host, &addrs).await;
1228        assert!(result.is_err(), "4 redirects should be rejected");
1229        let msg = result.unwrap_err().to_string();
1230        assert!(
1231            msg.contains("redirect"),
1232            "expected redirect-related error: {msg}"
1233        );
1234    }
1235
1236    #[tokio::test]
1237    async fn fetch_html_body_too_large_returns_error() {
1238        use wiremock::matchers::{method, path};
1239        use wiremock::{Mock, ResponseTemplate};
1240
1241        let small_limit_executor = WebScrapeExecutor {
1242            timeout: Duration::from_secs(5),
1243            max_body_bytes: 10,
1244            audit_logger: None,
1245        };
1246        let server = wiremock::MockServer::start().await;
1247        Mock::given(method("GET"))
1248            .and(path("/big"))
1249            .respond_with(
1250                ResponseTemplate::new(200)
1251                    .set_body_string("this body is definitely longer than ten bytes"),
1252            )
1253            .mount(&server)
1254            .await;
1255
1256        let (host, addrs) = server_host_and_addr(&server);
1257        let url = format!("{}/big", server.uri());
1258        let result = small_limit_executor.fetch_html(&url, &host, &addrs).await;
1259        assert!(result.is_err());
1260        let msg = result.unwrap_err().to_string();
1261        assert!(msg.contains("too large"), "expected too-large error: {msg}");
1262    }
1263
1264    #[test]
1265    fn extract_scrape_blocks_empty_block_content() {
1266        let text = "```scrape\n\n```";
1267        let blocks = extract_scrape_blocks(text);
1268        assert_eq!(blocks.len(), 1);
1269        assert!(blocks[0].is_empty());
1270    }
1271
1272    #[test]
1273    fn extract_scrape_blocks_whitespace_only() {
1274        let text = "```scrape\n   \n```";
1275        let blocks = extract_scrape_blocks(text);
1276        assert_eq!(blocks.len(), 1);
1277    }
1278
1279    #[test]
1280    fn parse_and_extract_multiple_selectors() {
1281        let html = "<div><h1>Title</h1><p>Para</p></div>";
1282        let result = parse_and_extract(html, "h1, p", &ExtractMode::Text, 10).unwrap();
1283        assert!(result.contains("Title"));
1284        assert!(result.contains("Para"));
1285    }
1286
1287    #[test]
1288    fn webscrape_executor_new_with_custom_config() {
1289        let config = ScrapeConfig {
1290            timeout: 60,
1291            max_body_bytes: 512,
1292        };
1293        let executor = WebScrapeExecutor::new(&config);
1294        assert_eq!(executor.max_body_bytes, 512);
1295    }
1296
1297    #[test]
1298    fn webscrape_executor_debug() {
1299        let config = ScrapeConfig::default();
1300        let executor = WebScrapeExecutor::new(&config);
1301        let dbg = format!("{executor:?}");
1302        assert!(dbg.contains("WebScrapeExecutor"));
1303    }
1304
1305    #[test]
1306    fn extract_mode_attr_empty_name() {
1307        let mode = ExtractMode::parse("attr:");
1308        assert!(matches!(mode, ExtractMode::Attr(ref s) if s.is_empty()));
1309    }
1310
1311    #[test]
1312    fn default_extract_returns_text() {
1313        assert_eq!(default_extract(), "text");
1314    }
1315
1316    #[test]
1317    fn scrape_instruction_debug() {
1318        let json = r#"{"url":"https://example.com","select":"h1"}"#;
1319        let instr: ScrapeInstruction = serde_json::from_str(json).unwrap();
1320        let dbg = format!("{instr:?}");
1321        assert!(dbg.contains("ScrapeInstruction"));
1322    }
1323
1324    #[test]
1325    fn extract_mode_debug() {
1326        let mode = ExtractMode::Text;
1327        let dbg = format!("{mode:?}");
1328        assert!(dbg.contains("Text"));
1329    }
1330
1331    // --- fetch_html redirect logic: constant and validation unit tests ---
1332
1333    /// `MAX_REDIRECTS` is 3; the 4th redirect attempt must be rejected.
1334    /// Verify the boundary is correct by inspecting the constant value.
1335    #[test]
1336    fn max_redirects_constant_is_three() {
1337        // fetch_html uses `for hop in 0..=MAX_REDIRECTS` and returns error when hop == MAX_REDIRECTS
1338        // while still in a redirect. That means hops 0,1,2 can redirect; hop 3 triggers the error.
1339        // This test documents the expected limit.
1340        const MAX_REDIRECTS: usize = 3;
1341        assert_eq!(MAX_REDIRECTS, 3, "fetch_html allows exactly 3 redirects");
1342    }
1343
1344    /// Verifies that a Location-less redirect would produce an error string containing the
1345    /// expected message, matching the error path in `fetch_html`.
1346    #[test]
1347    fn redirect_no_location_error_message() {
1348        let err = std::io::Error::other("redirect with no Location");
1349        assert!(err.to_string().contains("redirect with no Location"));
1350    }
1351
1352    /// Verifies that a too-many-redirects condition produces the expected error string.
1353    #[test]
1354    fn too_many_redirects_error_message() {
1355        let err = std::io::Error::other("too many redirects");
1356        assert!(err.to_string().contains("too many redirects"));
1357    }
1358
1359    /// Verifies that a non-2xx HTTP status produces an error message with the status code.
1360    #[test]
1361    fn non_2xx_status_error_format() {
1362        let status = reqwest::StatusCode::FORBIDDEN;
1363        let msg = format!("HTTP {status}");
1364        assert!(msg.contains("403"));
1365    }
1366
1367    /// Verifies that a 404 response status code formats into the expected error message.
1368    #[test]
1369    fn not_found_status_error_format() {
1370        let status = reqwest::StatusCode::NOT_FOUND;
1371        let msg = format!("HTTP {status}");
1372        assert!(msg.contains("404"));
1373    }
1374
1375    /// Verifies relative redirect resolution for same-host paths (simulates Location: /other).
1376    #[test]
1377    fn relative_redirect_same_host_path() {
1378        let base = Url::parse("https://example.com/current").unwrap();
1379        let resolved = base.join("/other").unwrap();
1380        assert_eq!(resolved.as_str(), "https://example.com/other");
1381    }
1382
1383    /// Verifies relative redirect resolution preserves scheme and host.
1384    #[test]
1385    fn relative_redirect_relative_path() {
1386        let base = Url::parse("https://example.com/a/b").unwrap();
1387        let resolved = base.join("c").unwrap();
1388        assert_eq!(resolved.as_str(), "https://example.com/a/c");
1389    }
1390
1391    /// Verifies that an absolute redirect URL overrides base URL completely.
1392    #[test]
1393    fn absolute_redirect_overrides_base() {
1394        let base = Url::parse("https://example.com/page").unwrap();
1395        let resolved = base.join("https://other.com/target").unwrap();
1396        assert_eq!(resolved.as_str(), "https://other.com/target");
1397    }
1398
1399    /// Verifies that a redirect Location of http:// (downgrade) is rejected.
1400    #[test]
1401    fn redirect_http_downgrade_rejected() {
1402        let location = "http://example.com/page";
1403        let base = Url::parse("https://example.com/start").unwrap();
1404        let next = base.join(location).unwrap();
1405        let err = validate_url(next.as_str()).unwrap_err();
1406        assert!(matches!(err, ToolError::Blocked { .. }));
1407    }
1408
1409    /// Verifies that a redirect to a private IP literal is blocked.
1410    #[test]
1411    fn redirect_location_private_ip_blocked() {
1412        let location = "https://192.168.100.1/admin";
1413        let base = Url::parse("https://example.com/start").unwrap();
1414        let next = base.join(location).unwrap();
1415        let err = validate_url(next.as_str()).unwrap_err();
1416        assert!(matches!(err, ToolError::Blocked { .. }));
1417        let ToolError::Blocked { command: cmd } = err else {
1418            panic!("expected Blocked");
1419        };
1420        assert!(
1421            cmd.contains("private") || cmd.contains("scheme"),
1422            "error message should describe the block reason: {cmd}"
1423        );
1424    }
1425
1426    /// Verifies that a redirect to a .internal domain is blocked.
1427    #[test]
1428    fn redirect_location_internal_domain_blocked() {
1429        let location = "https://metadata.internal/latest/meta-data/";
1430        let base = Url::parse("https://example.com/start").unwrap();
1431        let next = base.join(location).unwrap();
1432        let err = validate_url(next.as_str()).unwrap_err();
1433        assert!(matches!(err, ToolError::Blocked { .. }));
1434    }
1435
1436    /// Verifies that a chain of 3 valid public redirects passes `validate_url` at every hop.
1437    #[test]
1438    fn redirect_chain_three_hops_all_public() {
1439        let hops = [
1440            "https://redirect1.example.com/hop1",
1441            "https://redirect2.example.com/hop2",
1442            "https://destination.example.com/final",
1443        ];
1444        for hop in hops {
1445            assert!(validate_url(hop).is_ok(), "expected ok for {hop}");
1446        }
1447    }
1448
1449    // --- SSRF redirect chain defense ---
1450
1451    /// Verifies that a redirect Location pointing to a private IP is rejected by `validate_url`
1452    /// before any connection attempt — simulating the validation step inside `fetch_html`.
1453    #[test]
1454    fn redirect_to_private_ip_rejected_by_validate_url() {
1455        // These would appear as Location headers in a redirect response.
1456        let private_targets = [
1457            "https://127.0.0.1/secret",
1458            "https://10.0.0.1/internal",
1459            "https://192.168.1.1/admin",
1460            "https://172.16.0.1/data",
1461            "https://[::1]/path",
1462            "https://[fe80::1]/path",
1463            "https://localhost/path",
1464            "https://service.internal/api",
1465        ];
1466        for target in private_targets {
1467            let result = validate_url(target);
1468            assert!(result.is_err(), "expected error for {target}");
1469            assert!(
1470                matches!(result.unwrap_err(), ToolError::Blocked { .. }),
1471                "expected Blocked for {target}"
1472            );
1473        }
1474    }
1475
1476    /// Verifies that relative redirect URLs are resolved correctly before validation.
1477    #[test]
1478    fn redirect_relative_url_resolves_correctly() {
1479        let base = Url::parse("https://example.com/page").unwrap();
1480        let relative = "/other";
1481        let resolved = base.join(relative).unwrap();
1482        assert_eq!(resolved.as_str(), "https://example.com/other");
1483    }
1484
1485    /// Verifies that a protocol-relative redirect to http:// is rejected (scheme check).
1486    #[test]
1487    fn redirect_to_http_rejected() {
1488        let err = validate_url("http://example.com/page").unwrap_err();
1489        assert!(matches!(err, ToolError::Blocked { .. }));
1490    }
1491
1492    #[test]
1493    fn ipv4_mapped_ipv6_link_local_blocked() {
1494        let err = validate_url("https://[::ffff:169.254.0.1]/path").unwrap_err();
1495        assert!(matches!(err, ToolError::Blocked { .. }));
1496    }
1497
1498    #[test]
1499    fn ipv4_mapped_ipv6_public_allowed() {
1500        assert!(validate_url("https://[::ffff:93.184.216.34]/path").is_ok());
1501    }
1502
1503    // --- fetch tool ---
1504
1505    #[tokio::test]
1506    async fn fetch_http_scheme_blocked() {
1507        let config = ScrapeConfig::default();
1508        let executor = WebScrapeExecutor::new(&config);
1509        let call = crate::executor::ToolCall {
1510            tool_id: "fetch".to_owned(),
1511            params: {
1512                let mut m = serde_json::Map::new();
1513                m.insert("url".to_owned(), serde_json::json!("http://example.com"));
1514                m
1515            },
1516        };
1517        let result = executor.execute_tool_call(&call).await;
1518        assert!(matches!(result, Err(ToolError::Blocked { .. })));
1519    }
1520
1521    #[tokio::test]
1522    async fn fetch_private_ip_blocked() {
1523        let config = ScrapeConfig::default();
1524        let executor = WebScrapeExecutor::new(&config);
1525        let call = crate::executor::ToolCall {
1526            tool_id: "fetch".to_owned(),
1527            params: {
1528                let mut m = serde_json::Map::new();
1529                m.insert(
1530                    "url".to_owned(),
1531                    serde_json::json!("https://192.168.1.1/secret"),
1532                );
1533                m
1534            },
1535        };
1536        let result = executor.execute_tool_call(&call).await;
1537        assert!(matches!(result, Err(ToolError::Blocked { .. })));
1538    }
1539
1540    #[tokio::test]
1541    async fn fetch_localhost_blocked() {
1542        let config = ScrapeConfig::default();
1543        let executor = WebScrapeExecutor::new(&config);
1544        let call = crate::executor::ToolCall {
1545            tool_id: "fetch".to_owned(),
1546            params: {
1547                let mut m = serde_json::Map::new();
1548                m.insert(
1549                    "url".to_owned(),
1550                    serde_json::json!("https://localhost/page"),
1551                );
1552                m
1553            },
1554        };
1555        let result = executor.execute_tool_call(&call).await;
1556        assert!(matches!(result, Err(ToolError::Blocked { .. })));
1557    }
1558
1559    #[tokio::test]
1560    async fn fetch_unknown_tool_returns_none() {
1561        let config = ScrapeConfig::default();
1562        let executor = WebScrapeExecutor::new(&config);
1563        let call = crate::executor::ToolCall {
1564            tool_id: "unknown_tool".to_owned(),
1565            params: serde_json::Map::new(),
1566        };
1567        let result = executor.execute_tool_call(&call).await;
1568        assert!(result.unwrap().is_none());
1569    }
1570
1571    #[tokio::test]
1572    async fn fetch_returns_body_via_mock() {
1573        use wiremock::matchers::{method, path};
1574        use wiremock::{Mock, ResponseTemplate};
1575
1576        let (executor, server) = mock_server_executor().await;
1577        Mock::given(method("GET"))
1578            .and(path("/content"))
1579            .respond_with(ResponseTemplate::new(200).set_body_string("plain text content"))
1580            .mount(&server)
1581            .await;
1582
1583        let (host, addrs) = server_host_and_addr(&server);
1584        let url = format!("{}/content", server.uri());
1585        let result = executor.fetch_html(&url, &host, &addrs).await;
1586        assert!(result.is_ok());
1587        assert_eq!(result.unwrap(), "plain text content");
1588    }
1589
1590    #[test]
1591    fn tool_definitions_returns_web_scrape_and_fetch() {
1592        let config = ScrapeConfig::default();
1593        let executor = WebScrapeExecutor::new(&config);
1594        let defs = executor.tool_definitions();
1595        assert_eq!(defs.len(), 2);
1596        assert_eq!(defs[0].id, "web_scrape");
1597        assert_eq!(
1598            defs[0].invocation,
1599            crate::registry::InvocationHint::FencedBlock("scrape")
1600        );
1601        assert_eq!(defs[1].id, "fetch");
1602        assert_eq!(
1603            defs[1].invocation,
1604            crate::registry::InvocationHint::ToolCall
1605        );
1606    }
1607
1608    #[test]
1609    fn tool_definitions_schema_has_all_params() {
1610        let config = ScrapeConfig::default();
1611        let executor = WebScrapeExecutor::new(&config);
1612        let defs = executor.tool_definitions();
1613        let obj = defs[0].schema.as_object().unwrap();
1614        let props = obj["properties"].as_object().unwrap();
1615        assert!(props.contains_key("url"));
1616        assert!(props.contains_key("select"));
1617        assert!(props.contains_key("extract"));
1618        assert!(props.contains_key("limit"));
1619        let req = obj["required"].as_array().unwrap();
1620        assert!(req.iter().any(|v| v.as_str() == Some("url")));
1621        assert!(req.iter().any(|v| v.as_str() == Some("select")));
1622        assert!(!req.iter().any(|v| v.as_str() == Some("extract")));
1623    }
1624
1625    // --- is_private_host: new domain checks (AUD-02) ---
1626
1627    #[test]
1628    fn subdomain_localhost_blocked() {
1629        let host: url::Host<&str> = url::Host::Domain("foo.localhost");
1630        assert!(is_private_host(&host));
1631    }
1632
1633    #[test]
1634    fn internal_tld_blocked() {
1635        let host: url::Host<&str> = url::Host::Domain("service.internal");
1636        assert!(is_private_host(&host));
1637    }
1638
1639    #[test]
1640    fn local_tld_blocked() {
1641        let host: url::Host<&str> = url::Host::Domain("printer.local");
1642        assert!(is_private_host(&host));
1643    }
1644
1645    #[test]
1646    fn public_domain_not_blocked() {
1647        let host: url::Host<&str> = url::Host::Domain("example.com");
1648        assert!(!is_private_host(&host));
1649    }
1650
1651    // --- resolve_and_validate: private IP rejection ---
1652
1653    #[tokio::test]
1654    async fn resolve_loopback_rejected() {
1655        // 127.0.0.1 resolves directly (literal IP in DNS query)
1656        let url = url::Url::parse("https://127.0.0.1/path").unwrap();
1657        // validate_url catches this before resolve_and_validate, but test directly
1658        let result = resolve_and_validate(&url).await;
1659        assert!(
1660            result.is_err(),
1661            "loopback IP must be rejected by resolve_and_validate"
1662        );
1663        let err = result.unwrap_err();
1664        assert!(matches!(err, crate::executor::ToolError::Blocked { .. }));
1665    }
1666
1667    #[tokio::test]
1668    async fn resolve_private_10_rejected() {
1669        let url = url::Url::parse("https://10.0.0.1/path").unwrap();
1670        let result = resolve_and_validate(&url).await;
1671        assert!(result.is_err());
1672        assert!(matches!(
1673            result.unwrap_err(),
1674            crate::executor::ToolError::Blocked { .. }
1675        ));
1676    }
1677
1678    #[tokio::test]
1679    async fn resolve_private_192_rejected() {
1680        let url = url::Url::parse("https://192.168.1.1/path").unwrap();
1681        let result = resolve_and_validate(&url).await;
1682        assert!(result.is_err());
1683        assert!(matches!(
1684            result.unwrap_err(),
1685            crate::executor::ToolError::Blocked { .. }
1686        ));
1687    }
1688
1689    #[tokio::test]
1690    async fn resolve_ipv6_loopback_rejected() {
1691        let url = url::Url::parse("https://[::1]/path").unwrap();
1692        let result = resolve_and_validate(&url).await;
1693        assert!(result.is_err());
1694        assert!(matches!(
1695            result.unwrap_err(),
1696            crate::executor::ToolError::Blocked { .. }
1697        ));
1698    }
1699
1700    #[tokio::test]
1701    async fn resolve_no_host_returns_ok() {
1702        // URL without a resolvable host — should pass through
1703        let url = url::Url::parse("https://example.com/path").unwrap();
1704        // We can't do a live DNS test, but we can verify a URL with no host
1705        let url_no_host = url::Url::parse("data:text/plain,hello").unwrap();
1706        // data: URLs have no host; resolve_and_validate should return Ok with empty addrs
1707        let result = resolve_and_validate(&url_no_host).await;
1708        assert!(result.is_ok());
1709        let (host, addrs) = result.unwrap();
1710        assert!(host.is_empty());
1711        assert!(addrs.is_empty());
1712        drop(url);
1713        drop(url_no_host);
1714    }
1715
1716    // --- audit logging ---
1717
1718    /// Helper: build an `AuditLogger` writing to a temp file, and return the logger + path.
1719    async fn make_file_audit_logger(
1720        dir: &tempfile::TempDir,
1721    ) -> (
1722        std::sync::Arc<crate::audit::AuditLogger>,
1723        std::path::PathBuf,
1724    ) {
1725        use crate::audit::AuditLogger;
1726        use crate::config::AuditConfig;
1727        let path = dir.path().join("audit.log");
1728        let config = AuditConfig {
1729            enabled: true,
1730            destination: path.display().to_string(),
1731        };
1732        let logger = std::sync::Arc::new(AuditLogger::from_config(&config).await.unwrap());
1733        (logger, path)
1734    }
1735
1736    #[tokio::test]
1737    async fn with_audit_sets_logger() {
1738        let config = ScrapeConfig::default();
1739        let executor = WebScrapeExecutor::new(&config);
1740        assert!(executor.audit_logger.is_none());
1741
1742        let dir = tempfile::tempdir().unwrap();
1743        let (logger, _path) = make_file_audit_logger(&dir).await;
1744        let executor = executor.with_audit(logger);
1745        assert!(executor.audit_logger.is_some());
1746    }
1747
1748    #[test]
1749    fn tool_error_to_audit_result_blocked_maps_correctly() {
1750        let err = ToolError::Blocked {
1751            command: "scheme not allowed: http".into(),
1752        };
1753        let result = tool_error_to_audit_result(&err);
1754        assert!(
1755            matches!(result, AuditResult::Blocked { reason } if reason == "scheme not allowed: http")
1756        );
1757    }
1758
1759    #[test]
1760    fn tool_error_to_audit_result_timeout_maps_correctly() {
1761        let err = ToolError::Timeout { timeout_secs: 15 };
1762        let result = tool_error_to_audit_result(&err);
1763        assert!(matches!(result, AuditResult::Timeout));
1764    }
1765
1766    #[test]
1767    fn tool_error_to_audit_result_execution_error_maps_correctly() {
1768        let err = ToolError::Execution(std::io::Error::other("connection refused"));
1769        let result = tool_error_to_audit_result(&err);
1770        assert!(
1771            matches!(result, AuditResult::Error { message } if message.contains("connection refused"))
1772        );
1773    }
1774
1775    #[tokio::test]
1776    async fn fetch_audit_blocked_url_logged() {
1777        let dir = tempfile::tempdir().unwrap();
1778        let (logger, log_path) = make_file_audit_logger(&dir).await;
1779
1780        let config = ScrapeConfig::default();
1781        let executor = WebScrapeExecutor::new(&config).with_audit(logger);
1782
1783        let call = crate::executor::ToolCall {
1784            tool_id: "fetch".to_owned(),
1785            params: {
1786                let mut m = serde_json::Map::new();
1787                m.insert("url".to_owned(), serde_json::json!("http://example.com"));
1788                m
1789            },
1790        };
1791        let result = executor.execute_tool_call(&call).await;
1792        assert!(matches!(result, Err(ToolError::Blocked { .. })));
1793
1794        let content = tokio::fs::read_to_string(&log_path).await.unwrap();
1795        assert!(
1796            content.contains("\"tool\":\"fetch\""),
1797            "expected tool=fetch in audit: {content}"
1798        );
1799        assert!(
1800            content.contains("\"type\":\"blocked\""),
1801            "expected type=blocked in audit: {content}"
1802        );
1803        assert!(
1804            content.contains("http://example.com"),
1805            "expected URL in audit command field: {content}"
1806        );
1807    }
1808
1809    #[tokio::test]
1810    async fn log_audit_success_writes_to_file() {
1811        let dir = tempfile::tempdir().unwrap();
1812        let (logger, log_path) = make_file_audit_logger(&dir).await;
1813
1814        let config = ScrapeConfig::default();
1815        let executor = WebScrapeExecutor::new(&config).with_audit(logger);
1816
1817        executor
1818            .log_audit(
1819                "fetch",
1820                "https://example.com/page",
1821                AuditResult::Success,
1822                42,
1823            )
1824            .await;
1825
1826        let content = tokio::fs::read_to_string(&log_path).await.unwrap();
1827        assert!(
1828            content.contains("\"tool\":\"fetch\""),
1829            "expected tool=fetch in audit: {content}"
1830        );
1831        assert!(
1832            content.contains("\"type\":\"success\""),
1833            "expected type=success in audit: {content}"
1834        );
1835        assert!(
1836            content.contains("\"command\":\"https://example.com/page\""),
1837            "expected command URL in audit: {content}"
1838        );
1839        assert!(
1840            content.contains("\"duration_ms\":42"),
1841            "expected duration_ms in audit: {content}"
1842        );
1843    }
1844
1845    #[tokio::test]
1846    async fn log_audit_blocked_writes_to_file() {
1847        let dir = tempfile::tempdir().unwrap();
1848        let (logger, log_path) = make_file_audit_logger(&dir).await;
1849
1850        let config = ScrapeConfig::default();
1851        let executor = WebScrapeExecutor::new(&config).with_audit(logger);
1852
1853        executor
1854            .log_audit(
1855                "web_scrape",
1856                "http://evil.com/page",
1857                AuditResult::Blocked {
1858                    reason: "scheme not allowed: http".into(),
1859                },
1860                0,
1861            )
1862            .await;
1863
1864        let content = tokio::fs::read_to_string(&log_path).await.unwrap();
1865        assert!(
1866            content.contains("\"tool\":\"web_scrape\""),
1867            "expected tool=web_scrape in audit: {content}"
1868        );
1869        assert!(
1870            content.contains("\"type\":\"blocked\""),
1871            "expected type=blocked in audit: {content}"
1872        );
1873        assert!(
1874            content.contains("scheme not allowed"),
1875            "expected block reason in audit: {content}"
1876        );
1877    }
1878
1879    #[tokio::test]
1880    async fn web_scrape_audit_blocked_url_logged() {
1881        let dir = tempfile::tempdir().unwrap();
1882        let (logger, log_path) = make_file_audit_logger(&dir).await;
1883
1884        let config = ScrapeConfig::default();
1885        let executor = WebScrapeExecutor::new(&config).with_audit(logger);
1886
1887        let call = crate::executor::ToolCall {
1888            tool_id: "web_scrape".to_owned(),
1889            params: {
1890                let mut m = serde_json::Map::new();
1891                m.insert("url".to_owned(), serde_json::json!("http://example.com"));
1892                m.insert("select".to_owned(), serde_json::json!("h1"));
1893                m
1894            },
1895        };
1896        let result = executor.execute_tool_call(&call).await;
1897        assert!(matches!(result, Err(ToolError::Blocked { .. })));
1898
1899        let content = tokio::fs::read_to_string(&log_path).await.unwrap();
1900        assert!(
1901            content.contains("\"tool\":\"web_scrape\""),
1902            "expected tool=web_scrape in audit: {content}"
1903        );
1904        assert!(
1905            content.contains("\"type\":\"blocked\""),
1906            "expected type=blocked in audit: {content}"
1907        );
1908    }
1909
1910    #[tokio::test]
1911    async fn no_audit_logger_does_not_panic_on_blocked_fetch() {
1912        let config = ScrapeConfig::default();
1913        let executor = WebScrapeExecutor::new(&config);
1914        assert!(executor.audit_logger.is_none());
1915
1916        let call = crate::executor::ToolCall {
1917            tool_id: "fetch".to_owned(),
1918            params: {
1919                let mut m = serde_json::Map::new();
1920                m.insert("url".to_owned(), serde_json::json!("http://example.com"));
1921                m
1922            },
1923        };
1924        // Must not panic even without an audit logger
1925        let result = executor.execute_tool_call(&call).await;
1926        assert!(matches!(result, Err(ToolError::Blocked { .. })));
1927    }
1928
1929    // CR-10: fetch end-to-end via execute_tool_call -> handle_fetch -> fetch_html
1930    #[tokio::test]
1931    async fn fetch_execute_tool_call_end_to_end() {
1932        use wiremock::matchers::{method, path};
1933        use wiremock::{Mock, ResponseTemplate};
1934
1935        let (executor, server) = mock_server_executor().await;
1936        Mock::given(method("GET"))
1937            .and(path("/e2e"))
1938            .respond_with(ResponseTemplate::new(200).set_body_string("<h1>end-to-end</h1>"))
1939            .mount(&server)
1940            .await;
1941
1942        let (host, addrs) = server_host_and_addr(&server);
1943        // Call fetch_html directly (bypassing SSRF guard for loopback mock server)
1944        let result = executor
1945            .fetch_html(&format!("{}/e2e", server.uri()), &host, &addrs)
1946            .await;
1947        assert!(result.is_ok());
1948        assert!(result.unwrap().contains("end-to-end"));
1949    }
1950}