1use std::net::{IpAddr, SocketAddr};
5use std::sync::Arc;
6use std::time::{Duration, Instant};
7
8use schemars::JsonSchema;
9use serde::Deserialize;
10use url::Url;
11
12use crate::audit::{AuditEntry, AuditLogger, AuditResult, chrono_now};
13use crate::config::ScrapeConfig;
14use crate::executor::{
15 ClaimSource, ToolCall, ToolError, ToolExecutor, ToolOutput, deserialize_params,
16};
17use crate::net::is_private_ip;
18
19#[derive(Debug, Deserialize, JsonSchema)]
20struct FetchParams {
21 url: String,
23}
24
25#[derive(Debug, Deserialize, JsonSchema)]
26struct ScrapeInstruction {
27 url: String,
29 select: String,
31 #[serde(default = "default_extract")]
33 extract: String,
34 limit: Option<usize>,
36}
37
38fn default_extract() -> String {
39 "text".into()
40}
41
42#[derive(Debug)]
43enum ExtractMode {
44 Text,
45 Html,
46 Attr(String),
47}
48
49impl ExtractMode {
50 fn parse(s: &str) -> Self {
51 match s {
52 "text" => Self::Text,
53 "html" => Self::Html,
54 attr if attr.starts_with("attr:") => {
55 Self::Attr(attr.strip_prefix("attr:").unwrap_or(attr).to_owned())
56 }
57 _ => Self::Text,
58 }
59 }
60}
61
62#[derive(Debug)]
67pub struct WebScrapeExecutor {
68 timeout: Duration,
69 max_body_bytes: usize,
70 allowed_domains: Vec<String>,
71 denied_domains: Vec<String>,
72 audit_logger: Option<Arc<AuditLogger>>,
73}
74
75impl WebScrapeExecutor {
76 #[must_use]
77 pub fn new(config: &ScrapeConfig) -> Self {
78 Self {
79 timeout: Duration::from_secs(config.timeout),
80 max_body_bytes: config.max_body_bytes,
81 allowed_domains: config.allowed_domains.clone(),
82 denied_domains: config.denied_domains.clone(),
83 audit_logger: None,
84 }
85 }
86
87 #[must_use]
88 pub fn with_audit(mut self, logger: Arc<AuditLogger>) -> Self {
89 self.audit_logger = Some(logger);
90 self
91 }
92
93 fn build_client(&self, host: &str, addrs: &[SocketAddr]) -> reqwest::Client {
94 let mut builder = reqwest::Client::builder()
95 .timeout(self.timeout)
96 .redirect(reqwest::redirect::Policy::none());
97 builder = builder.resolve_to_addrs(host, addrs);
98 builder.build().unwrap_or_default()
99 }
100}
101
102impl ToolExecutor for WebScrapeExecutor {
103 fn tool_definitions(&self) -> Vec<crate::registry::ToolDef> {
104 use crate::registry::{InvocationHint, ToolDef};
105 vec![
106 ToolDef {
107 id: "web_scrape".into(),
108 description: "Extract structured data from a web page using CSS selectors.\n\nONLY call this tool when the user has explicitly provided a URL in their message, or when a prior tool call returned a URL to retrieve. NEVER construct, guess, or infer a URL from entity names, brand knowledge, or domain patterns.\n\nParameters: url (string, required) - HTTPS URL; select (string, required) - CSS selector; extract (string, optional) - \"text\", \"html\", or \"attr:<name>\"; limit (integer, optional) - max results\nReturns: extracted text/HTML/attribute values, one per line\nErrors: InvalidParams if URL is not HTTPS or selector is empty; Timeout after configured seconds; connection/DNS failures".into(),
109 schema: schemars::schema_for!(ScrapeInstruction),
110 invocation: InvocationHint::FencedBlock("scrape"),
111 },
112 ToolDef {
113 id: "fetch".into(),
114 description: "Fetch a URL and return the response body as plain text.\n\nONLY call this tool when the user has explicitly provided a URL in their message, or when a prior tool call returned a URL to retrieve. NEVER construct, guess, or infer a URL from entity names, brand knowledge, or domain patterns. If no URL is present in the conversation, do not call this tool.\n\nParameters: url (string, required) - HTTPS URL to fetch\nReturns: response body as UTF-8 text, truncated if exceeding max body size\nErrors: InvalidParams if URL is not HTTPS; Timeout; SSRF-blocked private IPs; connection failures".into(),
115 schema: schemars::schema_for!(FetchParams),
116 invocation: InvocationHint::ToolCall,
117 },
118 ]
119 }
120
121 async fn execute(&self, response: &str) -> Result<Option<ToolOutput>, ToolError> {
122 let blocks = extract_scrape_blocks(response);
123 if blocks.is_empty() {
124 return Ok(None);
125 }
126
127 let mut outputs = Vec::with_capacity(blocks.len());
128 #[allow(clippy::cast_possible_truncation)]
129 let blocks_executed = blocks.len() as u32;
130
131 for block in &blocks {
132 let instruction: ScrapeInstruction = serde_json::from_str(block).map_err(|e| {
133 ToolError::Execution(std::io::Error::new(
134 std::io::ErrorKind::InvalidData,
135 e.to_string(),
136 ))
137 })?;
138 let start = Instant::now();
139 let scrape_result = self.scrape_instruction(&instruction).await;
140 #[allow(clippy::cast_possible_truncation)]
141 let duration_ms = start.elapsed().as_millis() as u64;
142 match scrape_result {
143 Ok(output) => {
144 self.log_audit(
145 "web_scrape",
146 &instruction.url,
147 AuditResult::Success,
148 duration_ms,
149 None,
150 )
151 .await;
152 outputs.push(output);
153 }
154 Err(e) => {
155 let audit_result = tool_error_to_audit_result(&e);
156 self.log_audit(
157 "web_scrape",
158 &instruction.url,
159 audit_result,
160 duration_ms,
161 Some(&e),
162 )
163 .await;
164 return Err(e);
165 }
166 }
167 }
168
169 Ok(Some(ToolOutput {
170 tool_name: "web-scrape".to_owned(),
171 summary: outputs.join("\n\n"),
172 blocks_executed,
173 filter_stats: None,
174 diff: None,
175 streamed: false,
176 terminal_id: None,
177 locations: None,
178 raw_response: None,
179 claim_source: Some(ClaimSource::WebScrape),
180 }))
181 }
182
183 async fn execute_tool_call(&self, call: &ToolCall) -> Result<Option<ToolOutput>, ToolError> {
184 match call.tool_id.as_str() {
185 "web_scrape" => {
186 let instruction: ScrapeInstruction = deserialize_params(&call.params)?;
187 let start = Instant::now();
188 let result = self.scrape_instruction(&instruction).await;
189 #[allow(clippy::cast_possible_truncation)]
190 let duration_ms = start.elapsed().as_millis() as u64;
191 match result {
192 Ok(output) => {
193 self.log_audit(
194 "web_scrape",
195 &instruction.url,
196 AuditResult::Success,
197 duration_ms,
198 None,
199 )
200 .await;
201 Ok(Some(ToolOutput {
202 tool_name: "web-scrape".to_owned(),
203 summary: output,
204 blocks_executed: 1,
205 filter_stats: None,
206 diff: None,
207 streamed: false,
208 terminal_id: None,
209 locations: None,
210 raw_response: None,
211 claim_source: Some(ClaimSource::WebScrape),
212 }))
213 }
214 Err(e) => {
215 let audit_result = tool_error_to_audit_result(&e);
216 self.log_audit(
217 "web_scrape",
218 &instruction.url,
219 audit_result,
220 duration_ms,
221 Some(&e),
222 )
223 .await;
224 Err(e)
225 }
226 }
227 }
228 "fetch" => {
229 let p: FetchParams = deserialize_params(&call.params)?;
230 let start = Instant::now();
231 let result = self.handle_fetch(&p).await;
232 #[allow(clippy::cast_possible_truncation)]
233 let duration_ms = start.elapsed().as_millis() as u64;
234 match result {
235 Ok(output) => {
236 self.log_audit("fetch", &p.url, AuditResult::Success, duration_ms, None)
237 .await;
238 Ok(Some(ToolOutput {
239 tool_name: "fetch".to_owned(),
240 summary: output,
241 blocks_executed: 1,
242 filter_stats: None,
243 diff: None,
244 streamed: false,
245 terminal_id: None,
246 locations: None,
247 raw_response: None,
248 claim_source: Some(ClaimSource::WebScrape),
249 }))
250 }
251 Err(e) => {
252 let audit_result = tool_error_to_audit_result(&e);
253 self.log_audit("fetch", &p.url, audit_result, duration_ms, Some(&e))
254 .await;
255 Err(e)
256 }
257 }
258 }
259 _ => Ok(None),
260 }
261 }
262
263 fn is_tool_retryable(&self, tool_id: &str) -> bool {
264 matches!(tool_id, "web_scrape" | "fetch")
265 }
266}
267
268fn tool_error_to_audit_result(e: &ToolError) -> AuditResult {
269 match e {
270 ToolError::Blocked { command } => AuditResult::Blocked {
271 reason: command.clone(),
272 },
273 ToolError::Timeout { .. } => AuditResult::Timeout,
274 _ => AuditResult::Error {
275 message: e.to_string(),
276 },
277 }
278}
279
280impl WebScrapeExecutor {
281 async fn log_audit(
282 &self,
283 tool: &str,
284 command: &str,
285 result: AuditResult,
286 duration_ms: u64,
287 error: Option<&ToolError>,
288 ) {
289 if let Some(ref logger) = self.audit_logger {
290 let (error_category, error_domain, error_phase) =
291 error.map_or((None, None, None), |e| {
292 let cat = e.category();
293 (
294 Some(cat.label().to_owned()),
295 Some(cat.domain().label().to_owned()),
296 Some(cat.phase().label().to_owned()),
297 )
298 });
299 let entry = AuditEntry {
300 timestamp: chrono_now(),
301 tool: tool.into(),
302 command: command.into(),
303 result,
304 duration_ms,
305 error_category,
306 error_domain,
307 error_phase,
308 claim_source: Some(ClaimSource::WebScrape),
309 mcp_server_id: None,
310 injection_flagged: false,
311 embedding_anomalous: false,
312 cross_boundary_mcp_to_acp: false,
313 adversarial_policy_decision: None,
314 exit_code: None,
315 truncated: false,
316 caller_id: None,
317 policy_match: None,
318 };
319 logger.log(&entry).await;
320 }
321 }
322
323 async fn handle_fetch(&self, params: &FetchParams) -> Result<String, ToolError> {
324 let parsed = validate_url(¶ms.url)?;
325 check_domain_policy(
326 parsed.host_str().unwrap_or(""),
327 &self.allowed_domains,
328 &self.denied_domains,
329 )?;
330 let (host, addrs) = resolve_and_validate(&parsed).await?;
331 self.fetch_html(¶ms.url, &host, &addrs).await
332 }
333
334 async fn scrape_instruction(
335 &self,
336 instruction: &ScrapeInstruction,
337 ) -> Result<String, ToolError> {
338 let parsed = validate_url(&instruction.url)?;
339 check_domain_policy(
340 parsed.host_str().unwrap_or(""),
341 &self.allowed_domains,
342 &self.denied_domains,
343 )?;
344 let (host, addrs) = resolve_and_validate(&parsed).await?;
345 let html = self.fetch_html(&instruction.url, &host, &addrs).await?;
346 let selector = instruction.select.clone();
347 let extract = ExtractMode::parse(&instruction.extract);
348 let limit = instruction.limit.unwrap_or(10);
349 tokio::task::spawn_blocking(move || parse_and_extract(&html, &selector, &extract, limit))
350 .await
351 .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())))?
352 }
353
354 async fn fetch_html(
364 &self,
365 url: &str,
366 host: &str,
367 addrs: &[SocketAddr],
368 ) -> Result<String, ToolError> {
369 const MAX_REDIRECTS: usize = 3;
370
371 let mut current_url = url.to_owned();
372 let mut current_host = host.to_owned();
373 let mut current_addrs = addrs.to_vec();
374
375 for hop in 0..=MAX_REDIRECTS {
376 let client = self.build_client(¤t_host, ¤t_addrs);
378 let resp = client
379 .get(¤t_url)
380 .send()
381 .await
382 .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())))?;
383
384 let status = resp.status();
385
386 if status.is_redirection() {
387 if hop == MAX_REDIRECTS {
388 return Err(ToolError::Execution(std::io::Error::other(
389 "too many redirects",
390 )));
391 }
392
393 let location = resp
394 .headers()
395 .get(reqwest::header::LOCATION)
396 .and_then(|v| v.to_str().ok())
397 .ok_or_else(|| {
398 ToolError::Execution(std::io::Error::other("redirect with no Location"))
399 })?;
400
401 let base = Url::parse(¤t_url)
403 .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())))?;
404 let next_url = base
405 .join(location)
406 .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())))?;
407
408 let validated = validate_url(next_url.as_str())?;
409 let (next_host, next_addrs) = resolve_and_validate(&validated).await?;
410
411 current_url = next_url.to_string();
412 current_host = next_host;
413 current_addrs = next_addrs;
414 continue;
415 }
416
417 if !status.is_success() {
418 return Err(ToolError::Http {
419 status: status.as_u16(),
420 message: status.canonical_reason().unwrap_or("unknown").to_owned(),
421 });
422 }
423
424 let bytes = resp
425 .bytes()
426 .await
427 .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())))?;
428
429 if bytes.len() > self.max_body_bytes {
430 return Err(ToolError::Execution(std::io::Error::other(format!(
431 "response too large: {} bytes (max: {})",
432 bytes.len(),
433 self.max_body_bytes,
434 ))));
435 }
436
437 return String::from_utf8(bytes.to_vec())
438 .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())));
439 }
440
441 Err(ToolError::Execution(std::io::Error::other(
442 "too many redirects",
443 )))
444 }
445}
446
447fn extract_scrape_blocks(text: &str) -> Vec<&str> {
448 crate::executor::extract_fenced_blocks(text, "scrape")
449}
450
451fn check_domain_policy(
463 host: &str,
464 allowed_domains: &[String],
465 denied_domains: &[String],
466) -> Result<(), ToolError> {
467 if denied_domains.iter().any(|p| domain_matches(p, host)) {
468 return Err(ToolError::Blocked {
469 command: format!("domain blocked by denylist: {host}"),
470 });
471 }
472 if !allowed_domains.is_empty() {
473 let is_ip = host.parse::<std::net::IpAddr>().is_ok()
475 || (host.starts_with('[') && host.ends_with(']'));
476 if is_ip {
477 return Err(ToolError::Blocked {
478 command: format!(
479 "bare IP address not allowed when domain allowlist is active: {host}"
480 ),
481 });
482 }
483 if !allowed_domains.iter().any(|p| domain_matches(p, host)) {
484 return Err(ToolError::Blocked {
485 command: format!("domain not in allowlist: {host}"),
486 });
487 }
488 }
489 Ok(())
490}
491
492fn domain_matches(pattern: &str, host: &str) -> bool {
498 if pattern.starts_with("*.") {
499 let suffix = &pattern[1..]; if let Some(remainder) = host.strip_suffix(suffix) {
502 !remainder.is_empty() && !remainder.contains('.')
504 } else {
505 false
506 }
507 } else {
508 pattern == host
509 }
510}
511
512fn validate_url(raw: &str) -> Result<Url, ToolError> {
513 let parsed = Url::parse(raw).map_err(|_| ToolError::Blocked {
514 command: format!("invalid URL: {raw}"),
515 })?;
516
517 if parsed.scheme() != "https" {
518 return Err(ToolError::Blocked {
519 command: format!("scheme not allowed: {}", parsed.scheme()),
520 });
521 }
522
523 if let Some(host) = parsed.host()
524 && is_private_host(&host)
525 {
526 return Err(ToolError::Blocked {
527 command: format!(
528 "private/local host blocked: {}",
529 parsed.host_str().unwrap_or("")
530 ),
531 });
532 }
533
534 Ok(parsed)
535}
536
537fn is_private_host(host: &url::Host<&str>) -> bool {
538 match host {
539 url::Host::Domain(d) => {
540 #[allow(clippy::case_sensitive_file_extension_comparisons)]
543 {
544 *d == "localhost"
545 || d.ends_with(".localhost")
546 || d.ends_with(".internal")
547 || d.ends_with(".local")
548 }
549 }
550 url::Host::Ipv4(v4) => is_private_ip(IpAddr::V4(*v4)),
551 url::Host::Ipv6(v6) => is_private_ip(IpAddr::V6(*v6)),
552 }
553}
554
555async fn resolve_and_validate(url: &Url) -> Result<(String, Vec<SocketAddr>), ToolError> {
561 let Some(host) = url.host_str() else {
562 return Ok((String::new(), vec![]));
563 };
564 let port = url.port_or_known_default().unwrap_or(443);
565 let addrs: Vec<SocketAddr> = tokio::net::lookup_host(format!("{host}:{port}"))
566 .await
567 .map_err(|e| ToolError::Blocked {
568 command: format!("DNS resolution failed: {e}"),
569 })?
570 .collect();
571 for addr in &addrs {
572 if is_private_ip(addr.ip()) {
573 return Err(ToolError::Blocked {
574 command: format!("SSRF protection: private IP {} for host {host}", addr.ip()),
575 });
576 }
577 }
578 Ok((host.to_owned(), addrs))
579}
580
581fn parse_and_extract(
582 html: &str,
583 selector: &str,
584 extract: &ExtractMode,
585 limit: usize,
586) -> Result<String, ToolError> {
587 let soup = scrape_core::Soup::parse(html);
588
589 let tags = soup.find_all(selector).map_err(|e| {
590 ToolError::Execution(std::io::Error::new(
591 std::io::ErrorKind::InvalidData,
592 format!("invalid selector: {e}"),
593 ))
594 })?;
595
596 let mut results = Vec::new();
597
598 for tag in tags.into_iter().take(limit) {
599 let value = match extract {
600 ExtractMode::Text => tag.text(),
601 ExtractMode::Html => tag.inner_html(),
602 ExtractMode::Attr(name) => tag.get(name).unwrap_or_default().to_owned(),
603 };
604 if !value.trim().is_empty() {
605 results.push(value.trim().to_owned());
606 }
607 }
608
609 if results.is_empty() {
610 Ok(format!("No results for selector: {selector}"))
611 } else {
612 Ok(results.join("\n"))
613 }
614}
615
616#[cfg(test)]
617mod tests {
618 use super::*;
619
620 #[test]
623 fn extract_single_block() {
624 let text =
625 "Here:\n```scrape\n{\"url\":\"https://example.com\",\"select\":\"h1\"}\n```\nDone.";
626 let blocks = extract_scrape_blocks(text);
627 assert_eq!(blocks.len(), 1);
628 assert!(blocks[0].contains("example.com"));
629 }
630
631 #[test]
632 fn extract_multiple_blocks() {
633 let text = "```scrape\n{\"url\":\"https://a.com\",\"select\":\"h1\"}\n```\ntext\n```scrape\n{\"url\":\"https://b.com\",\"select\":\"p\"}\n```";
634 let blocks = extract_scrape_blocks(text);
635 assert_eq!(blocks.len(), 2);
636 }
637
638 #[test]
639 fn no_blocks_returns_empty() {
640 let blocks = extract_scrape_blocks("plain text, no code blocks");
641 assert!(blocks.is_empty());
642 }
643
644 #[test]
645 fn unclosed_block_ignored() {
646 let blocks = extract_scrape_blocks("```scrape\n{\"url\":\"https://x.com\"}");
647 assert!(blocks.is_empty());
648 }
649
650 #[test]
651 fn non_scrape_block_ignored() {
652 let text =
653 "```bash\necho hi\n```\n```scrape\n{\"url\":\"https://x.com\",\"select\":\"h1\"}\n```";
654 let blocks = extract_scrape_blocks(text);
655 assert_eq!(blocks.len(), 1);
656 assert!(blocks[0].contains("x.com"));
657 }
658
659 #[test]
660 fn multiline_json_block() {
661 let text =
662 "```scrape\n{\n \"url\": \"https://example.com\",\n \"select\": \"h1\"\n}\n```";
663 let blocks = extract_scrape_blocks(text);
664 assert_eq!(blocks.len(), 1);
665 let instr: ScrapeInstruction = serde_json::from_str(blocks[0]).unwrap();
666 assert_eq!(instr.url, "https://example.com");
667 }
668
669 #[test]
672 fn parse_valid_instruction() {
673 let json = r#"{"url":"https://example.com","select":"h1","extract":"text","limit":5}"#;
674 let instr: ScrapeInstruction = serde_json::from_str(json).unwrap();
675 assert_eq!(instr.url, "https://example.com");
676 assert_eq!(instr.select, "h1");
677 assert_eq!(instr.extract, "text");
678 assert_eq!(instr.limit, Some(5));
679 }
680
681 #[test]
682 fn parse_minimal_instruction() {
683 let json = r#"{"url":"https://example.com","select":"p"}"#;
684 let instr: ScrapeInstruction = serde_json::from_str(json).unwrap();
685 assert_eq!(instr.extract, "text");
686 assert!(instr.limit.is_none());
687 }
688
689 #[test]
690 fn parse_attr_extract() {
691 let json = r#"{"url":"https://example.com","select":"a","extract":"attr:href"}"#;
692 let instr: ScrapeInstruction = serde_json::from_str(json).unwrap();
693 assert_eq!(instr.extract, "attr:href");
694 }
695
696 #[test]
697 fn parse_invalid_json_errors() {
698 let result = serde_json::from_str::<ScrapeInstruction>("not json");
699 assert!(result.is_err());
700 }
701
702 #[test]
705 fn extract_mode_text() {
706 assert!(matches!(ExtractMode::parse("text"), ExtractMode::Text));
707 }
708
709 #[test]
710 fn extract_mode_html() {
711 assert!(matches!(ExtractMode::parse("html"), ExtractMode::Html));
712 }
713
714 #[test]
715 fn extract_mode_attr() {
716 let mode = ExtractMode::parse("attr:href");
717 assert!(matches!(mode, ExtractMode::Attr(ref s) if s == "href"));
718 }
719
720 #[test]
721 fn extract_mode_unknown_defaults_to_text() {
722 assert!(matches!(ExtractMode::parse("unknown"), ExtractMode::Text));
723 }
724
725 #[test]
728 fn valid_https_url() {
729 assert!(validate_url("https://example.com").is_ok());
730 }
731
732 #[test]
733 fn http_rejected() {
734 let err = validate_url("http://example.com").unwrap_err();
735 assert!(matches!(err, ToolError::Blocked { .. }));
736 }
737
738 #[test]
739 fn ftp_rejected() {
740 let err = validate_url("ftp://files.example.com").unwrap_err();
741 assert!(matches!(err, ToolError::Blocked { .. }));
742 }
743
744 #[test]
745 fn file_rejected() {
746 let err = validate_url("file:///etc/passwd").unwrap_err();
747 assert!(matches!(err, ToolError::Blocked { .. }));
748 }
749
750 #[test]
751 fn invalid_url_rejected() {
752 let err = validate_url("not a url").unwrap_err();
753 assert!(matches!(err, ToolError::Blocked { .. }));
754 }
755
756 #[test]
757 fn localhost_blocked() {
758 let err = validate_url("https://localhost/path").unwrap_err();
759 assert!(matches!(err, ToolError::Blocked { .. }));
760 }
761
762 #[test]
763 fn loopback_ip_blocked() {
764 let err = validate_url("https://127.0.0.1/path").unwrap_err();
765 assert!(matches!(err, ToolError::Blocked { .. }));
766 }
767
768 #[test]
769 fn private_10_blocked() {
770 let err = validate_url("https://10.0.0.1/api").unwrap_err();
771 assert!(matches!(err, ToolError::Blocked { .. }));
772 }
773
774 #[test]
775 fn private_172_blocked() {
776 let err = validate_url("https://172.16.0.1/api").unwrap_err();
777 assert!(matches!(err, ToolError::Blocked { .. }));
778 }
779
780 #[test]
781 fn private_192_blocked() {
782 let err = validate_url("https://192.168.1.1/api").unwrap_err();
783 assert!(matches!(err, ToolError::Blocked { .. }));
784 }
785
786 #[test]
787 fn ipv6_loopback_blocked() {
788 let err = validate_url("https://[::1]/path").unwrap_err();
789 assert!(matches!(err, ToolError::Blocked { .. }));
790 }
791
792 #[test]
793 fn public_ip_allowed() {
794 assert!(validate_url("https://93.184.216.34/page").is_ok());
795 }
796
797 #[test]
800 fn extract_text_from_html() {
801 let html = "<html><body><h1>Hello World</h1><p>Content</p></body></html>";
802 let result = parse_and_extract(html, "h1", &ExtractMode::Text, 10).unwrap();
803 assert_eq!(result, "Hello World");
804 }
805
806 #[test]
807 fn extract_multiple_elements() {
808 let html = "<ul><li>A</li><li>B</li><li>C</li></ul>";
809 let result = parse_and_extract(html, "li", &ExtractMode::Text, 10).unwrap();
810 assert_eq!(result, "A\nB\nC");
811 }
812
813 #[test]
814 fn extract_with_limit() {
815 let html = "<ul><li>A</li><li>B</li><li>C</li></ul>";
816 let result = parse_and_extract(html, "li", &ExtractMode::Text, 2).unwrap();
817 assert_eq!(result, "A\nB");
818 }
819
820 #[test]
821 fn extract_attr_href() {
822 let html = r#"<a href="https://example.com">Link</a>"#;
823 let result =
824 parse_and_extract(html, "a", &ExtractMode::Attr("href".to_owned()), 10).unwrap();
825 assert_eq!(result, "https://example.com");
826 }
827
828 #[test]
829 fn extract_inner_html() {
830 let html = "<div><span>inner</span></div>";
831 let result = parse_and_extract(html, "div", &ExtractMode::Html, 10).unwrap();
832 assert!(result.contains("<span>inner</span>"));
833 }
834
835 #[test]
836 fn no_matches_returns_message() {
837 let html = "<html><body><p>text</p></body></html>";
838 let result = parse_and_extract(html, "h1", &ExtractMode::Text, 10).unwrap();
839 assert!(result.starts_with("No results for selector:"));
840 }
841
842 #[test]
843 fn empty_text_skipped() {
844 let html = "<ul><li> </li><li>A</li></ul>";
845 let result = parse_and_extract(html, "li", &ExtractMode::Text, 10).unwrap();
846 assert_eq!(result, "A");
847 }
848
849 #[test]
850 fn invalid_selector_errors() {
851 let html = "<html><body></body></html>";
852 let result = parse_and_extract(html, "[[[invalid", &ExtractMode::Text, 10);
853 assert!(result.is_err());
854 }
855
856 #[test]
857 fn empty_html_returns_no_results() {
858 let result = parse_and_extract("", "h1", &ExtractMode::Text, 10).unwrap();
859 assert!(result.starts_with("No results for selector:"));
860 }
861
862 #[test]
863 fn nested_selector() {
864 let html = "<div><span>inner</span></div><span>outer</span>";
865 let result = parse_and_extract(html, "div > span", &ExtractMode::Text, 10).unwrap();
866 assert_eq!(result, "inner");
867 }
868
869 #[test]
870 fn attr_missing_returns_empty() {
871 let html = r"<a>No href</a>";
872 let result =
873 parse_and_extract(html, "a", &ExtractMode::Attr("href".to_owned()), 10).unwrap();
874 assert!(result.starts_with("No results for selector:"));
875 }
876
877 #[test]
878 fn extract_html_mode() {
879 let html = "<div><b>bold</b> text</div>";
880 let result = parse_and_extract(html, "div", &ExtractMode::Html, 10).unwrap();
881 assert!(result.contains("<b>bold</b>"));
882 }
883
884 #[test]
885 fn limit_zero_returns_no_results() {
886 let html = "<ul><li>A</li><li>B</li></ul>";
887 let result = parse_and_extract(html, "li", &ExtractMode::Text, 0).unwrap();
888 assert!(result.starts_with("No results for selector:"));
889 }
890
891 #[test]
894 fn url_with_port_allowed() {
895 assert!(validate_url("https://example.com:8443/path").is_ok());
896 }
897
898 #[test]
899 fn link_local_ip_blocked() {
900 let err = validate_url("https://169.254.1.1/path").unwrap_err();
901 assert!(matches!(err, ToolError::Blocked { .. }));
902 }
903
904 #[test]
905 fn url_no_scheme_rejected() {
906 let err = validate_url("example.com/path").unwrap_err();
907 assert!(matches!(err, ToolError::Blocked { .. }));
908 }
909
910 #[test]
911 fn unspecified_ipv4_blocked() {
912 let err = validate_url("https://0.0.0.0/path").unwrap_err();
913 assert!(matches!(err, ToolError::Blocked { .. }));
914 }
915
916 #[test]
917 fn broadcast_ipv4_blocked() {
918 let err = validate_url("https://255.255.255.255/path").unwrap_err();
919 assert!(matches!(err, ToolError::Blocked { .. }));
920 }
921
922 #[test]
923 fn ipv6_link_local_blocked() {
924 let err = validate_url("https://[fe80::1]/path").unwrap_err();
925 assert!(matches!(err, ToolError::Blocked { .. }));
926 }
927
928 #[test]
929 fn ipv6_unique_local_blocked() {
930 let err = validate_url("https://[fd12::1]/path").unwrap_err();
931 assert!(matches!(err, ToolError::Blocked { .. }));
932 }
933
934 #[test]
935 fn ipv4_mapped_ipv6_loopback_blocked() {
936 let err = validate_url("https://[::ffff:127.0.0.1]/path").unwrap_err();
937 assert!(matches!(err, ToolError::Blocked { .. }));
938 }
939
940 #[test]
941 fn ipv4_mapped_ipv6_private_blocked() {
942 let err = validate_url("https://[::ffff:10.0.0.1]/path").unwrap_err();
943 assert!(matches!(err, ToolError::Blocked { .. }));
944 }
945
946 #[tokio::test]
949 async fn executor_no_blocks_returns_none() {
950 let config = ScrapeConfig::default();
951 let executor = WebScrapeExecutor::new(&config);
952 let result = executor.execute("plain text").await;
953 assert!(result.unwrap().is_none());
954 }
955
956 #[tokio::test]
957 async fn executor_invalid_json_errors() {
958 let config = ScrapeConfig::default();
959 let executor = WebScrapeExecutor::new(&config);
960 let response = "```scrape\nnot json\n```";
961 let result = executor.execute(response).await;
962 assert!(matches!(result, Err(ToolError::Execution(_))));
963 }
964
965 #[tokio::test]
966 async fn executor_blocked_url_errors() {
967 let config = ScrapeConfig::default();
968 let executor = WebScrapeExecutor::new(&config);
969 let response = "```scrape\n{\"url\":\"http://example.com\",\"select\":\"h1\"}\n```";
970 let result = executor.execute(response).await;
971 assert!(matches!(result, Err(ToolError::Blocked { .. })));
972 }
973
974 #[tokio::test]
975 async fn executor_private_ip_blocked() {
976 let config = ScrapeConfig::default();
977 let executor = WebScrapeExecutor::new(&config);
978 let response = "```scrape\n{\"url\":\"https://192.168.1.1/api\",\"select\":\"h1\"}\n```";
979 let result = executor.execute(response).await;
980 assert!(matches!(result, Err(ToolError::Blocked { .. })));
981 }
982
983 #[tokio::test]
984 async fn executor_unreachable_host_returns_error() {
985 let config = ScrapeConfig {
986 timeout: 1,
987 max_body_bytes: 1_048_576,
988 ..Default::default()
989 };
990 let executor = WebScrapeExecutor::new(&config);
991 let response = "```scrape\n{\"url\":\"https://192.0.2.1:1/page\",\"select\":\"h1\"}\n```";
992 let result = executor.execute(response).await;
993 assert!(matches!(result, Err(ToolError::Execution(_))));
994 }
995
996 #[tokio::test]
997 async fn executor_localhost_url_blocked() {
998 let config = ScrapeConfig::default();
999 let executor = WebScrapeExecutor::new(&config);
1000 let response = "```scrape\n{\"url\":\"https://localhost:9999/api\",\"select\":\"h1\"}\n```";
1001 let result = executor.execute(response).await;
1002 assert!(matches!(result, Err(ToolError::Blocked { .. })));
1003 }
1004
1005 #[tokio::test]
1006 async fn executor_empty_text_returns_none() {
1007 let config = ScrapeConfig::default();
1008 let executor = WebScrapeExecutor::new(&config);
1009 let result = executor.execute("").await;
1010 assert!(result.unwrap().is_none());
1011 }
1012
1013 #[tokio::test]
1014 async fn executor_multiple_blocks_first_blocked() {
1015 let config = ScrapeConfig::default();
1016 let executor = WebScrapeExecutor::new(&config);
1017 let response = "```scrape\n{\"url\":\"http://evil.com\",\"select\":\"h1\"}\n```\n\
1018 ```scrape\n{\"url\":\"https://ok.com\",\"select\":\"h1\"}\n```";
1019 let result = executor.execute(response).await;
1020 assert!(result.is_err());
1021 }
1022
1023 #[test]
1024 fn validate_url_empty_string() {
1025 let err = validate_url("").unwrap_err();
1026 assert!(matches!(err, ToolError::Blocked { .. }));
1027 }
1028
1029 #[test]
1030 fn validate_url_javascript_scheme_blocked() {
1031 let err = validate_url("javascript:alert(1)").unwrap_err();
1032 assert!(matches!(err, ToolError::Blocked { .. }));
1033 }
1034
1035 #[test]
1036 fn validate_url_data_scheme_blocked() {
1037 let err = validate_url("data:text/html,<h1>hi</h1>").unwrap_err();
1038 assert!(matches!(err, ToolError::Blocked { .. }));
1039 }
1040
1041 #[test]
1042 fn is_private_host_public_domain_is_false() {
1043 let host: url::Host<&str> = url::Host::Domain("example.com");
1044 assert!(!is_private_host(&host));
1045 }
1046
1047 #[test]
1048 fn is_private_host_localhost_is_true() {
1049 let host: url::Host<&str> = url::Host::Domain("localhost");
1050 assert!(is_private_host(&host));
1051 }
1052
1053 #[test]
1054 fn is_private_host_ipv6_unspecified_is_true() {
1055 let host = url::Host::Ipv6(std::net::Ipv6Addr::UNSPECIFIED);
1056 assert!(is_private_host(&host));
1057 }
1058
1059 #[test]
1060 fn is_private_host_public_ipv6_is_false() {
1061 let host = url::Host::Ipv6("2001:db8::1".parse().unwrap());
1062 assert!(!is_private_host(&host));
1063 }
1064
1065 async fn mock_server_executor() -> (WebScrapeExecutor, wiremock::MockServer) {
1076 let server = wiremock::MockServer::start().await;
1077 let executor = WebScrapeExecutor {
1078 timeout: Duration::from_secs(5),
1079 max_body_bytes: 1_048_576,
1080 allowed_domains: vec![],
1081 denied_domains: vec![],
1082 audit_logger: None,
1083 };
1084 (executor, server)
1085 }
1086
1087 fn server_host_and_addr(server: &wiremock::MockServer) -> (String, Vec<std::net::SocketAddr>) {
1089 let uri = server.uri();
1090 let url = Url::parse(&uri).unwrap();
1091 let host = url.host_str().unwrap_or("127.0.0.1").to_owned();
1092 let port = url.port().unwrap_or(80);
1093 let addr: std::net::SocketAddr = format!("{host}:{port}").parse().unwrap();
1094 (host, vec![addr])
1095 }
1096
1097 async fn follow_redirects_raw(
1101 executor: &WebScrapeExecutor,
1102 start_url: &str,
1103 host: &str,
1104 addrs: &[std::net::SocketAddr],
1105 ) -> Result<String, ToolError> {
1106 const MAX_REDIRECTS: usize = 3;
1107 let mut current_url = start_url.to_owned();
1108 let mut current_host = host.to_owned();
1109 let mut current_addrs = addrs.to_vec();
1110
1111 for hop in 0..=MAX_REDIRECTS {
1112 let client = executor.build_client(¤t_host, ¤t_addrs);
1113 let resp = client
1114 .get(¤t_url)
1115 .send()
1116 .await
1117 .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())))?;
1118
1119 let status = resp.status();
1120
1121 if status.is_redirection() {
1122 if hop == MAX_REDIRECTS {
1123 return Err(ToolError::Execution(std::io::Error::other(
1124 "too many redirects",
1125 )));
1126 }
1127
1128 let location = resp
1129 .headers()
1130 .get(reqwest::header::LOCATION)
1131 .and_then(|v| v.to_str().ok())
1132 .ok_or_else(|| {
1133 ToolError::Execution(std::io::Error::other("redirect with no Location"))
1134 })?;
1135
1136 let base = Url::parse(¤t_url)
1137 .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())))?;
1138 let next_url = base
1139 .join(location)
1140 .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())))?;
1141
1142 current_url = next_url.to_string();
1144 let _ = &mut current_host;
1146 let _ = &mut current_addrs;
1147 continue;
1148 }
1149
1150 if !status.is_success() {
1151 return Err(ToolError::Execution(std::io::Error::other(format!(
1152 "HTTP {status}",
1153 ))));
1154 }
1155
1156 let bytes = resp
1157 .bytes()
1158 .await
1159 .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())))?;
1160
1161 if bytes.len() > executor.max_body_bytes {
1162 return Err(ToolError::Execution(std::io::Error::other(format!(
1163 "response too large: {} bytes (max: {})",
1164 bytes.len(),
1165 executor.max_body_bytes,
1166 ))));
1167 }
1168
1169 return String::from_utf8(bytes.to_vec())
1170 .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())));
1171 }
1172
1173 Err(ToolError::Execution(std::io::Error::other(
1174 "too many redirects",
1175 )))
1176 }
1177
1178 #[tokio::test]
1179 async fn fetch_html_success_returns_body() {
1180 use wiremock::matchers::{method, path};
1181 use wiremock::{Mock, ResponseTemplate};
1182
1183 let (executor, server) = mock_server_executor().await;
1184 Mock::given(method("GET"))
1185 .and(path("/page"))
1186 .respond_with(ResponseTemplate::new(200).set_body_string("<h1>OK</h1>"))
1187 .mount(&server)
1188 .await;
1189
1190 let (host, addrs) = server_host_and_addr(&server);
1191 let url = format!("{}/page", server.uri());
1192 let result = executor.fetch_html(&url, &host, &addrs).await;
1193 assert!(result.is_ok(), "expected Ok, got: {result:?}");
1194 assert_eq!(result.unwrap(), "<h1>OK</h1>");
1195 }
1196
1197 #[tokio::test]
1198 async fn fetch_html_non_2xx_returns_error() {
1199 use wiremock::matchers::{method, path};
1200 use wiremock::{Mock, ResponseTemplate};
1201
1202 let (executor, server) = mock_server_executor().await;
1203 Mock::given(method("GET"))
1204 .and(path("/forbidden"))
1205 .respond_with(ResponseTemplate::new(403))
1206 .mount(&server)
1207 .await;
1208
1209 let (host, addrs) = server_host_and_addr(&server);
1210 let url = format!("{}/forbidden", server.uri());
1211 let result = executor.fetch_html(&url, &host, &addrs).await;
1212 assert!(result.is_err());
1213 let msg = result.unwrap_err().to_string();
1214 assert!(msg.contains("403"), "expected 403 in error: {msg}");
1215 }
1216
1217 #[tokio::test]
1218 async fn fetch_html_404_returns_error() {
1219 use wiremock::matchers::{method, path};
1220 use wiremock::{Mock, ResponseTemplate};
1221
1222 let (executor, server) = mock_server_executor().await;
1223 Mock::given(method("GET"))
1224 .and(path("/missing"))
1225 .respond_with(ResponseTemplate::new(404))
1226 .mount(&server)
1227 .await;
1228
1229 let (host, addrs) = server_host_and_addr(&server);
1230 let url = format!("{}/missing", server.uri());
1231 let result = executor.fetch_html(&url, &host, &addrs).await;
1232 assert!(result.is_err());
1233 let msg = result.unwrap_err().to_string();
1234 assert!(msg.contains("404"), "expected 404 in error: {msg}");
1235 }
1236
1237 #[tokio::test]
1238 async fn fetch_html_redirect_no_location_returns_error() {
1239 use wiremock::matchers::{method, path};
1240 use wiremock::{Mock, ResponseTemplate};
1241
1242 let (executor, server) = mock_server_executor().await;
1243 Mock::given(method("GET"))
1245 .and(path("/redirect-no-loc"))
1246 .respond_with(ResponseTemplate::new(302))
1247 .mount(&server)
1248 .await;
1249
1250 let (host, addrs) = server_host_and_addr(&server);
1251 let url = format!("{}/redirect-no-loc", server.uri());
1252 let result = executor.fetch_html(&url, &host, &addrs).await;
1253 assert!(result.is_err());
1254 let msg = result.unwrap_err().to_string();
1255 assert!(
1256 msg.contains("Location") || msg.contains("location"),
1257 "expected Location-related error: {msg}"
1258 );
1259 }
1260
1261 #[tokio::test]
1262 async fn fetch_html_single_redirect_followed() {
1263 use wiremock::matchers::{method, path};
1264 use wiremock::{Mock, ResponseTemplate};
1265
1266 let (executor, server) = mock_server_executor().await;
1267 let final_url = format!("{}/final", server.uri());
1268
1269 Mock::given(method("GET"))
1270 .and(path("/start"))
1271 .respond_with(ResponseTemplate::new(302).insert_header("location", final_url.as_str()))
1272 .mount(&server)
1273 .await;
1274
1275 Mock::given(method("GET"))
1276 .and(path("/final"))
1277 .respond_with(ResponseTemplate::new(200).set_body_string("<p>final</p>"))
1278 .mount(&server)
1279 .await;
1280
1281 let (host, addrs) = server_host_and_addr(&server);
1282 let url = format!("{}/start", server.uri());
1283 let result = follow_redirects_raw(&executor, &url, &host, &addrs).await;
1284 assert!(result.is_ok(), "single redirect should succeed: {result:?}");
1285 assert_eq!(result.unwrap(), "<p>final</p>");
1286 }
1287
1288 #[tokio::test]
1289 async fn fetch_html_three_redirects_allowed() {
1290 use wiremock::matchers::{method, path};
1291 use wiremock::{Mock, ResponseTemplate};
1292
1293 let (executor, server) = mock_server_executor().await;
1294 let hop2 = format!("{}/hop2", server.uri());
1295 let hop3 = format!("{}/hop3", server.uri());
1296 let final_dest = format!("{}/done", server.uri());
1297
1298 Mock::given(method("GET"))
1299 .and(path("/hop1"))
1300 .respond_with(ResponseTemplate::new(301).insert_header("location", hop2.as_str()))
1301 .mount(&server)
1302 .await;
1303 Mock::given(method("GET"))
1304 .and(path("/hop2"))
1305 .respond_with(ResponseTemplate::new(301).insert_header("location", hop3.as_str()))
1306 .mount(&server)
1307 .await;
1308 Mock::given(method("GET"))
1309 .and(path("/hop3"))
1310 .respond_with(ResponseTemplate::new(301).insert_header("location", final_dest.as_str()))
1311 .mount(&server)
1312 .await;
1313 Mock::given(method("GET"))
1314 .and(path("/done"))
1315 .respond_with(ResponseTemplate::new(200).set_body_string("<p>done</p>"))
1316 .mount(&server)
1317 .await;
1318
1319 let (host, addrs) = server_host_and_addr(&server);
1320 let url = format!("{}/hop1", server.uri());
1321 let result = follow_redirects_raw(&executor, &url, &host, &addrs).await;
1322 assert!(result.is_ok(), "3 redirects should succeed: {result:?}");
1323 assert_eq!(result.unwrap(), "<p>done</p>");
1324 }
1325
1326 #[tokio::test]
1327 async fn fetch_html_four_redirects_rejected() {
1328 use wiremock::matchers::{method, path};
1329 use wiremock::{Mock, ResponseTemplate};
1330
1331 let (executor, server) = mock_server_executor().await;
1332 let hop2 = format!("{}/r2", server.uri());
1333 let hop3 = format!("{}/r3", server.uri());
1334 let hop4 = format!("{}/r4", server.uri());
1335 let hop5 = format!("{}/r5", server.uri());
1336
1337 for (from, to) in [
1338 ("/r1", &hop2),
1339 ("/r2", &hop3),
1340 ("/r3", &hop4),
1341 ("/r4", &hop5),
1342 ] {
1343 Mock::given(method("GET"))
1344 .and(path(from))
1345 .respond_with(ResponseTemplate::new(301).insert_header("location", to.as_str()))
1346 .mount(&server)
1347 .await;
1348 }
1349
1350 let (host, addrs) = server_host_and_addr(&server);
1351 let url = format!("{}/r1", server.uri());
1352 let result = follow_redirects_raw(&executor, &url, &host, &addrs).await;
1353 assert!(result.is_err(), "4 redirects should be rejected");
1354 let msg = result.unwrap_err().to_string();
1355 assert!(
1356 msg.contains("redirect"),
1357 "expected redirect-related error: {msg}"
1358 );
1359 }
1360
1361 #[tokio::test]
1362 async fn fetch_html_body_too_large_returns_error() {
1363 use wiremock::matchers::{method, path};
1364 use wiremock::{Mock, ResponseTemplate};
1365
1366 let small_limit_executor = WebScrapeExecutor {
1367 timeout: Duration::from_secs(5),
1368 max_body_bytes: 10,
1369 allowed_domains: vec![],
1370 denied_domains: vec![],
1371 audit_logger: None,
1372 };
1373 let server = wiremock::MockServer::start().await;
1374 Mock::given(method("GET"))
1375 .and(path("/big"))
1376 .respond_with(
1377 ResponseTemplate::new(200)
1378 .set_body_string("this body is definitely longer than ten bytes"),
1379 )
1380 .mount(&server)
1381 .await;
1382
1383 let (host, addrs) = server_host_and_addr(&server);
1384 let url = format!("{}/big", server.uri());
1385 let result = small_limit_executor.fetch_html(&url, &host, &addrs).await;
1386 assert!(result.is_err());
1387 let msg = result.unwrap_err().to_string();
1388 assert!(msg.contains("too large"), "expected too-large error: {msg}");
1389 }
1390
1391 #[test]
1392 fn extract_scrape_blocks_empty_block_content() {
1393 let text = "```scrape\n\n```";
1394 let blocks = extract_scrape_blocks(text);
1395 assert_eq!(blocks.len(), 1);
1396 assert!(blocks[0].is_empty());
1397 }
1398
1399 #[test]
1400 fn extract_scrape_blocks_whitespace_only() {
1401 let text = "```scrape\n \n```";
1402 let blocks = extract_scrape_blocks(text);
1403 assert_eq!(blocks.len(), 1);
1404 }
1405
1406 #[test]
1407 fn parse_and_extract_multiple_selectors() {
1408 let html = "<div><h1>Title</h1><p>Para</p></div>";
1409 let result = parse_and_extract(html, "h1, p", &ExtractMode::Text, 10).unwrap();
1410 assert!(result.contains("Title"));
1411 assert!(result.contains("Para"));
1412 }
1413
1414 #[test]
1415 fn webscrape_executor_new_with_custom_config() {
1416 let config = ScrapeConfig {
1417 timeout: 60,
1418 max_body_bytes: 512,
1419 ..Default::default()
1420 };
1421 let executor = WebScrapeExecutor::new(&config);
1422 assert_eq!(executor.max_body_bytes, 512);
1423 }
1424
1425 #[test]
1426 fn webscrape_executor_debug() {
1427 let config = ScrapeConfig::default();
1428 let executor = WebScrapeExecutor::new(&config);
1429 let dbg = format!("{executor:?}");
1430 assert!(dbg.contains("WebScrapeExecutor"));
1431 }
1432
1433 #[test]
1434 fn extract_mode_attr_empty_name() {
1435 let mode = ExtractMode::parse("attr:");
1436 assert!(matches!(mode, ExtractMode::Attr(ref s) if s.is_empty()));
1437 }
1438
1439 #[test]
1440 fn default_extract_returns_text() {
1441 assert_eq!(default_extract(), "text");
1442 }
1443
1444 #[test]
1445 fn scrape_instruction_debug() {
1446 let json = r#"{"url":"https://example.com","select":"h1"}"#;
1447 let instr: ScrapeInstruction = serde_json::from_str(json).unwrap();
1448 let dbg = format!("{instr:?}");
1449 assert!(dbg.contains("ScrapeInstruction"));
1450 }
1451
1452 #[test]
1453 fn extract_mode_debug() {
1454 let mode = ExtractMode::Text;
1455 let dbg = format!("{mode:?}");
1456 assert!(dbg.contains("Text"));
1457 }
1458
1459 #[test]
1464 fn max_redirects_constant_is_three() {
1465 const MAX_REDIRECTS: usize = 3;
1469 assert_eq!(MAX_REDIRECTS, 3, "fetch_html allows exactly 3 redirects");
1470 }
1471
1472 #[test]
1475 fn redirect_no_location_error_message() {
1476 let err = std::io::Error::other("redirect with no Location");
1477 assert!(err.to_string().contains("redirect with no Location"));
1478 }
1479
1480 #[test]
1482 fn too_many_redirects_error_message() {
1483 let err = std::io::Error::other("too many redirects");
1484 assert!(err.to_string().contains("too many redirects"));
1485 }
1486
1487 #[test]
1489 fn non_2xx_status_error_format() {
1490 let status = reqwest::StatusCode::FORBIDDEN;
1491 let msg = format!("HTTP {status}");
1492 assert!(msg.contains("403"));
1493 }
1494
1495 #[test]
1497 fn not_found_status_error_format() {
1498 let status = reqwest::StatusCode::NOT_FOUND;
1499 let msg = format!("HTTP {status}");
1500 assert!(msg.contains("404"));
1501 }
1502
1503 #[test]
1505 fn relative_redirect_same_host_path() {
1506 let base = Url::parse("https://example.com/current").unwrap();
1507 let resolved = base.join("/other").unwrap();
1508 assert_eq!(resolved.as_str(), "https://example.com/other");
1509 }
1510
1511 #[test]
1513 fn relative_redirect_relative_path() {
1514 let base = Url::parse("https://example.com/a/b").unwrap();
1515 let resolved = base.join("c").unwrap();
1516 assert_eq!(resolved.as_str(), "https://example.com/a/c");
1517 }
1518
1519 #[test]
1521 fn absolute_redirect_overrides_base() {
1522 let base = Url::parse("https://example.com/page").unwrap();
1523 let resolved = base.join("https://other.com/target").unwrap();
1524 assert_eq!(resolved.as_str(), "https://other.com/target");
1525 }
1526
1527 #[test]
1529 fn redirect_http_downgrade_rejected() {
1530 let location = "http://example.com/page";
1531 let base = Url::parse("https://example.com/start").unwrap();
1532 let next = base.join(location).unwrap();
1533 let err = validate_url(next.as_str()).unwrap_err();
1534 assert!(matches!(err, ToolError::Blocked { .. }));
1535 }
1536
1537 #[test]
1539 fn redirect_location_private_ip_blocked() {
1540 let location = "https://192.168.100.1/admin";
1541 let base = Url::parse("https://example.com/start").unwrap();
1542 let next = base.join(location).unwrap();
1543 let err = validate_url(next.as_str()).unwrap_err();
1544 assert!(matches!(err, ToolError::Blocked { .. }));
1545 let ToolError::Blocked { command: cmd } = err else {
1546 panic!("expected Blocked");
1547 };
1548 assert!(
1549 cmd.contains("private") || cmd.contains("scheme"),
1550 "error message should describe the block reason: {cmd}"
1551 );
1552 }
1553
1554 #[test]
1556 fn redirect_location_internal_domain_blocked() {
1557 let location = "https://metadata.internal/latest/meta-data/";
1558 let base = Url::parse("https://example.com/start").unwrap();
1559 let next = base.join(location).unwrap();
1560 let err = validate_url(next.as_str()).unwrap_err();
1561 assert!(matches!(err, ToolError::Blocked { .. }));
1562 }
1563
1564 #[test]
1566 fn redirect_chain_three_hops_all_public() {
1567 let hops = [
1568 "https://redirect1.example.com/hop1",
1569 "https://redirect2.example.com/hop2",
1570 "https://destination.example.com/final",
1571 ];
1572 for hop in hops {
1573 assert!(validate_url(hop).is_ok(), "expected ok for {hop}");
1574 }
1575 }
1576
1577 #[test]
1582 fn redirect_to_private_ip_rejected_by_validate_url() {
1583 let private_targets = [
1585 "https://127.0.0.1/secret",
1586 "https://10.0.0.1/internal",
1587 "https://192.168.1.1/admin",
1588 "https://172.16.0.1/data",
1589 "https://[::1]/path",
1590 "https://[fe80::1]/path",
1591 "https://localhost/path",
1592 "https://service.internal/api",
1593 ];
1594 for target in private_targets {
1595 let result = validate_url(target);
1596 assert!(result.is_err(), "expected error for {target}");
1597 assert!(
1598 matches!(result.unwrap_err(), ToolError::Blocked { .. }),
1599 "expected Blocked for {target}"
1600 );
1601 }
1602 }
1603
1604 #[test]
1606 fn redirect_relative_url_resolves_correctly() {
1607 let base = Url::parse("https://example.com/page").unwrap();
1608 let relative = "/other";
1609 let resolved = base.join(relative).unwrap();
1610 assert_eq!(resolved.as_str(), "https://example.com/other");
1611 }
1612
1613 #[test]
1615 fn redirect_to_http_rejected() {
1616 let err = validate_url("http://example.com/page").unwrap_err();
1617 assert!(matches!(err, ToolError::Blocked { .. }));
1618 }
1619
1620 #[test]
1621 fn ipv4_mapped_ipv6_link_local_blocked() {
1622 let err = validate_url("https://[::ffff:169.254.0.1]/path").unwrap_err();
1623 assert!(matches!(err, ToolError::Blocked { .. }));
1624 }
1625
1626 #[test]
1627 fn ipv4_mapped_ipv6_public_allowed() {
1628 assert!(validate_url("https://[::ffff:93.184.216.34]/path").is_ok());
1629 }
1630
1631 #[tokio::test]
1634 async fn fetch_http_scheme_blocked() {
1635 let config = ScrapeConfig::default();
1636 let executor = WebScrapeExecutor::new(&config);
1637 let call = crate::executor::ToolCall {
1638 tool_id: "fetch".to_owned(),
1639 params: {
1640 let mut m = serde_json::Map::new();
1641 m.insert("url".to_owned(), serde_json::json!("http://example.com"));
1642 m
1643 },
1644 caller_id: None,
1645 };
1646 let result = executor.execute_tool_call(&call).await;
1647 assert!(matches!(result, Err(ToolError::Blocked { .. })));
1648 }
1649
1650 #[tokio::test]
1651 async fn fetch_private_ip_blocked() {
1652 let config = ScrapeConfig::default();
1653 let executor = WebScrapeExecutor::new(&config);
1654 let call = crate::executor::ToolCall {
1655 tool_id: "fetch".to_owned(),
1656 params: {
1657 let mut m = serde_json::Map::new();
1658 m.insert(
1659 "url".to_owned(),
1660 serde_json::json!("https://192.168.1.1/secret"),
1661 );
1662 m
1663 },
1664 caller_id: None,
1665 };
1666 let result = executor.execute_tool_call(&call).await;
1667 assert!(matches!(result, Err(ToolError::Blocked { .. })));
1668 }
1669
1670 #[tokio::test]
1671 async fn fetch_localhost_blocked() {
1672 let config = ScrapeConfig::default();
1673 let executor = WebScrapeExecutor::new(&config);
1674 let call = crate::executor::ToolCall {
1675 tool_id: "fetch".to_owned(),
1676 params: {
1677 let mut m = serde_json::Map::new();
1678 m.insert(
1679 "url".to_owned(),
1680 serde_json::json!("https://localhost/page"),
1681 );
1682 m
1683 },
1684 caller_id: None,
1685 };
1686 let result = executor.execute_tool_call(&call).await;
1687 assert!(matches!(result, Err(ToolError::Blocked { .. })));
1688 }
1689
1690 #[tokio::test]
1691 async fn fetch_unknown_tool_returns_none() {
1692 let config = ScrapeConfig::default();
1693 let executor = WebScrapeExecutor::new(&config);
1694 let call = crate::executor::ToolCall {
1695 tool_id: "unknown_tool".to_owned(),
1696 params: serde_json::Map::new(),
1697 caller_id: None,
1698 };
1699 let result = executor.execute_tool_call(&call).await;
1700 assert!(result.unwrap().is_none());
1701 }
1702
1703 #[tokio::test]
1704 async fn fetch_returns_body_via_mock() {
1705 use wiremock::matchers::{method, path};
1706 use wiremock::{Mock, ResponseTemplate};
1707
1708 let (executor, server) = mock_server_executor().await;
1709 Mock::given(method("GET"))
1710 .and(path("/content"))
1711 .respond_with(ResponseTemplate::new(200).set_body_string("plain text content"))
1712 .mount(&server)
1713 .await;
1714
1715 let (host, addrs) = server_host_and_addr(&server);
1716 let url = format!("{}/content", server.uri());
1717 let result = executor.fetch_html(&url, &host, &addrs).await;
1718 assert!(result.is_ok());
1719 assert_eq!(result.unwrap(), "plain text content");
1720 }
1721
1722 #[test]
1723 fn tool_definitions_returns_web_scrape_and_fetch() {
1724 let config = ScrapeConfig::default();
1725 let executor = WebScrapeExecutor::new(&config);
1726 let defs = executor.tool_definitions();
1727 assert_eq!(defs.len(), 2);
1728 assert_eq!(defs[0].id, "web_scrape");
1729 assert_eq!(
1730 defs[0].invocation,
1731 crate::registry::InvocationHint::FencedBlock("scrape")
1732 );
1733 assert_eq!(defs[1].id, "fetch");
1734 assert_eq!(
1735 defs[1].invocation,
1736 crate::registry::InvocationHint::ToolCall
1737 );
1738 }
1739
1740 #[test]
1741 fn tool_definitions_schema_has_all_params() {
1742 let config = ScrapeConfig::default();
1743 let executor = WebScrapeExecutor::new(&config);
1744 let defs = executor.tool_definitions();
1745 let obj = defs[0].schema.as_object().unwrap();
1746 let props = obj["properties"].as_object().unwrap();
1747 assert!(props.contains_key("url"));
1748 assert!(props.contains_key("select"));
1749 assert!(props.contains_key("extract"));
1750 assert!(props.contains_key("limit"));
1751 let req = obj["required"].as_array().unwrap();
1752 assert!(req.iter().any(|v| v.as_str() == Some("url")));
1753 assert!(req.iter().any(|v| v.as_str() == Some("select")));
1754 assert!(!req.iter().any(|v| v.as_str() == Some("extract")));
1755 }
1756
1757 #[test]
1760 fn subdomain_localhost_blocked() {
1761 let host: url::Host<&str> = url::Host::Domain("foo.localhost");
1762 assert!(is_private_host(&host));
1763 }
1764
1765 #[test]
1766 fn internal_tld_blocked() {
1767 let host: url::Host<&str> = url::Host::Domain("service.internal");
1768 assert!(is_private_host(&host));
1769 }
1770
1771 #[test]
1772 fn local_tld_blocked() {
1773 let host: url::Host<&str> = url::Host::Domain("printer.local");
1774 assert!(is_private_host(&host));
1775 }
1776
1777 #[test]
1778 fn public_domain_not_blocked() {
1779 let host: url::Host<&str> = url::Host::Domain("example.com");
1780 assert!(!is_private_host(&host));
1781 }
1782
1783 #[tokio::test]
1786 async fn resolve_loopback_rejected() {
1787 let url = url::Url::parse("https://127.0.0.1/path").unwrap();
1789 let result = resolve_and_validate(&url).await;
1791 assert!(
1792 result.is_err(),
1793 "loopback IP must be rejected by resolve_and_validate"
1794 );
1795 let err = result.unwrap_err();
1796 assert!(matches!(err, crate::executor::ToolError::Blocked { .. }));
1797 }
1798
1799 #[tokio::test]
1800 async fn resolve_private_10_rejected() {
1801 let url = url::Url::parse("https://10.0.0.1/path").unwrap();
1802 let result = resolve_and_validate(&url).await;
1803 assert!(result.is_err());
1804 assert!(matches!(
1805 result.unwrap_err(),
1806 crate::executor::ToolError::Blocked { .. }
1807 ));
1808 }
1809
1810 #[tokio::test]
1811 async fn resolve_private_192_rejected() {
1812 let url = url::Url::parse("https://192.168.1.1/path").unwrap();
1813 let result = resolve_and_validate(&url).await;
1814 assert!(result.is_err());
1815 assert!(matches!(
1816 result.unwrap_err(),
1817 crate::executor::ToolError::Blocked { .. }
1818 ));
1819 }
1820
1821 #[tokio::test]
1822 async fn resolve_ipv6_loopback_rejected() {
1823 let url = url::Url::parse("https://[::1]/path").unwrap();
1824 let result = resolve_and_validate(&url).await;
1825 assert!(result.is_err());
1826 assert!(matches!(
1827 result.unwrap_err(),
1828 crate::executor::ToolError::Blocked { .. }
1829 ));
1830 }
1831
1832 #[tokio::test]
1833 async fn resolve_no_host_returns_ok() {
1834 let url = url::Url::parse("https://example.com/path").unwrap();
1836 let url_no_host = url::Url::parse("data:text/plain,hello").unwrap();
1838 let result = resolve_and_validate(&url_no_host).await;
1840 assert!(result.is_ok());
1841 let (host, addrs) = result.unwrap();
1842 assert!(host.is_empty());
1843 assert!(addrs.is_empty());
1844 drop(url);
1845 drop(url_no_host);
1846 }
1847
1848 async fn make_file_audit_logger(
1852 dir: &tempfile::TempDir,
1853 ) -> (
1854 std::sync::Arc<crate::audit::AuditLogger>,
1855 std::path::PathBuf,
1856 ) {
1857 use crate::audit::AuditLogger;
1858 use crate::config::AuditConfig;
1859 let path = dir.path().join("audit.log");
1860 let config = AuditConfig {
1861 enabled: true,
1862 destination: path.display().to_string(),
1863 ..Default::default()
1864 };
1865 let logger = std::sync::Arc::new(AuditLogger::from_config(&config).await.unwrap());
1866 (logger, path)
1867 }
1868
1869 #[tokio::test]
1870 async fn with_audit_sets_logger() {
1871 let config = ScrapeConfig::default();
1872 let executor = WebScrapeExecutor::new(&config);
1873 assert!(executor.audit_logger.is_none());
1874
1875 let dir = tempfile::tempdir().unwrap();
1876 let (logger, _path) = make_file_audit_logger(&dir).await;
1877 let executor = executor.with_audit(logger);
1878 assert!(executor.audit_logger.is_some());
1879 }
1880
1881 #[test]
1882 fn tool_error_to_audit_result_blocked_maps_correctly() {
1883 let err = ToolError::Blocked {
1884 command: "scheme not allowed: http".into(),
1885 };
1886 let result = tool_error_to_audit_result(&err);
1887 assert!(
1888 matches!(result, AuditResult::Blocked { reason } if reason == "scheme not allowed: http")
1889 );
1890 }
1891
1892 #[test]
1893 fn tool_error_to_audit_result_timeout_maps_correctly() {
1894 let err = ToolError::Timeout { timeout_secs: 15 };
1895 let result = tool_error_to_audit_result(&err);
1896 assert!(matches!(result, AuditResult::Timeout));
1897 }
1898
1899 #[test]
1900 fn tool_error_to_audit_result_execution_error_maps_correctly() {
1901 let err = ToolError::Execution(std::io::Error::other("connection refused"));
1902 let result = tool_error_to_audit_result(&err);
1903 assert!(
1904 matches!(result, AuditResult::Error { message } if message.contains("connection refused"))
1905 );
1906 }
1907
1908 #[tokio::test]
1909 async fn fetch_audit_blocked_url_logged() {
1910 let dir = tempfile::tempdir().unwrap();
1911 let (logger, log_path) = make_file_audit_logger(&dir).await;
1912
1913 let config = ScrapeConfig::default();
1914 let executor = WebScrapeExecutor::new(&config).with_audit(logger);
1915
1916 let call = crate::executor::ToolCall {
1917 tool_id: "fetch".to_owned(),
1918 params: {
1919 let mut m = serde_json::Map::new();
1920 m.insert("url".to_owned(), serde_json::json!("http://example.com"));
1921 m
1922 },
1923 caller_id: None,
1924 };
1925 let result = executor.execute_tool_call(&call).await;
1926 assert!(matches!(result, Err(ToolError::Blocked { .. })));
1927
1928 let content = tokio::fs::read_to_string(&log_path).await.unwrap();
1929 assert!(
1930 content.contains("\"tool\":\"fetch\""),
1931 "expected tool=fetch in audit: {content}"
1932 );
1933 assert!(
1934 content.contains("\"type\":\"blocked\""),
1935 "expected type=blocked in audit: {content}"
1936 );
1937 assert!(
1938 content.contains("http://example.com"),
1939 "expected URL in audit command field: {content}"
1940 );
1941 }
1942
1943 #[tokio::test]
1944 async fn log_audit_success_writes_to_file() {
1945 let dir = tempfile::tempdir().unwrap();
1946 let (logger, log_path) = make_file_audit_logger(&dir).await;
1947
1948 let config = ScrapeConfig::default();
1949 let executor = WebScrapeExecutor::new(&config).with_audit(logger);
1950
1951 executor
1952 .log_audit(
1953 "fetch",
1954 "https://example.com/page",
1955 AuditResult::Success,
1956 42,
1957 None,
1958 )
1959 .await;
1960
1961 let content = tokio::fs::read_to_string(&log_path).await.unwrap();
1962 assert!(
1963 content.contains("\"tool\":\"fetch\""),
1964 "expected tool=fetch in audit: {content}"
1965 );
1966 assert!(
1967 content.contains("\"type\":\"success\""),
1968 "expected type=success in audit: {content}"
1969 );
1970 assert!(
1971 content.contains("\"command\":\"https://example.com/page\""),
1972 "expected command URL in audit: {content}"
1973 );
1974 assert!(
1975 content.contains("\"duration_ms\":42"),
1976 "expected duration_ms in audit: {content}"
1977 );
1978 }
1979
1980 #[tokio::test]
1981 async fn log_audit_blocked_writes_to_file() {
1982 let dir = tempfile::tempdir().unwrap();
1983 let (logger, log_path) = make_file_audit_logger(&dir).await;
1984
1985 let config = ScrapeConfig::default();
1986 let executor = WebScrapeExecutor::new(&config).with_audit(logger);
1987
1988 executor
1989 .log_audit(
1990 "web_scrape",
1991 "http://evil.com/page",
1992 AuditResult::Blocked {
1993 reason: "scheme not allowed: http".into(),
1994 },
1995 0,
1996 None,
1997 )
1998 .await;
1999
2000 let content = tokio::fs::read_to_string(&log_path).await.unwrap();
2001 assert!(
2002 content.contains("\"tool\":\"web_scrape\""),
2003 "expected tool=web_scrape in audit: {content}"
2004 );
2005 assert!(
2006 content.contains("\"type\":\"blocked\""),
2007 "expected type=blocked in audit: {content}"
2008 );
2009 assert!(
2010 content.contains("scheme not allowed"),
2011 "expected block reason in audit: {content}"
2012 );
2013 }
2014
2015 #[tokio::test]
2016 async fn web_scrape_audit_blocked_url_logged() {
2017 let dir = tempfile::tempdir().unwrap();
2018 let (logger, log_path) = make_file_audit_logger(&dir).await;
2019
2020 let config = ScrapeConfig::default();
2021 let executor = WebScrapeExecutor::new(&config).with_audit(logger);
2022
2023 let call = crate::executor::ToolCall {
2024 tool_id: "web_scrape".to_owned(),
2025 params: {
2026 let mut m = serde_json::Map::new();
2027 m.insert("url".to_owned(), serde_json::json!("http://example.com"));
2028 m.insert("select".to_owned(), serde_json::json!("h1"));
2029 m
2030 },
2031 caller_id: None,
2032 };
2033 let result = executor.execute_tool_call(&call).await;
2034 assert!(matches!(result, Err(ToolError::Blocked { .. })));
2035
2036 let content = tokio::fs::read_to_string(&log_path).await.unwrap();
2037 assert!(
2038 content.contains("\"tool\":\"web_scrape\""),
2039 "expected tool=web_scrape in audit: {content}"
2040 );
2041 assert!(
2042 content.contains("\"type\":\"blocked\""),
2043 "expected type=blocked in audit: {content}"
2044 );
2045 }
2046
2047 #[tokio::test]
2048 async fn no_audit_logger_does_not_panic_on_blocked_fetch() {
2049 let config = ScrapeConfig::default();
2050 let executor = WebScrapeExecutor::new(&config);
2051 assert!(executor.audit_logger.is_none());
2052
2053 let call = crate::executor::ToolCall {
2054 tool_id: "fetch".to_owned(),
2055 params: {
2056 let mut m = serde_json::Map::new();
2057 m.insert("url".to_owned(), serde_json::json!("http://example.com"));
2058 m
2059 },
2060 caller_id: None,
2061 };
2062 let result = executor.execute_tool_call(&call).await;
2064 assert!(matches!(result, Err(ToolError::Blocked { .. })));
2065 }
2066
2067 #[tokio::test]
2069 async fn fetch_execute_tool_call_end_to_end() {
2070 use wiremock::matchers::{method, path};
2071 use wiremock::{Mock, ResponseTemplate};
2072
2073 let (executor, server) = mock_server_executor().await;
2074 Mock::given(method("GET"))
2075 .and(path("/e2e"))
2076 .respond_with(ResponseTemplate::new(200).set_body_string("<h1>end-to-end</h1>"))
2077 .mount(&server)
2078 .await;
2079
2080 let (host, addrs) = server_host_and_addr(&server);
2081 let result = executor
2083 .fetch_html(&format!("{}/e2e", server.uri()), &host, &addrs)
2084 .await;
2085 assert!(result.is_ok());
2086 assert!(result.unwrap().contains("end-to-end"));
2087 }
2088
2089 #[test]
2092 fn domain_matches_exact() {
2093 assert!(domain_matches("example.com", "example.com"));
2094 assert!(!domain_matches("example.com", "other.com"));
2095 assert!(!domain_matches("example.com", "sub.example.com"));
2096 }
2097
2098 #[test]
2099 fn domain_matches_wildcard_single_subdomain() {
2100 assert!(domain_matches("*.example.com", "sub.example.com"));
2101 assert!(!domain_matches("*.example.com", "example.com"));
2102 assert!(!domain_matches("*.example.com", "sub.sub.example.com"));
2103 }
2104
2105 #[test]
2106 fn domain_matches_wildcard_does_not_match_empty_label() {
2107 assert!(!domain_matches("*.example.com", ".example.com"));
2109 }
2110
2111 #[test]
2112 fn domain_matches_multi_wildcard_treated_as_exact() {
2113 assert!(!domain_matches("*.*.example.com", "a.b.example.com"));
2115 }
2116
2117 #[test]
2120 fn check_domain_policy_empty_lists_allow_all() {
2121 assert!(check_domain_policy("example.com", &[], &[]).is_ok());
2122 assert!(check_domain_policy("evil.com", &[], &[]).is_ok());
2123 }
2124
2125 #[test]
2126 fn check_domain_policy_denylist_blocks() {
2127 let denied = vec!["evil.com".to_string()];
2128 let err = check_domain_policy("evil.com", &[], &denied).unwrap_err();
2129 assert!(matches!(err, ToolError::Blocked { .. }));
2130 }
2131
2132 #[test]
2133 fn check_domain_policy_denylist_does_not_block_other_domains() {
2134 let denied = vec!["evil.com".to_string()];
2135 assert!(check_domain_policy("good.com", &[], &denied).is_ok());
2136 }
2137
2138 #[test]
2139 fn check_domain_policy_allowlist_permits_matching() {
2140 let allowed = vec!["docs.rs".to_string(), "*.rust-lang.org".to_string()];
2141 assert!(check_domain_policy("docs.rs", &allowed, &[]).is_ok());
2142 assert!(check_domain_policy("blog.rust-lang.org", &allowed, &[]).is_ok());
2143 }
2144
2145 #[test]
2146 fn check_domain_policy_allowlist_blocks_unknown() {
2147 let allowed = vec!["docs.rs".to_string()];
2148 let err = check_domain_policy("other.com", &allowed, &[]).unwrap_err();
2149 assert!(matches!(err, ToolError::Blocked { .. }));
2150 }
2151
2152 #[test]
2153 fn check_domain_policy_deny_overrides_allow() {
2154 let allowed = vec!["example.com".to_string()];
2155 let denied = vec!["example.com".to_string()];
2156 let err = check_domain_policy("example.com", &allowed, &denied).unwrap_err();
2157 assert!(matches!(err, ToolError::Blocked { .. }));
2158 }
2159
2160 #[test]
2161 fn check_domain_policy_wildcard_in_denylist() {
2162 let denied = vec!["*.evil.com".to_string()];
2163 let err = check_domain_policy("sub.evil.com", &[], &denied).unwrap_err();
2164 assert!(matches!(err, ToolError::Blocked { .. }));
2165 assert!(check_domain_policy("evil.com", &[], &denied).is_ok());
2167 }
2168}