1use std::net::{IpAddr, SocketAddr};
5use std::sync::Arc;
6use std::time::{Duration, Instant};
7
8use schemars::JsonSchema;
9use serde::Deserialize;
10use url::Url;
11
12use crate::audit::{AuditEntry, AuditLogger, AuditResult, chrono_now};
13use crate::config::ScrapeConfig;
14use crate::executor::{
15 ClaimSource, ToolCall, ToolError, ToolExecutor, ToolOutput, deserialize_params,
16};
17use crate::net::is_private_ip;
18
19#[derive(Debug, Deserialize, JsonSchema)]
20struct FetchParams {
21 url: String,
23}
24
25#[derive(Debug, Deserialize, JsonSchema)]
26struct ScrapeInstruction {
27 url: String,
29 select: String,
31 #[serde(default = "default_extract")]
33 extract: String,
34 limit: Option<usize>,
36}
37
38fn default_extract() -> String {
39 "text".into()
40}
41
42#[derive(Debug)]
43enum ExtractMode {
44 Text,
45 Html,
46 Attr(String),
47}
48
49impl ExtractMode {
50 fn parse(s: &str) -> Self {
51 match s {
52 "text" => Self::Text,
53 "html" => Self::Html,
54 attr if attr.starts_with("attr:") => {
55 Self::Attr(attr.strip_prefix("attr:").unwrap_or(attr).to_owned())
56 }
57 _ => Self::Text,
58 }
59 }
60}
61
62#[derive(Debug)]
67pub struct WebScrapeExecutor {
68 timeout: Duration,
69 max_body_bytes: usize,
70 audit_logger: Option<Arc<AuditLogger>>,
71}
72
73impl WebScrapeExecutor {
74 #[must_use]
75 pub fn new(config: &ScrapeConfig) -> Self {
76 Self {
77 timeout: Duration::from_secs(config.timeout),
78 max_body_bytes: config.max_body_bytes,
79 audit_logger: None,
80 }
81 }
82
83 #[must_use]
84 pub fn with_audit(mut self, logger: Arc<AuditLogger>) -> Self {
85 self.audit_logger = Some(logger);
86 self
87 }
88
89 fn build_client(&self, host: &str, addrs: &[SocketAddr]) -> reqwest::Client {
90 let mut builder = reqwest::Client::builder()
91 .timeout(self.timeout)
92 .redirect(reqwest::redirect::Policy::none());
93 builder = builder.resolve_to_addrs(host, addrs);
94 builder.build().unwrap_or_default()
95 }
96}
97
98impl ToolExecutor for WebScrapeExecutor {
99 fn tool_definitions(&self) -> Vec<crate::registry::ToolDef> {
100 use crate::registry::{InvocationHint, ToolDef};
101 vec![
102 ToolDef {
103 id: "web_scrape".into(),
104 description: "Extract structured data from a web page using CSS selectors.\n\nONLY call this tool when the user has explicitly provided a URL in their message, or when a prior tool call returned a URL to retrieve. NEVER construct, guess, or infer a URL from entity names, brand knowledge, or domain patterns.\n\nParameters: url (string, required) - HTTPS URL; select (string, required) - CSS selector; extract (string, optional) - \"text\", \"html\", or \"attr:<name>\"; limit (integer, optional) - max results\nReturns: extracted text/HTML/attribute values, one per line\nErrors: InvalidParams if URL is not HTTPS or selector is empty; Timeout after configured seconds; connection/DNS failures".into(),
105 schema: schemars::schema_for!(ScrapeInstruction),
106 invocation: InvocationHint::FencedBlock("scrape"),
107 },
108 ToolDef {
109 id: "fetch".into(),
110 description: "Fetch a URL and return the response body as plain text.\n\nONLY call this tool when the user has explicitly provided a URL in their message, or when a prior tool call returned a URL to retrieve. NEVER construct, guess, or infer a URL from entity names, brand knowledge, or domain patterns. If no URL is present in the conversation, do not call this tool.\n\nParameters: url (string, required) - HTTPS URL to fetch\nReturns: response body as UTF-8 text, truncated if exceeding max body size\nErrors: InvalidParams if URL is not HTTPS; Timeout; SSRF-blocked private IPs; connection failures".into(),
111 schema: schemars::schema_for!(FetchParams),
112 invocation: InvocationHint::ToolCall,
113 },
114 ]
115 }
116
117 async fn execute(&self, response: &str) -> Result<Option<ToolOutput>, ToolError> {
118 let blocks = extract_scrape_blocks(response);
119 if blocks.is_empty() {
120 return Ok(None);
121 }
122
123 let mut outputs = Vec::with_capacity(blocks.len());
124 #[allow(clippy::cast_possible_truncation)]
125 let blocks_executed = blocks.len() as u32;
126
127 for block in &blocks {
128 let instruction: ScrapeInstruction = serde_json::from_str(block).map_err(|e| {
129 ToolError::Execution(std::io::Error::new(
130 std::io::ErrorKind::InvalidData,
131 e.to_string(),
132 ))
133 })?;
134 let start = Instant::now();
135 let scrape_result = self.scrape_instruction(&instruction).await;
136 #[allow(clippy::cast_possible_truncation)]
137 let duration_ms = start.elapsed().as_millis() as u64;
138 match scrape_result {
139 Ok(output) => {
140 self.log_audit(
141 "web_scrape",
142 &instruction.url,
143 AuditResult::Success,
144 duration_ms,
145 None,
146 )
147 .await;
148 outputs.push(output);
149 }
150 Err(e) => {
151 let audit_result = tool_error_to_audit_result(&e);
152 self.log_audit(
153 "web_scrape",
154 &instruction.url,
155 audit_result,
156 duration_ms,
157 Some(&e),
158 )
159 .await;
160 return Err(e);
161 }
162 }
163 }
164
165 Ok(Some(ToolOutput {
166 tool_name: "web-scrape".to_owned(),
167 summary: outputs.join("\n\n"),
168 blocks_executed,
169 filter_stats: None,
170 diff: None,
171 streamed: false,
172 terminal_id: None,
173 locations: None,
174 raw_response: None,
175 claim_source: Some(ClaimSource::WebScrape),
176 }))
177 }
178
179 async fn execute_tool_call(&self, call: &ToolCall) -> Result<Option<ToolOutput>, ToolError> {
180 match call.tool_id.as_str() {
181 "web_scrape" => {
182 let instruction: ScrapeInstruction = deserialize_params(&call.params)?;
183 let start = Instant::now();
184 let result = self.scrape_instruction(&instruction).await;
185 #[allow(clippy::cast_possible_truncation)]
186 let duration_ms = start.elapsed().as_millis() as u64;
187 match result {
188 Ok(output) => {
189 self.log_audit(
190 "web_scrape",
191 &instruction.url,
192 AuditResult::Success,
193 duration_ms,
194 None,
195 )
196 .await;
197 Ok(Some(ToolOutput {
198 tool_name: "web-scrape".to_owned(),
199 summary: output,
200 blocks_executed: 1,
201 filter_stats: None,
202 diff: None,
203 streamed: false,
204 terminal_id: None,
205 locations: None,
206 raw_response: None,
207 claim_source: Some(ClaimSource::WebScrape),
208 }))
209 }
210 Err(e) => {
211 let audit_result = tool_error_to_audit_result(&e);
212 self.log_audit(
213 "web_scrape",
214 &instruction.url,
215 audit_result,
216 duration_ms,
217 Some(&e),
218 )
219 .await;
220 Err(e)
221 }
222 }
223 }
224 "fetch" => {
225 let p: FetchParams = deserialize_params(&call.params)?;
226 let start = Instant::now();
227 let result = self.handle_fetch(&p).await;
228 #[allow(clippy::cast_possible_truncation)]
229 let duration_ms = start.elapsed().as_millis() as u64;
230 match result {
231 Ok(output) => {
232 self.log_audit("fetch", &p.url, AuditResult::Success, duration_ms, None)
233 .await;
234 Ok(Some(ToolOutput {
235 tool_name: "fetch".to_owned(),
236 summary: output,
237 blocks_executed: 1,
238 filter_stats: None,
239 diff: None,
240 streamed: false,
241 terminal_id: None,
242 locations: None,
243 raw_response: None,
244 claim_source: Some(ClaimSource::WebScrape),
245 }))
246 }
247 Err(e) => {
248 let audit_result = tool_error_to_audit_result(&e);
249 self.log_audit("fetch", &p.url, audit_result, duration_ms, Some(&e))
250 .await;
251 Err(e)
252 }
253 }
254 }
255 _ => Ok(None),
256 }
257 }
258
259 fn is_tool_retryable(&self, tool_id: &str) -> bool {
260 matches!(tool_id, "web_scrape" | "fetch")
261 }
262}
263
264fn tool_error_to_audit_result(e: &ToolError) -> AuditResult {
265 match e {
266 ToolError::Blocked { command } => AuditResult::Blocked {
267 reason: command.clone(),
268 },
269 ToolError::Timeout { .. } => AuditResult::Timeout,
270 _ => AuditResult::Error {
271 message: e.to_string(),
272 },
273 }
274}
275
276impl WebScrapeExecutor {
277 async fn log_audit(
278 &self,
279 tool: &str,
280 command: &str,
281 result: AuditResult,
282 duration_ms: u64,
283 error: Option<&ToolError>,
284 ) {
285 if let Some(ref logger) = self.audit_logger {
286 let (error_category, error_domain, error_phase) =
287 error.map_or((None, None, None), |e| {
288 let cat = e.category();
289 (
290 Some(cat.label().to_owned()),
291 Some(cat.domain().label().to_owned()),
292 Some(cat.phase().label().to_owned()),
293 )
294 });
295 let entry = AuditEntry {
296 timestamp: chrono_now(),
297 tool: tool.into(),
298 command: command.into(),
299 result,
300 duration_ms,
301 error_category,
302 error_domain,
303 error_phase,
304 claim_source: Some(ClaimSource::WebScrape),
305 mcp_server_id: None,
306 injection_flagged: false,
307 embedding_anomalous: false,
308 cross_boundary_mcp_to_acp: false,
309 adversarial_policy_decision: None,
310 exit_code: None,
311 truncated: false,
312 };
313 logger.log(&entry).await;
314 }
315 }
316
317 async fn handle_fetch(&self, params: &FetchParams) -> Result<String, ToolError> {
318 let parsed = validate_url(¶ms.url)?;
319 let (host, addrs) = resolve_and_validate(&parsed).await?;
320 self.fetch_html(¶ms.url, &host, &addrs).await
321 }
322
323 async fn scrape_instruction(
324 &self,
325 instruction: &ScrapeInstruction,
326 ) -> Result<String, ToolError> {
327 let parsed = validate_url(&instruction.url)?;
328 let (host, addrs) = resolve_and_validate(&parsed).await?;
329 let html = self.fetch_html(&instruction.url, &host, &addrs).await?;
330 let selector = instruction.select.clone();
331 let extract = ExtractMode::parse(&instruction.extract);
332 let limit = instruction.limit.unwrap_or(10);
333 tokio::task::spawn_blocking(move || parse_and_extract(&html, &selector, &extract, limit))
334 .await
335 .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())))?
336 }
337
338 async fn fetch_html(
348 &self,
349 url: &str,
350 host: &str,
351 addrs: &[SocketAddr],
352 ) -> Result<String, ToolError> {
353 const MAX_REDIRECTS: usize = 3;
354
355 let mut current_url = url.to_owned();
356 let mut current_host = host.to_owned();
357 let mut current_addrs = addrs.to_vec();
358
359 for hop in 0..=MAX_REDIRECTS {
360 let client = self.build_client(¤t_host, ¤t_addrs);
362 let resp = client
363 .get(¤t_url)
364 .send()
365 .await
366 .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())))?;
367
368 let status = resp.status();
369
370 if status.is_redirection() {
371 if hop == MAX_REDIRECTS {
372 return Err(ToolError::Execution(std::io::Error::other(
373 "too many redirects",
374 )));
375 }
376
377 let location = resp
378 .headers()
379 .get(reqwest::header::LOCATION)
380 .and_then(|v| v.to_str().ok())
381 .ok_or_else(|| {
382 ToolError::Execution(std::io::Error::other("redirect with no Location"))
383 })?;
384
385 let base = Url::parse(¤t_url)
387 .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())))?;
388 let next_url = base
389 .join(location)
390 .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())))?;
391
392 let validated = validate_url(next_url.as_str())?;
393 let (next_host, next_addrs) = resolve_and_validate(&validated).await?;
394
395 current_url = next_url.to_string();
396 current_host = next_host;
397 current_addrs = next_addrs;
398 continue;
399 }
400
401 if !status.is_success() {
402 return Err(ToolError::Http {
403 status: status.as_u16(),
404 message: status.canonical_reason().unwrap_or("unknown").to_owned(),
405 });
406 }
407
408 let bytes = resp
409 .bytes()
410 .await
411 .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())))?;
412
413 if bytes.len() > self.max_body_bytes {
414 return Err(ToolError::Execution(std::io::Error::other(format!(
415 "response too large: {} bytes (max: {})",
416 bytes.len(),
417 self.max_body_bytes,
418 ))));
419 }
420
421 return String::from_utf8(bytes.to_vec())
422 .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())));
423 }
424
425 Err(ToolError::Execution(std::io::Error::other(
426 "too many redirects",
427 )))
428 }
429}
430
431fn extract_scrape_blocks(text: &str) -> Vec<&str> {
432 crate::executor::extract_fenced_blocks(text, "scrape")
433}
434
435fn validate_url(raw: &str) -> Result<Url, ToolError> {
436 let parsed = Url::parse(raw).map_err(|_| ToolError::Blocked {
437 command: format!("invalid URL: {raw}"),
438 })?;
439
440 if parsed.scheme() != "https" {
441 return Err(ToolError::Blocked {
442 command: format!("scheme not allowed: {}", parsed.scheme()),
443 });
444 }
445
446 if let Some(host) = parsed.host()
447 && is_private_host(&host)
448 {
449 return Err(ToolError::Blocked {
450 command: format!(
451 "private/local host blocked: {}",
452 parsed.host_str().unwrap_or("")
453 ),
454 });
455 }
456
457 Ok(parsed)
458}
459
460fn is_private_host(host: &url::Host<&str>) -> bool {
461 match host {
462 url::Host::Domain(d) => {
463 #[allow(clippy::case_sensitive_file_extension_comparisons)]
466 {
467 *d == "localhost"
468 || d.ends_with(".localhost")
469 || d.ends_with(".internal")
470 || d.ends_with(".local")
471 }
472 }
473 url::Host::Ipv4(v4) => is_private_ip(IpAddr::V4(*v4)),
474 url::Host::Ipv6(v6) => is_private_ip(IpAddr::V6(*v6)),
475 }
476}
477
478async fn resolve_and_validate(url: &Url) -> Result<(String, Vec<SocketAddr>), ToolError> {
484 let Some(host) = url.host_str() else {
485 return Ok((String::new(), vec![]));
486 };
487 let port = url.port_or_known_default().unwrap_or(443);
488 let addrs: Vec<SocketAddr> = tokio::net::lookup_host(format!("{host}:{port}"))
489 .await
490 .map_err(|e| ToolError::Blocked {
491 command: format!("DNS resolution failed: {e}"),
492 })?
493 .collect();
494 for addr in &addrs {
495 if is_private_ip(addr.ip()) {
496 return Err(ToolError::Blocked {
497 command: format!("SSRF protection: private IP {} for host {host}", addr.ip()),
498 });
499 }
500 }
501 Ok((host.to_owned(), addrs))
502}
503
504fn parse_and_extract(
505 html: &str,
506 selector: &str,
507 extract: &ExtractMode,
508 limit: usize,
509) -> Result<String, ToolError> {
510 let soup = scrape_core::Soup::parse(html);
511
512 let tags = soup.find_all(selector).map_err(|e| {
513 ToolError::Execution(std::io::Error::new(
514 std::io::ErrorKind::InvalidData,
515 format!("invalid selector: {e}"),
516 ))
517 })?;
518
519 let mut results = Vec::new();
520
521 for tag in tags.into_iter().take(limit) {
522 let value = match extract {
523 ExtractMode::Text => tag.text(),
524 ExtractMode::Html => tag.inner_html(),
525 ExtractMode::Attr(name) => tag.get(name).unwrap_or_default().to_owned(),
526 };
527 if !value.trim().is_empty() {
528 results.push(value.trim().to_owned());
529 }
530 }
531
532 if results.is_empty() {
533 Ok(format!("No results for selector: {selector}"))
534 } else {
535 Ok(results.join("\n"))
536 }
537}
538
539#[cfg(test)]
540mod tests {
541 use super::*;
542
543 #[test]
546 fn extract_single_block() {
547 let text =
548 "Here:\n```scrape\n{\"url\":\"https://example.com\",\"select\":\"h1\"}\n```\nDone.";
549 let blocks = extract_scrape_blocks(text);
550 assert_eq!(blocks.len(), 1);
551 assert!(blocks[0].contains("example.com"));
552 }
553
554 #[test]
555 fn extract_multiple_blocks() {
556 let text = "```scrape\n{\"url\":\"https://a.com\",\"select\":\"h1\"}\n```\ntext\n```scrape\n{\"url\":\"https://b.com\",\"select\":\"p\"}\n```";
557 let blocks = extract_scrape_blocks(text);
558 assert_eq!(blocks.len(), 2);
559 }
560
561 #[test]
562 fn no_blocks_returns_empty() {
563 let blocks = extract_scrape_blocks("plain text, no code blocks");
564 assert!(blocks.is_empty());
565 }
566
567 #[test]
568 fn unclosed_block_ignored() {
569 let blocks = extract_scrape_blocks("```scrape\n{\"url\":\"https://x.com\"}");
570 assert!(blocks.is_empty());
571 }
572
573 #[test]
574 fn non_scrape_block_ignored() {
575 let text =
576 "```bash\necho hi\n```\n```scrape\n{\"url\":\"https://x.com\",\"select\":\"h1\"}\n```";
577 let blocks = extract_scrape_blocks(text);
578 assert_eq!(blocks.len(), 1);
579 assert!(blocks[0].contains("x.com"));
580 }
581
582 #[test]
583 fn multiline_json_block() {
584 let text =
585 "```scrape\n{\n \"url\": \"https://example.com\",\n \"select\": \"h1\"\n}\n```";
586 let blocks = extract_scrape_blocks(text);
587 assert_eq!(blocks.len(), 1);
588 let instr: ScrapeInstruction = serde_json::from_str(blocks[0]).unwrap();
589 assert_eq!(instr.url, "https://example.com");
590 }
591
592 #[test]
595 fn parse_valid_instruction() {
596 let json = r#"{"url":"https://example.com","select":"h1","extract":"text","limit":5}"#;
597 let instr: ScrapeInstruction = serde_json::from_str(json).unwrap();
598 assert_eq!(instr.url, "https://example.com");
599 assert_eq!(instr.select, "h1");
600 assert_eq!(instr.extract, "text");
601 assert_eq!(instr.limit, Some(5));
602 }
603
604 #[test]
605 fn parse_minimal_instruction() {
606 let json = r#"{"url":"https://example.com","select":"p"}"#;
607 let instr: ScrapeInstruction = serde_json::from_str(json).unwrap();
608 assert_eq!(instr.extract, "text");
609 assert!(instr.limit.is_none());
610 }
611
612 #[test]
613 fn parse_attr_extract() {
614 let json = r#"{"url":"https://example.com","select":"a","extract":"attr:href"}"#;
615 let instr: ScrapeInstruction = serde_json::from_str(json).unwrap();
616 assert_eq!(instr.extract, "attr:href");
617 }
618
619 #[test]
620 fn parse_invalid_json_errors() {
621 let result = serde_json::from_str::<ScrapeInstruction>("not json");
622 assert!(result.is_err());
623 }
624
625 #[test]
628 fn extract_mode_text() {
629 assert!(matches!(ExtractMode::parse("text"), ExtractMode::Text));
630 }
631
632 #[test]
633 fn extract_mode_html() {
634 assert!(matches!(ExtractMode::parse("html"), ExtractMode::Html));
635 }
636
637 #[test]
638 fn extract_mode_attr() {
639 let mode = ExtractMode::parse("attr:href");
640 assert!(matches!(mode, ExtractMode::Attr(ref s) if s == "href"));
641 }
642
643 #[test]
644 fn extract_mode_unknown_defaults_to_text() {
645 assert!(matches!(ExtractMode::parse("unknown"), ExtractMode::Text));
646 }
647
648 #[test]
651 fn valid_https_url() {
652 assert!(validate_url("https://example.com").is_ok());
653 }
654
655 #[test]
656 fn http_rejected() {
657 let err = validate_url("http://example.com").unwrap_err();
658 assert!(matches!(err, ToolError::Blocked { .. }));
659 }
660
661 #[test]
662 fn ftp_rejected() {
663 let err = validate_url("ftp://files.example.com").unwrap_err();
664 assert!(matches!(err, ToolError::Blocked { .. }));
665 }
666
667 #[test]
668 fn file_rejected() {
669 let err = validate_url("file:///etc/passwd").unwrap_err();
670 assert!(matches!(err, ToolError::Blocked { .. }));
671 }
672
673 #[test]
674 fn invalid_url_rejected() {
675 let err = validate_url("not a url").unwrap_err();
676 assert!(matches!(err, ToolError::Blocked { .. }));
677 }
678
679 #[test]
680 fn localhost_blocked() {
681 let err = validate_url("https://localhost/path").unwrap_err();
682 assert!(matches!(err, ToolError::Blocked { .. }));
683 }
684
685 #[test]
686 fn loopback_ip_blocked() {
687 let err = validate_url("https://127.0.0.1/path").unwrap_err();
688 assert!(matches!(err, ToolError::Blocked { .. }));
689 }
690
691 #[test]
692 fn private_10_blocked() {
693 let err = validate_url("https://10.0.0.1/api").unwrap_err();
694 assert!(matches!(err, ToolError::Blocked { .. }));
695 }
696
697 #[test]
698 fn private_172_blocked() {
699 let err = validate_url("https://172.16.0.1/api").unwrap_err();
700 assert!(matches!(err, ToolError::Blocked { .. }));
701 }
702
703 #[test]
704 fn private_192_blocked() {
705 let err = validate_url("https://192.168.1.1/api").unwrap_err();
706 assert!(matches!(err, ToolError::Blocked { .. }));
707 }
708
709 #[test]
710 fn ipv6_loopback_blocked() {
711 let err = validate_url("https://[::1]/path").unwrap_err();
712 assert!(matches!(err, ToolError::Blocked { .. }));
713 }
714
715 #[test]
716 fn public_ip_allowed() {
717 assert!(validate_url("https://93.184.216.34/page").is_ok());
718 }
719
720 #[test]
723 fn extract_text_from_html() {
724 let html = "<html><body><h1>Hello World</h1><p>Content</p></body></html>";
725 let result = parse_and_extract(html, "h1", &ExtractMode::Text, 10).unwrap();
726 assert_eq!(result, "Hello World");
727 }
728
729 #[test]
730 fn extract_multiple_elements() {
731 let html = "<ul><li>A</li><li>B</li><li>C</li></ul>";
732 let result = parse_and_extract(html, "li", &ExtractMode::Text, 10).unwrap();
733 assert_eq!(result, "A\nB\nC");
734 }
735
736 #[test]
737 fn extract_with_limit() {
738 let html = "<ul><li>A</li><li>B</li><li>C</li></ul>";
739 let result = parse_and_extract(html, "li", &ExtractMode::Text, 2).unwrap();
740 assert_eq!(result, "A\nB");
741 }
742
743 #[test]
744 fn extract_attr_href() {
745 let html = r#"<a href="https://example.com">Link</a>"#;
746 let result =
747 parse_and_extract(html, "a", &ExtractMode::Attr("href".to_owned()), 10).unwrap();
748 assert_eq!(result, "https://example.com");
749 }
750
751 #[test]
752 fn extract_inner_html() {
753 let html = "<div><span>inner</span></div>";
754 let result = parse_and_extract(html, "div", &ExtractMode::Html, 10).unwrap();
755 assert!(result.contains("<span>inner</span>"));
756 }
757
758 #[test]
759 fn no_matches_returns_message() {
760 let html = "<html><body><p>text</p></body></html>";
761 let result = parse_and_extract(html, "h1", &ExtractMode::Text, 10).unwrap();
762 assert!(result.starts_with("No results for selector:"));
763 }
764
765 #[test]
766 fn empty_text_skipped() {
767 let html = "<ul><li> </li><li>A</li></ul>";
768 let result = parse_and_extract(html, "li", &ExtractMode::Text, 10).unwrap();
769 assert_eq!(result, "A");
770 }
771
772 #[test]
773 fn invalid_selector_errors() {
774 let html = "<html><body></body></html>";
775 let result = parse_and_extract(html, "[[[invalid", &ExtractMode::Text, 10);
776 assert!(result.is_err());
777 }
778
779 #[test]
780 fn empty_html_returns_no_results() {
781 let result = parse_and_extract("", "h1", &ExtractMode::Text, 10).unwrap();
782 assert!(result.starts_with("No results for selector:"));
783 }
784
785 #[test]
786 fn nested_selector() {
787 let html = "<div><span>inner</span></div><span>outer</span>";
788 let result = parse_and_extract(html, "div > span", &ExtractMode::Text, 10).unwrap();
789 assert_eq!(result, "inner");
790 }
791
792 #[test]
793 fn attr_missing_returns_empty() {
794 let html = r"<a>No href</a>";
795 let result =
796 parse_and_extract(html, "a", &ExtractMode::Attr("href".to_owned()), 10).unwrap();
797 assert!(result.starts_with("No results for selector:"));
798 }
799
800 #[test]
801 fn extract_html_mode() {
802 let html = "<div><b>bold</b> text</div>";
803 let result = parse_and_extract(html, "div", &ExtractMode::Html, 10).unwrap();
804 assert!(result.contains("<b>bold</b>"));
805 }
806
807 #[test]
808 fn limit_zero_returns_no_results() {
809 let html = "<ul><li>A</li><li>B</li></ul>";
810 let result = parse_and_extract(html, "li", &ExtractMode::Text, 0).unwrap();
811 assert!(result.starts_with("No results for selector:"));
812 }
813
814 #[test]
817 fn url_with_port_allowed() {
818 assert!(validate_url("https://example.com:8443/path").is_ok());
819 }
820
821 #[test]
822 fn link_local_ip_blocked() {
823 let err = validate_url("https://169.254.1.1/path").unwrap_err();
824 assert!(matches!(err, ToolError::Blocked { .. }));
825 }
826
827 #[test]
828 fn url_no_scheme_rejected() {
829 let err = validate_url("example.com/path").unwrap_err();
830 assert!(matches!(err, ToolError::Blocked { .. }));
831 }
832
833 #[test]
834 fn unspecified_ipv4_blocked() {
835 let err = validate_url("https://0.0.0.0/path").unwrap_err();
836 assert!(matches!(err, ToolError::Blocked { .. }));
837 }
838
839 #[test]
840 fn broadcast_ipv4_blocked() {
841 let err = validate_url("https://255.255.255.255/path").unwrap_err();
842 assert!(matches!(err, ToolError::Blocked { .. }));
843 }
844
845 #[test]
846 fn ipv6_link_local_blocked() {
847 let err = validate_url("https://[fe80::1]/path").unwrap_err();
848 assert!(matches!(err, ToolError::Blocked { .. }));
849 }
850
851 #[test]
852 fn ipv6_unique_local_blocked() {
853 let err = validate_url("https://[fd12::1]/path").unwrap_err();
854 assert!(matches!(err, ToolError::Blocked { .. }));
855 }
856
857 #[test]
858 fn ipv4_mapped_ipv6_loopback_blocked() {
859 let err = validate_url("https://[::ffff:127.0.0.1]/path").unwrap_err();
860 assert!(matches!(err, ToolError::Blocked { .. }));
861 }
862
863 #[test]
864 fn ipv4_mapped_ipv6_private_blocked() {
865 let err = validate_url("https://[::ffff:10.0.0.1]/path").unwrap_err();
866 assert!(matches!(err, ToolError::Blocked { .. }));
867 }
868
869 #[tokio::test]
872 async fn executor_no_blocks_returns_none() {
873 let config = ScrapeConfig::default();
874 let executor = WebScrapeExecutor::new(&config);
875 let result = executor.execute("plain text").await;
876 assert!(result.unwrap().is_none());
877 }
878
879 #[tokio::test]
880 async fn executor_invalid_json_errors() {
881 let config = ScrapeConfig::default();
882 let executor = WebScrapeExecutor::new(&config);
883 let response = "```scrape\nnot json\n```";
884 let result = executor.execute(response).await;
885 assert!(matches!(result, Err(ToolError::Execution(_))));
886 }
887
888 #[tokio::test]
889 async fn executor_blocked_url_errors() {
890 let config = ScrapeConfig::default();
891 let executor = WebScrapeExecutor::new(&config);
892 let response = "```scrape\n{\"url\":\"http://example.com\",\"select\":\"h1\"}\n```";
893 let result = executor.execute(response).await;
894 assert!(matches!(result, Err(ToolError::Blocked { .. })));
895 }
896
897 #[tokio::test]
898 async fn executor_private_ip_blocked() {
899 let config = ScrapeConfig::default();
900 let executor = WebScrapeExecutor::new(&config);
901 let response = "```scrape\n{\"url\":\"https://192.168.1.1/api\",\"select\":\"h1\"}\n```";
902 let result = executor.execute(response).await;
903 assert!(matches!(result, Err(ToolError::Blocked { .. })));
904 }
905
906 #[tokio::test]
907 async fn executor_unreachable_host_returns_error() {
908 let config = ScrapeConfig {
909 timeout: 1,
910 max_body_bytes: 1_048_576,
911 };
912 let executor = WebScrapeExecutor::new(&config);
913 let response = "```scrape\n{\"url\":\"https://192.0.2.1:1/page\",\"select\":\"h1\"}\n```";
914 let result = executor.execute(response).await;
915 assert!(matches!(result, Err(ToolError::Execution(_))));
916 }
917
918 #[tokio::test]
919 async fn executor_localhost_url_blocked() {
920 let config = ScrapeConfig::default();
921 let executor = WebScrapeExecutor::new(&config);
922 let response = "```scrape\n{\"url\":\"https://localhost:9999/api\",\"select\":\"h1\"}\n```";
923 let result = executor.execute(response).await;
924 assert!(matches!(result, Err(ToolError::Blocked { .. })));
925 }
926
927 #[tokio::test]
928 async fn executor_empty_text_returns_none() {
929 let config = ScrapeConfig::default();
930 let executor = WebScrapeExecutor::new(&config);
931 let result = executor.execute("").await;
932 assert!(result.unwrap().is_none());
933 }
934
935 #[tokio::test]
936 async fn executor_multiple_blocks_first_blocked() {
937 let config = ScrapeConfig::default();
938 let executor = WebScrapeExecutor::new(&config);
939 let response = "```scrape\n{\"url\":\"http://evil.com\",\"select\":\"h1\"}\n```\n\
940 ```scrape\n{\"url\":\"https://ok.com\",\"select\":\"h1\"}\n```";
941 let result = executor.execute(response).await;
942 assert!(result.is_err());
943 }
944
945 #[test]
946 fn validate_url_empty_string() {
947 let err = validate_url("").unwrap_err();
948 assert!(matches!(err, ToolError::Blocked { .. }));
949 }
950
951 #[test]
952 fn validate_url_javascript_scheme_blocked() {
953 let err = validate_url("javascript:alert(1)").unwrap_err();
954 assert!(matches!(err, ToolError::Blocked { .. }));
955 }
956
957 #[test]
958 fn validate_url_data_scheme_blocked() {
959 let err = validate_url("data:text/html,<h1>hi</h1>").unwrap_err();
960 assert!(matches!(err, ToolError::Blocked { .. }));
961 }
962
963 #[test]
964 fn is_private_host_public_domain_is_false() {
965 let host: url::Host<&str> = url::Host::Domain("example.com");
966 assert!(!is_private_host(&host));
967 }
968
969 #[test]
970 fn is_private_host_localhost_is_true() {
971 let host: url::Host<&str> = url::Host::Domain("localhost");
972 assert!(is_private_host(&host));
973 }
974
975 #[test]
976 fn is_private_host_ipv6_unspecified_is_true() {
977 let host = url::Host::Ipv6(std::net::Ipv6Addr::UNSPECIFIED);
978 assert!(is_private_host(&host));
979 }
980
981 #[test]
982 fn is_private_host_public_ipv6_is_false() {
983 let host = url::Host::Ipv6("2001:db8::1".parse().unwrap());
984 assert!(!is_private_host(&host));
985 }
986
987 async fn mock_server_executor() -> (WebScrapeExecutor, wiremock::MockServer) {
998 let server = wiremock::MockServer::start().await;
999 let executor = WebScrapeExecutor {
1000 timeout: Duration::from_secs(5),
1001 max_body_bytes: 1_048_576,
1002 audit_logger: None,
1003 };
1004 (executor, server)
1005 }
1006
1007 fn server_host_and_addr(server: &wiremock::MockServer) -> (String, Vec<std::net::SocketAddr>) {
1009 let uri = server.uri();
1010 let url = Url::parse(&uri).unwrap();
1011 let host = url.host_str().unwrap_or("127.0.0.1").to_owned();
1012 let port = url.port().unwrap_or(80);
1013 let addr: std::net::SocketAddr = format!("{host}:{port}").parse().unwrap();
1014 (host, vec![addr])
1015 }
1016
1017 async fn follow_redirects_raw(
1021 executor: &WebScrapeExecutor,
1022 start_url: &str,
1023 host: &str,
1024 addrs: &[std::net::SocketAddr],
1025 ) -> Result<String, ToolError> {
1026 const MAX_REDIRECTS: usize = 3;
1027 let mut current_url = start_url.to_owned();
1028 let mut current_host = host.to_owned();
1029 let mut current_addrs = addrs.to_vec();
1030
1031 for hop in 0..=MAX_REDIRECTS {
1032 let client = executor.build_client(¤t_host, ¤t_addrs);
1033 let resp = client
1034 .get(¤t_url)
1035 .send()
1036 .await
1037 .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())))?;
1038
1039 let status = resp.status();
1040
1041 if status.is_redirection() {
1042 if hop == MAX_REDIRECTS {
1043 return Err(ToolError::Execution(std::io::Error::other(
1044 "too many redirects",
1045 )));
1046 }
1047
1048 let location = resp
1049 .headers()
1050 .get(reqwest::header::LOCATION)
1051 .and_then(|v| v.to_str().ok())
1052 .ok_or_else(|| {
1053 ToolError::Execution(std::io::Error::other("redirect with no Location"))
1054 })?;
1055
1056 let base = Url::parse(¤t_url)
1057 .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())))?;
1058 let next_url = base
1059 .join(location)
1060 .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())))?;
1061
1062 current_url = next_url.to_string();
1064 let _ = &mut current_host;
1066 let _ = &mut current_addrs;
1067 continue;
1068 }
1069
1070 if !status.is_success() {
1071 return Err(ToolError::Execution(std::io::Error::other(format!(
1072 "HTTP {status}",
1073 ))));
1074 }
1075
1076 let bytes = resp
1077 .bytes()
1078 .await
1079 .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())))?;
1080
1081 if bytes.len() > executor.max_body_bytes {
1082 return Err(ToolError::Execution(std::io::Error::other(format!(
1083 "response too large: {} bytes (max: {})",
1084 bytes.len(),
1085 executor.max_body_bytes,
1086 ))));
1087 }
1088
1089 return String::from_utf8(bytes.to_vec())
1090 .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())));
1091 }
1092
1093 Err(ToolError::Execution(std::io::Error::other(
1094 "too many redirects",
1095 )))
1096 }
1097
1098 #[tokio::test]
1099 async fn fetch_html_success_returns_body() {
1100 use wiremock::matchers::{method, path};
1101 use wiremock::{Mock, ResponseTemplate};
1102
1103 let (executor, server) = mock_server_executor().await;
1104 Mock::given(method("GET"))
1105 .and(path("/page"))
1106 .respond_with(ResponseTemplate::new(200).set_body_string("<h1>OK</h1>"))
1107 .mount(&server)
1108 .await;
1109
1110 let (host, addrs) = server_host_and_addr(&server);
1111 let url = format!("{}/page", server.uri());
1112 let result = executor.fetch_html(&url, &host, &addrs).await;
1113 assert!(result.is_ok(), "expected Ok, got: {result:?}");
1114 assert_eq!(result.unwrap(), "<h1>OK</h1>");
1115 }
1116
1117 #[tokio::test]
1118 async fn fetch_html_non_2xx_returns_error() {
1119 use wiremock::matchers::{method, path};
1120 use wiremock::{Mock, ResponseTemplate};
1121
1122 let (executor, server) = mock_server_executor().await;
1123 Mock::given(method("GET"))
1124 .and(path("/forbidden"))
1125 .respond_with(ResponseTemplate::new(403))
1126 .mount(&server)
1127 .await;
1128
1129 let (host, addrs) = server_host_and_addr(&server);
1130 let url = format!("{}/forbidden", server.uri());
1131 let result = executor.fetch_html(&url, &host, &addrs).await;
1132 assert!(result.is_err());
1133 let msg = result.unwrap_err().to_string();
1134 assert!(msg.contains("403"), "expected 403 in error: {msg}");
1135 }
1136
1137 #[tokio::test]
1138 async fn fetch_html_404_returns_error() {
1139 use wiremock::matchers::{method, path};
1140 use wiremock::{Mock, ResponseTemplate};
1141
1142 let (executor, server) = mock_server_executor().await;
1143 Mock::given(method("GET"))
1144 .and(path("/missing"))
1145 .respond_with(ResponseTemplate::new(404))
1146 .mount(&server)
1147 .await;
1148
1149 let (host, addrs) = server_host_and_addr(&server);
1150 let url = format!("{}/missing", server.uri());
1151 let result = executor.fetch_html(&url, &host, &addrs).await;
1152 assert!(result.is_err());
1153 let msg = result.unwrap_err().to_string();
1154 assert!(msg.contains("404"), "expected 404 in error: {msg}");
1155 }
1156
1157 #[tokio::test]
1158 async fn fetch_html_redirect_no_location_returns_error() {
1159 use wiremock::matchers::{method, path};
1160 use wiremock::{Mock, ResponseTemplate};
1161
1162 let (executor, server) = mock_server_executor().await;
1163 Mock::given(method("GET"))
1165 .and(path("/redirect-no-loc"))
1166 .respond_with(ResponseTemplate::new(302))
1167 .mount(&server)
1168 .await;
1169
1170 let (host, addrs) = server_host_and_addr(&server);
1171 let url = format!("{}/redirect-no-loc", server.uri());
1172 let result = executor.fetch_html(&url, &host, &addrs).await;
1173 assert!(result.is_err());
1174 let msg = result.unwrap_err().to_string();
1175 assert!(
1176 msg.contains("Location") || msg.contains("location"),
1177 "expected Location-related error: {msg}"
1178 );
1179 }
1180
1181 #[tokio::test]
1182 async fn fetch_html_single_redirect_followed() {
1183 use wiremock::matchers::{method, path};
1184 use wiremock::{Mock, ResponseTemplate};
1185
1186 let (executor, server) = mock_server_executor().await;
1187 let final_url = format!("{}/final", server.uri());
1188
1189 Mock::given(method("GET"))
1190 .and(path("/start"))
1191 .respond_with(ResponseTemplate::new(302).insert_header("location", final_url.as_str()))
1192 .mount(&server)
1193 .await;
1194
1195 Mock::given(method("GET"))
1196 .and(path("/final"))
1197 .respond_with(ResponseTemplate::new(200).set_body_string("<p>final</p>"))
1198 .mount(&server)
1199 .await;
1200
1201 let (host, addrs) = server_host_and_addr(&server);
1202 let url = format!("{}/start", server.uri());
1203 let result = follow_redirects_raw(&executor, &url, &host, &addrs).await;
1204 assert!(result.is_ok(), "single redirect should succeed: {result:?}");
1205 assert_eq!(result.unwrap(), "<p>final</p>");
1206 }
1207
1208 #[tokio::test]
1209 async fn fetch_html_three_redirects_allowed() {
1210 use wiremock::matchers::{method, path};
1211 use wiremock::{Mock, ResponseTemplate};
1212
1213 let (executor, server) = mock_server_executor().await;
1214 let hop2 = format!("{}/hop2", server.uri());
1215 let hop3 = format!("{}/hop3", server.uri());
1216 let final_dest = format!("{}/done", server.uri());
1217
1218 Mock::given(method("GET"))
1219 .and(path("/hop1"))
1220 .respond_with(ResponseTemplate::new(301).insert_header("location", hop2.as_str()))
1221 .mount(&server)
1222 .await;
1223 Mock::given(method("GET"))
1224 .and(path("/hop2"))
1225 .respond_with(ResponseTemplate::new(301).insert_header("location", hop3.as_str()))
1226 .mount(&server)
1227 .await;
1228 Mock::given(method("GET"))
1229 .and(path("/hop3"))
1230 .respond_with(ResponseTemplate::new(301).insert_header("location", final_dest.as_str()))
1231 .mount(&server)
1232 .await;
1233 Mock::given(method("GET"))
1234 .and(path("/done"))
1235 .respond_with(ResponseTemplate::new(200).set_body_string("<p>done</p>"))
1236 .mount(&server)
1237 .await;
1238
1239 let (host, addrs) = server_host_and_addr(&server);
1240 let url = format!("{}/hop1", server.uri());
1241 let result = follow_redirects_raw(&executor, &url, &host, &addrs).await;
1242 assert!(result.is_ok(), "3 redirects should succeed: {result:?}");
1243 assert_eq!(result.unwrap(), "<p>done</p>");
1244 }
1245
1246 #[tokio::test]
1247 async fn fetch_html_four_redirects_rejected() {
1248 use wiremock::matchers::{method, path};
1249 use wiremock::{Mock, ResponseTemplate};
1250
1251 let (executor, server) = mock_server_executor().await;
1252 let hop2 = format!("{}/r2", server.uri());
1253 let hop3 = format!("{}/r3", server.uri());
1254 let hop4 = format!("{}/r4", server.uri());
1255 let hop5 = format!("{}/r5", server.uri());
1256
1257 for (from, to) in [
1258 ("/r1", &hop2),
1259 ("/r2", &hop3),
1260 ("/r3", &hop4),
1261 ("/r4", &hop5),
1262 ] {
1263 Mock::given(method("GET"))
1264 .and(path(from))
1265 .respond_with(ResponseTemplate::new(301).insert_header("location", to.as_str()))
1266 .mount(&server)
1267 .await;
1268 }
1269
1270 let (host, addrs) = server_host_and_addr(&server);
1271 let url = format!("{}/r1", server.uri());
1272 let result = follow_redirects_raw(&executor, &url, &host, &addrs).await;
1273 assert!(result.is_err(), "4 redirects should be rejected");
1274 let msg = result.unwrap_err().to_string();
1275 assert!(
1276 msg.contains("redirect"),
1277 "expected redirect-related error: {msg}"
1278 );
1279 }
1280
1281 #[tokio::test]
1282 async fn fetch_html_body_too_large_returns_error() {
1283 use wiremock::matchers::{method, path};
1284 use wiremock::{Mock, ResponseTemplate};
1285
1286 let small_limit_executor = WebScrapeExecutor {
1287 timeout: Duration::from_secs(5),
1288 max_body_bytes: 10,
1289 audit_logger: None,
1290 };
1291 let server = wiremock::MockServer::start().await;
1292 Mock::given(method("GET"))
1293 .and(path("/big"))
1294 .respond_with(
1295 ResponseTemplate::new(200)
1296 .set_body_string("this body is definitely longer than ten bytes"),
1297 )
1298 .mount(&server)
1299 .await;
1300
1301 let (host, addrs) = server_host_and_addr(&server);
1302 let url = format!("{}/big", server.uri());
1303 let result = small_limit_executor.fetch_html(&url, &host, &addrs).await;
1304 assert!(result.is_err());
1305 let msg = result.unwrap_err().to_string();
1306 assert!(msg.contains("too large"), "expected too-large error: {msg}");
1307 }
1308
1309 #[test]
1310 fn extract_scrape_blocks_empty_block_content() {
1311 let text = "```scrape\n\n```";
1312 let blocks = extract_scrape_blocks(text);
1313 assert_eq!(blocks.len(), 1);
1314 assert!(blocks[0].is_empty());
1315 }
1316
1317 #[test]
1318 fn extract_scrape_blocks_whitespace_only() {
1319 let text = "```scrape\n \n```";
1320 let blocks = extract_scrape_blocks(text);
1321 assert_eq!(blocks.len(), 1);
1322 }
1323
1324 #[test]
1325 fn parse_and_extract_multiple_selectors() {
1326 let html = "<div><h1>Title</h1><p>Para</p></div>";
1327 let result = parse_and_extract(html, "h1, p", &ExtractMode::Text, 10).unwrap();
1328 assert!(result.contains("Title"));
1329 assert!(result.contains("Para"));
1330 }
1331
1332 #[test]
1333 fn webscrape_executor_new_with_custom_config() {
1334 let config = ScrapeConfig {
1335 timeout: 60,
1336 max_body_bytes: 512,
1337 };
1338 let executor = WebScrapeExecutor::new(&config);
1339 assert_eq!(executor.max_body_bytes, 512);
1340 }
1341
1342 #[test]
1343 fn webscrape_executor_debug() {
1344 let config = ScrapeConfig::default();
1345 let executor = WebScrapeExecutor::new(&config);
1346 let dbg = format!("{executor:?}");
1347 assert!(dbg.contains("WebScrapeExecutor"));
1348 }
1349
1350 #[test]
1351 fn extract_mode_attr_empty_name() {
1352 let mode = ExtractMode::parse("attr:");
1353 assert!(matches!(mode, ExtractMode::Attr(ref s) if s.is_empty()));
1354 }
1355
1356 #[test]
1357 fn default_extract_returns_text() {
1358 assert_eq!(default_extract(), "text");
1359 }
1360
1361 #[test]
1362 fn scrape_instruction_debug() {
1363 let json = r#"{"url":"https://example.com","select":"h1"}"#;
1364 let instr: ScrapeInstruction = serde_json::from_str(json).unwrap();
1365 let dbg = format!("{instr:?}");
1366 assert!(dbg.contains("ScrapeInstruction"));
1367 }
1368
1369 #[test]
1370 fn extract_mode_debug() {
1371 let mode = ExtractMode::Text;
1372 let dbg = format!("{mode:?}");
1373 assert!(dbg.contains("Text"));
1374 }
1375
1376 #[test]
1381 fn max_redirects_constant_is_three() {
1382 const MAX_REDIRECTS: usize = 3;
1386 assert_eq!(MAX_REDIRECTS, 3, "fetch_html allows exactly 3 redirects");
1387 }
1388
1389 #[test]
1392 fn redirect_no_location_error_message() {
1393 let err = std::io::Error::other("redirect with no Location");
1394 assert!(err.to_string().contains("redirect with no Location"));
1395 }
1396
1397 #[test]
1399 fn too_many_redirects_error_message() {
1400 let err = std::io::Error::other("too many redirects");
1401 assert!(err.to_string().contains("too many redirects"));
1402 }
1403
1404 #[test]
1406 fn non_2xx_status_error_format() {
1407 let status = reqwest::StatusCode::FORBIDDEN;
1408 let msg = format!("HTTP {status}");
1409 assert!(msg.contains("403"));
1410 }
1411
1412 #[test]
1414 fn not_found_status_error_format() {
1415 let status = reqwest::StatusCode::NOT_FOUND;
1416 let msg = format!("HTTP {status}");
1417 assert!(msg.contains("404"));
1418 }
1419
1420 #[test]
1422 fn relative_redirect_same_host_path() {
1423 let base = Url::parse("https://example.com/current").unwrap();
1424 let resolved = base.join("/other").unwrap();
1425 assert_eq!(resolved.as_str(), "https://example.com/other");
1426 }
1427
1428 #[test]
1430 fn relative_redirect_relative_path() {
1431 let base = Url::parse("https://example.com/a/b").unwrap();
1432 let resolved = base.join("c").unwrap();
1433 assert_eq!(resolved.as_str(), "https://example.com/a/c");
1434 }
1435
1436 #[test]
1438 fn absolute_redirect_overrides_base() {
1439 let base = Url::parse("https://example.com/page").unwrap();
1440 let resolved = base.join("https://other.com/target").unwrap();
1441 assert_eq!(resolved.as_str(), "https://other.com/target");
1442 }
1443
1444 #[test]
1446 fn redirect_http_downgrade_rejected() {
1447 let location = "http://example.com/page";
1448 let base = Url::parse("https://example.com/start").unwrap();
1449 let next = base.join(location).unwrap();
1450 let err = validate_url(next.as_str()).unwrap_err();
1451 assert!(matches!(err, ToolError::Blocked { .. }));
1452 }
1453
1454 #[test]
1456 fn redirect_location_private_ip_blocked() {
1457 let location = "https://192.168.100.1/admin";
1458 let base = Url::parse("https://example.com/start").unwrap();
1459 let next = base.join(location).unwrap();
1460 let err = validate_url(next.as_str()).unwrap_err();
1461 assert!(matches!(err, ToolError::Blocked { .. }));
1462 let ToolError::Blocked { command: cmd } = err else {
1463 panic!("expected Blocked");
1464 };
1465 assert!(
1466 cmd.contains("private") || cmd.contains("scheme"),
1467 "error message should describe the block reason: {cmd}"
1468 );
1469 }
1470
1471 #[test]
1473 fn redirect_location_internal_domain_blocked() {
1474 let location = "https://metadata.internal/latest/meta-data/";
1475 let base = Url::parse("https://example.com/start").unwrap();
1476 let next = base.join(location).unwrap();
1477 let err = validate_url(next.as_str()).unwrap_err();
1478 assert!(matches!(err, ToolError::Blocked { .. }));
1479 }
1480
1481 #[test]
1483 fn redirect_chain_three_hops_all_public() {
1484 let hops = [
1485 "https://redirect1.example.com/hop1",
1486 "https://redirect2.example.com/hop2",
1487 "https://destination.example.com/final",
1488 ];
1489 for hop in hops {
1490 assert!(validate_url(hop).is_ok(), "expected ok for {hop}");
1491 }
1492 }
1493
1494 #[test]
1499 fn redirect_to_private_ip_rejected_by_validate_url() {
1500 let private_targets = [
1502 "https://127.0.0.1/secret",
1503 "https://10.0.0.1/internal",
1504 "https://192.168.1.1/admin",
1505 "https://172.16.0.1/data",
1506 "https://[::1]/path",
1507 "https://[fe80::1]/path",
1508 "https://localhost/path",
1509 "https://service.internal/api",
1510 ];
1511 for target in private_targets {
1512 let result = validate_url(target);
1513 assert!(result.is_err(), "expected error for {target}");
1514 assert!(
1515 matches!(result.unwrap_err(), ToolError::Blocked { .. }),
1516 "expected Blocked for {target}"
1517 );
1518 }
1519 }
1520
1521 #[test]
1523 fn redirect_relative_url_resolves_correctly() {
1524 let base = Url::parse("https://example.com/page").unwrap();
1525 let relative = "/other";
1526 let resolved = base.join(relative).unwrap();
1527 assert_eq!(resolved.as_str(), "https://example.com/other");
1528 }
1529
1530 #[test]
1532 fn redirect_to_http_rejected() {
1533 let err = validate_url("http://example.com/page").unwrap_err();
1534 assert!(matches!(err, ToolError::Blocked { .. }));
1535 }
1536
1537 #[test]
1538 fn ipv4_mapped_ipv6_link_local_blocked() {
1539 let err = validate_url("https://[::ffff:169.254.0.1]/path").unwrap_err();
1540 assert!(matches!(err, ToolError::Blocked { .. }));
1541 }
1542
1543 #[test]
1544 fn ipv4_mapped_ipv6_public_allowed() {
1545 assert!(validate_url("https://[::ffff:93.184.216.34]/path").is_ok());
1546 }
1547
1548 #[tokio::test]
1551 async fn fetch_http_scheme_blocked() {
1552 let config = ScrapeConfig::default();
1553 let executor = WebScrapeExecutor::new(&config);
1554 let call = crate::executor::ToolCall {
1555 tool_id: "fetch".to_owned(),
1556 params: {
1557 let mut m = serde_json::Map::new();
1558 m.insert("url".to_owned(), serde_json::json!("http://example.com"));
1559 m
1560 },
1561 };
1562 let result = executor.execute_tool_call(&call).await;
1563 assert!(matches!(result, Err(ToolError::Blocked { .. })));
1564 }
1565
1566 #[tokio::test]
1567 async fn fetch_private_ip_blocked() {
1568 let config = ScrapeConfig::default();
1569 let executor = WebScrapeExecutor::new(&config);
1570 let call = crate::executor::ToolCall {
1571 tool_id: "fetch".to_owned(),
1572 params: {
1573 let mut m = serde_json::Map::new();
1574 m.insert(
1575 "url".to_owned(),
1576 serde_json::json!("https://192.168.1.1/secret"),
1577 );
1578 m
1579 },
1580 };
1581 let result = executor.execute_tool_call(&call).await;
1582 assert!(matches!(result, Err(ToolError::Blocked { .. })));
1583 }
1584
1585 #[tokio::test]
1586 async fn fetch_localhost_blocked() {
1587 let config = ScrapeConfig::default();
1588 let executor = WebScrapeExecutor::new(&config);
1589 let call = crate::executor::ToolCall {
1590 tool_id: "fetch".to_owned(),
1591 params: {
1592 let mut m = serde_json::Map::new();
1593 m.insert(
1594 "url".to_owned(),
1595 serde_json::json!("https://localhost/page"),
1596 );
1597 m
1598 },
1599 };
1600 let result = executor.execute_tool_call(&call).await;
1601 assert!(matches!(result, Err(ToolError::Blocked { .. })));
1602 }
1603
1604 #[tokio::test]
1605 async fn fetch_unknown_tool_returns_none() {
1606 let config = ScrapeConfig::default();
1607 let executor = WebScrapeExecutor::new(&config);
1608 let call = crate::executor::ToolCall {
1609 tool_id: "unknown_tool".to_owned(),
1610 params: serde_json::Map::new(),
1611 };
1612 let result = executor.execute_tool_call(&call).await;
1613 assert!(result.unwrap().is_none());
1614 }
1615
1616 #[tokio::test]
1617 async fn fetch_returns_body_via_mock() {
1618 use wiremock::matchers::{method, path};
1619 use wiremock::{Mock, ResponseTemplate};
1620
1621 let (executor, server) = mock_server_executor().await;
1622 Mock::given(method("GET"))
1623 .and(path("/content"))
1624 .respond_with(ResponseTemplate::new(200).set_body_string("plain text content"))
1625 .mount(&server)
1626 .await;
1627
1628 let (host, addrs) = server_host_and_addr(&server);
1629 let url = format!("{}/content", server.uri());
1630 let result = executor.fetch_html(&url, &host, &addrs).await;
1631 assert!(result.is_ok());
1632 assert_eq!(result.unwrap(), "plain text content");
1633 }
1634
1635 #[test]
1636 fn tool_definitions_returns_web_scrape_and_fetch() {
1637 let config = ScrapeConfig::default();
1638 let executor = WebScrapeExecutor::new(&config);
1639 let defs = executor.tool_definitions();
1640 assert_eq!(defs.len(), 2);
1641 assert_eq!(defs[0].id, "web_scrape");
1642 assert_eq!(
1643 defs[0].invocation,
1644 crate::registry::InvocationHint::FencedBlock("scrape")
1645 );
1646 assert_eq!(defs[1].id, "fetch");
1647 assert_eq!(
1648 defs[1].invocation,
1649 crate::registry::InvocationHint::ToolCall
1650 );
1651 }
1652
1653 #[test]
1654 fn tool_definitions_schema_has_all_params() {
1655 let config = ScrapeConfig::default();
1656 let executor = WebScrapeExecutor::new(&config);
1657 let defs = executor.tool_definitions();
1658 let obj = defs[0].schema.as_object().unwrap();
1659 let props = obj["properties"].as_object().unwrap();
1660 assert!(props.contains_key("url"));
1661 assert!(props.contains_key("select"));
1662 assert!(props.contains_key("extract"));
1663 assert!(props.contains_key("limit"));
1664 let req = obj["required"].as_array().unwrap();
1665 assert!(req.iter().any(|v| v.as_str() == Some("url")));
1666 assert!(req.iter().any(|v| v.as_str() == Some("select")));
1667 assert!(!req.iter().any(|v| v.as_str() == Some("extract")));
1668 }
1669
1670 #[test]
1673 fn subdomain_localhost_blocked() {
1674 let host: url::Host<&str> = url::Host::Domain("foo.localhost");
1675 assert!(is_private_host(&host));
1676 }
1677
1678 #[test]
1679 fn internal_tld_blocked() {
1680 let host: url::Host<&str> = url::Host::Domain("service.internal");
1681 assert!(is_private_host(&host));
1682 }
1683
1684 #[test]
1685 fn local_tld_blocked() {
1686 let host: url::Host<&str> = url::Host::Domain("printer.local");
1687 assert!(is_private_host(&host));
1688 }
1689
1690 #[test]
1691 fn public_domain_not_blocked() {
1692 let host: url::Host<&str> = url::Host::Domain("example.com");
1693 assert!(!is_private_host(&host));
1694 }
1695
1696 #[tokio::test]
1699 async fn resolve_loopback_rejected() {
1700 let url = url::Url::parse("https://127.0.0.1/path").unwrap();
1702 let result = resolve_and_validate(&url).await;
1704 assert!(
1705 result.is_err(),
1706 "loopback IP must be rejected by resolve_and_validate"
1707 );
1708 let err = result.unwrap_err();
1709 assert!(matches!(err, crate::executor::ToolError::Blocked { .. }));
1710 }
1711
1712 #[tokio::test]
1713 async fn resolve_private_10_rejected() {
1714 let url = url::Url::parse("https://10.0.0.1/path").unwrap();
1715 let result = resolve_and_validate(&url).await;
1716 assert!(result.is_err());
1717 assert!(matches!(
1718 result.unwrap_err(),
1719 crate::executor::ToolError::Blocked { .. }
1720 ));
1721 }
1722
1723 #[tokio::test]
1724 async fn resolve_private_192_rejected() {
1725 let url = url::Url::parse("https://192.168.1.1/path").unwrap();
1726 let result = resolve_and_validate(&url).await;
1727 assert!(result.is_err());
1728 assert!(matches!(
1729 result.unwrap_err(),
1730 crate::executor::ToolError::Blocked { .. }
1731 ));
1732 }
1733
1734 #[tokio::test]
1735 async fn resolve_ipv6_loopback_rejected() {
1736 let url = url::Url::parse("https://[::1]/path").unwrap();
1737 let result = resolve_and_validate(&url).await;
1738 assert!(result.is_err());
1739 assert!(matches!(
1740 result.unwrap_err(),
1741 crate::executor::ToolError::Blocked { .. }
1742 ));
1743 }
1744
1745 #[tokio::test]
1746 async fn resolve_no_host_returns_ok() {
1747 let url = url::Url::parse("https://example.com/path").unwrap();
1749 let url_no_host = url::Url::parse("data:text/plain,hello").unwrap();
1751 let result = resolve_and_validate(&url_no_host).await;
1753 assert!(result.is_ok());
1754 let (host, addrs) = result.unwrap();
1755 assert!(host.is_empty());
1756 assert!(addrs.is_empty());
1757 drop(url);
1758 drop(url_no_host);
1759 }
1760
1761 async fn make_file_audit_logger(
1765 dir: &tempfile::TempDir,
1766 ) -> (
1767 std::sync::Arc<crate::audit::AuditLogger>,
1768 std::path::PathBuf,
1769 ) {
1770 use crate::audit::AuditLogger;
1771 use crate::config::AuditConfig;
1772 let path = dir.path().join("audit.log");
1773 let config = AuditConfig {
1774 enabled: true,
1775 destination: path.display().to_string(),
1776 };
1777 let logger = std::sync::Arc::new(AuditLogger::from_config(&config).await.unwrap());
1778 (logger, path)
1779 }
1780
1781 #[tokio::test]
1782 async fn with_audit_sets_logger() {
1783 let config = ScrapeConfig::default();
1784 let executor = WebScrapeExecutor::new(&config);
1785 assert!(executor.audit_logger.is_none());
1786
1787 let dir = tempfile::tempdir().unwrap();
1788 let (logger, _path) = make_file_audit_logger(&dir).await;
1789 let executor = executor.with_audit(logger);
1790 assert!(executor.audit_logger.is_some());
1791 }
1792
1793 #[test]
1794 fn tool_error_to_audit_result_blocked_maps_correctly() {
1795 let err = ToolError::Blocked {
1796 command: "scheme not allowed: http".into(),
1797 };
1798 let result = tool_error_to_audit_result(&err);
1799 assert!(
1800 matches!(result, AuditResult::Blocked { reason } if reason == "scheme not allowed: http")
1801 );
1802 }
1803
1804 #[test]
1805 fn tool_error_to_audit_result_timeout_maps_correctly() {
1806 let err = ToolError::Timeout { timeout_secs: 15 };
1807 let result = tool_error_to_audit_result(&err);
1808 assert!(matches!(result, AuditResult::Timeout));
1809 }
1810
1811 #[test]
1812 fn tool_error_to_audit_result_execution_error_maps_correctly() {
1813 let err = ToolError::Execution(std::io::Error::other("connection refused"));
1814 let result = tool_error_to_audit_result(&err);
1815 assert!(
1816 matches!(result, AuditResult::Error { message } if message.contains("connection refused"))
1817 );
1818 }
1819
1820 #[tokio::test]
1821 async fn fetch_audit_blocked_url_logged() {
1822 let dir = tempfile::tempdir().unwrap();
1823 let (logger, log_path) = make_file_audit_logger(&dir).await;
1824
1825 let config = ScrapeConfig::default();
1826 let executor = WebScrapeExecutor::new(&config).with_audit(logger);
1827
1828 let call = crate::executor::ToolCall {
1829 tool_id: "fetch".to_owned(),
1830 params: {
1831 let mut m = serde_json::Map::new();
1832 m.insert("url".to_owned(), serde_json::json!("http://example.com"));
1833 m
1834 },
1835 };
1836 let result = executor.execute_tool_call(&call).await;
1837 assert!(matches!(result, Err(ToolError::Blocked { .. })));
1838
1839 let content = tokio::fs::read_to_string(&log_path).await.unwrap();
1840 assert!(
1841 content.contains("\"tool\":\"fetch\""),
1842 "expected tool=fetch in audit: {content}"
1843 );
1844 assert!(
1845 content.contains("\"type\":\"blocked\""),
1846 "expected type=blocked in audit: {content}"
1847 );
1848 assert!(
1849 content.contains("http://example.com"),
1850 "expected URL in audit command field: {content}"
1851 );
1852 }
1853
1854 #[tokio::test]
1855 async fn log_audit_success_writes_to_file() {
1856 let dir = tempfile::tempdir().unwrap();
1857 let (logger, log_path) = make_file_audit_logger(&dir).await;
1858
1859 let config = ScrapeConfig::default();
1860 let executor = WebScrapeExecutor::new(&config).with_audit(logger);
1861
1862 executor
1863 .log_audit(
1864 "fetch",
1865 "https://example.com/page",
1866 AuditResult::Success,
1867 42,
1868 None,
1869 )
1870 .await;
1871
1872 let content = tokio::fs::read_to_string(&log_path).await.unwrap();
1873 assert!(
1874 content.contains("\"tool\":\"fetch\""),
1875 "expected tool=fetch in audit: {content}"
1876 );
1877 assert!(
1878 content.contains("\"type\":\"success\""),
1879 "expected type=success in audit: {content}"
1880 );
1881 assert!(
1882 content.contains("\"command\":\"https://example.com/page\""),
1883 "expected command URL in audit: {content}"
1884 );
1885 assert!(
1886 content.contains("\"duration_ms\":42"),
1887 "expected duration_ms in audit: {content}"
1888 );
1889 }
1890
1891 #[tokio::test]
1892 async fn log_audit_blocked_writes_to_file() {
1893 let dir = tempfile::tempdir().unwrap();
1894 let (logger, log_path) = make_file_audit_logger(&dir).await;
1895
1896 let config = ScrapeConfig::default();
1897 let executor = WebScrapeExecutor::new(&config).with_audit(logger);
1898
1899 executor
1900 .log_audit(
1901 "web_scrape",
1902 "http://evil.com/page",
1903 AuditResult::Blocked {
1904 reason: "scheme not allowed: http".into(),
1905 },
1906 0,
1907 None,
1908 )
1909 .await;
1910
1911 let content = tokio::fs::read_to_string(&log_path).await.unwrap();
1912 assert!(
1913 content.contains("\"tool\":\"web_scrape\""),
1914 "expected tool=web_scrape in audit: {content}"
1915 );
1916 assert!(
1917 content.contains("\"type\":\"blocked\""),
1918 "expected type=blocked in audit: {content}"
1919 );
1920 assert!(
1921 content.contains("scheme not allowed"),
1922 "expected block reason in audit: {content}"
1923 );
1924 }
1925
1926 #[tokio::test]
1927 async fn web_scrape_audit_blocked_url_logged() {
1928 let dir = tempfile::tempdir().unwrap();
1929 let (logger, log_path) = make_file_audit_logger(&dir).await;
1930
1931 let config = ScrapeConfig::default();
1932 let executor = WebScrapeExecutor::new(&config).with_audit(logger);
1933
1934 let call = crate::executor::ToolCall {
1935 tool_id: "web_scrape".to_owned(),
1936 params: {
1937 let mut m = serde_json::Map::new();
1938 m.insert("url".to_owned(), serde_json::json!("http://example.com"));
1939 m.insert("select".to_owned(), serde_json::json!("h1"));
1940 m
1941 },
1942 };
1943 let result = executor.execute_tool_call(&call).await;
1944 assert!(matches!(result, Err(ToolError::Blocked { .. })));
1945
1946 let content = tokio::fs::read_to_string(&log_path).await.unwrap();
1947 assert!(
1948 content.contains("\"tool\":\"web_scrape\""),
1949 "expected tool=web_scrape in audit: {content}"
1950 );
1951 assert!(
1952 content.contains("\"type\":\"blocked\""),
1953 "expected type=blocked in audit: {content}"
1954 );
1955 }
1956
1957 #[tokio::test]
1958 async fn no_audit_logger_does_not_panic_on_blocked_fetch() {
1959 let config = ScrapeConfig::default();
1960 let executor = WebScrapeExecutor::new(&config);
1961 assert!(executor.audit_logger.is_none());
1962
1963 let call = crate::executor::ToolCall {
1964 tool_id: "fetch".to_owned(),
1965 params: {
1966 let mut m = serde_json::Map::new();
1967 m.insert("url".to_owned(), serde_json::json!("http://example.com"));
1968 m
1969 },
1970 };
1971 let result = executor.execute_tool_call(&call).await;
1973 assert!(matches!(result, Err(ToolError::Blocked { .. })));
1974 }
1975
1976 #[tokio::test]
1978 async fn fetch_execute_tool_call_end_to_end() {
1979 use wiremock::matchers::{method, path};
1980 use wiremock::{Mock, ResponseTemplate};
1981
1982 let (executor, server) = mock_server_executor().await;
1983 Mock::given(method("GET"))
1984 .and(path("/e2e"))
1985 .respond_with(ResponseTemplate::new(200).set_body_string("<h1>end-to-end</h1>"))
1986 .mount(&server)
1987 .await;
1988
1989 let (host, addrs) = server_host_and_addr(&server);
1990 let result = executor
1992 .fetch_html(&format!("{}/e2e", server.uri()), &host, &addrs)
1993 .await;
1994 assert!(result.is_ok());
1995 assert!(result.unwrap().contains("end-to-end"));
1996 }
1997}