1use std::net::{IpAddr, SocketAddr};
5use std::sync::Arc;
6use std::time::{Duration, Instant};
7
8use schemars::JsonSchema;
9use serde::Deserialize;
10use url::Url;
11
12use crate::audit::{AuditEntry, AuditLogger, AuditResult, chrono_now};
13use crate::config::ScrapeConfig;
14use crate::executor::{
15 ClaimSource, ToolCall, ToolError, ToolExecutor, ToolOutput, deserialize_params,
16};
17use crate::net::is_private_ip;
18
19#[derive(Debug, Deserialize, JsonSchema)]
20struct FetchParams {
21 url: String,
23}
24
25#[derive(Debug, Deserialize, JsonSchema)]
26struct ScrapeInstruction {
27 url: String,
29 select: String,
31 #[serde(default = "default_extract")]
33 extract: String,
34 limit: Option<usize>,
36}
37
38fn default_extract() -> String {
39 "text".into()
40}
41
42#[derive(Debug)]
43enum ExtractMode {
44 Text,
45 Html,
46 Attr(String),
47}
48
49impl ExtractMode {
50 fn parse(s: &str) -> Self {
51 match s {
52 "text" => Self::Text,
53 "html" => Self::Html,
54 attr if attr.starts_with("attr:") => {
55 Self::Attr(attr.strip_prefix("attr:").unwrap_or(attr).to_owned())
56 }
57 _ => Self::Text,
58 }
59 }
60}
61
62#[derive(Debug)]
67pub struct WebScrapeExecutor {
68 timeout: Duration,
69 max_body_bytes: usize,
70 allowed_domains: Vec<String>,
71 denied_domains: Vec<String>,
72 audit_logger: Option<Arc<AuditLogger>>,
73}
74
75impl WebScrapeExecutor {
76 #[must_use]
77 pub fn new(config: &ScrapeConfig) -> Self {
78 Self {
79 timeout: Duration::from_secs(config.timeout),
80 max_body_bytes: config.max_body_bytes,
81 allowed_domains: config.allowed_domains.clone(),
82 denied_domains: config.denied_domains.clone(),
83 audit_logger: None,
84 }
85 }
86
87 #[must_use]
88 pub fn with_audit(mut self, logger: Arc<AuditLogger>) -> Self {
89 self.audit_logger = Some(logger);
90 self
91 }
92
93 fn build_client(&self, host: &str, addrs: &[SocketAddr]) -> reqwest::Client {
94 let mut builder = reqwest::Client::builder()
95 .timeout(self.timeout)
96 .redirect(reqwest::redirect::Policy::none());
97 builder = builder.resolve_to_addrs(host, addrs);
98 builder.build().unwrap_or_default()
99 }
100}
101
102impl ToolExecutor for WebScrapeExecutor {
103 fn tool_definitions(&self) -> Vec<crate::registry::ToolDef> {
104 use crate::registry::{InvocationHint, ToolDef};
105 vec![
106 ToolDef {
107 id: "web_scrape".into(),
108 description: "Extract structured data from a web page using CSS selectors.\n\nONLY call this tool when the user has explicitly provided a URL in their message, or when a prior tool call returned a URL to retrieve. NEVER construct, guess, or infer a URL from entity names, brand knowledge, or domain patterns.\n\nParameters: url (string, required) - HTTPS URL; select (string, required) - CSS selector; extract (string, optional) - \"text\", \"html\", or \"attr:<name>\"; limit (integer, optional) - max results\nReturns: extracted text/HTML/attribute values, one per line\nErrors: InvalidParams if URL is not HTTPS or selector is empty; Timeout after configured seconds; connection/DNS failures".into(),
109 schema: schemars::schema_for!(ScrapeInstruction),
110 invocation: InvocationHint::FencedBlock("scrape"),
111 },
112 ToolDef {
113 id: "fetch".into(),
114 description: "Fetch a URL and return the response body as plain text.\n\nONLY call this tool when the user has explicitly provided a URL in their message, or when a prior tool call returned a URL to retrieve. NEVER construct, guess, or infer a URL from entity names, brand knowledge, or domain patterns. If no URL is present in the conversation, do not call this tool.\n\nParameters: url (string, required) - HTTPS URL to fetch\nReturns: response body as UTF-8 text, truncated if exceeding max body size\nErrors: InvalidParams if URL is not HTTPS; Timeout; SSRF-blocked private IPs; connection failures".into(),
115 schema: schemars::schema_for!(FetchParams),
116 invocation: InvocationHint::ToolCall,
117 },
118 ]
119 }
120
121 async fn execute(&self, response: &str) -> Result<Option<ToolOutput>, ToolError> {
122 let blocks = extract_scrape_blocks(response);
123 if blocks.is_empty() {
124 return Ok(None);
125 }
126
127 let mut outputs = Vec::with_capacity(blocks.len());
128 #[allow(clippy::cast_possible_truncation)]
129 let blocks_executed = blocks.len() as u32;
130
131 for block in &blocks {
132 let instruction: ScrapeInstruction = serde_json::from_str(block).map_err(|e| {
133 ToolError::Execution(std::io::Error::new(
134 std::io::ErrorKind::InvalidData,
135 e.to_string(),
136 ))
137 })?;
138 let start = Instant::now();
139 let scrape_result = self.scrape_instruction(&instruction).await;
140 #[allow(clippy::cast_possible_truncation)]
141 let duration_ms = start.elapsed().as_millis() as u64;
142 match scrape_result {
143 Ok(output) => {
144 self.log_audit(
145 "web_scrape",
146 &instruction.url,
147 AuditResult::Success,
148 duration_ms,
149 None,
150 )
151 .await;
152 outputs.push(output);
153 }
154 Err(e) => {
155 let audit_result = tool_error_to_audit_result(&e);
156 self.log_audit(
157 "web_scrape",
158 &instruction.url,
159 audit_result,
160 duration_ms,
161 Some(&e),
162 )
163 .await;
164 return Err(e);
165 }
166 }
167 }
168
169 Ok(Some(ToolOutput {
170 tool_name: "web-scrape".to_owned(),
171 summary: outputs.join("\n\n"),
172 blocks_executed,
173 filter_stats: None,
174 diff: None,
175 streamed: false,
176 terminal_id: None,
177 locations: None,
178 raw_response: None,
179 claim_source: Some(ClaimSource::WebScrape),
180 }))
181 }
182
183 async fn execute_tool_call(&self, call: &ToolCall) -> Result<Option<ToolOutput>, ToolError> {
184 match call.tool_id.as_str() {
185 "web_scrape" => {
186 let instruction: ScrapeInstruction = deserialize_params(&call.params)?;
187 let start = Instant::now();
188 let result = self.scrape_instruction(&instruction).await;
189 #[allow(clippy::cast_possible_truncation)]
190 let duration_ms = start.elapsed().as_millis() as u64;
191 match result {
192 Ok(output) => {
193 self.log_audit(
194 "web_scrape",
195 &instruction.url,
196 AuditResult::Success,
197 duration_ms,
198 None,
199 )
200 .await;
201 Ok(Some(ToolOutput {
202 tool_name: "web-scrape".to_owned(),
203 summary: output,
204 blocks_executed: 1,
205 filter_stats: None,
206 diff: None,
207 streamed: false,
208 terminal_id: None,
209 locations: None,
210 raw_response: None,
211 claim_source: Some(ClaimSource::WebScrape),
212 }))
213 }
214 Err(e) => {
215 let audit_result = tool_error_to_audit_result(&e);
216 self.log_audit(
217 "web_scrape",
218 &instruction.url,
219 audit_result,
220 duration_ms,
221 Some(&e),
222 )
223 .await;
224 Err(e)
225 }
226 }
227 }
228 "fetch" => {
229 let p: FetchParams = deserialize_params(&call.params)?;
230 let start = Instant::now();
231 let result = self.handle_fetch(&p).await;
232 #[allow(clippy::cast_possible_truncation)]
233 let duration_ms = start.elapsed().as_millis() as u64;
234 match result {
235 Ok(output) => {
236 self.log_audit("fetch", &p.url, AuditResult::Success, duration_ms, None)
237 .await;
238 Ok(Some(ToolOutput {
239 tool_name: "fetch".to_owned(),
240 summary: output,
241 blocks_executed: 1,
242 filter_stats: None,
243 diff: None,
244 streamed: false,
245 terminal_id: None,
246 locations: None,
247 raw_response: None,
248 claim_source: Some(ClaimSource::WebScrape),
249 }))
250 }
251 Err(e) => {
252 let audit_result = tool_error_to_audit_result(&e);
253 self.log_audit("fetch", &p.url, audit_result, duration_ms, Some(&e))
254 .await;
255 Err(e)
256 }
257 }
258 }
259 _ => Ok(None),
260 }
261 }
262
263 fn is_tool_retryable(&self, tool_id: &str) -> bool {
264 matches!(tool_id, "web_scrape" | "fetch")
265 }
266}
267
268fn tool_error_to_audit_result(e: &ToolError) -> AuditResult {
269 match e {
270 ToolError::Blocked { command } => AuditResult::Blocked {
271 reason: command.clone(),
272 },
273 ToolError::Timeout { .. } => AuditResult::Timeout,
274 _ => AuditResult::Error {
275 message: e.to_string(),
276 },
277 }
278}
279
280impl WebScrapeExecutor {
281 async fn log_audit(
282 &self,
283 tool: &str,
284 command: &str,
285 result: AuditResult,
286 duration_ms: u64,
287 error: Option<&ToolError>,
288 ) {
289 if let Some(ref logger) = self.audit_logger {
290 let (error_category, error_domain, error_phase) =
291 error.map_or((None, None, None), |e| {
292 let cat = e.category();
293 (
294 Some(cat.label().to_owned()),
295 Some(cat.domain().label().to_owned()),
296 Some(cat.phase().label().to_owned()),
297 )
298 });
299 let entry = AuditEntry {
300 timestamp: chrono_now(),
301 tool: tool.into(),
302 command: command.into(),
303 result,
304 duration_ms,
305 error_category,
306 error_domain,
307 error_phase,
308 claim_source: Some(ClaimSource::WebScrape),
309 mcp_server_id: None,
310 injection_flagged: false,
311 embedding_anomalous: false,
312 cross_boundary_mcp_to_acp: false,
313 adversarial_policy_decision: None,
314 exit_code: None,
315 truncated: false,
316 };
317 logger.log(&entry).await;
318 }
319 }
320
321 async fn handle_fetch(&self, params: &FetchParams) -> Result<String, ToolError> {
322 let parsed = validate_url(¶ms.url)?;
323 check_domain_policy(
324 parsed.host_str().unwrap_or(""),
325 &self.allowed_domains,
326 &self.denied_domains,
327 )?;
328 let (host, addrs) = resolve_and_validate(&parsed).await?;
329 self.fetch_html(¶ms.url, &host, &addrs).await
330 }
331
332 async fn scrape_instruction(
333 &self,
334 instruction: &ScrapeInstruction,
335 ) -> Result<String, ToolError> {
336 let parsed = validate_url(&instruction.url)?;
337 check_domain_policy(
338 parsed.host_str().unwrap_or(""),
339 &self.allowed_domains,
340 &self.denied_domains,
341 )?;
342 let (host, addrs) = resolve_and_validate(&parsed).await?;
343 let html = self.fetch_html(&instruction.url, &host, &addrs).await?;
344 let selector = instruction.select.clone();
345 let extract = ExtractMode::parse(&instruction.extract);
346 let limit = instruction.limit.unwrap_or(10);
347 tokio::task::spawn_blocking(move || parse_and_extract(&html, &selector, &extract, limit))
348 .await
349 .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())))?
350 }
351
352 async fn fetch_html(
362 &self,
363 url: &str,
364 host: &str,
365 addrs: &[SocketAddr],
366 ) -> Result<String, ToolError> {
367 const MAX_REDIRECTS: usize = 3;
368
369 let mut current_url = url.to_owned();
370 let mut current_host = host.to_owned();
371 let mut current_addrs = addrs.to_vec();
372
373 for hop in 0..=MAX_REDIRECTS {
374 let client = self.build_client(¤t_host, ¤t_addrs);
376 let resp = client
377 .get(¤t_url)
378 .send()
379 .await
380 .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())))?;
381
382 let status = resp.status();
383
384 if status.is_redirection() {
385 if hop == MAX_REDIRECTS {
386 return Err(ToolError::Execution(std::io::Error::other(
387 "too many redirects",
388 )));
389 }
390
391 let location = resp
392 .headers()
393 .get(reqwest::header::LOCATION)
394 .and_then(|v| v.to_str().ok())
395 .ok_or_else(|| {
396 ToolError::Execution(std::io::Error::other("redirect with no Location"))
397 })?;
398
399 let base = Url::parse(¤t_url)
401 .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())))?;
402 let next_url = base
403 .join(location)
404 .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())))?;
405
406 let validated = validate_url(next_url.as_str())?;
407 let (next_host, next_addrs) = resolve_and_validate(&validated).await?;
408
409 current_url = next_url.to_string();
410 current_host = next_host;
411 current_addrs = next_addrs;
412 continue;
413 }
414
415 if !status.is_success() {
416 return Err(ToolError::Http {
417 status: status.as_u16(),
418 message: status.canonical_reason().unwrap_or("unknown").to_owned(),
419 });
420 }
421
422 let bytes = resp
423 .bytes()
424 .await
425 .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())))?;
426
427 if bytes.len() > self.max_body_bytes {
428 return Err(ToolError::Execution(std::io::Error::other(format!(
429 "response too large: {} bytes (max: {})",
430 bytes.len(),
431 self.max_body_bytes,
432 ))));
433 }
434
435 return String::from_utf8(bytes.to_vec())
436 .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())));
437 }
438
439 Err(ToolError::Execution(std::io::Error::other(
440 "too many redirects",
441 )))
442 }
443}
444
445fn extract_scrape_blocks(text: &str) -> Vec<&str> {
446 crate::executor::extract_fenced_blocks(text, "scrape")
447}
448
449fn check_domain_policy(
461 host: &str,
462 allowed_domains: &[String],
463 denied_domains: &[String],
464) -> Result<(), ToolError> {
465 if denied_domains.iter().any(|p| domain_matches(p, host)) {
466 return Err(ToolError::Blocked {
467 command: format!("domain blocked by denylist: {host}"),
468 });
469 }
470 if !allowed_domains.is_empty() {
471 let is_ip = host.parse::<std::net::IpAddr>().is_ok()
473 || (host.starts_with('[') && host.ends_with(']'));
474 if is_ip {
475 return Err(ToolError::Blocked {
476 command: format!(
477 "bare IP address not allowed when domain allowlist is active: {host}"
478 ),
479 });
480 }
481 if !allowed_domains.iter().any(|p| domain_matches(p, host)) {
482 return Err(ToolError::Blocked {
483 command: format!("domain not in allowlist: {host}"),
484 });
485 }
486 }
487 Ok(())
488}
489
490fn domain_matches(pattern: &str, host: &str) -> bool {
496 if pattern.starts_with("*.") {
497 let suffix = &pattern[1..]; if let Some(remainder) = host.strip_suffix(suffix) {
500 !remainder.is_empty() && !remainder.contains('.')
502 } else {
503 false
504 }
505 } else {
506 pattern == host
507 }
508}
509
510fn validate_url(raw: &str) -> Result<Url, ToolError> {
511 let parsed = Url::parse(raw).map_err(|_| ToolError::Blocked {
512 command: format!("invalid URL: {raw}"),
513 })?;
514
515 if parsed.scheme() != "https" {
516 return Err(ToolError::Blocked {
517 command: format!("scheme not allowed: {}", parsed.scheme()),
518 });
519 }
520
521 if let Some(host) = parsed.host()
522 && is_private_host(&host)
523 {
524 return Err(ToolError::Blocked {
525 command: format!(
526 "private/local host blocked: {}",
527 parsed.host_str().unwrap_or("")
528 ),
529 });
530 }
531
532 Ok(parsed)
533}
534
535fn is_private_host(host: &url::Host<&str>) -> bool {
536 match host {
537 url::Host::Domain(d) => {
538 #[allow(clippy::case_sensitive_file_extension_comparisons)]
541 {
542 *d == "localhost"
543 || d.ends_with(".localhost")
544 || d.ends_with(".internal")
545 || d.ends_with(".local")
546 }
547 }
548 url::Host::Ipv4(v4) => is_private_ip(IpAddr::V4(*v4)),
549 url::Host::Ipv6(v6) => is_private_ip(IpAddr::V6(*v6)),
550 }
551}
552
553async fn resolve_and_validate(url: &Url) -> Result<(String, Vec<SocketAddr>), ToolError> {
559 let Some(host) = url.host_str() else {
560 return Ok((String::new(), vec![]));
561 };
562 let port = url.port_or_known_default().unwrap_or(443);
563 let addrs: Vec<SocketAddr> = tokio::net::lookup_host(format!("{host}:{port}"))
564 .await
565 .map_err(|e| ToolError::Blocked {
566 command: format!("DNS resolution failed: {e}"),
567 })?
568 .collect();
569 for addr in &addrs {
570 if is_private_ip(addr.ip()) {
571 return Err(ToolError::Blocked {
572 command: format!("SSRF protection: private IP {} for host {host}", addr.ip()),
573 });
574 }
575 }
576 Ok((host.to_owned(), addrs))
577}
578
579fn parse_and_extract(
580 html: &str,
581 selector: &str,
582 extract: &ExtractMode,
583 limit: usize,
584) -> Result<String, ToolError> {
585 let soup = scrape_core::Soup::parse(html);
586
587 let tags = soup.find_all(selector).map_err(|e| {
588 ToolError::Execution(std::io::Error::new(
589 std::io::ErrorKind::InvalidData,
590 format!("invalid selector: {e}"),
591 ))
592 })?;
593
594 let mut results = Vec::new();
595
596 for tag in tags.into_iter().take(limit) {
597 let value = match extract {
598 ExtractMode::Text => tag.text(),
599 ExtractMode::Html => tag.inner_html(),
600 ExtractMode::Attr(name) => tag.get(name).unwrap_or_default().to_owned(),
601 };
602 if !value.trim().is_empty() {
603 results.push(value.trim().to_owned());
604 }
605 }
606
607 if results.is_empty() {
608 Ok(format!("No results for selector: {selector}"))
609 } else {
610 Ok(results.join("\n"))
611 }
612}
613
614#[cfg(test)]
615mod tests {
616 use super::*;
617
618 #[test]
621 fn extract_single_block() {
622 let text =
623 "Here:\n```scrape\n{\"url\":\"https://example.com\",\"select\":\"h1\"}\n```\nDone.";
624 let blocks = extract_scrape_blocks(text);
625 assert_eq!(blocks.len(), 1);
626 assert!(blocks[0].contains("example.com"));
627 }
628
629 #[test]
630 fn extract_multiple_blocks() {
631 let text = "```scrape\n{\"url\":\"https://a.com\",\"select\":\"h1\"}\n```\ntext\n```scrape\n{\"url\":\"https://b.com\",\"select\":\"p\"}\n```";
632 let blocks = extract_scrape_blocks(text);
633 assert_eq!(blocks.len(), 2);
634 }
635
636 #[test]
637 fn no_blocks_returns_empty() {
638 let blocks = extract_scrape_blocks("plain text, no code blocks");
639 assert!(blocks.is_empty());
640 }
641
642 #[test]
643 fn unclosed_block_ignored() {
644 let blocks = extract_scrape_blocks("```scrape\n{\"url\":\"https://x.com\"}");
645 assert!(blocks.is_empty());
646 }
647
648 #[test]
649 fn non_scrape_block_ignored() {
650 let text =
651 "```bash\necho hi\n```\n```scrape\n{\"url\":\"https://x.com\",\"select\":\"h1\"}\n```";
652 let blocks = extract_scrape_blocks(text);
653 assert_eq!(blocks.len(), 1);
654 assert!(blocks[0].contains("x.com"));
655 }
656
657 #[test]
658 fn multiline_json_block() {
659 let text =
660 "```scrape\n{\n \"url\": \"https://example.com\",\n \"select\": \"h1\"\n}\n```";
661 let blocks = extract_scrape_blocks(text);
662 assert_eq!(blocks.len(), 1);
663 let instr: ScrapeInstruction = serde_json::from_str(blocks[0]).unwrap();
664 assert_eq!(instr.url, "https://example.com");
665 }
666
667 #[test]
670 fn parse_valid_instruction() {
671 let json = r#"{"url":"https://example.com","select":"h1","extract":"text","limit":5}"#;
672 let instr: ScrapeInstruction = serde_json::from_str(json).unwrap();
673 assert_eq!(instr.url, "https://example.com");
674 assert_eq!(instr.select, "h1");
675 assert_eq!(instr.extract, "text");
676 assert_eq!(instr.limit, Some(5));
677 }
678
679 #[test]
680 fn parse_minimal_instruction() {
681 let json = r#"{"url":"https://example.com","select":"p"}"#;
682 let instr: ScrapeInstruction = serde_json::from_str(json).unwrap();
683 assert_eq!(instr.extract, "text");
684 assert!(instr.limit.is_none());
685 }
686
687 #[test]
688 fn parse_attr_extract() {
689 let json = r#"{"url":"https://example.com","select":"a","extract":"attr:href"}"#;
690 let instr: ScrapeInstruction = serde_json::from_str(json).unwrap();
691 assert_eq!(instr.extract, "attr:href");
692 }
693
694 #[test]
695 fn parse_invalid_json_errors() {
696 let result = serde_json::from_str::<ScrapeInstruction>("not json");
697 assert!(result.is_err());
698 }
699
700 #[test]
703 fn extract_mode_text() {
704 assert!(matches!(ExtractMode::parse("text"), ExtractMode::Text));
705 }
706
707 #[test]
708 fn extract_mode_html() {
709 assert!(matches!(ExtractMode::parse("html"), ExtractMode::Html));
710 }
711
712 #[test]
713 fn extract_mode_attr() {
714 let mode = ExtractMode::parse("attr:href");
715 assert!(matches!(mode, ExtractMode::Attr(ref s) if s == "href"));
716 }
717
718 #[test]
719 fn extract_mode_unknown_defaults_to_text() {
720 assert!(matches!(ExtractMode::parse("unknown"), ExtractMode::Text));
721 }
722
723 #[test]
726 fn valid_https_url() {
727 assert!(validate_url("https://example.com").is_ok());
728 }
729
730 #[test]
731 fn http_rejected() {
732 let err = validate_url("http://example.com").unwrap_err();
733 assert!(matches!(err, ToolError::Blocked { .. }));
734 }
735
736 #[test]
737 fn ftp_rejected() {
738 let err = validate_url("ftp://files.example.com").unwrap_err();
739 assert!(matches!(err, ToolError::Blocked { .. }));
740 }
741
742 #[test]
743 fn file_rejected() {
744 let err = validate_url("file:///etc/passwd").unwrap_err();
745 assert!(matches!(err, ToolError::Blocked { .. }));
746 }
747
748 #[test]
749 fn invalid_url_rejected() {
750 let err = validate_url("not a url").unwrap_err();
751 assert!(matches!(err, ToolError::Blocked { .. }));
752 }
753
754 #[test]
755 fn localhost_blocked() {
756 let err = validate_url("https://localhost/path").unwrap_err();
757 assert!(matches!(err, ToolError::Blocked { .. }));
758 }
759
760 #[test]
761 fn loopback_ip_blocked() {
762 let err = validate_url("https://127.0.0.1/path").unwrap_err();
763 assert!(matches!(err, ToolError::Blocked { .. }));
764 }
765
766 #[test]
767 fn private_10_blocked() {
768 let err = validate_url("https://10.0.0.1/api").unwrap_err();
769 assert!(matches!(err, ToolError::Blocked { .. }));
770 }
771
772 #[test]
773 fn private_172_blocked() {
774 let err = validate_url("https://172.16.0.1/api").unwrap_err();
775 assert!(matches!(err, ToolError::Blocked { .. }));
776 }
777
778 #[test]
779 fn private_192_blocked() {
780 let err = validate_url("https://192.168.1.1/api").unwrap_err();
781 assert!(matches!(err, ToolError::Blocked { .. }));
782 }
783
784 #[test]
785 fn ipv6_loopback_blocked() {
786 let err = validate_url("https://[::1]/path").unwrap_err();
787 assert!(matches!(err, ToolError::Blocked { .. }));
788 }
789
790 #[test]
791 fn public_ip_allowed() {
792 assert!(validate_url("https://93.184.216.34/page").is_ok());
793 }
794
795 #[test]
798 fn extract_text_from_html() {
799 let html = "<html><body><h1>Hello World</h1><p>Content</p></body></html>";
800 let result = parse_and_extract(html, "h1", &ExtractMode::Text, 10).unwrap();
801 assert_eq!(result, "Hello World");
802 }
803
804 #[test]
805 fn extract_multiple_elements() {
806 let html = "<ul><li>A</li><li>B</li><li>C</li></ul>";
807 let result = parse_and_extract(html, "li", &ExtractMode::Text, 10).unwrap();
808 assert_eq!(result, "A\nB\nC");
809 }
810
811 #[test]
812 fn extract_with_limit() {
813 let html = "<ul><li>A</li><li>B</li><li>C</li></ul>";
814 let result = parse_and_extract(html, "li", &ExtractMode::Text, 2).unwrap();
815 assert_eq!(result, "A\nB");
816 }
817
818 #[test]
819 fn extract_attr_href() {
820 let html = r#"<a href="https://example.com">Link</a>"#;
821 let result =
822 parse_and_extract(html, "a", &ExtractMode::Attr("href".to_owned()), 10).unwrap();
823 assert_eq!(result, "https://example.com");
824 }
825
826 #[test]
827 fn extract_inner_html() {
828 let html = "<div><span>inner</span></div>";
829 let result = parse_and_extract(html, "div", &ExtractMode::Html, 10).unwrap();
830 assert!(result.contains("<span>inner</span>"));
831 }
832
833 #[test]
834 fn no_matches_returns_message() {
835 let html = "<html><body><p>text</p></body></html>";
836 let result = parse_and_extract(html, "h1", &ExtractMode::Text, 10).unwrap();
837 assert!(result.starts_with("No results for selector:"));
838 }
839
840 #[test]
841 fn empty_text_skipped() {
842 let html = "<ul><li> </li><li>A</li></ul>";
843 let result = parse_and_extract(html, "li", &ExtractMode::Text, 10).unwrap();
844 assert_eq!(result, "A");
845 }
846
847 #[test]
848 fn invalid_selector_errors() {
849 let html = "<html><body></body></html>";
850 let result = parse_and_extract(html, "[[[invalid", &ExtractMode::Text, 10);
851 assert!(result.is_err());
852 }
853
854 #[test]
855 fn empty_html_returns_no_results() {
856 let result = parse_and_extract("", "h1", &ExtractMode::Text, 10).unwrap();
857 assert!(result.starts_with("No results for selector:"));
858 }
859
860 #[test]
861 fn nested_selector() {
862 let html = "<div><span>inner</span></div><span>outer</span>";
863 let result = parse_and_extract(html, "div > span", &ExtractMode::Text, 10).unwrap();
864 assert_eq!(result, "inner");
865 }
866
867 #[test]
868 fn attr_missing_returns_empty() {
869 let html = r"<a>No href</a>";
870 let result =
871 parse_and_extract(html, "a", &ExtractMode::Attr("href".to_owned()), 10).unwrap();
872 assert!(result.starts_with("No results for selector:"));
873 }
874
875 #[test]
876 fn extract_html_mode() {
877 let html = "<div><b>bold</b> text</div>";
878 let result = parse_and_extract(html, "div", &ExtractMode::Html, 10).unwrap();
879 assert!(result.contains("<b>bold</b>"));
880 }
881
882 #[test]
883 fn limit_zero_returns_no_results() {
884 let html = "<ul><li>A</li><li>B</li></ul>";
885 let result = parse_and_extract(html, "li", &ExtractMode::Text, 0).unwrap();
886 assert!(result.starts_with("No results for selector:"));
887 }
888
889 #[test]
892 fn url_with_port_allowed() {
893 assert!(validate_url("https://example.com:8443/path").is_ok());
894 }
895
896 #[test]
897 fn link_local_ip_blocked() {
898 let err = validate_url("https://169.254.1.1/path").unwrap_err();
899 assert!(matches!(err, ToolError::Blocked { .. }));
900 }
901
902 #[test]
903 fn url_no_scheme_rejected() {
904 let err = validate_url("example.com/path").unwrap_err();
905 assert!(matches!(err, ToolError::Blocked { .. }));
906 }
907
908 #[test]
909 fn unspecified_ipv4_blocked() {
910 let err = validate_url("https://0.0.0.0/path").unwrap_err();
911 assert!(matches!(err, ToolError::Blocked { .. }));
912 }
913
914 #[test]
915 fn broadcast_ipv4_blocked() {
916 let err = validate_url("https://255.255.255.255/path").unwrap_err();
917 assert!(matches!(err, ToolError::Blocked { .. }));
918 }
919
920 #[test]
921 fn ipv6_link_local_blocked() {
922 let err = validate_url("https://[fe80::1]/path").unwrap_err();
923 assert!(matches!(err, ToolError::Blocked { .. }));
924 }
925
926 #[test]
927 fn ipv6_unique_local_blocked() {
928 let err = validate_url("https://[fd12::1]/path").unwrap_err();
929 assert!(matches!(err, ToolError::Blocked { .. }));
930 }
931
932 #[test]
933 fn ipv4_mapped_ipv6_loopback_blocked() {
934 let err = validate_url("https://[::ffff:127.0.0.1]/path").unwrap_err();
935 assert!(matches!(err, ToolError::Blocked { .. }));
936 }
937
938 #[test]
939 fn ipv4_mapped_ipv6_private_blocked() {
940 let err = validate_url("https://[::ffff:10.0.0.1]/path").unwrap_err();
941 assert!(matches!(err, ToolError::Blocked { .. }));
942 }
943
944 #[tokio::test]
947 async fn executor_no_blocks_returns_none() {
948 let config = ScrapeConfig::default();
949 let executor = WebScrapeExecutor::new(&config);
950 let result = executor.execute("plain text").await;
951 assert!(result.unwrap().is_none());
952 }
953
954 #[tokio::test]
955 async fn executor_invalid_json_errors() {
956 let config = ScrapeConfig::default();
957 let executor = WebScrapeExecutor::new(&config);
958 let response = "```scrape\nnot json\n```";
959 let result = executor.execute(response).await;
960 assert!(matches!(result, Err(ToolError::Execution(_))));
961 }
962
963 #[tokio::test]
964 async fn executor_blocked_url_errors() {
965 let config = ScrapeConfig::default();
966 let executor = WebScrapeExecutor::new(&config);
967 let response = "```scrape\n{\"url\":\"http://example.com\",\"select\":\"h1\"}\n```";
968 let result = executor.execute(response).await;
969 assert!(matches!(result, Err(ToolError::Blocked { .. })));
970 }
971
972 #[tokio::test]
973 async fn executor_private_ip_blocked() {
974 let config = ScrapeConfig::default();
975 let executor = WebScrapeExecutor::new(&config);
976 let response = "```scrape\n{\"url\":\"https://192.168.1.1/api\",\"select\":\"h1\"}\n```";
977 let result = executor.execute(response).await;
978 assert!(matches!(result, Err(ToolError::Blocked { .. })));
979 }
980
981 #[tokio::test]
982 async fn executor_unreachable_host_returns_error() {
983 let config = ScrapeConfig {
984 timeout: 1,
985 max_body_bytes: 1_048_576,
986 ..Default::default()
987 };
988 let executor = WebScrapeExecutor::new(&config);
989 let response = "```scrape\n{\"url\":\"https://192.0.2.1:1/page\",\"select\":\"h1\"}\n```";
990 let result = executor.execute(response).await;
991 assert!(matches!(result, Err(ToolError::Execution(_))));
992 }
993
994 #[tokio::test]
995 async fn executor_localhost_url_blocked() {
996 let config = ScrapeConfig::default();
997 let executor = WebScrapeExecutor::new(&config);
998 let response = "```scrape\n{\"url\":\"https://localhost:9999/api\",\"select\":\"h1\"}\n```";
999 let result = executor.execute(response).await;
1000 assert!(matches!(result, Err(ToolError::Blocked { .. })));
1001 }
1002
1003 #[tokio::test]
1004 async fn executor_empty_text_returns_none() {
1005 let config = ScrapeConfig::default();
1006 let executor = WebScrapeExecutor::new(&config);
1007 let result = executor.execute("").await;
1008 assert!(result.unwrap().is_none());
1009 }
1010
1011 #[tokio::test]
1012 async fn executor_multiple_blocks_first_blocked() {
1013 let config = ScrapeConfig::default();
1014 let executor = WebScrapeExecutor::new(&config);
1015 let response = "```scrape\n{\"url\":\"http://evil.com\",\"select\":\"h1\"}\n```\n\
1016 ```scrape\n{\"url\":\"https://ok.com\",\"select\":\"h1\"}\n```";
1017 let result = executor.execute(response).await;
1018 assert!(result.is_err());
1019 }
1020
1021 #[test]
1022 fn validate_url_empty_string() {
1023 let err = validate_url("").unwrap_err();
1024 assert!(matches!(err, ToolError::Blocked { .. }));
1025 }
1026
1027 #[test]
1028 fn validate_url_javascript_scheme_blocked() {
1029 let err = validate_url("javascript:alert(1)").unwrap_err();
1030 assert!(matches!(err, ToolError::Blocked { .. }));
1031 }
1032
1033 #[test]
1034 fn validate_url_data_scheme_blocked() {
1035 let err = validate_url("data:text/html,<h1>hi</h1>").unwrap_err();
1036 assert!(matches!(err, ToolError::Blocked { .. }));
1037 }
1038
1039 #[test]
1040 fn is_private_host_public_domain_is_false() {
1041 let host: url::Host<&str> = url::Host::Domain("example.com");
1042 assert!(!is_private_host(&host));
1043 }
1044
1045 #[test]
1046 fn is_private_host_localhost_is_true() {
1047 let host: url::Host<&str> = url::Host::Domain("localhost");
1048 assert!(is_private_host(&host));
1049 }
1050
1051 #[test]
1052 fn is_private_host_ipv6_unspecified_is_true() {
1053 let host = url::Host::Ipv6(std::net::Ipv6Addr::UNSPECIFIED);
1054 assert!(is_private_host(&host));
1055 }
1056
1057 #[test]
1058 fn is_private_host_public_ipv6_is_false() {
1059 let host = url::Host::Ipv6("2001:db8::1".parse().unwrap());
1060 assert!(!is_private_host(&host));
1061 }
1062
1063 async fn mock_server_executor() -> (WebScrapeExecutor, wiremock::MockServer) {
1074 let server = wiremock::MockServer::start().await;
1075 let executor = WebScrapeExecutor {
1076 timeout: Duration::from_secs(5),
1077 max_body_bytes: 1_048_576,
1078 allowed_domains: vec![],
1079 denied_domains: vec![],
1080 audit_logger: None,
1081 };
1082 (executor, server)
1083 }
1084
1085 fn server_host_and_addr(server: &wiremock::MockServer) -> (String, Vec<std::net::SocketAddr>) {
1087 let uri = server.uri();
1088 let url = Url::parse(&uri).unwrap();
1089 let host = url.host_str().unwrap_or("127.0.0.1").to_owned();
1090 let port = url.port().unwrap_or(80);
1091 let addr: std::net::SocketAddr = format!("{host}:{port}").parse().unwrap();
1092 (host, vec![addr])
1093 }
1094
1095 async fn follow_redirects_raw(
1099 executor: &WebScrapeExecutor,
1100 start_url: &str,
1101 host: &str,
1102 addrs: &[std::net::SocketAddr],
1103 ) -> Result<String, ToolError> {
1104 const MAX_REDIRECTS: usize = 3;
1105 let mut current_url = start_url.to_owned();
1106 let mut current_host = host.to_owned();
1107 let mut current_addrs = addrs.to_vec();
1108
1109 for hop in 0..=MAX_REDIRECTS {
1110 let client = executor.build_client(¤t_host, ¤t_addrs);
1111 let resp = client
1112 .get(¤t_url)
1113 .send()
1114 .await
1115 .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())))?;
1116
1117 let status = resp.status();
1118
1119 if status.is_redirection() {
1120 if hop == MAX_REDIRECTS {
1121 return Err(ToolError::Execution(std::io::Error::other(
1122 "too many redirects",
1123 )));
1124 }
1125
1126 let location = resp
1127 .headers()
1128 .get(reqwest::header::LOCATION)
1129 .and_then(|v| v.to_str().ok())
1130 .ok_or_else(|| {
1131 ToolError::Execution(std::io::Error::other("redirect with no Location"))
1132 })?;
1133
1134 let base = Url::parse(¤t_url)
1135 .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())))?;
1136 let next_url = base
1137 .join(location)
1138 .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())))?;
1139
1140 current_url = next_url.to_string();
1142 let _ = &mut current_host;
1144 let _ = &mut current_addrs;
1145 continue;
1146 }
1147
1148 if !status.is_success() {
1149 return Err(ToolError::Execution(std::io::Error::other(format!(
1150 "HTTP {status}",
1151 ))));
1152 }
1153
1154 let bytes = resp
1155 .bytes()
1156 .await
1157 .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())))?;
1158
1159 if bytes.len() > executor.max_body_bytes {
1160 return Err(ToolError::Execution(std::io::Error::other(format!(
1161 "response too large: {} bytes (max: {})",
1162 bytes.len(),
1163 executor.max_body_bytes,
1164 ))));
1165 }
1166
1167 return String::from_utf8(bytes.to_vec())
1168 .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())));
1169 }
1170
1171 Err(ToolError::Execution(std::io::Error::other(
1172 "too many redirects",
1173 )))
1174 }
1175
1176 #[tokio::test]
1177 async fn fetch_html_success_returns_body() {
1178 use wiremock::matchers::{method, path};
1179 use wiremock::{Mock, ResponseTemplate};
1180
1181 let (executor, server) = mock_server_executor().await;
1182 Mock::given(method("GET"))
1183 .and(path("/page"))
1184 .respond_with(ResponseTemplate::new(200).set_body_string("<h1>OK</h1>"))
1185 .mount(&server)
1186 .await;
1187
1188 let (host, addrs) = server_host_and_addr(&server);
1189 let url = format!("{}/page", server.uri());
1190 let result = executor.fetch_html(&url, &host, &addrs).await;
1191 assert!(result.is_ok(), "expected Ok, got: {result:?}");
1192 assert_eq!(result.unwrap(), "<h1>OK</h1>");
1193 }
1194
1195 #[tokio::test]
1196 async fn fetch_html_non_2xx_returns_error() {
1197 use wiremock::matchers::{method, path};
1198 use wiremock::{Mock, ResponseTemplate};
1199
1200 let (executor, server) = mock_server_executor().await;
1201 Mock::given(method("GET"))
1202 .and(path("/forbidden"))
1203 .respond_with(ResponseTemplate::new(403))
1204 .mount(&server)
1205 .await;
1206
1207 let (host, addrs) = server_host_and_addr(&server);
1208 let url = format!("{}/forbidden", server.uri());
1209 let result = executor.fetch_html(&url, &host, &addrs).await;
1210 assert!(result.is_err());
1211 let msg = result.unwrap_err().to_string();
1212 assert!(msg.contains("403"), "expected 403 in error: {msg}");
1213 }
1214
1215 #[tokio::test]
1216 async fn fetch_html_404_returns_error() {
1217 use wiremock::matchers::{method, path};
1218 use wiremock::{Mock, ResponseTemplate};
1219
1220 let (executor, server) = mock_server_executor().await;
1221 Mock::given(method("GET"))
1222 .and(path("/missing"))
1223 .respond_with(ResponseTemplate::new(404))
1224 .mount(&server)
1225 .await;
1226
1227 let (host, addrs) = server_host_and_addr(&server);
1228 let url = format!("{}/missing", server.uri());
1229 let result = executor.fetch_html(&url, &host, &addrs).await;
1230 assert!(result.is_err());
1231 let msg = result.unwrap_err().to_string();
1232 assert!(msg.contains("404"), "expected 404 in error: {msg}");
1233 }
1234
1235 #[tokio::test]
1236 async fn fetch_html_redirect_no_location_returns_error() {
1237 use wiremock::matchers::{method, path};
1238 use wiremock::{Mock, ResponseTemplate};
1239
1240 let (executor, server) = mock_server_executor().await;
1241 Mock::given(method("GET"))
1243 .and(path("/redirect-no-loc"))
1244 .respond_with(ResponseTemplate::new(302))
1245 .mount(&server)
1246 .await;
1247
1248 let (host, addrs) = server_host_and_addr(&server);
1249 let url = format!("{}/redirect-no-loc", server.uri());
1250 let result = executor.fetch_html(&url, &host, &addrs).await;
1251 assert!(result.is_err());
1252 let msg = result.unwrap_err().to_string();
1253 assert!(
1254 msg.contains("Location") || msg.contains("location"),
1255 "expected Location-related error: {msg}"
1256 );
1257 }
1258
1259 #[tokio::test]
1260 async fn fetch_html_single_redirect_followed() {
1261 use wiremock::matchers::{method, path};
1262 use wiremock::{Mock, ResponseTemplate};
1263
1264 let (executor, server) = mock_server_executor().await;
1265 let final_url = format!("{}/final", server.uri());
1266
1267 Mock::given(method("GET"))
1268 .and(path("/start"))
1269 .respond_with(ResponseTemplate::new(302).insert_header("location", final_url.as_str()))
1270 .mount(&server)
1271 .await;
1272
1273 Mock::given(method("GET"))
1274 .and(path("/final"))
1275 .respond_with(ResponseTemplate::new(200).set_body_string("<p>final</p>"))
1276 .mount(&server)
1277 .await;
1278
1279 let (host, addrs) = server_host_and_addr(&server);
1280 let url = format!("{}/start", server.uri());
1281 let result = follow_redirects_raw(&executor, &url, &host, &addrs).await;
1282 assert!(result.is_ok(), "single redirect should succeed: {result:?}");
1283 assert_eq!(result.unwrap(), "<p>final</p>");
1284 }
1285
1286 #[tokio::test]
1287 async fn fetch_html_three_redirects_allowed() {
1288 use wiremock::matchers::{method, path};
1289 use wiremock::{Mock, ResponseTemplate};
1290
1291 let (executor, server) = mock_server_executor().await;
1292 let hop2 = format!("{}/hop2", server.uri());
1293 let hop3 = format!("{}/hop3", server.uri());
1294 let final_dest = format!("{}/done", server.uri());
1295
1296 Mock::given(method("GET"))
1297 .and(path("/hop1"))
1298 .respond_with(ResponseTemplate::new(301).insert_header("location", hop2.as_str()))
1299 .mount(&server)
1300 .await;
1301 Mock::given(method("GET"))
1302 .and(path("/hop2"))
1303 .respond_with(ResponseTemplate::new(301).insert_header("location", hop3.as_str()))
1304 .mount(&server)
1305 .await;
1306 Mock::given(method("GET"))
1307 .and(path("/hop3"))
1308 .respond_with(ResponseTemplate::new(301).insert_header("location", final_dest.as_str()))
1309 .mount(&server)
1310 .await;
1311 Mock::given(method("GET"))
1312 .and(path("/done"))
1313 .respond_with(ResponseTemplate::new(200).set_body_string("<p>done</p>"))
1314 .mount(&server)
1315 .await;
1316
1317 let (host, addrs) = server_host_and_addr(&server);
1318 let url = format!("{}/hop1", server.uri());
1319 let result = follow_redirects_raw(&executor, &url, &host, &addrs).await;
1320 assert!(result.is_ok(), "3 redirects should succeed: {result:?}");
1321 assert_eq!(result.unwrap(), "<p>done</p>");
1322 }
1323
1324 #[tokio::test]
1325 async fn fetch_html_four_redirects_rejected() {
1326 use wiremock::matchers::{method, path};
1327 use wiremock::{Mock, ResponseTemplate};
1328
1329 let (executor, server) = mock_server_executor().await;
1330 let hop2 = format!("{}/r2", server.uri());
1331 let hop3 = format!("{}/r3", server.uri());
1332 let hop4 = format!("{}/r4", server.uri());
1333 let hop5 = format!("{}/r5", server.uri());
1334
1335 for (from, to) in [
1336 ("/r1", &hop2),
1337 ("/r2", &hop3),
1338 ("/r3", &hop4),
1339 ("/r4", &hop5),
1340 ] {
1341 Mock::given(method("GET"))
1342 .and(path(from))
1343 .respond_with(ResponseTemplate::new(301).insert_header("location", to.as_str()))
1344 .mount(&server)
1345 .await;
1346 }
1347
1348 let (host, addrs) = server_host_and_addr(&server);
1349 let url = format!("{}/r1", server.uri());
1350 let result = follow_redirects_raw(&executor, &url, &host, &addrs).await;
1351 assert!(result.is_err(), "4 redirects should be rejected");
1352 let msg = result.unwrap_err().to_string();
1353 assert!(
1354 msg.contains("redirect"),
1355 "expected redirect-related error: {msg}"
1356 );
1357 }
1358
1359 #[tokio::test]
1360 async fn fetch_html_body_too_large_returns_error() {
1361 use wiremock::matchers::{method, path};
1362 use wiremock::{Mock, ResponseTemplate};
1363
1364 let small_limit_executor = WebScrapeExecutor {
1365 timeout: Duration::from_secs(5),
1366 max_body_bytes: 10,
1367 allowed_domains: vec![],
1368 denied_domains: vec![],
1369 audit_logger: None,
1370 };
1371 let server = wiremock::MockServer::start().await;
1372 Mock::given(method("GET"))
1373 .and(path("/big"))
1374 .respond_with(
1375 ResponseTemplate::new(200)
1376 .set_body_string("this body is definitely longer than ten bytes"),
1377 )
1378 .mount(&server)
1379 .await;
1380
1381 let (host, addrs) = server_host_and_addr(&server);
1382 let url = format!("{}/big", server.uri());
1383 let result = small_limit_executor.fetch_html(&url, &host, &addrs).await;
1384 assert!(result.is_err());
1385 let msg = result.unwrap_err().to_string();
1386 assert!(msg.contains("too large"), "expected too-large error: {msg}");
1387 }
1388
1389 #[test]
1390 fn extract_scrape_blocks_empty_block_content() {
1391 let text = "```scrape\n\n```";
1392 let blocks = extract_scrape_blocks(text);
1393 assert_eq!(blocks.len(), 1);
1394 assert!(blocks[0].is_empty());
1395 }
1396
1397 #[test]
1398 fn extract_scrape_blocks_whitespace_only() {
1399 let text = "```scrape\n \n```";
1400 let blocks = extract_scrape_blocks(text);
1401 assert_eq!(blocks.len(), 1);
1402 }
1403
1404 #[test]
1405 fn parse_and_extract_multiple_selectors() {
1406 let html = "<div><h1>Title</h1><p>Para</p></div>";
1407 let result = parse_and_extract(html, "h1, p", &ExtractMode::Text, 10).unwrap();
1408 assert!(result.contains("Title"));
1409 assert!(result.contains("Para"));
1410 }
1411
1412 #[test]
1413 fn webscrape_executor_new_with_custom_config() {
1414 let config = ScrapeConfig {
1415 timeout: 60,
1416 max_body_bytes: 512,
1417 ..Default::default()
1418 };
1419 let executor = WebScrapeExecutor::new(&config);
1420 assert_eq!(executor.max_body_bytes, 512);
1421 }
1422
1423 #[test]
1424 fn webscrape_executor_debug() {
1425 let config = ScrapeConfig::default();
1426 let executor = WebScrapeExecutor::new(&config);
1427 let dbg = format!("{executor:?}");
1428 assert!(dbg.contains("WebScrapeExecutor"));
1429 }
1430
1431 #[test]
1432 fn extract_mode_attr_empty_name() {
1433 let mode = ExtractMode::parse("attr:");
1434 assert!(matches!(mode, ExtractMode::Attr(ref s) if s.is_empty()));
1435 }
1436
1437 #[test]
1438 fn default_extract_returns_text() {
1439 assert_eq!(default_extract(), "text");
1440 }
1441
1442 #[test]
1443 fn scrape_instruction_debug() {
1444 let json = r#"{"url":"https://example.com","select":"h1"}"#;
1445 let instr: ScrapeInstruction = serde_json::from_str(json).unwrap();
1446 let dbg = format!("{instr:?}");
1447 assert!(dbg.contains("ScrapeInstruction"));
1448 }
1449
1450 #[test]
1451 fn extract_mode_debug() {
1452 let mode = ExtractMode::Text;
1453 let dbg = format!("{mode:?}");
1454 assert!(dbg.contains("Text"));
1455 }
1456
1457 #[test]
1462 fn max_redirects_constant_is_three() {
1463 const MAX_REDIRECTS: usize = 3;
1467 assert_eq!(MAX_REDIRECTS, 3, "fetch_html allows exactly 3 redirects");
1468 }
1469
1470 #[test]
1473 fn redirect_no_location_error_message() {
1474 let err = std::io::Error::other("redirect with no Location");
1475 assert!(err.to_string().contains("redirect with no Location"));
1476 }
1477
1478 #[test]
1480 fn too_many_redirects_error_message() {
1481 let err = std::io::Error::other("too many redirects");
1482 assert!(err.to_string().contains("too many redirects"));
1483 }
1484
1485 #[test]
1487 fn non_2xx_status_error_format() {
1488 let status = reqwest::StatusCode::FORBIDDEN;
1489 let msg = format!("HTTP {status}");
1490 assert!(msg.contains("403"));
1491 }
1492
1493 #[test]
1495 fn not_found_status_error_format() {
1496 let status = reqwest::StatusCode::NOT_FOUND;
1497 let msg = format!("HTTP {status}");
1498 assert!(msg.contains("404"));
1499 }
1500
1501 #[test]
1503 fn relative_redirect_same_host_path() {
1504 let base = Url::parse("https://example.com/current").unwrap();
1505 let resolved = base.join("/other").unwrap();
1506 assert_eq!(resolved.as_str(), "https://example.com/other");
1507 }
1508
1509 #[test]
1511 fn relative_redirect_relative_path() {
1512 let base = Url::parse("https://example.com/a/b").unwrap();
1513 let resolved = base.join("c").unwrap();
1514 assert_eq!(resolved.as_str(), "https://example.com/a/c");
1515 }
1516
1517 #[test]
1519 fn absolute_redirect_overrides_base() {
1520 let base = Url::parse("https://example.com/page").unwrap();
1521 let resolved = base.join("https://other.com/target").unwrap();
1522 assert_eq!(resolved.as_str(), "https://other.com/target");
1523 }
1524
1525 #[test]
1527 fn redirect_http_downgrade_rejected() {
1528 let location = "http://example.com/page";
1529 let base = Url::parse("https://example.com/start").unwrap();
1530 let next = base.join(location).unwrap();
1531 let err = validate_url(next.as_str()).unwrap_err();
1532 assert!(matches!(err, ToolError::Blocked { .. }));
1533 }
1534
1535 #[test]
1537 fn redirect_location_private_ip_blocked() {
1538 let location = "https://192.168.100.1/admin";
1539 let base = Url::parse("https://example.com/start").unwrap();
1540 let next = base.join(location).unwrap();
1541 let err = validate_url(next.as_str()).unwrap_err();
1542 assert!(matches!(err, ToolError::Blocked { .. }));
1543 let ToolError::Blocked { command: cmd } = err else {
1544 panic!("expected Blocked");
1545 };
1546 assert!(
1547 cmd.contains("private") || cmd.contains("scheme"),
1548 "error message should describe the block reason: {cmd}"
1549 );
1550 }
1551
1552 #[test]
1554 fn redirect_location_internal_domain_blocked() {
1555 let location = "https://metadata.internal/latest/meta-data/";
1556 let base = Url::parse("https://example.com/start").unwrap();
1557 let next = base.join(location).unwrap();
1558 let err = validate_url(next.as_str()).unwrap_err();
1559 assert!(matches!(err, ToolError::Blocked { .. }));
1560 }
1561
1562 #[test]
1564 fn redirect_chain_three_hops_all_public() {
1565 let hops = [
1566 "https://redirect1.example.com/hop1",
1567 "https://redirect2.example.com/hop2",
1568 "https://destination.example.com/final",
1569 ];
1570 for hop in hops {
1571 assert!(validate_url(hop).is_ok(), "expected ok for {hop}");
1572 }
1573 }
1574
1575 #[test]
1580 fn redirect_to_private_ip_rejected_by_validate_url() {
1581 let private_targets = [
1583 "https://127.0.0.1/secret",
1584 "https://10.0.0.1/internal",
1585 "https://192.168.1.1/admin",
1586 "https://172.16.0.1/data",
1587 "https://[::1]/path",
1588 "https://[fe80::1]/path",
1589 "https://localhost/path",
1590 "https://service.internal/api",
1591 ];
1592 for target in private_targets {
1593 let result = validate_url(target);
1594 assert!(result.is_err(), "expected error for {target}");
1595 assert!(
1596 matches!(result.unwrap_err(), ToolError::Blocked { .. }),
1597 "expected Blocked for {target}"
1598 );
1599 }
1600 }
1601
1602 #[test]
1604 fn redirect_relative_url_resolves_correctly() {
1605 let base = Url::parse("https://example.com/page").unwrap();
1606 let relative = "/other";
1607 let resolved = base.join(relative).unwrap();
1608 assert_eq!(resolved.as_str(), "https://example.com/other");
1609 }
1610
1611 #[test]
1613 fn redirect_to_http_rejected() {
1614 let err = validate_url("http://example.com/page").unwrap_err();
1615 assert!(matches!(err, ToolError::Blocked { .. }));
1616 }
1617
1618 #[test]
1619 fn ipv4_mapped_ipv6_link_local_blocked() {
1620 let err = validate_url("https://[::ffff:169.254.0.1]/path").unwrap_err();
1621 assert!(matches!(err, ToolError::Blocked { .. }));
1622 }
1623
1624 #[test]
1625 fn ipv4_mapped_ipv6_public_allowed() {
1626 assert!(validate_url("https://[::ffff:93.184.216.34]/path").is_ok());
1627 }
1628
1629 #[tokio::test]
1632 async fn fetch_http_scheme_blocked() {
1633 let config = ScrapeConfig::default();
1634 let executor = WebScrapeExecutor::new(&config);
1635 let call = crate::executor::ToolCall {
1636 tool_id: "fetch".to_owned(),
1637 params: {
1638 let mut m = serde_json::Map::new();
1639 m.insert("url".to_owned(), serde_json::json!("http://example.com"));
1640 m
1641 },
1642 };
1643 let result = executor.execute_tool_call(&call).await;
1644 assert!(matches!(result, Err(ToolError::Blocked { .. })));
1645 }
1646
1647 #[tokio::test]
1648 async fn fetch_private_ip_blocked() {
1649 let config = ScrapeConfig::default();
1650 let executor = WebScrapeExecutor::new(&config);
1651 let call = crate::executor::ToolCall {
1652 tool_id: "fetch".to_owned(),
1653 params: {
1654 let mut m = serde_json::Map::new();
1655 m.insert(
1656 "url".to_owned(),
1657 serde_json::json!("https://192.168.1.1/secret"),
1658 );
1659 m
1660 },
1661 };
1662 let result = executor.execute_tool_call(&call).await;
1663 assert!(matches!(result, Err(ToolError::Blocked { .. })));
1664 }
1665
1666 #[tokio::test]
1667 async fn fetch_localhost_blocked() {
1668 let config = ScrapeConfig::default();
1669 let executor = WebScrapeExecutor::new(&config);
1670 let call = crate::executor::ToolCall {
1671 tool_id: "fetch".to_owned(),
1672 params: {
1673 let mut m = serde_json::Map::new();
1674 m.insert(
1675 "url".to_owned(),
1676 serde_json::json!("https://localhost/page"),
1677 );
1678 m
1679 },
1680 };
1681 let result = executor.execute_tool_call(&call).await;
1682 assert!(matches!(result, Err(ToolError::Blocked { .. })));
1683 }
1684
1685 #[tokio::test]
1686 async fn fetch_unknown_tool_returns_none() {
1687 let config = ScrapeConfig::default();
1688 let executor = WebScrapeExecutor::new(&config);
1689 let call = crate::executor::ToolCall {
1690 tool_id: "unknown_tool".to_owned(),
1691 params: serde_json::Map::new(),
1692 };
1693 let result = executor.execute_tool_call(&call).await;
1694 assert!(result.unwrap().is_none());
1695 }
1696
1697 #[tokio::test]
1698 async fn fetch_returns_body_via_mock() {
1699 use wiremock::matchers::{method, path};
1700 use wiremock::{Mock, ResponseTemplate};
1701
1702 let (executor, server) = mock_server_executor().await;
1703 Mock::given(method("GET"))
1704 .and(path("/content"))
1705 .respond_with(ResponseTemplate::new(200).set_body_string("plain text content"))
1706 .mount(&server)
1707 .await;
1708
1709 let (host, addrs) = server_host_and_addr(&server);
1710 let url = format!("{}/content", server.uri());
1711 let result = executor.fetch_html(&url, &host, &addrs).await;
1712 assert!(result.is_ok());
1713 assert_eq!(result.unwrap(), "plain text content");
1714 }
1715
1716 #[test]
1717 fn tool_definitions_returns_web_scrape_and_fetch() {
1718 let config = ScrapeConfig::default();
1719 let executor = WebScrapeExecutor::new(&config);
1720 let defs = executor.tool_definitions();
1721 assert_eq!(defs.len(), 2);
1722 assert_eq!(defs[0].id, "web_scrape");
1723 assert_eq!(
1724 defs[0].invocation,
1725 crate::registry::InvocationHint::FencedBlock("scrape")
1726 );
1727 assert_eq!(defs[1].id, "fetch");
1728 assert_eq!(
1729 defs[1].invocation,
1730 crate::registry::InvocationHint::ToolCall
1731 );
1732 }
1733
1734 #[test]
1735 fn tool_definitions_schema_has_all_params() {
1736 let config = ScrapeConfig::default();
1737 let executor = WebScrapeExecutor::new(&config);
1738 let defs = executor.tool_definitions();
1739 let obj = defs[0].schema.as_object().unwrap();
1740 let props = obj["properties"].as_object().unwrap();
1741 assert!(props.contains_key("url"));
1742 assert!(props.contains_key("select"));
1743 assert!(props.contains_key("extract"));
1744 assert!(props.contains_key("limit"));
1745 let req = obj["required"].as_array().unwrap();
1746 assert!(req.iter().any(|v| v.as_str() == Some("url")));
1747 assert!(req.iter().any(|v| v.as_str() == Some("select")));
1748 assert!(!req.iter().any(|v| v.as_str() == Some("extract")));
1749 }
1750
1751 #[test]
1754 fn subdomain_localhost_blocked() {
1755 let host: url::Host<&str> = url::Host::Domain("foo.localhost");
1756 assert!(is_private_host(&host));
1757 }
1758
1759 #[test]
1760 fn internal_tld_blocked() {
1761 let host: url::Host<&str> = url::Host::Domain("service.internal");
1762 assert!(is_private_host(&host));
1763 }
1764
1765 #[test]
1766 fn local_tld_blocked() {
1767 let host: url::Host<&str> = url::Host::Domain("printer.local");
1768 assert!(is_private_host(&host));
1769 }
1770
1771 #[test]
1772 fn public_domain_not_blocked() {
1773 let host: url::Host<&str> = url::Host::Domain("example.com");
1774 assert!(!is_private_host(&host));
1775 }
1776
1777 #[tokio::test]
1780 async fn resolve_loopback_rejected() {
1781 let url = url::Url::parse("https://127.0.0.1/path").unwrap();
1783 let result = resolve_and_validate(&url).await;
1785 assert!(
1786 result.is_err(),
1787 "loopback IP must be rejected by resolve_and_validate"
1788 );
1789 let err = result.unwrap_err();
1790 assert!(matches!(err, crate::executor::ToolError::Blocked { .. }));
1791 }
1792
1793 #[tokio::test]
1794 async fn resolve_private_10_rejected() {
1795 let url = url::Url::parse("https://10.0.0.1/path").unwrap();
1796 let result = resolve_and_validate(&url).await;
1797 assert!(result.is_err());
1798 assert!(matches!(
1799 result.unwrap_err(),
1800 crate::executor::ToolError::Blocked { .. }
1801 ));
1802 }
1803
1804 #[tokio::test]
1805 async fn resolve_private_192_rejected() {
1806 let url = url::Url::parse("https://192.168.1.1/path").unwrap();
1807 let result = resolve_and_validate(&url).await;
1808 assert!(result.is_err());
1809 assert!(matches!(
1810 result.unwrap_err(),
1811 crate::executor::ToolError::Blocked { .. }
1812 ));
1813 }
1814
1815 #[tokio::test]
1816 async fn resolve_ipv6_loopback_rejected() {
1817 let url = url::Url::parse("https://[::1]/path").unwrap();
1818 let result = resolve_and_validate(&url).await;
1819 assert!(result.is_err());
1820 assert!(matches!(
1821 result.unwrap_err(),
1822 crate::executor::ToolError::Blocked { .. }
1823 ));
1824 }
1825
1826 #[tokio::test]
1827 async fn resolve_no_host_returns_ok() {
1828 let url = url::Url::parse("https://example.com/path").unwrap();
1830 let url_no_host = url::Url::parse("data:text/plain,hello").unwrap();
1832 let result = resolve_and_validate(&url_no_host).await;
1834 assert!(result.is_ok());
1835 let (host, addrs) = result.unwrap();
1836 assert!(host.is_empty());
1837 assert!(addrs.is_empty());
1838 drop(url);
1839 drop(url_no_host);
1840 }
1841
1842 async fn make_file_audit_logger(
1846 dir: &tempfile::TempDir,
1847 ) -> (
1848 std::sync::Arc<crate::audit::AuditLogger>,
1849 std::path::PathBuf,
1850 ) {
1851 use crate::audit::AuditLogger;
1852 use crate::config::AuditConfig;
1853 let path = dir.path().join("audit.log");
1854 let config = AuditConfig {
1855 enabled: true,
1856 destination: path.display().to_string(),
1857 ..Default::default()
1858 };
1859 let logger = std::sync::Arc::new(AuditLogger::from_config(&config).await.unwrap());
1860 (logger, path)
1861 }
1862
1863 #[tokio::test]
1864 async fn with_audit_sets_logger() {
1865 let config = ScrapeConfig::default();
1866 let executor = WebScrapeExecutor::new(&config);
1867 assert!(executor.audit_logger.is_none());
1868
1869 let dir = tempfile::tempdir().unwrap();
1870 let (logger, _path) = make_file_audit_logger(&dir).await;
1871 let executor = executor.with_audit(logger);
1872 assert!(executor.audit_logger.is_some());
1873 }
1874
1875 #[test]
1876 fn tool_error_to_audit_result_blocked_maps_correctly() {
1877 let err = ToolError::Blocked {
1878 command: "scheme not allowed: http".into(),
1879 };
1880 let result = tool_error_to_audit_result(&err);
1881 assert!(
1882 matches!(result, AuditResult::Blocked { reason } if reason == "scheme not allowed: http")
1883 );
1884 }
1885
1886 #[test]
1887 fn tool_error_to_audit_result_timeout_maps_correctly() {
1888 let err = ToolError::Timeout { timeout_secs: 15 };
1889 let result = tool_error_to_audit_result(&err);
1890 assert!(matches!(result, AuditResult::Timeout));
1891 }
1892
1893 #[test]
1894 fn tool_error_to_audit_result_execution_error_maps_correctly() {
1895 let err = ToolError::Execution(std::io::Error::other("connection refused"));
1896 let result = tool_error_to_audit_result(&err);
1897 assert!(
1898 matches!(result, AuditResult::Error { message } if message.contains("connection refused"))
1899 );
1900 }
1901
1902 #[tokio::test]
1903 async fn fetch_audit_blocked_url_logged() {
1904 let dir = tempfile::tempdir().unwrap();
1905 let (logger, log_path) = make_file_audit_logger(&dir).await;
1906
1907 let config = ScrapeConfig::default();
1908 let executor = WebScrapeExecutor::new(&config).with_audit(logger);
1909
1910 let call = crate::executor::ToolCall {
1911 tool_id: "fetch".to_owned(),
1912 params: {
1913 let mut m = serde_json::Map::new();
1914 m.insert("url".to_owned(), serde_json::json!("http://example.com"));
1915 m
1916 },
1917 };
1918 let result = executor.execute_tool_call(&call).await;
1919 assert!(matches!(result, Err(ToolError::Blocked { .. })));
1920
1921 let content = tokio::fs::read_to_string(&log_path).await.unwrap();
1922 assert!(
1923 content.contains("\"tool\":\"fetch\""),
1924 "expected tool=fetch in audit: {content}"
1925 );
1926 assert!(
1927 content.contains("\"type\":\"blocked\""),
1928 "expected type=blocked in audit: {content}"
1929 );
1930 assert!(
1931 content.contains("http://example.com"),
1932 "expected URL in audit command field: {content}"
1933 );
1934 }
1935
1936 #[tokio::test]
1937 async fn log_audit_success_writes_to_file() {
1938 let dir = tempfile::tempdir().unwrap();
1939 let (logger, log_path) = make_file_audit_logger(&dir).await;
1940
1941 let config = ScrapeConfig::default();
1942 let executor = WebScrapeExecutor::new(&config).with_audit(logger);
1943
1944 executor
1945 .log_audit(
1946 "fetch",
1947 "https://example.com/page",
1948 AuditResult::Success,
1949 42,
1950 None,
1951 )
1952 .await;
1953
1954 let content = tokio::fs::read_to_string(&log_path).await.unwrap();
1955 assert!(
1956 content.contains("\"tool\":\"fetch\""),
1957 "expected tool=fetch in audit: {content}"
1958 );
1959 assert!(
1960 content.contains("\"type\":\"success\""),
1961 "expected type=success in audit: {content}"
1962 );
1963 assert!(
1964 content.contains("\"command\":\"https://example.com/page\""),
1965 "expected command URL in audit: {content}"
1966 );
1967 assert!(
1968 content.contains("\"duration_ms\":42"),
1969 "expected duration_ms in audit: {content}"
1970 );
1971 }
1972
1973 #[tokio::test]
1974 async fn log_audit_blocked_writes_to_file() {
1975 let dir = tempfile::tempdir().unwrap();
1976 let (logger, log_path) = make_file_audit_logger(&dir).await;
1977
1978 let config = ScrapeConfig::default();
1979 let executor = WebScrapeExecutor::new(&config).with_audit(logger);
1980
1981 executor
1982 .log_audit(
1983 "web_scrape",
1984 "http://evil.com/page",
1985 AuditResult::Blocked {
1986 reason: "scheme not allowed: http".into(),
1987 },
1988 0,
1989 None,
1990 )
1991 .await;
1992
1993 let content = tokio::fs::read_to_string(&log_path).await.unwrap();
1994 assert!(
1995 content.contains("\"tool\":\"web_scrape\""),
1996 "expected tool=web_scrape in audit: {content}"
1997 );
1998 assert!(
1999 content.contains("\"type\":\"blocked\""),
2000 "expected type=blocked in audit: {content}"
2001 );
2002 assert!(
2003 content.contains("scheme not allowed"),
2004 "expected block reason in audit: {content}"
2005 );
2006 }
2007
2008 #[tokio::test]
2009 async fn web_scrape_audit_blocked_url_logged() {
2010 let dir = tempfile::tempdir().unwrap();
2011 let (logger, log_path) = make_file_audit_logger(&dir).await;
2012
2013 let config = ScrapeConfig::default();
2014 let executor = WebScrapeExecutor::new(&config).with_audit(logger);
2015
2016 let call = crate::executor::ToolCall {
2017 tool_id: "web_scrape".to_owned(),
2018 params: {
2019 let mut m = serde_json::Map::new();
2020 m.insert("url".to_owned(), serde_json::json!("http://example.com"));
2021 m.insert("select".to_owned(), serde_json::json!("h1"));
2022 m
2023 },
2024 };
2025 let result = executor.execute_tool_call(&call).await;
2026 assert!(matches!(result, Err(ToolError::Blocked { .. })));
2027
2028 let content = tokio::fs::read_to_string(&log_path).await.unwrap();
2029 assert!(
2030 content.contains("\"tool\":\"web_scrape\""),
2031 "expected tool=web_scrape in audit: {content}"
2032 );
2033 assert!(
2034 content.contains("\"type\":\"blocked\""),
2035 "expected type=blocked in audit: {content}"
2036 );
2037 }
2038
2039 #[tokio::test]
2040 async fn no_audit_logger_does_not_panic_on_blocked_fetch() {
2041 let config = ScrapeConfig::default();
2042 let executor = WebScrapeExecutor::new(&config);
2043 assert!(executor.audit_logger.is_none());
2044
2045 let call = crate::executor::ToolCall {
2046 tool_id: "fetch".to_owned(),
2047 params: {
2048 let mut m = serde_json::Map::new();
2049 m.insert("url".to_owned(), serde_json::json!("http://example.com"));
2050 m
2051 },
2052 };
2053 let result = executor.execute_tool_call(&call).await;
2055 assert!(matches!(result, Err(ToolError::Blocked { .. })));
2056 }
2057
2058 #[tokio::test]
2060 async fn fetch_execute_tool_call_end_to_end() {
2061 use wiremock::matchers::{method, path};
2062 use wiremock::{Mock, ResponseTemplate};
2063
2064 let (executor, server) = mock_server_executor().await;
2065 Mock::given(method("GET"))
2066 .and(path("/e2e"))
2067 .respond_with(ResponseTemplate::new(200).set_body_string("<h1>end-to-end</h1>"))
2068 .mount(&server)
2069 .await;
2070
2071 let (host, addrs) = server_host_and_addr(&server);
2072 let result = executor
2074 .fetch_html(&format!("{}/e2e", server.uri()), &host, &addrs)
2075 .await;
2076 assert!(result.is_ok());
2077 assert!(result.unwrap().contains("end-to-end"));
2078 }
2079
2080 #[test]
2083 fn domain_matches_exact() {
2084 assert!(domain_matches("example.com", "example.com"));
2085 assert!(!domain_matches("example.com", "other.com"));
2086 assert!(!domain_matches("example.com", "sub.example.com"));
2087 }
2088
2089 #[test]
2090 fn domain_matches_wildcard_single_subdomain() {
2091 assert!(domain_matches("*.example.com", "sub.example.com"));
2092 assert!(!domain_matches("*.example.com", "example.com"));
2093 assert!(!domain_matches("*.example.com", "sub.sub.example.com"));
2094 }
2095
2096 #[test]
2097 fn domain_matches_wildcard_does_not_match_empty_label() {
2098 assert!(!domain_matches("*.example.com", ".example.com"));
2100 }
2101
2102 #[test]
2103 fn domain_matches_multi_wildcard_treated_as_exact() {
2104 assert!(!domain_matches("*.*.example.com", "a.b.example.com"));
2106 }
2107
2108 #[test]
2111 fn check_domain_policy_empty_lists_allow_all() {
2112 assert!(check_domain_policy("example.com", &[], &[]).is_ok());
2113 assert!(check_domain_policy("evil.com", &[], &[]).is_ok());
2114 }
2115
2116 #[test]
2117 fn check_domain_policy_denylist_blocks() {
2118 let denied = vec!["evil.com".to_string()];
2119 let err = check_domain_policy("evil.com", &[], &denied).unwrap_err();
2120 assert!(matches!(err, ToolError::Blocked { .. }));
2121 }
2122
2123 #[test]
2124 fn check_domain_policy_denylist_does_not_block_other_domains() {
2125 let denied = vec!["evil.com".to_string()];
2126 assert!(check_domain_policy("good.com", &[], &denied).is_ok());
2127 }
2128
2129 #[test]
2130 fn check_domain_policy_allowlist_permits_matching() {
2131 let allowed = vec!["docs.rs".to_string(), "*.rust-lang.org".to_string()];
2132 assert!(check_domain_policy("docs.rs", &allowed, &[]).is_ok());
2133 assert!(check_domain_policy("blog.rust-lang.org", &allowed, &[]).is_ok());
2134 }
2135
2136 #[test]
2137 fn check_domain_policy_allowlist_blocks_unknown() {
2138 let allowed = vec!["docs.rs".to_string()];
2139 let err = check_domain_policy("other.com", &allowed, &[]).unwrap_err();
2140 assert!(matches!(err, ToolError::Blocked { .. }));
2141 }
2142
2143 #[test]
2144 fn check_domain_policy_deny_overrides_allow() {
2145 let allowed = vec!["example.com".to_string()];
2146 let denied = vec!["example.com".to_string()];
2147 let err = check_domain_policy("example.com", &allowed, &denied).unwrap_err();
2148 assert!(matches!(err, ToolError::Blocked { .. }));
2149 }
2150
2151 #[test]
2152 fn check_domain_policy_wildcard_in_denylist() {
2153 let denied = vec!["*.evil.com".to_string()];
2154 let err = check_domain_policy("sub.evil.com", &[], &denied).unwrap_err();
2155 assert!(matches!(err, ToolError::Blocked { .. }));
2156 assert!(check_domain_policy("evil.com", &[], &denied).is_ok());
2158 }
2159}