1use std::net::{IpAddr, SocketAddr};
5use std::sync::Arc;
6use std::time::{Duration, Instant};
7
8use schemars::JsonSchema;
9use serde::Deserialize;
10use url::Url;
11
12use crate::audit::{AuditEntry, AuditLogger, AuditResult, chrono_now};
13use crate::config::ScrapeConfig;
14use crate::executor::{ToolCall, ToolError, ToolExecutor, ToolOutput, deserialize_params};
15use crate::net::is_private_ip;
16
17#[derive(Debug, Deserialize, JsonSchema)]
18struct FetchParams {
19 url: String,
21}
22
23#[derive(Debug, Deserialize, JsonSchema)]
24struct ScrapeInstruction {
25 url: String,
27 select: String,
29 #[serde(default = "default_extract")]
31 extract: String,
32 limit: Option<usize>,
34}
35
36fn default_extract() -> String {
37 "text".into()
38}
39
40#[derive(Debug)]
41enum ExtractMode {
42 Text,
43 Html,
44 Attr(String),
45}
46
47impl ExtractMode {
48 fn parse(s: &str) -> Self {
49 match s {
50 "text" => Self::Text,
51 "html" => Self::Html,
52 attr if attr.starts_with("attr:") => {
53 Self::Attr(attr.strip_prefix("attr:").unwrap_or(attr).to_owned())
54 }
55 _ => Self::Text,
56 }
57 }
58}
59
60#[derive(Debug)]
65pub struct WebScrapeExecutor {
66 timeout: Duration,
67 max_body_bytes: usize,
68 audit_logger: Option<Arc<AuditLogger>>,
69}
70
71impl WebScrapeExecutor {
72 #[must_use]
73 pub fn new(config: &ScrapeConfig) -> Self {
74 Self {
75 timeout: Duration::from_secs(config.timeout),
76 max_body_bytes: config.max_body_bytes,
77 audit_logger: None,
78 }
79 }
80
81 #[must_use]
82 pub fn with_audit(mut self, logger: Arc<AuditLogger>) -> Self {
83 self.audit_logger = Some(logger);
84 self
85 }
86
87 fn build_client(&self, host: &str, addrs: &[SocketAddr]) -> reqwest::Client {
88 let mut builder = reqwest::Client::builder()
89 .timeout(self.timeout)
90 .redirect(reqwest::redirect::Policy::none());
91 builder = builder.resolve_to_addrs(host, addrs);
92 builder.build().unwrap_or_default()
93 }
94}
95
96impl ToolExecutor for WebScrapeExecutor {
97 fn tool_definitions(&self) -> Vec<crate::registry::ToolDef> {
98 use crate::registry::{InvocationHint, ToolDef};
99 vec![
100 ToolDef {
101 id: "web_scrape".into(),
102 description: "Extract structured data from a web page using CSS selectors.\n\nONLY call this tool when the user has explicitly provided a URL in their message, or when a prior tool call returned a URL to retrieve. NEVER construct, guess, or infer a URL from entity names, brand knowledge, or domain patterns.\n\nParameters: url (string, required) - HTTPS URL; select (string, required) - CSS selector; extract (string, optional) - \"text\", \"html\", or \"attr:<name>\"; limit (integer, optional) - max results\nReturns: extracted text/HTML/attribute values, one per line\nErrors: InvalidParams if URL is not HTTPS or selector is empty; Timeout after configured seconds; connection/DNS failures".into(),
103 schema: schemars::schema_for!(ScrapeInstruction),
104 invocation: InvocationHint::FencedBlock("scrape"),
105 },
106 ToolDef {
107 id: "fetch".into(),
108 description: "Fetch a URL and return the response body as plain text.\n\nONLY call this tool when the user has explicitly provided a URL in their message, or when a prior tool call returned a URL to retrieve. NEVER construct, guess, or infer a URL from entity names, brand knowledge, or domain patterns. If no URL is present in the conversation, do not call this tool.\n\nParameters: url (string, required) - HTTPS URL to fetch\nReturns: response body as UTF-8 text, truncated if exceeding max body size\nErrors: InvalidParams if URL is not HTTPS; Timeout; SSRF-blocked private IPs; connection failures".into(),
109 schema: schemars::schema_for!(FetchParams),
110 invocation: InvocationHint::ToolCall,
111 },
112 ]
113 }
114
115 async fn execute(&self, response: &str) -> Result<Option<ToolOutput>, ToolError> {
116 let blocks = extract_scrape_blocks(response);
117 if blocks.is_empty() {
118 return Ok(None);
119 }
120
121 let mut outputs = Vec::with_capacity(blocks.len());
122 #[allow(clippy::cast_possible_truncation)]
123 let blocks_executed = blocks.len() as u32;
124
125 for block in &blocks {
126 let instruction: ScrapeInstruction = serde_json::from_str(block).map_err(|e| {
127 ToolError::Execution(std::io::Error::new(
128 std::io::ErrorKind::InvalidData,
129 e.to_string(),
130 ))
131 })?;
132 let start = Instant::now();
133 let scrape_result = self.scrape_instruction(&instruction).await;
134 #[allow(clippy::cast_possible_truncation)]
135 let duration_ms = start.elapsed().as_millis() as u64;
136 match scrape_result {
137 Ok(output) => {
138 self.log_audit(
139 "web_scrape",
140 &instruction.url,
141 AuditResult::Success,
142 duration_ms,
143 )
144 .await;
145 outputs.push(output);
146 }
147 Err(e) => {
148 let audit_result = tool_error_to_audit_result(&e);
149 self.log_audit("web_scrape", &instruction.url, audit_result, duration_ms)
150 .await;
151 return Err(e);
152 }
153 }
154 }
155
156 Ok(Some(ToolOutput {
157 tool_name: "web-scrape".to_owned(),
158 summary: outputs.join("\n\n"),
159 blocks_executed,
160 filter_stats: None,
161 diff: None,
162 streamed: false,
163 terminal_id: None,
164 locations: None,
165 raw_response: None,
166 }))
167 }
168
169 async fn execute_tool_call(&self, call: &ToolCall) -> Result<Option<ToolOutput>, ToolError> {
170 match call.tool_id.as_str() {
171 "web_scrape" => {
172 let instruction: ScrapeInstruction = deserialize_params(&call.params)?;
173 let start = Instant::now();
174 let result = self.scrape_instruction(&instruction).await;
175 #[allow(clippy::cast_possible_truncation)]
176 let duration_ms = start.elapsed().as_millis() as u64;
177 match result {
178 Ok(output) => {
179 self.log_audit(
180 "web_scrape",
181 &instruction.url,
182 AuditResult::Success,
183 duration_ms,
184 )
185 .await;
186 Ok(Some(ToolOutput {
187 tool_name: "web-scrape".to_owned(),
188 summary: output,
189 blocks_executed: 1,
190 filter_stats: None,
191 diff: None,
192 streamed: false,
193 terminal_id: None,
194 locations: None,
195 raw_response: None,
196 }))
197 }
198 Err(e) => {
199 let audit_result = tool_error_to_audit_result(&e);
200 self.log_audit("web_scrape", &instruction.url, audit_result, duration_ms)
201 .await;
202 Err(e)
203 }
204 }
205 }
206 "fetch" => {
207 let p: FetchParams = deserialize_params(&call.params)?;
208 let start = Instant::now();
209 let result = self.handle_fetch(&p).await;
210 #[allow(clippy::cast_possible_truncation)]
211 let duration_ms = start.elapsed().as_millis() as u64;
212 match result {
213 Ok(output) => {
214 self.log_audit("fetch", &p.url, AuditResult::Success, duration_ms)
215 .await;
216 Ok(Some(ToolOutput {
217 tool_name: "fetch".to_owned(),
218 summary: output,
219 blocks_executed: 1,
220 filter_stats: None,
221 diff: None,
222 streamed: false,
223 terminal_id: None,
224 locations: None,
225 raw_response: None,
226 }))
227 }
228 Err(e) => {
229 let audit_result = tool_error_to_audit_result(&e);
230 self.log_audit("fetch", &p.url, audit_result, duration_ms)
231 .await;
232 Err(e)
233 }
234 }
235 }
236 _ => Ok(None),
237 }
238 }
239
240 fn is_tool_retryable(&self, tool_id: &str) -> bool {
241 matches!(tool_id, "web_scrape" | "fetch")
242 }
243}
244
245fn tool_error_to_audit_result(e: &ToolError) -> AuditResult {
246 match e {
247 ToolError::Blocked { command } => AuditResult::Blocked {
248 reason: command.clone(),
249 },
250 ToolError::Timeout { .. } => AuditResult::Timeout,
251 _ => AuditResult::Error {
252 message: e.to_string(),
253 },
254 }
255}
256
257impl WebScrapeExecutor {
258 async fn log_audit(&self, tool: &str, command: &str, result: AuditResult, duration_ms: u64) {
259 if let Some(ref logger) = self.audit_logger {
260 let entry = AuditEntry {
261 timestamp: chrono_now(),
262 tool: tool.into(),
263 command: command.into(),
264 result,
265 duration_ms,
266 error_category: None,
267 };
268 logger.log(&entry).await;
269 }
270 }
271
272 async fn handle_fetch(&self, params: &FetchParams) -> Result<String, ToolError> {
273 let parsed = validate_url(¶ms.url)?;
274 let (host, addrs) = resolve_and_validate(&parsed).await?;
275 self.fetch_html(¶ms.url, &host, &addrs).await
276 }
277
278 async fn scrape_instruction(
279 &self,
280 instruction: &ScrapeInstruction,
281 ) -> Result<String, ToolError> {
282 let parsed = validate_url(&instruction.url)?;
283 let (host, addrs) = resolve_and_validate(&parsed).await?;
284 let html = self.fetch_html(&instruction.url, &host, &addrs).await?;
285 let selector = instruction.select.clone();
286 let extract = ExtractMode::parse(&instruction.extract);
287 let limit = instruction.limit.unwrap_or(10);
288 tokio::task::spawn_blocking(move || parse_and_extract(&html, &selector, &extract, limit))
289 .await
290 .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())))?
291 }
292
293 async fn fetch_html(
303 &self,
304 url: &str,
305 host: &str,
306 addrs: &[SocketAddr],
307 ) -> Result<String, ToolError> {
308 const MAX_REDIRECTS: usize = 3;
309
310 let mut current_url = url.to_owned();
311 let mut current_host = host.to_owned();
312 let mut current_addrs = addrs.to_vec();
313
314 for hop in 0..=MAX_REDIRECTS {
315 let client = self.build_client(¤t_host, ¤t_addrs);
317 let resp = client
318 .get(¤t_url)
319 .send()
320 .await
321 .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())))?;
322
323 let status = resp.status();
324
325 if status.is_redirection() {
326 if hop == MAX_REDIRECTS {
327 return Err(ToolError::Execution(std::io::Error::other(
328 "too many redirects",
329 )));
330 }
331
332 let location = resp
333 .headers()
334 .get(reqwest::header::LOCATION)
335 .and_then(|v| v.to_str().ok())
336 .ok_or_else(|| {
337 ToolError::Execution(std::io::Error::other("redirect with no Location"))
338 })?;
339
340 let base = Url::parse(¤t_url)
342 .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())))?;
343 let next_url = base
344 .join(location)
345 .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())))?;
346
347 let validated = validate_url(next_url.as_str())?;
348 let (next_host, next_addrs) = resolve_and_validate(&validated).await?;
349
350 current_url = next_url.to_string();
351 current_host = next_host;
352 current_addrs = next_addrs;
353 continue;
354 }
355
356 if !status.is_success() {
357 return Err(ToolError::Http {
358 status: status.as_u16(),
359 message: status.canonical_reason().unwrap_or("unknown").to_owned(),
360 });
361 }
362
363 let bytes = resp
364 .bytes()
365 .await
366 .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())))?;
367
368 if bytes.len() > self.max_body_bytes {
369 return Err(ToolError::Execution(std::io::Error::other(format!(
370 "response too large: {} bytes (max: {})",
371 bytes.len(),
372 self.max_body_bytes,
373 ))));
374 }
375
376 return String::from_utf8(bytes.to_vec())
377 .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())));
378 }
379
380 Err(ToolError::Execution(std::io::Error::other(
381 "too many redirects",
382 )))
383 }
384}
385
386fn extract_scrape_blocks(text: &str) -> Vec<&str> {
387 crate::executor::extract_fenced_blocks(text, "scrape")
388}
389
390fn validate_url(raw: &str) -> Result<Url, ToolError> {
391 let parsed = Url::parse(raw).map_err(|_| ToolError::Blocked {
392 command: format!("invalid URL: {raw}"),
393 })?;
394
395 if parsed.scheme() != "https" {
396 return Err(ToolError::Blocked {
397 command: format!("scheme not allowed: {}", parsed.scheme()),
398 });
399 }
400
401 if let Some(host) = parsed.host()
402 && is_private_host(&host)
403 {
404 return Err(ToolError::Blocked {
405 command: format!(
406 "private/local host blocked: {}",
407 parsed.host_str().unwrap_or("")
408 ),
409 });
410 }
411
412 Ok(parsed)
413}
414
415fn is_private_host(host: &url::Host<&str>) -> bool {
416 match host {
417 url::Host::Domain(d) => {
418 #[allow(clippy::case_sensitive_file_extension_comparisons)]
421 {
422 *d == "localhost"
423 || d.ends_with(".localhost")
424 || d.ends_with(".internal")
425 || d.ends_with(".local")
426 }
427 }
428 url::Host::Ipv4(v4) => is_private_ip(IpAddr::V4(*v4)),
429 url::Host::Ipv6(v6) => is_private_ip(IpAddr::V6(*v6)),
430 }
431}
432
433async fn resolve_and_validate(url: &Url) -> Result<(String, Vec<SocketAddr>), ToolError> {
439 let Some(host) = url.host_str() else {
440 return Ok((String::new(), vec![]));
441 };
442 let port = url.port_or_known_default().unwrap_or(443);
443 let addrs: Vec<SocketAddr> = tokio::net::lookup_host(format!("{host}:{port}"))
444 .await
445 .map_err(|e| ToolError::Blocked {
446 command: format!("DNS resolution failed: {e}"),
447 })?
448 .collect();
449 for addr in &addrs {
450 if is_private_ip(addr.ip()) {
451 return Err(ToolError::Blocked {
452 command: format!("SSRF protection: private IP {} for host {host}", addr.ip()),
453 });
454 }
455 }
456 Ok((host.to_owned(), addrs))
457}
458
459fn parse_and_extract(
460 html: &str,
461 selector: &str,
462 extract: &ExtractMode,
463 limit: usize,
464) -> Result<String, ToolError> {
465 let soup = scrape_core::Soup::parse(html);
466
467 let tags = soup.find_all(selector).map_err(|e| {
468 ToolError::Execution(std::io::Error::new(
469 std::io::ErrorKind::InvalidData,
470 format!("invalid selector: {e}"),
471 ))
472 })?;
473
474 let mut results = Vec::new();
475
476 for tag in tags.into_iter().take(limit) {
477 let value = match extract {
478 ExtractMode::Text => tag.text(),
479 ExtractMode::Html => tag.inner_html(),
480 ExtractMode::Attr(name) => tag.get(name).unwrap_or_default().to_owned(),
481 };
482 if !value.trim().is_empty() {
483 results.push(value.trim().to_owned());
484 }
485 }
486
487 if results.is_empty() {
488 Ok(format!("No results for selector: {selector}"))
489 } else {
490 Ok(results.join("\n"))
491 }
492}
493
494#[cfg(test)]
495mod tests {
496 use super::*;
497
498 #[test]
501 fn extract_single_block() {
502 let text =
503 "Here:\n```scrape\n{\"url\":\"https://example.com\",\"select\":\"h1\"}\n```\nDone.";
504 let blocks = extract_scrape_blocks(text);
505 assert_eq!(blocks.len(), 1);
506 assert!(blocks[0].contains("example.com"));
507 }
508
509 #[test]
510 fn extract_multiple_blocks() {
511 let text = "```scrape\n{\"url\":\"https://a.com\",\"select\":\"h1\"}\n```\ntext\n```scrape\n{\"url\":\"https://b.com\",\"select\":\"p\"}\n```";
512 let blocks = extract_scrape_blocks(text);
513 assert_eq!(blocks.len(), 2);
514 }
515
516 #[test]
517 fn no_blocks_returns_empty() {
518 let blocks = extract_scrape_blocks("plain text, no code blocks");
519 assert!(blocks.is_empty());
520 }
521
522 #[test]
523 fn unclosed_block_ignored() {
524 let blocks = extract_scrape_blocks("```scrape\n{\"url\":\"https://x.com\"}");
525 assert!(blocks.is_empty());
526 }
527
528 #[test]
529 fn non_scrape_block_ignored() {
530 let text =
531 "```bash\necho hi\n```\n```scrape\n{\"url\":\"https://x.com\",\"select\":\"h1\"}\n```";
532 let blocks = extract_scrape_blocks(text);
533 assert_eq!(blocks.len(), 1);
534 assert!(blocks[0].contains("x.com"));
535 }
536
537 #[test]
538 fn multiline_json_block() {
539 let text =
540 "```scrape\n{\n \"url\": \"https://example.com\",\n \"select\": \"h1\"\n}\n```";
541 let blocks = extract_scrape_blocks(text);
542 assert_eq!(blocks.len(), 1);
543 let instr: ScrapeInstruction = serde_json::from_str(blocks[0]).unwrap();
544 assert_eq!(instr.url, "https://example.com");
545 }
546
547 #[test]
550 fn parse_valid_instruction() {
551 let json = r#"{"url":"https://example.com","select":"h1","extract":"text","limit":5}"#;
552 let instr: ScrapeInstruction = serde_json::from_str(json).unwrap();
553 assert_eq!(instr.url, "https://example.com");
554 assert_eq!(instr.select, "h1");
555 assert_eq!(instr.extract, "text");
556 assert_eq!(instr.limit, Some(5));
557 }
558
559 #[test]
560 fn parse_minimal_instruction() {
561 let json = r#"{"url":"https://example.com","select":"p"}"#;
562 let instr: ScrapeInstruction = serde_json::from_str(json).unwrap();
563 assert_eq!(instr.extract, "text");
564 assert!(instr.limit.is_none());
565 }
566
567 #[test]
568 fn parse_attr_extract() {
569 let json = r#"{"url":"https://example.com","select":"a","extract":"attr:href"}"#;
570 let instr: ScrapeInstruction = serde_json::from_str(json).unwrap();
571 assert_eq!(instr.extract, "attr:href");
572 }
573
574 #[test]
575 fn parse_invalid_json_errors() {
576 let result = serde_json::from_str::<ScrapeInstruction>("not json");
577 assert!(result.is_err());
578 }
579
580 #[test]
583 fn extract_mode_text() {
584 assert!(matches!(ExtractMode::parse("text"), ExtractMode::Text));
585 }
586
587 #[test]
588 fn extract_mode_html() {
589 assert!(matches!(ExtractMode::parse("html"), ExtractMode::Html));
590 }
591
592 #[test]
593 fn extract_mode_attr() {
594 let mode = ExtractMode::parse("attr:href");
595 assert!(matches!(mode, ExtractMode::Attr(ref s) if s == "href"));
596 }
597
598 #[test]
599 fn extract_mode_unknown_defaults_to_text() {
600 assert!(matches!(ExtractMode::parse("unknown"), ExtractMode::Text));
601 }
602
603 #[test]
606 fn valid_https_url() {
607 assert!(validate_url("https://example.com").is_ok());
608 }
609
610 #[test]
611 fn http_rejected() {
612 let err = validate_url("http://example.com").unwrap_err();
613 assert!(matches!(err, ToolError::Blocked { .. }));
614 }
615
616 #[test]
617 fn ftp_rejected() {
618 let err = validate_url("ftp://files.example.com").unwrap_err();
619 assert!(matches!(err, ToolError::Blocked { .. }));
620 }
621
622 #[test]
623 fn file_rejected() {
624 let err = validate_url("file:///etc/passwd").unwrap_err();
625 assert!(matches!(err, ToolError::Blocked { .. }));
626 }
627
628 #[test]
629 fn invalid_url_rejected() {
630 let err = validate_url("not a url").unwrap_err();
631 assert!(matches!(err, ToolError::Blocked { .. }));
632 }
633
634 #[test]
635 fn localhost_blocked() {
636 let err = validate_url("https://localhost/path").unwrap_err();
637 assert!(matches!(err, ToolError::Blocked { .. }));
638 }
639
640 #[test]
641 fn loopback_ip_blocked() {
642 let err = validate_url("https://127.0.0.1/path").unwrap_err();
643 assert!(matches!(err, ToolError::Blocked { .. }));
644 }
645
646 #[test]
647 fn private_10_blocked() {
648 let err = validate_url("https://10.0.0.1/api").unwrap_err();
649 assert!(matches!(err, ToolError::Blocked { .. }));
650 }
651
652 #[test]
653 fn private_172_blocked() {
654 let err = validate_url("https://172.16.0.1/api").unwrap_err();
655 assert!(matches!(err, ToolError::Blocked { .. }));
656 }
657
658 #[test]
659 fn private_192_blocked() {
660 let err = validate_url("https://192.168.1.1/api").unwrap_err();
661 assert!(matches!(err, ToolError::Blocked { .. }));
662 }
663
664 #[test]
665 fn ipv6_loopback_blocked() {
666 let err = validate_url("https://[::1]/path").unwrap_err();
667 assert!(matches!(err, ToolError::Blocked { .. }));
668 }
669
670 #[test]
671 fn public_ip_allowed() {
672 assert!(validate_url("https://93.184.216.34/page").is_ok());
673 }
674
675 #[test]
678 fn extract_text_from_html() {
679 let html = "<html><body><h1>Hello World</h1><p>Content</p></body></html>";
680 let result = parse_and_extract(html, "h1", &ExtractMode::Text, 10).unwrap();
681 assert_eq!(result, "Hello World");
682 }
683
684 #[test]
685 fn extract_multiple_elements() {
686 let html = "<ul><li>A</li><li>B</li><li>C</li></ul>";
687 let result = parse_and_extract(html, "li", &ExtractMode::Text, 10).unwrap();
688 assert_eq!(result, "A\nB\nC");
689 }
690
691 #[test]
692 fn extract_with_limit() {
693 let html = "<ul><li>A</li><li>B</li><li>C</li></ul>";
694 let result = parse_and_extract(html, "li", &ExtractMode::Text, 2).unwrap();
695 assert_eq!(result, "A\nB");
696 }
697
698 #[test]
699 fn extract_attr_href() {
700 let html = r#"<a href="https://example.com">Link</a>"#;
701 let result =
702 parse_and_extract(html, "a", &ExtractMode::Attr("href".to_owned()), 10).unwrap();
703 assert_eq!(result, "https://example.com");
704 }
705
706 #[test]
707 fn extract_inner_html() {
708 let html = "<div><span>inner</span></div>";
709 let result = parse_and_extract(html, "div", &ExtractMode::Html, 10).unwrap();
710 assert!(result.contains("<span>inner</span>"));
711 }
712
713 #[test]
714 fn no_matches_returns_message() {
715 let html = "<html><body><p>text</p></body></html>";
716 let result = parse_and_extract(html, "h1", &ExtractMode::Text, 10).unwrap();
717 assert!(result.starts_with("No results for selector:"));
718 }
719
720 #[test]
721 fn empty_text_skipped() {
722 let html = "<ul><li> </li><li>A</li></ul>";
723 let result = parse_and_extract(html, "li", &ExtractMode::Text, 10).unwrap();
724 assert_eq!(result, "A");
725 }
726
727 #[test]
728 fn invalid_selector_errors() {
729 let html = "<html><body></body></html>";
730 let result = parse_and_extract(html, "[[[invalid", &ExtractMode::Text, 10);
731 assert!(result.is_err());
732 }
733
734 #[test]
735 fn empty_html_returns_no_results() {
736 let result = parse_and_extract("", "h1", &ExtractMode::Text, 10).unwrap();
737 assert!(result.starts_with("No results for selector:"));
738 }
739
740 #[test]
741 fn nested_selector() {
742 let html = "<div><span>inner</span></div><span>outer</span>";
743 let result = parse_and_extract(html, "div > span", &ExtractMode::Text, 10).unwrap();
744 assert_eq!(result, "inner");
745 }
746
747 #[test]
748 fn attr_missing_returns_empty() {
749 let html = r"<a>No href</a>";
750 let result =
751 parse_and_extract(html, "a", &ExtractMode::Attr("href".to_owned()), 10).unwrap();
752 assert!(result.starts_with("No results for selector:"));
753 }
754
755 #[test]
756 fn extract_html_mode() {
757 let html = "<div><b>bold</b> text</div>";
758 let result = parse_and_extract(html, "div", &ExtractMode::Html, 10).unwrap();
759 assert!(result.contains("<b>bold</b>"));
760 }
761
762 #[test]
763 fn limit_zero_returns_no_results() {
764 let html = "<ul><li>A</li><li>B</li></ul>";
765 let result = parse_and_extract(html, "li", &ExtractMode::Text, 0).unwrap();
766 assert!(result.starts_with("No results for selector:"));
767 }
768
769 #[test]
772 fn url_with_port_allowed() {
773 assert!(validate_url("https://example.com:8443/path").is_ok());
774 }
775
776 #[test]
777 fn link_local_ip_blocked() {
778 let err = validate_url("https://169.254.1.1/path").unwrap_err();
779 assert!(matches!(err, ToolError::Blocked { .. }));
780 }
781
782 #[test]
783 fn url_no_scheme_rejected() {
784 let err = validate_url("example.com/path").unwrap_err();
785 assert!(matches!(err, ToolError::Blocked { .. }));
786 }
787
788 #[test]
789 fn unspecified_ipv4_blocked() {
790 let err = validate_url("https://0.0.0.0/path").unwrap_err();
791 assert!(matches!(err, ToolError::Blocked { .. }));
792 }
793
794 #[test]
795 fn broadcast_ipv4_blocked() {
796 let err = validate_url("https://255.255.255.255/path").unwrap_err();
797 assert!(matches!(err, ToolError::Blocked { .. }));
798 }
799
800 #[test]
801 fn ipv6_link_local_blocked() {
802 let err = validate_url("https://[fe80::1]/path").unwrap_err();
803 assert!(matches!(err, ToolError::Blocked { .. }));
804 }
805
806 #[test]
807 fn ipv6_unique_local_blocked() {
808 let err = validate_url("https://[fd12::1]/path").unwrap_err();
809 assert!(matches!(err, ToolError::Blocked { .. }));
810 }
811
812 #[test]
813 fn ipv4_mapped_ipv6_loopback_blocked() {
814 let err = validate_url("https://[::ffff:127.0.0.1]/path").unwrap_err();
815 assert!(matches!(err, ToolError::Blocked { .. }));
816 }
817
818 #[test]
819 fn ipv4_mapped_ipv6_private_blocked() {
820 let err = validate_url("https://[::ffff:10.0.0.1]/path").unwrap_err();
821 assert!(matches!(err, ToolError::Blocked { .. }));
822 }
823
824 #[tokio::test]
827 async fn executor_no_blocks_returns_none() {
828 let config = ScrapeConfig::default();
829 let executor = WebScrapeExecutor::new(&config);
830 let result = executor.execute("plain text").await;
831 assert!(result.unwrap().is_none());
832 }
833
834 #[tokio::test]
835 async fn executor_invalid_json_errors() {
836 let config = ScrapeConfig::default();
837 let executor = WebScrapeExecutor::new(&config);
838 let response = "```scrape\nnot json\n```";
839 let result = executor.execute(response).await;
840 assert!(matches!(result, Err(ToolError::Execution(_))));
841 }
842
843 #[tokio::test]
844 async fn executor_blocked_url_errors() {
845 let config = ScrapeConfig::default();
846 let executor = WebScrapeExecutor::new(&config);
847 let response = "```scrape\n{\"url\":\"http://example.com\",\"select\":\"h1\"}\n```";
848 let result = executor.execute(response).await;
849 assert!(matches!(result, Err(ToolError::Blocked { .. })));
850 }
851
852 #[tokio::test]
853 async fn executor_private_ip_blocked() {
854 let config = ScrapeConfig::default();
855 let executor = WebScrapeExecutor::new(&config);
856 let response = "```scrape\n{\"url\":\"https://192.168.1.1/api\",\"select\":\"h1\"}\n```";
857 let result = executor.execute(response).await;
858 assert!(matches!(result, Err(ToolError::Blocked { .. })));
859 }
860
861 #[tokio::test]
862 async fn executor_unreachable_host_returns_error() {
863 let config = ScrapeConfig {
864 timeout: 1,
865 max_body_bytes: 1_048_576,
866 };
867 let executor = WebScrapeExecutor::new(&config);
868 let response = "```scrape\n{\"url\":\"https://192.0.2.1:1/page\",\"select\":\"h1\"}\n```";
869 let result = executor.execute(response).await;
870 assert!(matches!(result, Err(ToolError::Execution(_))));
871 }
872
873 #[tokio::test]
874 async fn executor_localhost_url_blocked() {
875 let config = ScrapeConfig::default();
876 let executor = WebScrapeExecutor::new(&config);
877 let response = "```scrape\n{\"url\":\"https://localhost:9999/api\",\"select\":\"h1\"}\n```";
878 let result = executor.execute(response).await;
879 assert!(matches!(result, Err(ToolError::Blocked { .. })));
880 }
881
882 #[tokio::test]
883 async fn executor_empty_text_returns_none() {
884 let config = ScrapeConfig::default();
885 let executor = WebScrapeExecutor::new(&config);
886 let result = executor.execute("").await;
887 assert!(result.unwrap().is_none());
888 }
889
890 #[tokio::test]
891 async fn executor_multiple_blocks_first_blocked() {
892 let config = ScrapeConfig::default();
893 let executor = WebScrapeExecutor::new(&config);
894 let response = "```scrape\n{\"url\":\"http://evil.com\",\"select\":\"h1\"}\n```\n\
895 ```scrape\n{\"url\":\"https://ok.com\",\"select\":\"h1\"}\n```";
896 let result = executor.execute(response).await;
897 assert!(result.is_err());
898 }
899
900 #[test]
901 fn validate_url_empty_string() {
902 let err = validate_url("").unwrap_err();
903 assert!(matches!(err, ToolError::Blocked { .. }));
904 }
905
906 #[test]
907 fn validate_url_javascript_scheme_blocked() {
908 let err = validate_url("javascript:alert(1)").unwrap_err();
909 assert!(matches!(err, ToolError::Blocked { .. }));
910 }
911
912 #[test]
913 fn validate_url_data_scheme_blocked() {
914 let err = validate_url("data:text/html,<h1>hi</h1>").unwrap_err();
915 assert!(matches!(err, ToolError::Blocked { .. }));
916 }
917
918 #[test]
919 fn is_private_host_public_domain_is_false() {
920 let host: url::Host<&str> = url::Host::Domain("example.com");
921 assert!(!is_private_host(&host));
922 }
923
924 #[test]
925 fn is_private_host_localhost_is_true() {
926 let host: url::Host<&str> = url::Host::Domain("localhost");
927 assert!(is_private_host(&host));
928 }
929
930 #[test]
931 fn is_private_host_ipv6_unspecified_is_true() {
932 let host = url::Host::Ipv6(std::net::Ipv6Addr::UNSPECIFIED);
933 assert!(is_private_host(&host));
934 }
935
936 #[test]
937 fn is_private_host_public_ipv6_is_false() {
938 let host = url::Host::Ipv6("2001:db8::1".parse().unwrap());
939 assert!(!is_private_host(&host));
940 }
941
942 async fn mock_server_executor() -> (WebScrapeExecutor, wiremock::MockServer) {
953 let server = wiremock::MockServer::start().await;
954 let executor = WebScrapeExecutor {
955 timeout: Duration::from_secs(5),
956 max_body_bytes: 1_048_576,
957 audit_logger: None,
958 };
959 (executor, server)
960 }
961
962 fn server_host_and_addr(server: &wiremock::MockServer) -> (String, Vec<std::net::SocketAddr>) {
964 let uri = server.uri();
965 let url = Url::parse(&uri).unwrap();
966 let host = url.host_str().unwrap_or("127.0.0.1").to_owned();
967 let port = url.port().unwrap_or(80);
968 let addr: std::net::SocketAddr = format!("{host}:{port}").parse().unwrap();
969 (host, vec![addr])
970 }
971
972 async fn follow_redirects_raw(
976 executor: &WebScrapeExecutor,
977 start_url: &str,
978 host: &str,
979 addrs: &[std::net::SocketAddr],
980 ) -> Result<String, ToolError> {
981 const MAX_REDIRECTS: usize = 3;
982 let mut current_url = start_url.to_owned();
983 let mut current_host = host.to_owned();
984 let mut current_addrs = addrs.to_vec();
985
986 for hop in 0..=MAX_REDIRECTS {
987 let client = executor.build_client(¤t_host, ¤t_addrs);
988 let resp = client
989 .get(¤t_url)
990 .send()
991 .await
992 .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())))?;
993
994 let status = resp.status();
995
996 if status.is_redirection() {
997 if hop == MAX_REDIRECTS {
998 return Err(ToolError::Execution(std::io::Error::other(
999 "too many redirects",
1000 )));
1001 }
1002
1003 let location = resp
1004 .headers()
1005 .get(reqwest::header::LOCATION)
1006 .and_then(|v| v.to_str().ok())
1007 .ok_or_else(|| {
1008 ToolError::Execution(std::io::Error::other("redirect with no Location"))
1009 })?;
1010
1011 let base = Url::parse(¤t_url)
1012 .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())))?;
1013 let next_url = base
1014 .join(location)
1015 .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())))?;
1016
1017 current_url = next_url.to_string();
1019 let _ = &mut current_host;
1021 let _ = &mut current_addrs;
1022 continue;
1023 }
1024
1025 if !status.is_success() {
1026 return Err(ToolError::Execution(std::io::Error::other(format!(
1027 "HTTP {status}",
1028 ))));
1029 }
1030
1031 let bytes = resp
1032 .bytes()
1033 .await
1034 .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())))?;
1035
1036 if bytes.len() > executor.max_body_bytes {
1037 return Err(ToolError::Execution(std::io::Error::other(format!(
1038 "response too large: {} bytes (max: {})",
1039 bytes.len(),
1040 executor.max_body_bytes,
1041 ))));
1042 }
1043
1044 return String::from_utf8(bytes.to_vec())
1045 .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())));
1046 }
1047
1048 Err(ToolError::Execution(std::io::Error::other(
1049 "too many redirects",
1050 )))
1051 }
1052
1053 #[tokio::test]
1054 async fn fetch_html_success_returns_body() {
1055 use wiremock::matchers::{method, path};
1056 use wiremock::{Mock, ResponseTemplate};
1057
1058 let (executor, server) = mock_server_executor().await;
1059 Mock::given(method("GET"))
1060 .and(path("/page"))
1061 .respond_with(ResponseTemplate::new(200).set_body_string("<h1>OK</h1>"))
1062 .mount(&server)
1063 .await;
1064
1065 let (host, addrs) = server_host_and_addr(&server);
1066 let url = format!("{}/page", server.uri());
1067 let result = executor.fetch_html(&url, &host, &addrs).await;
1068 assert!(result.is_ok(), "expected Ok, got: {result:?}");
1069 assert_eq!(result.unwrap(), "<h1>OK</h1>");
1070 }
1071
1072 #[tokio::test]
1073 async fn fetch_html_non_2xx_returns_error() {
1074 use wiremock::matchers::{method, path};
1075 use wiremock::{Mock, ResponseTemplate};
1076
1077 let (executor, server) = mock_server_executor().await;
1078 Mock::given(method("GET"))
1079 .and(path("/forbidden"))
1080 .respond_with(ResponseTemplate::new(403))
1081 .mount(&server)
1082 .await;
1083
1084 let (host, addrs) = server_host_and_addr(&server);
1085 let url = format!("{}/forbidden", server.uri());
1086 let result = executor.fetch_html(&url, &host, &addrs).await;
1087 assert!(result.is_err());
1088 let msg = result.unwrap_err().to_string();
1089 assert!(msg.contains("403"), "expected 403 in error: {msg}");
1090 }
1091
1092 #[tokio::test]
1093 async fn fetch_html_404_returns_error() {
1094 use wiremock::matchers::{method, path};
1095 use wiremock::{Mock, ResponseTemplate};
1096
1097 let (executor, server) = mock_server_executor().await;
1098 Mock::given(method("GET"))
1099 .and(path("/missing"))
1100 .respond_with(ResponseTemplate::new(404))
1101 .mount(&server)
1102 .await;
1103
1104 let (host, addrs) = server_host_and_addr(&server);
1105 let url = format!("{}/missing", server.uri());
1106 let result = executor.fetch_html(&url, &host, &addrs).await;
1107 assert!(result.is_err());
1108 let msg = result.unwrap_err().to_string();
1109 assert!(msg.contains("404"), "expected 404 in error: {msg}");
1110 }
1111
1112 #[tokio::test]
1113 async fn fetch_html_redirect_no_location_returns_error() {
1114 use wiremock::matchers::{method, path};
1115 use wiremock::{Mock, ResponseTemplate};
1116
1117 let (executor, server) = mock_server_executor().await;
1118 Mock::given(method("GET"))
1120 .and(path("/redirect-no-loc"))
1121 .respond_with(ResponseTemplate::new(302))
1122 .mount(&server)
1123 .await;
1124
1125 let (host, addrs) = server_host_and_addr(&server);
1126 let url = format!("{}/redirect-no-loc", server.uri());
1127 let result = executor.fetch_html(&url, &host, &addrs).await;
1128 assert!(result.is_err());
1129 let msg = result.unwrap_err().to_string();
1130 assert!(
1131 msg.contains("Location") || msg.contains("location"),
1132 "expected Location-related error: {msg}"
1133 );
1134 }
1135
1136 #[tokio::test]
1137 async fn fetch_html_single_redirect_followed() {
1138 use wiremock::matchers::{method, path};
1139 use wiremock::{Mock, ResponseTemplate};
1140
1141 let (executor, server) = mock_server_executor().await;
1142 let final_url = format!("{}/final", server.uri());
1143
1144 Mock::given(method("GET"))
1145 .and(path("/start"))
1146 .respond_with(ResponseTemplate::new(302).insert_header("location", final_url.as_str()))
1147 .mount(&server)
1148 .await;
1149
1150 Mock::given(method("GET"))
1151 .and(path("/final"))
1152 .respond_with(ResponseTemplate::new(200).set_body_string("<p>final</p>"))
1153 .mount(&server)
1154 .await;
1155
1156 let (host, addrs) = server_host_and_addr(&server);
1157 let url = format!("{}/start", server.uri());
1158 let result = follow_redirects_raw(&executor, &url, &host, &addrs).await;
1159 assert!(result.is_ok(), "single redirect should succeed: {result:?}");
1160 assert_eq!(result.unwrap(), "<p>final</p>");
1161 }
1162
1163 #[tokio::test]
1164 async fn fetch_html_three_redirects_allowed() {
1165 use wiremock::matchers::{method, path};
1166 use wiremock::{Mock, ResponseTemplate};
1167
1168 let (executor, server) = mock_server_executor().await;
1169 let hop2 = format!("{}/hop2", server.uri());
1170 let hop3 = format!("{}/hop3", server.uri());
1171 let final_dest = format!("{}/done", server.uri());
1172
1173 Mock::given(method("GET"))
1174 .and(path("/hop1"))
1175 .respond_with(ResponseTemplate::new(301).insert_header("location", hop2.as_str()))
1176 .mount(&server)
1177 .await;
1178 Mock::given(method("GET"))
1179 .and(path("/hop2"))
1180 .respond_with(ResponseTemplate::new(301).insert_header("location", hop3.as_str()))
1181 .mount(&server)
1182 .await;
1183 Mock::given(method("GET"))
1184 .and(path("/hop3"))
1185 .respond_with(ResponseTemplate::new(301).insert_header("location", final_dest.as_str()))
1186 .mount(&server)
1187 .await;
1188 Mock::given(method("GET"))
1189 .and(path("/done"))
1190 .respond_with(ResponseTemplate::new(200).set_body_string("<p>done</p>"))
1191 .mount(&server)
1192 .await;
1193
1194 let (host, addrs) = server_host_and_addr(&server);
1195 let url = format!("{}/hop1", server.uri());
1196 let result = follow_redirects_raw(&executor, &url, &host, &addrs).await;
1197 assert!(result.is_ok(), "3 redirects should succeed: {result:?}");
1198 assert_eq!(result.unwrap(), "<p>done</p>");
1199 }
1200
1201 #[tokio::test]
1202 async fn fetch_html_four_redirects_rejected() {
1203 use wiremock::matchers::{method, path};
1204 use wiremock::{Mock, ResponseTemplate};
1205
1206 let (executor, server) = mock_server_executor().await;
1207 let hop2 = format!("{}/r2", server.uri());
1208 let hop3 = format!("{}/r3", server.uri());
1209 let hop4 = format!("{}/r4", server.uri());
1210 let hop5 = format!("{}/r5", server.uri());
1211
1212 for (from, to) in [
1213 ("/r1", &hop2),
1214 ("/r2", &hop3),
1215 ("/r3", &hop4),
1216 ("/r4", &hop5),
1217 ] {
1218 Mock::given(method("GET"))
1219 .and(path(from))
1220 .respond_with(ResponseTemplate::new(301).insert_header("location", to.as_str()))
1221 .mount(&server)
1222 .await;
1223 }
1224
1225 let (host, addrs) = server_host_and_addr(&server);
1226 let url = format!("{}/r1", server.uri());
1227 let result = follow_redirects_raw(&executor, &url, &host, &addrs).await;
1228 assert!(result.is_err(), "4 redirects should be rejected");
1229 let msg = result.unwrap_err().to_string();
1230 assert!(
1231 msg.contains("redirect"),
1232 "expected redirect-related error: {msg}"
1233 );
1234 }
1235
1236 #[tokio::test]
1237 async fn fetch_html_body_too_large_returns_error() {
1238 use wiremock::matchers::{method, path};
1239 use wiremock::{Mock, ResponseTemplate};
1240
1241 let small_limit_executor = WebScrapeExecutor {
1242 timeout: Duration::from_secs(5),
1243 max_body_bytes: 10,
1244 audit_logger: None,
1245 };
1246 let server = wiremock::MockServer::start().await;
1247 Mock::given(method("GET"))
1248 .and(path("/big"))
1249 .respond_with(
1250 ResponseTemplate::new(200)
1251 .set_body_string("this body is definitely longer than ten bytes"),
1252 )
1253 .mount(&server)
1254 .await;
1255
1256 let (host, addrs) = server_host_and_addr(&server);
1257 let url = format!("{}/big", server.uri());
1258 let result = small_limit_executor.fetch_html(&url, &host, &addrs).await;
1259 assert!(result.is_err());
1260 let msg = result.unwrap_err().to_string();
1261 assert!(msg.contains("too large"), "expected too-large error: {msg}");
1262 }
1263
1264 #[test]
1265 fn extract_scrape_blocks_empty_block_content() {
1266 let text = "```scrape\n\n```";
1267 let blocks = extract_scrape_blocks(text);
1268 assert_eq!(blocks.len(), 1);
1269 assert!(blocks[0].is_empty());
1270 }
1271
1272 #[test]
1273 fn extract_scrape_blocks_whitespace_only() {
1274 let text = "```scrape\n \n```";
1275 let blocks = extract_scrape_blocks(text);
1276 assert_eq!(blocks.len(), 1);
1277 }
1278
1279 #[test]
1280 fn parse_and_extract_multiple_selectors() {
1281 let html = "<div><h1>Title</h1><p>Para</p></div>";
1282 let result = parse_and_extract(html, "h1, p", &ExtractMode::Text, 10).unwrap();
1283 assert!(result.contains("Title"));
1284 assert!(result.contains("Para"));
1285 }
1286
1287 #[test]
1288 fn webscrape_executor_new_with_custom_config() {
1289 let config = ScrapeConfig {
1290 timeout: 60,
1291 max_body_bytes: 512,
1292 };
1293 let executor = WebScrapeExecutor::new(&config);
1294 assert_eq!(executor.max_body_bytes, 512);
1295 }
1296
1297 #[test]
1298 fn webscrape_executor_debug() {
1299 let config = ScrapeConfig::default();
1300 let executor = WebScrapeExecutor::new(&config);
1301 let dbg = format!("{executor:?}");
1302 assert!(dbg.contains("WebScrapeExecutor"));
1303 }
1304
1305 #[test]
1306 fn extract_mode_attr_empty_name() {
1307 let mode = ExtractMode::parse("attr:");
1308 assert!(matches!(mode, ExtractMode::Attr(ref s) if s.is_empty()));
1309 }
1310
1311 #[test]
1312 fn default_extract_returns_text() {
1313 assert_eq!(default_extract(), "text");
1314 }
1315
1316 #[test]
1317 fn scrape_instruction_debug() {
1318 let json = r#"{"url":"https://example.com","select":"h1"}"#;
1319 let instr: ScrapeInstruction = serde_json::from_str(json).unwrap();
1320 let dbg = format!("{instr:?}");
1321 assert!(dbg.contains("ScrapeInstruction"));
1322 }
1323
1324 #[test]
1325 fn extract_mode_debug() {
1326 let mode = ExtractMode::Text;
1327 let dbg = format!("{mode:?}");
1328 assert!(dbg.contains("Text"));
1329 }
1330
1331 #[test]
1336 fn max_redirects_constant_is_three() {
1337 const MAX_REDIRECTS: usize = 3;
1341 assert_eq!(MAX_REDIRECTS, 3, "fetch_html allows exactly 3 redirects");
1342 }
1343
1344 #[test]
1347 fn redirect_no_location_error_message() {
1348 let err = std::io::Error::other("redirect with no Location");
1349 assert!(err.to_string().contains("redirect with no Location"));
1350 }
1351
1352 #[test]
1354 fn too_many_redirects_error_message() {
1355 let err = std::io::Error::other("too many redirects");
1356 assert!(err.to_string().contains("too many redirects"));
1357 }
1358
1359 #[test]
1361 fn non_2xx_status_error_format() {
1362 let status = reqwest::StatusCode::FORBIDDEN;
1363 let msg = format!("HTTP {status}");
1364 assert!(msg.contains("403"));
1365 }
1366
1367 #[test]
1369 fn not_found_status_error_format() {
1370 let status = reqwest::StatusCode::NOT_FOUND;
1371 let msg = format!("HTTP {status}");
1372 assert!(msg.contains("404"));
1373 }
1374
1375 #[test]
1377 fn relative_redirect_same_host_path() {
1378 let base = Url::parse("https://example.com/current").unwrap();
1379 let resolved = base.join("/other").unwrap();
1380 assert_eq!(resolved.as_str(), "https://example.com/other");
1381 }
1382
1383 #[test]
1385 fn relative_redirect_relative_path() {
1386 let base = Url::parse("https://example.com/a/b").unwrap();
1387 let resolved = base.join("c").unwrap();
1388 assert_eq!(resolved.as_str(), "https://example.com/a/c");
1389 }
1390
1391 #[test]
1393 fn absolute_redirect_overrides_base() {
1394 let base = Url::parse("https://example.com/page").unwrap();
1395 let resolved = base.join("https://other.com/target").unwrap();
1396 assert_eq!(resolved.as_str(), "https://other.com/target");
1397 }
1398
1399 #[test]
1401 fn redirect_http_downgrade_rejected() {
1402 let location = "http://example.com/page";
1403 let base = Url::parse("https://example.com/start").unwrap();
1404 let next = base.join(location).unwrap();
1405 let err = validate_url(next.as_str()).unwrap_err();
1406 assert!(matches!(err, ToolError::Blocked { .. }));
1407 }
1408
1409 #[test]
1411 fn redirect_location_private_ip_blocked() {
1412 let location = "https://192.168.100.1/admin";
1413 let base = Url::parse("https://example.com/start").unwrap();
1414 let next = base.join(location).unwrap();
1415 let err = validate_url(next.as_str()).unwrap_err();
1416 assert!(matches!(err, ToolError::Blocked { .. }));
1417 let ToolError::Blocked { command: cmd } = err else {
1418 panic!("expected Blocked");
1419 };
1420 assert!(
1421 cmd.contains("private") || cmd.contains("scheme"),
1422 "error message should describe the block reason: {cmd}"
1423 );
1424 }
1425
1426 #[test]
1428 fn redirect_location_internal_domain_blocked() {
1429 let location = "https://metadata.internal/latest/meta-data/";
1430 let base = Url::parse("https://example.com/start").unwrap();
1431 let next = base.join(location).unwrap();
1432 let err = validate_url(next.as_str()).unwrap_err();
1433 assert!(matches!(err, ToolError::Blocked { .. }));
1434 }
1435
1436 #[test]
1438 fn redirect_chain_three_hops_all_public() {
1439 let hops = [
1440 "https://redirect1.example.com/hop1",
1441 "https://redirect2.example.com/hop2",
1442 "https://destination.example.com/final",
1443 ];
1444 for hop in hops {
1445 assert!(validate_url(hop).is_ok(), "expected ok for {hop}");
1446 }
1447 }
1448
1449 #[test]
1454 fn redirect_to_private_ip_rejected_by_validate_url() {
1455 let private_targets = [
1457 "https://127.0.0.1/secret",
1458 "https://10.0.0.1/internal",
1459 "https://192.168.1.1/admin",
1460 "https://172.16.0.1/data",
1461 "https://[::1]/path",
1462 "https://[fe80::1]/path",
1463 "https://localhost/path",
1464 "https://service.internal/api",
1465 ];
1466 for target in private_targets {
1467 let result = validate_url(target);
1468 assert!(result.is_err(), "expected error for {target}");
1469 assert!(
1470 matches!(result.unwrap_err(), ToolError::Blocked { .. }),
1471 "expected Blocked for {target}"
1472 );
1473 }
1474 }
1475
1476 #[test]
1478 fn redirect_relative_url_resolves_correctly() {
1479 let base = Url::parse("https://example.com/page").unwrap();
1480 let relative = "/other";
1481 let resolved = base.join(relative).unwrap();
1482 assert_eq!(resolved.as_str(), "https://example.com/other");
1483 }
1484
1485 #[test]
1487 fn redirect_to_http_rejected() {
1488 let err = validate_url("http://example.com/page").unwrap_err();
1489 assert!(matches!(err, ToolError::Blocked { .. }));
1490 }
1491
1492 #[test]
1493 fn ipv4_mapped_ipv6_link_local_blocked() {
1494 let err = validate_url("https://[::ffff:169.254.0.1]/path").unwrap_err();
1495 assert!(matches!(err, ToolError::Blocked { .. }));
1496 }
1497
1498 #[test]
1499 fn ipv4_mapped_ipv6_public_allowed() {
1500 assert!(validate_url("https://[::ffff:93.184.216.34]/path").is_ok());
1501 }
1502
1503 #[tokio::test]
1506 async fn fetch_http_scheme_blocked() {
1507 let config = ScrapeConfig::default();
1508 let executor = WebScrapeExecutor::new(&config);
1509 let call = crate::executor::ToolCall {
1510 tool_id: "fetch".to_owned(),
1511 params: {
1512 let mut m = serde_json::Map::new();
1513 m.insert("url".to_owned(), serde_json::json!("http://example.com"));
1514 m
1515 },
1516 };
1517 let result = executor.execute_tool_call(&call).await;
1518 assert!(matches!(result, Err(ToolError::Blocked { .. })));
1519 }
1520
1521 #[tokio::test]
1522 async fn fetch_private_ip_blocked() {
1523 let config = ScrapeConfig::default();
1524 let executor = WebScrapeExecutor::new(&config);
1525 let call = crate::executor::ToolCall {
1526 tool_id: "fetch".to_owned(),
1527 params: {
1528 let mut m = serde_json::Map::new();
1529 m.insert(
1530 "url".to_owned(),
1531 serde_json::json!("https://192.168.1.1/secret"),
1532 );
1533 m
1534 },
1535 };
1536 let result = executor.execute_tool_call(&call).await;
1537 assert!(matches!(result, Err(ToolError::Blocked { .. })));
1538 }
1539
1540 #[tokio::test]
1541 async fn fetch_localhost_blocked() {
1542 let config = ScrapeConfig::default();
1543 let executor = WebScrapeExecutor::new(&config);
1544 let call = crate::executor::ToolCall {
1545 tool_id: "fetch".to_owned(),
1546 params: {
1547 let mut m = serde_json::Map::new();
1548 m.insert(
1549 "url".to_owned(),
1550 serde_json::json!("https://localhost/page"),
1551 );
1552 m
1553 },
1554 };
1555 let result = executor.execute_tool_call(&call).await;
1556 assert!(matches!(result, Err(ToolError::Blocked { .. })));
1557 }
1558
1559 #[tokio::test]
1560 async fn fetch_unknown_tool_returns_none() {
1561 let config = ScrapeConfig::default();
1562 let executor = WebScrapeExecutor::new(&config);
1563 let call = crate::executor::ToolCall {
1564 tool_id: "unknown_tool".to_owned(),
1565 params: serde_json::Map::new(),
1566 };
1567 let result = executor.execute_tool_call(&call).await;
1568 assert!(result.unwrap().is_none());
1569 }
1570
1571 #[tokio::test]
1572 async fn fetch_returns_body_via_mock() {
1573 use wiremock::matchers::{method, path};
1574 use wiremock::{Mock, ResponseTemplate};
1575
1576 let (executor, server) = mock_server_executor().await;
1577 Mock::given(method("GET"))
1578 .and(path("/content"))
1579 .respond_with(ResponseTemplate::new(200).set_body_string("plain text content"))
1580 .mount(&server)
1581 .await;
1582
1583 let (host, addrs) = server_host_and_addr(&server);
1584 let url = format!("{}/content", server.uri());
1585 let result = executor.fetch_html(&url, &host, &addrs).await;
1586 assert!(result.is_ok());
1587 assert_eq!(result.unwrap(), "plain text content");
1588 }
1589
1590 #[test]
1591 fn tool_definitions_returns_web_scrape_and_fetch() {
1592 let config = ScrapeConfig::default();
1593 let executor = WebScrapeExecutor::new(&config);
1594 let defs = executor.tool_definitions();
1595 assert_eq!(defs.len(), 2);
1596 assert_eq!(defs[0].id, "web_scrape");
1597 assert_eq!(
1598 defs[0].invocation,
1599 crate::registry::InvocationHint::FencedBlock("scrape")
1600 );
1601 assert_eq!(defs[1].id, "fetch");
1602 assert_eq!(
1603 defs[1].invocation,
1604 crate::registry::InvocationHint::ToolCall
1605 );
1606 }
1607
1608 #[test]
1609 fn tool_definitions_schema_has_all_params() {
1610 let config = ScrapeConfig::default();
1611 let executor = WebScrapeExecutor::new(&config);
1612 let defs = executor.tool_definitions();
1613 let obj = defs[0].schema.as_object().unwrap();
1614 let props = obj["properties"].as_object().unwrap();
1615 assert!(props.contains_key("url"));
1616 assert!(props.contains_key("select"));
1617 assert!(props.contains_key("extract"));
1618 assert!(props.contains_key("limit"));
1619 let req = obj["required"].as_array().unwrap();
1620 assert!(req.iter().any(|v| v.as_str() == Some("url")));
1621 assert!(req.iter().any(|v| v.as_str() == Some("select")));
1622 assert!(!req.iter().any(|v| v.as_str() == Some("extract")));
1623 }
1624
1625 #[test]
1628 fn subdomain_localhost_blocked() {
1629 let host: url::Host<&str> = url::Host::Domain("foo.localhost");
1630 assert!(is_private_host(&host));
1631 }
1632
1633 #[test]
1634 fn internal_tld_blocked() {
1635 let host: url::Host<&str> = url::Host::Domain("service.internal");
1636 assert!(is_private_host(&host));
1637 }
1638
1639 #[test]
1640 fn local_tld_blocked() {
1641 let host: url::Host<&str> = url::Host::Domain("printer.local");
1642 assert!(is_private_host(&host));
1643 }
1644
1645 #[test]
1646 fn public_domain_not_blocked() {
1647 let host: url::Host<&str> = url::Host::Domain("example.com");
1648 assert!(!is_private_host(&host));
1649 }
1650
1651 #[tokio::test]
1654 async fn resolve_loopback_rejected() {
1655 let url = url::Url::parse("https://127.0.0.1/path").unwrap();
1657 let result = resolve_and_validate(&url).await;
1659 assert!(
1660 result.is_err(),
1661 "loopback IP must be rejected by resolve_and_validate"
1662 );
1663 let err = result.unwrap_err();
1664 assert!(matches!(err, crate::executor::ToolError::Blocked { .. }));
1665 }
1666
1667 #[tokio::test]
1668 async fn resolve_private_10_rejected() {
1669 let url = url::Url::parse("https://10.0.0.1/path").unwrap();
1670 let result = resolve_and_validate(&url).await;
1671 assert!(result.is_err());
1672 assert!(matches!(
1673 result.unwrap_err(),
1674 crate::executor::ToolError::Blocked { .. }
1675 ));
1676 }
1677
1678 #[tokio::test]
1679 async fn resolve_private_192_rejected() {
1680 let url = url::Url::parse("https://192.168.1.1/path").unwrap();
1681 let result = resolve_and_validate(&url).await;
1682 assert!(result.is_err());
1683 assert!(matches!(
1684 result.unwrap_err(),
1685 crate::executor::ToolError::Blocked { .. }
1686 ));
1687 }
1688
1689 #[tokio::test]
1690 async fn resolve_ipv6_loopback_rejected() {
1691 let url = url::Url::parse("https://[::1]/path").unwrap();
1692 let result = resolve_and_validate(&url).await;
1693 assert!(result.is_err());
1694 assert!(matches!(
1695 result.unwrap_err(),
1696 crate::executor::ToolError::Blocked { .. }
1697 ));
1698 }
1699
1700 #[tokio::test]
1701 async fn resolve_no_host_returns_ok() {
1702 let url = url::Url::parse("https://example.com/path").unwrap();
1704 let url_no_host = url::Url::parse("data:text/plain,hello").unwrap();
1706 let result = resolve_and_validate(&url_no_host).await;
1708 assert!(result.is_ok());
1709 let (host, addrs) = result.unwrap();
1710 assert!(host.is_empty());
1711 assert!(addrs.is_empty());
1712 drop(url);
1713 drop(url_no_host);
1714 }
1715
1716 async fn make_file_audit_logger(
1720 dir: &tempfile::TempDir,
1721 ) -> (
1722 std::sync::Arc<crate::audit::AuditLogger>,
1723 std::path::PathBuf,
1724 ) {
1725 use crate::audit::AuditLogger;
1726 use crate::config::AuditConfig;
1727 let path = dir.path().join("audit.log");
1728 let config = AuditConfig {
1729 enabled: true,
1730 destination: path.display().to_string(),
1731 };
1732 let logger = std::sync::Arc::new(AuditLogger::from_config(&config).await.unwrap());
1733 (logger, path)
1734 }
1735
1736 #[tokio::test]
1737 async fn with_audit_sets_logger() {
1738 let config = ScrapeConfig::default();
1739 let executor = WebScrapeExecutor::new(&config);
1740 assert!(executor.audit_logger.is_none());
1741
1742 let dir = tempfile::tempdir().unwrap();
1743 let (logger, _path) = make_file_audit_logger(&dir).await;
1744 let executor = executor.with_audit(logger);
1745 assert!(executor.audit_logger.is_some());
1746 }
1747
1748 #[test]
1749 fn tool_error_to_audit_result_blocked_maps_correctly() {
1750 let err = ToolError::Blocked {
1751 command: "scheme not allowed: http".into(),
1752 };
1753 let result = tool_error_to_audit_result(&err);
1754 assert!(
1755 matches!(result, AuditResult::Blocked { reason } if reason == "scheme not allowed: http")
1756 );
1757 }
1758
1759 #[test]
1760 fn tool_error_to_audit_result_timeout_maps_correctly() {
1761 let err = ToolError::Timeout { timeout_secs: 15 };
1762 let result = tool_error_to_audit_result(&err);
1763 assert!(matches!(result, AuditResult::Timeout));
1764 }
1765
1766 #[test]
1767 fn tool_error_to_audit_result_execution_error_maps_correctly() {
1768 let err = ToolError::Execution(std::io::Error::other("connection refused"));
1769 let result = tool_error_to_audit_result(&err);
1770 assert!(
1771 matches!(result, AuditResult::Error { message } if message.contains("connection refused"))
1772 );
1773 }
1774
1775 #[tokio::test]
1776 async fn fetch_audit_blocked_url_logged() {
1777 let dir = tempfile::tempdir().unwrap();
1778 let (logger, log_path) = make_file_audit_logger(&dir).await;
1779
1780 let config = ScrapeConfig::default();
1781 let executor = WebScrapeExecutor::new(&config).with_audit(logger);
1782
1783 let call = crate::executor::ToolCall {
1784 tool_id: "fetch".to_owned(),
1785 params: {
1786 let mut m = serde_json::Map::new();
1787 m.insert("url".to_owned(), serde_json::json!("http://example.com"));
1788 m
1789 },
1790 };
1791 let result = executor.execute_tool_call(&call).await;
1792 assert!(matches!(result, Err(ToolError::Blocked { .. })));
1793
1794 let content = tokio::fs::read_to_string(&log_path).await.unwrap();
1795 assert!(
1796 content.contains("\"tool\":\"fetch\""),
1797 "expected tool=fetch in audit: {content}"
1798 );
1799 assert!(
1800 content.contains("\"type\":\"blocked\""),
1801 "expected type=blocked in audit: {content}"
1802 );
1803 assert!(
1804 content.contains("http://example.com"),
1805 "expected URL in audit command field: {content}"
1806 );
1807 }
1808
1809 #[tokio::test]
1810 async fn log_audit_success_writes_to_file() {
1811 let dir = tempfile::tempdir().unwrap();
1812 let (logger, log_path) = make_file_audit_logger(&dir).await;
1813
1814 let config = ScrapeConfig::default();
1815 let executor = WebScrapeExecutor::new(&config).with_audit(logger);
1816
1817 executor
1818 .log_audit(
1819 "fetch",
1820 "https://example.com/page",
1821 AuditResult::Success,
1822 42,
1823 )
1824 .await;
1825
1826 let content = tokio::fs::read_to_string(&log_path).await.unwrap();
1827 assert!(
1828 content.contains("\"tool\":\"fetch\""),
1829 "expected tool=fetch in audit: {content}"
1830 );
1831 assert!(
1832 content.contains("\"type\":\"success\""),
1833 "expected type=success in audit: {content}"
1834 );
1835 assert!(
1836 content.contains("\"command\":\"https://example.com/page\""),
1837 "expected command URL in audit: {content}"
1838 );
1839 assert!(
1840 content.contains("\"duration_ms\":42"),
1841 "expected duration_ms in audit: {content}"
1842 );
1843 }
1844
1845 #[tokio::test]
1846 async fn log_audit_blocked_writes_to_file() {
1847 let dir = tempfile::tempdir().unwrap();
1848 let (logger, log_path) = make_file_audit_logger(&dir).await;
1849
1850 let config = ScrapeConfig::default();
1851 let executor = WebScrapeExecutor::new(&config).with_audit(logger);
1852
1853 executor
1854 .log_audit(
1855 "web_scrape",
1856 "http://evil.com/page",
1857 AuditResult::Blocked {
1858 reason: "scheme not allowed: http".into(),
1859 },
1860 0,
1861 )
1862 .await;
1863
1864 let content = tokio::fs::read_to_string(&log_path).await.unwrap();
1865 assert!(
1866 content.contains("\"tool\":\"web_scrape\""),
1867 "expected tool=web_scrape in audit: {content}"
1868 );
1869 assert!(
1870 content.contains("\"type\":\"blocked\""),
1871 "expected type=blocked in audit: {content}"
1872 );
1873 assert!(
1874 content.contains("scheme not allowed"),
1875 "expected block reason in audit: {content}"
1876 );
1877 }
1878
1879 #[tokio::test]
1880 async fn web_scrape_audit_blocked_url_logged() {
1881 let dir = tempfile::tempdir().unwrap();
1882 let (logger, log_path) = make_file_audit_logger(&dir).await;
1883
1884 let config = ScrapeConfig::default();
1885 let executor = WebScrapeExecutor::new(&config).with_audit(logger);
1886
1887 let call = crate::executor::ToolCall {
1888 tool_id: "web_scrape".to_owned(),
1889 params: {
1890 let mut m = serde_json::Map::new();
1891 m.insert("url".to_owned(), serde_json::json!("http://example.com"));
1892 m.insert("select".to_owned(), serde_json::json!("h1"));
1893 m
1894 },
1895 };
1896 let result = executor.execute_tool_call(&call).await;
1897 assert!(matches!(result, Err(ToolError::Blocked { .. })));
1898
1899 let content = tokio::fs::read_to_string(&log_path).await.unwrap();
1900 assert!(
1901 content.contains("\"tool\":\"web_scrape\""),
1902 "expected tool=web_scrape in audit: {content}"
1903 );
1904 assert!(
1905 content.contains("\"type\":\"blocked\""),
1906 "expected type=blocked in audit: {content}"
1907 );
1908 }
1909
1910 #[tokio::test]
1911 async fn no_audit_logger_does_not_panic_on_blocked_fetch() {
1912 let config = ScrapeConfig::default();
1913 let executor = WebScrapeExecutor::new(&config);
1914 assert!(executor.audit_logger.is_none());
1915
1916 let call = crate::executor::ToolCall {
1917 tool_id: "fetch".to_owned(),
1918 params: {
1919 let mut m = serde_json::Map::new();
1920 m.insert("url".to_owned(), serde_json::json!("http://example.com"));
1921 m
1922 },
1923 };
1924 let result = executor.execute_tool_call(&call).await;
1926 assert!(matches!(result, Err(ToolError::Blocked { .. })));
1927 }
1928
1929 #[tokio::test]
1931 async fn fetch_execute_tool_call_end_to_end() {
1932 use wiremock::matchers::{method, path};
1933 use wiremock::{Mock, ResponseTemplate};
1934
1935 let (executor, server) = mock_server_executor().await;
1936 Mock::given(method("GET"))
1937 .and(path("/e2e"))
1938 .respond_with(ResponseTemplate::new(200).set_body_string("<h1>end-to-end</h1>"))
1939 .mount(&server)
1940 .await;
1941
1942 let (host, addrs) = server_host_and_addr(&server);
1943 let result = executor
1945 .fetch_html(&format!("{}/e2e", server.uri()), &host, &addrs)
1946 .await;
1947 assert!(result.is_ok());
1948 assert!(result.unwrap().contains("end-to-end"));
1949 }
1950}