1use std::net::{IpAddr, SocketAddr};
5use std::sync::Arc;
6use std::time::{Duration, Instant};
7
8use schemars::JsonSchema;
9use serde::Deserialize;
10use url::Url;
11
12use crate::audit::{AuditEntry, AuditLogger, AuditResult, chrono_now};
13use crate::config::ScrapeConfig;
14use crate::executor::{
15 ClaimSource, ToolCall, ToolError, ToolExecutor, ToolOutput, deserialize_params,
16};
17use crate::net::is_private_ip;
18
19#[derive(Debug, Deserialize, JsonSchema)]
20struct FetchParams {
21 url: String,
23}
24
25#[derive(Debug, Deserialize, JsonSchema)]
26struct ScrapeInstruction {
27 url: String,
29 select: String,
31 #[serde(default = "default_extract")]
33 extract: String,
34 limit: Option<usize>,
36}
37
38fn default_extract() -> String {
39 "text".into()
40}
41
42#[derive(Debug)]
43enum ExtractMode {
44 Text,
45 Html,
46 Attr(String),
47}
48
49impl ExtractMode {
50 fn parse(s: &str) -> Self {
51 match s {
52 "text" => Self::Text,
53 "html" => Self::Html,
54 attr if attr.starts_with("attr:") => {
55 Self::Attr(attr.strip_prefix("attr:").unwrap_or(attr).to_owned())
56 }
57 _ => Self::Text,
58 }
59 }
60}
61
62#[derive(Debug)]
67pub struct WebScrapeExecutor {
68 timeout: Duration,
69 max_body_bytes: usize,
70 audit_logger: Option<Arc<AuditLogger>>,
71}
72
73impl WebScrapeExecutor {
74 #[must_use]
75 pub fn new(config: &ScrapeConfig) -> Self {
76 Self {
77 timeout: Duration::from_secs(config.timeout),
78 max_body_bytes: config.max_body_bytes,
79 audit_logger: None,
80 }
81 }
82
83 #[must_use]
84 pub fn with_audit(mut self, logger: Arc<AuditLogger>) -> Self {
85 self.audit_logger = Some(logger);
86 self
87 }
88
89 fn build_client(&self, host: &str, addrs: &[SocketAddr]) -> reqwest::Client {
90 let mut builder = reqwest::Client::builder()
91 .timeout(self.timeout)
92 .redirect(reqwest::redirect::Policy::none());
93 builder = builder.resolve_to_addrs(host, addrs);
94 builder.build().unwrap_or_default()
95 }
96}
97
98impl ToolExecutor for WebScrapeExecutor {
99 fn tool_definitions(&self) -> Vec<crate::registry::ToolDef> {
100 use crate::registry::{InvocationHint, ToolDef};
101 vec![
102 ToolDef {
103 id: "web_scrape".into(),
104 description: "Extract structured data from a web page using CSS selectors.\n\nONLY call this tool when the user has explicitly provided a URL in their message, or when a prior tool call returned a URL to retrieve. NEVER construct, guess, or infer a URL from entity names, brand knowledge, or domain patterns.\n\nParameters: url (string, required) - HTTPS URL; select (string, required) - CSS selector; extract (string, optional) - \"text\", \"html\", or \"attr:<name>\"; limit (integer, optional) - max results\nReturns: extracted text/HTML/attribute values, one per line\nErrors: InvalidParams if URL is not HTTPS or selector is empty; Timeout after configured seconds; connection/DNS failures".into(),
105 schema: schemars::schema_for!(ScrapeInstruction),
106 invocation: InvocationHint::FencedBlock("scrape"),
107 },
108 ToolDef {
109 id: "fetch".into(),
110 description: "Fetch a URL and return the response body as plain text.\n\nONLY call this tool when the user has explicitly provided a URL in their message, or when a prior tool call returned a URL to retrieve. NEVER construct, guess, or infer a URL from entity names, brand knowledge, or domain patterns. If no URL is present in the conversation, do not call this tool.\n\nParameters: url (string, required) - HTTPS URL to fetch\nReturns: response body as UTF-8 text, truncated if exceeding max body size\nErrors: InvalidParams if URL is not HTTPS; Timeout; SSRF-blocked private IPs; connection failures".into(),
111 schema: schemars::schema_for!(FetchParams),
112 invocation: InvocationHint::ToolCall,
113 },
114 ]
115 }
116
117 async fn execute(&self, response: &str) -> Result<Option<ToolOutput>, ToolError> {
118 let blocks = extract_scrape_blocks(response);
119 if blocks.is_empty() {
120 return Ok(None);
121 }
122
123 let mut outputs = Vec::with_capacity(blocks.len());
124 #[allow(clippy::cast_possible_truncation)]
125 let blocks_executed = blocks.len() as u32;
126
127 for block in &blocks {
128 let instruction: ScrapeInstruction = serde_json::from_str(block).map_err(|e| {
129 ToolError::Execution(std::io::Error::new(
130 std::io::ErrorKind::InvalidData,
131 e.to_string(),
132 ))
133 })?;
134 let start = Instant::now();
135 let scrape_result = self.scrape_instruction(&instruction).await;
136 #[allow(clippy::cast_possible_truncation)]
137 let duration_ms = start.elapsed().as_millis() as u64;
138 match scrape_result {
139 Ok(output) => {
140 self.log_audit(
141 "web_scrape",
142 &instruction.url,
143 AuditResult::Success,
144 duration_ms,
145 None,
146 )
147 .await;
148 outputs.push(output);
149 }
150 Err(e) => {
151 let audit_result = tool_error_to_audit_result(&e);
152 self.log_audit(
153 "web_scrape",
154 &instruction.url,
155 audit_result,
156 duration_ms,
157 Some(&e),
158 )
159 .await;
160 return Err(e);
161 }
162 }
163 }
164
165 Ok(Some(ToolOutput {
166 tool_name: "web-scrape".to_owned(),
167 summary: outputs.join("\n\n"),
168 blocks_executed,
169 filter_stats: None,
170 diff: None,
171 streamed: false,
172 terminal_id: None,
173 locations: None,
174 raw_response: None,
175 claim_source: Some(ClaimSource::WebScrape),
176 }))
177 }
178
179 async fn execute_tool_call(&self, call: &ToolCall) -> Result<Option<ToolOutput>, ToolError> {
180 match call.tool_id.as_str() {
181 "web_scrape" => {
182 let instruction: ScrapeInstruction = deserialize_params(&call.params)?;
183 let start = Instant::now();
184 let result = self.scrape_instruction(&instruction).await;
185 #[allow(clippy::cast_possible_truncation)]
186 let duration_ms = start.elapsed().as_millis() as u64;
187 match result {
188 Ok(output) => {
189 self.log_audit(
190 "web_scrape",
191 &instruction.url,
192 AuditResult::Success,
193 duration_ms,
194 None,
195 )
196 .await;
197 Ok(Some(ToolOutput {
198 tool_name: "web-scrape".to_owned(),
199 summary: output,
200 blocks_executed: 1,
201 filter_stats: None,
202 diff: None,
203 streamed: false,
204 terminal_id: None,
205 locations: None,
206 raw_response: None,
207 claim_source: Some(ClaimSource::WebScrape),
208 }))
209 }
210 Err(e) => {
211 let audit_result = tool_error_to_audit_result(&e);
212 self.log_audit(
213 "web_scrape",
214 &instruction.url,
215 audit_result,
216 duration_ms,
217 Some(&e),
218 )
219 .await;
220 Err(e)
221 }
222 }
223 }
224 "fetch" => {
225 let p: FetchParams = deserialize_params(&call.params)?;
226 let start = Instant::now();
227 let result = self.handle_fetch(&p).await;
228 #[allow(clippy::cast_possible_truncation)]
229 let duration_ms = start.elapsed().as_millis() as u64;
230 match result {
231 Ok(output) => {
232 self.log_audit("fetch", &p.url, AuditResult::Success, duration_ms, None)
233 .await;
234 Ok(Some(ToolOutput {
235 tool_name: "fetch".to_owned(),
236 summary: output,
237 blocks_executed: 1,
238 filter_stats: None,
239 diff: None,
240 streamed: false,
241 terminal_id: None,
242 locations: None,
243 raw_response: None,
244 claim_source: Some(ClaimSource::WebScrape),
245 }))
246 }
247 Err(e) => {
248 let audit_result = tool_error_to_audit_result(&e);
249 self.log_audit("fetch", &p.url, audit_result, duration_ms, Some(&e))
250 .await;
251 Err(e)
252 }
253 }
254 }
255 _ => Ok(None),
256 }
257 }
258
259 fn is_tool_retryable(&self, tool_id: &str) -> bool {
260 matches!(tool_id, "web_scrape" | "fetch")
261 }
262}
263
264fn tool_error_to_audit_result(e: &ToolError) -> AuditResult {
265 match e {
266 ToolError::Blocked { command } => AuditResult::Blocked {
267 reason: command.clone(),
268 },
269 ToolError::Timeout { .. } => AuditResult::Timeout,
270 _ => AuditResult::Error {
271 message: e.to_string(),
272 },
273 }
274}
275
276impl WebScrapeExecutor {
277 async fn log_audit(
278 &self,
279 tool: &str,
280 command: &str,
281 result: AuditResult,
282 duration_ms: u64,
283 error: Option<&ToolError>,
284 ) {
285 if let Some(ref logger) = self.audit_logger {
286 let (error_category, error_domain, error_phase) =
287 error.map_or((None, None, None), |e| {
288 let cat = e.category();
289 (
290 Some(cat.label().to_owned()),
291 Some(cat.domain().label().to_owned()),
292 Some(cat.phase().label().to_owned()),
293 )
294 });
295 let entry = AuditEntry {
296 timestamp: chrono_now(),
297 tool: tool.into(),
298 command: command.into(),
299 result,
300 duration_ms,
301 error_category,
302 error_domain,
303 error_phase,
304 claim_source: Some(ClaimSource::WebScrape),
305 mcp_server_id: None,
306 injection_flagged: false,
307 embedding_anomalous: false,
308 cross_boundary_mcp_to_acp: false,
309 adversarial_policy_decision: None,
310 };
311 logger.log(&entry).await;
312 }
313 }
314
315 async fn handle_fetch(&self, params: &FetchParams) -> Result<String, ToolError> {
316 let parsed = validate_url(¶ms.url)?;
317 let (host, addrs) = resolve_and_validate(&parsed).await?;
318 self.fetch_html(¶ms.url, &host, &addrs).await
319 }
320
321 async fn scrape_instruction(
322 &self,
323 instruction: &ScrapeInstruction,
324 ) -> Result<String, ToolError> {
325 let parsed = validate_url(&instruction.url)?;
326 let (host, addrs) = resolve_and_validate(&parsed).await?;
327 let html = self.fetch_html(&instruction.url, &host, &addrs).await?;
328 let selector = instruction.select.clone();
329 let extract = ExtractMode::parse(&instruction.extract);
330 let limit = instruction.limit.unwrap_or(10);
331 tokio::task::spawn_blocking(move || parse_and_extract(&html, &selector, &extract, limit))
332 .await
333 .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())))?
334 }
335
336 async fn fetch_html(
346 &self,
347 url: &str,
348 host: &str,
349 addrs: &[SocketAddr],
350 ) -> Result<String, ToolError> {
351 const MAX_REDIRECTS: usize = 3;
352
353 let mut current_url = url.to_owned();
354 let mut current_host = host.to_owned();
355 let mut current_addrs = addrs.to_vec();
356
357 for hop in 0..=MAX_REDIRECTS {
358 let client = self.build_client(¤t_host, ¤t_addrs);
360 let resp = client
361 .get(¤t_url)
362 .send()
363 .await
364 .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())))?;
365
366 let status = resp.status();
367
368 if status.is_redirection() {
369 if hop == MAX_REDIRECTS {
370 return Err(ToolError::Execution(std::io::Error::other(
371 "too many redirects",
372 )));
373 }
374
375 let location = resp
376 .headers()
377 .get(reqwest::header::LOCATION)
378 .and_then(|v| v.to_str().ok())
379 .ok_or_else(|| {
380 ToolError::Execution(std::io::Error::other("redirect with no Location"))
381 })?;
382
383 let base = Url::parse(¤t_url)
385 .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())))?;
386 let next_url = base
387 .join(location)
388 .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())))?;
389
390 let validated = validate_url(next_url.as_str())?;
391 let (next_host, next_addrs) = resolve_and_validate(&validated).await?;
392
393 current_url = next_url.to_string();
394 current_host = next_host;
395 current_addrs = next_addrs;
396 continue;
397 }
398
399 if !status.is_success() {
400 return Err(ToolError::Http {
401 status: status.as_u16(),
402 message: status.canonical_reason().unwrap_or("unknown").to_owned(),
403 });
404 }
405
406 let bytes = resp
407 .bytes()
408 .await
409 .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())))?;
410
411 if bytes.len() > self.max_body_bytes {
412 return Err(ToolError::Execution(std::io::Error::other(format!(
413 "response too large: {} bytes (max: {})",
414 bytes.len(),
415 self.max_body_bytes,
416 ))));
417 }
418
419 return String::from_utf8(bytes.to_vec())
420 .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())));
421 }
422
423 Err(ToolError::Execution(std::io::Error::other(
424 "too many redirects",
425 )))
426 }
427}
428
429fn extract_scrape_blocks(text: &str) -> Vec<&str> {
430 crate::executor::extract_fenced_blocks(text, "scrape")
431}
432
433fn validate_url(raw: &str) -> Result<Url, ToolError> {
434 let parsed = Url::parse(raw).map_err(|_| ToolError::Blocked {
435 command: format!("invalid URL: {raw}"),
436 })?;
437
438 if parsed.scheme() != "https" {
439 return Err(ToolError::Blocked {
440 command: format!("scheme not allowed: {}", parsed.scheme()),
441 });
442 }
443
444 if let Some(host) = parsed.host()
445 && is_private_host(&host)
446 {
447 return Err(ToolError::Blocked {
448 command: format!(
449 "private/local host blocked: {}",
450 parsed.host_str().unwrap_or("")
451 ),
452 });
453 }
454
455 Ok(parsed)
456}
457
458fn is_private_host(host: &url::Host<&str>) -> bool {
459 match host {
460 url::Host::Domain(d) => {
461 #[allow(clippy::case_sensitive_file_extension_comparisons)]
464 {
465 *d == "localhost"
466 || d.ends_with(".localhost")
467 || d.ends_with(".internal")
468 || d.ends_with(".local")
469 }
470 }
471 url::Host::Ipv4(v4) => is_private_ip(IpAddr::V4(*v4)),
472 url::Host::Ipv6(v6) => is_private_ip(IpAddr::V6(*v6)),
473 }
474}
475
476async fn resolve_and_validate(url: &Url) -> Result<(String, Vec<SocketAddr>), ToolError> {
482 let Some(host) = url.host_str() else {
483 return Ok((String::new(), vec![]));
484 };
485 let port = url.port_or_known_default().unwrap_or(443);
486 let addrs: Vec<SocketAddr> = tokio::net::lookup_host(format!("{host}:{port}"))
487 .await
488 .map_err(|e| ToolError::Blocked {
489 command: format!("DNS resolution failed: {e}"),
490 })?
491 .collect();
492 for addr in &addrs {
493 if is_private_ip(addr.ip()) {
494 return Err(ToolError::Blocked {
495 command: format!("SSRF protection: private IP {} for host {host}", addr.ip()),
496 });
497 }
498 }
499 Ok((host.to_owned(), addrs))
500}
501
502fn parse_and_extract(
503 html: &str,
504 selector: &str,
505 extract: &ExtractMode,
506 limit: usize,
507) -> Result<String, ToolError> {
508 let soup = scrape_core::Soup::parse(html);
509
510 let tags = soup.find_all(selector).map_err(|e| {
511 ToolError::Execution(std::io::Error::new(
512 std::io::ErrorKind::InvalidData,
513 format!("invalid selector: {e}"),
514 ))
515 })?;
516
517 let mut results = Vec::new();
518
519 for tag in tags.into_iter().take(limit) {
520 let value = match extract {
521 ExtractMode::Text => tag.text(),
522 ExtractMode::Html => tag.inner_html(),
523 ExtractMode::Attr(name) => tag.get(name).unwrap_or_default().to_owned(),
524 };
525 if !value.trim().is_empty() {
526 results.push(value.trim().to_owned());
527 }
528 }
529
530 if results.is_empty() {
531 Ok(format!("No results for selector: {selector}"))
532 } else {
533 Ok(results.join("\n"))
534 }
535}
536
537#[cfg(test)]
538mod tests {
539 use super::*;
540
541 #[test]
544 fn extract_single_block() {
545 let text =
546 "Here:\n```scrape\n{\"url\":\"https://example.com\",\"select\":\"h1\"}\n```\nDone.";
547 let blocks = extract_scrape_blocks(text);
548 assert_eq!(blocks.len(), 1);
549 assert!(blocks[0].contains("example.com"));
550 }
551
552 #[test]
553 fn extract_multiple_blocks() {
554 let text = "```scrape\n{\"url\":\"https://a.com\",\"select\":\"h1\"}\n```\ntext\n```scrape\n{\"url\":\"https://b.com\",\"select\":\"p\"}\n```";
555 let blocks = extract_scrape_blocks(text);
556 assert_eq!(blocks.len(), 2);
557 }
558
559 #[test]
560 fn no_blocks_returns_empty() {
561 let blocks = extract_scrape_blocks("plain text, no code blocks");
562 assert!(blocks.is_empty());
563 }
564
565 #[test]
566 fn unclosed_block_ignored() {
567 let blocks = extract_scrape_blocks("```scrape\n{\"url\":\"https://x.com\"}");
568 assert!(blocks.is_empty());
569 }
570
571 #[test]
572 fn non_scrape_block_ignored() {
573 let text =
574 "```bash\necho hi\n```\n```scrape\n{\"url\":\"https://x.com\",\"select\":\"h1\"}\n```";
575 let blocks = extract_scrape_blocks(text);
576 assert_eq!(blocks.len(), 1);
577 assert!(blocks[0].contains("x.com"));
578 }
579
580 #[test]
581 fn multiline_json_block() {
582 let text =
583 "```scrape\n{\n \"url\": \"https://example.com\",\n \"select\": \"h1\"\n}\n```";
584 let blocks = extract_scrape_blocks(text);
585 assert_eq!(blocks.len(), 1);
586 let instr: ScrapeInstruction = serde_json::from_str(blocks[0]).unwrap();
587 assert_eq!(instr.url, "https://example.com");
588 }
589
590 #[test]
593 fn parse_valid_instruction() {
594 let json = r#"{"url":"https://example.com","select":"h1","extract":"text","limit":5}"#;
595 let instr: ScrapeInstruction = serde_json::from_str(json).unwrap();
596 assert_eq!(instr.url, "https://example.com");
597 assert_eq!(instr.select, "h1");
598 assert_eq!(instr.extract, "text");
599 assert_eq!(instr.limit, Some(5));
600 }
601
602 #[test]
603 fn parse_minimal_instruction() {
604 let json = r#"{"url":"https://example.com","select":"p"}"#;
605 let instr: ScrapeInstruction = serde_json::from_str(json).unwrap();
606 assert_eq!(instr.extract, "text");
607 assert!(instr.limit.is_none());
608 }
609
610 #[test]
611 fn parse_attr_extract() {
612 let json = r#"{"url":"https://example.com","select":"a","extract":"attr:href"}"#;
613 let instr: ScrapeInstruction = serde_json::from_str(json).unwrap();
614 assert_eq!(instr.extract, "attr:href");
615 }
616
617 #[test]
618 fn parse_invalid_json_errors() {
619 let result = serde_json::from_str::<ScrapeInstruction>("not json");
620 assert!(result.is_err());
621 }
622
623 #[test]
626 fn extract_mode_text() {
627 assert!(matches!(ExtractMode::parse("text"), ExtractMode::Text));
628 }
629
630 #[test]
631 fn extract_mode_html() {
632 assert!(matches!(ExtractMode::parse("html"), ExtractMode::Html));
633 }
634
635 #[test]
636 fn extract_mode_attr() {
637 let mode = ExtractMode::parse("attr:href");
638 assert!(matches!(mode, ExtractMode::Attr(ref s) if s == "href"));
639 }
640
641 #[test]
642 fn extract_mode_unknown_defaults_to_text() {
643 assert!(matches!(ExtractMode::parse("unknown"), ExtractMode::Text));
644 }
645
646 #[test]
649 fn valid_https_url() {
650 assert!(validate_url("https://example.com").is_ok());
651 }
652
653 #[test]
654 fn http_rejected() {
655 let err = validate_url("http://example.com").unwrap_err();
656 assert!(matches!(err, ToolError::Blocked { .. }));
657 }
658
659 #[test]
660 fn ftp_rejected() {
661 let err = validate_url("ftp://files.example.com").unwrap_err();
662 assert!(matches!(err, ToolError::Blocked { .. }));
663 }
664
665 #[test]
666 fn file_rejected() {
667 let err = validate_url("file:///etc/passwd").unwrap_err();
668 assert!(matches!(err, ToolError::Blocked { .. }));
669 }
670
671 #[test]
672 fn invalid_url_rejected() {
673 let err = validate_url("not a url").unwrap_err();
674 assert!(matches!(err, ToolError::Blocked { .. }));
675 }
676
677 #[test]
678 fn localhost_blocked() {
679 let err = validate_url("https://localhost/path").unwrap_err();
680 assert!(matches!(err, ToolError::Blocked { .. }));
681 }
682
683 #[test]
684 fn loopback_ip_blocked() {
685 let err = validate_url("https://127.0.0.1/path").unwrap_err();
686 assert!(matches!(err, ToolError::Blocked { .. }));
687 }
688
689 #[test]
690 fn private_10_blocked() {
691 let err = validate_url("https://10.0.0.1/api").unwrap_err();
692 assert!(matches!(err, ToolError::Blocked { .. }));
693 }
694
695 #[test]
696 fn private_172_blocked() {
697 let err = validate_url("https://172.16.0.1/api").unwrap_err();
698 assert!(matches!(err, ToolError::Blocked { .. }));
699 }
700
701 #[test]
702 fn private_192_blocked() {
703 let err = validate_url("https://192.168.1.1/api").unwrap_err();
704 assert!(matches!(err, ToolError::Blocked { .. }));
705 }
706
707 #[test]
708 fn ipv6_loopback_blocked() {
709 let err = validate_url("https://[::1]/path").unwrap_err();
710 assert!(matches!(err, ToolError::Blocked { .. }));
711 }
712
713 #[test]
714 fn public_ip_allowed() {
715 assert!(validate_url("https://93.184.216.34/page").is_ok());
716 }
717
718 #[test]
721 fn extract_text_from_html() {
722 let html = "<html><body><h1>Hello World</h1><p>Content</p></body></html>";
723 let result = parse_and_extract(html, "h1", &ExtractMode::Text, 10).unwrap();
724 assert_eq!(result, "Hello World");
725 }
726
727 #[test]
728 fn extract_multiple_elements() {
729 let html = "<ul><li>A</li><li>B</li><li>C</li></ul>";
730 let result = parse_and_extract(html, "li", &ExtractMode::Text, 10).unwrap();
731 assert_eq!(result, "A\nB\nC");
732 }
733
734 #[test]
735 fn extract_with_limit() {
736 let html = "<ul><li>A</li><li>B</li><li>C</li></ul>";
737 let result = parse_and_extract(html, "li", &ExtractMode::Text, 2).unwrap();
738 assert_eq!(result, "A\nB");
739 }
740
741 #[test]
742 fn extract_attr_href() {
743 let html = r#"<a href="https://example.com">Link</a>"#;
744 let result =
745 parse_and_extract(html, "a", &ExtractMode::Attr("href".to_owned()), 10).unwrap();
746 assert_eq!(result, "https://example.com");
747 }
748
749 #[test]
750 fn extract_inner_html() {
751 let html = "<div><span>inner</span></div>";
752 let result = parse_and_extract(html, "div", &ExtractMode::Html, 10).unwrap();
753 assert!(result.contains("<span>inner</span>"));
754 }
755
756 #[test]
757 fn no_matches_returns_message() {
758 let html = "<html><body><p>text</p></body></html>";
759 let result = parse_and_extract(html, "h1", &ExtractMode::Text, 10).unwrap();
760 assert!(result.starts_with("No results for selector:"));
761 }
762
763 #[test]
764 fn empty_text_skipped() {
765 let html = "<ul><li> </li><li>A</li></ul>";
766 let result = parse_and_extract(html, "li", &ExtractMode::Text, 10).unwrap();
767 assert_eq!(result, "A");
768 }
769
770 #[test]
771 fn invalid_selector_errors() {
772 let html = "<html><body></body></html>";
773 let result = parse_and_extract(html, "[[[invalid", &ExtractMode::Text, 10);
774 assert!(result.is_err());
775 }
776
777 #[test]
778 fn empty_html_returns_no_results() {
779 let result = parse_and_extract("", "h1", &ExtractMode::Text, 10).unwrap();
780 assert!(result.starts_with("No results for selector:"));
781 }
782
783 #[test]
784 fn nested_selector() {
785 let html = "<div><span>inner</span></div><span>outer</span>";
786 let result = parse_and_extract(html, "div > span", &ExtractMode::Text, 10).unwrap();
787 assert_eq!(result, "inner");
788 }
789
790 #[test]
791 fn attr_missing_returns_empty() {
792 let html = r"<a>No href</a>";
793 let result =
794 parse_and_extract(html, "a", &ExtractMode::Attr("href".to_owned()), 10).unwrap();
795 assert!(result.starts_with("No results for selector:"));
796 }
797
798 #[test]
799 fn extract_html_mode() {
800 let html = "<div><b>bold</b> text</div>";
801 let result = parse_and_extract(html, "div", &ExtractMode::Html, 10).unwrap();
802 assert!(result.contains("<b>bold</b>"));
803 }
804
805 #[test]
806 fn limit_zero_returns_no_results() {
807 let html = "<ul><li>A</li><li>B</li></ul>";
808 let result = parse_and_extract(html, "li", &ExtractMode::Text, 0).unwrap();
809 assert!(result.starts_with("No results for selector:"));
810 }
811
812 #[test]
815 fn url_with_port_allowed() {
816 assert!(validate_url("https://example.com:8443/path").is_ok());
817 }
818
819 #[test]
820 fn link_local_ip_blocked() {
821 let err = validate_url("https://169.254.1.1/path").unwrap_err();
822 assert!(matches!(err, ToolError::Blocked { .. }));
823 }
824
825 #[test]
826 fn url_no_scheme_rejected() {
827 let err = validate_url("example.com/path").unwrap_err();
828 assert!(matches!(err, ToolError::Blocked { .. }));
829 }
830
831 #[test]
832 fn unspecified_ipv4_blocked() {
833 let err = validate_url("https://0.0.0.0/path").unwrap_err();
834 assert!(matches!(err, ToolError::Blocked { .. }));
835 }
836
837 #[test]
838 fn broadcast_ipv4_blocked() {
839 let err = validate_url("https://255.255.255.255/path").unwrap_err();
840 assert!(matches!(err, ToolError::Blocked { .. }));
841 }
842
843 #[test]
844 fn ipv6_link_local_blocked() {
845 let err = validate_url("https://[fe80::1]/path").unwrap_err();
846 assert!(matches!(err, ToolError::Blocked { .. }));
847 }
848
849 #[test]
850 fn ipv6_unique_local_blocked() {
851 let err = validate_url("https://[fd12::1]/path").unwrap_err();
852 assert!(matches!(err, ToolError::Blocked { .. }));
853 }
854
855 #[test]
856 fn ipv4_mapped_ipv6_loopback_blocked() {
857 let err = validate_url("https://[::ffff:127.0.0.1]/path").unwrap_err();
858 assert!(matches!(err, ToolError::Blocked { .. }));
859 }
860
861 #[test]
862 fn ipv4_mapped_ipv6_private_blocked() {
863 let err = validate_url("https://[::ffff:10.0.0.1]/path").unwrap_err();
864 assert!(matches!(err, ToolError::Blocked { .. }));
865 }
866
867 #[tokio::test]
870 async fn executor_no_blocks_returns_none() {
871 let config = ScrapeConfig::default();
872 let executor = WebScrapeExecutor::new(&config);
873 let result = executor.execute("plain text").await;
874 assert!(result.unwrap().is_none());
875 }
876
877 #[tokio::test]
878 async fn executor_invalid_json_errors() {
879 let config = ScrapeConfig::default();
880 let executor = WebScrapeExecutor::new(&config);
881 let response = "```scrape\nnot json\n```";
882 let result = executor.execute(response).await;
883 assert!(matches!(result, Err(ToolError::Execution(_))));
884 }
885
886 #[tokio::test]
887 async fn executor_blocked_url_errors() {
888 let config = ScrapeConfig::default();
889 let executor = WebScrapeExecutor::new(&config);
890 let response = "```scrape\n{\"url\":\"http://example.com\",\"select\":\"h1\"}\n```";
891 let result = executor.execute(response).await;
892 assert!(matches!(result, Err(ToolError::Blocked { .. })));
893 }
894
895 #[tokio::test]
896 async fn executor_private_ip_blocked() {
897 let config = ScrapeConfig::default();
898 let executor = WebScrapeExecutor::new(&config);
899 let response = "```scrape\n{\"url\":\"https://192.168.1.1/api\",\"select\":\"h1\"}\n```";
900 let result = executor.execute(response).await;
901 assert!(matches!(result, Err(ToolError::Blocked { .. })));
902 }
903
904 #[tokio::test]
905 async fn executor_unreachable_host_returns_error() {
906 let config = ScrapeConfig {
907 timeout: 1,
908 max_body_bytes: 1_048_576,
909 };
910 let executor = WebScrapeExecutor::new(&config);
911 let response = "```scrape\n{\"url\":\"https://192.0.2.1:1/page\",\"select\":\"h1\"}\n```";
912 let result = executor.execute(response).await;
913 assert!(matches!(result, Err(ToolError::Execution(_))));
914 }
915
916 #[tokio::test]
917 async fn executor_localhost_url_blocked() {
918 let config = ScrapeConfig::default();
919 let executor = WebScrapeExecutor::new(&config);
920 let response = "```scrape\n{\"url\":\"https://localhost:9999/api\",\"select\":\"h1\"}\n```";
921 let result = executor.execute(response).await;
922 assert!(matches!(result, Err(ToolError::Blocked { .. })));
923 }
924
925 #[tokio::test]
926 async fn executor_empty_text_returns_none() {
927 let config = ScrapeConfig::default();
928 let executor = WebScrapeExecutor::new(&config);
929 let result = executor.execute("").await;
930 assert!(result.unwrap().is_none());
931 }
932
933 #[tokio::test]
934 async fn executor_multiple_blocks_first_blocked() {
935 let config = ScrapeConfig::default();
936 let executor = WebScrapeExecutor::new(&config);
937 let response = "```scrape\n{\"url\":\"http://evil.com\",\"select\":\"h1\"}\n```\n\
938 ```scrape\n{\"url\":\"https://ok.com\",\"select\":\"h1\"}\n```";
939 let result = executor.execute(response).await;
940 assert!(result.is_err());
941 }
942
943 #[test]
944 fn validate_url_empty_string() {
945 let err = validate_url("").unwrap_err();
946 assert!(matches!(err, ToolError::Blocked { .. }));
947 }
948
949 #[test]
950 fn validate_url_javascript_scheme_blocked() {
951 let err = validate_url("javascript:alert(1)").unwrap_err();
952 assert!(matches!(err, ToolError::Blocked { .. }));
953 }
954
955 #[test]
956 fn validate_url_data_scheme_blocked() {
957 let err = validate_url("data:text/html,<h1>hi</h1>").unwrap_err();
958 assert!(matches!(err, ToolError::Blocked { .. }));
959 }
960
961 #[test]
962 fn is_private_host_public_domain_is_false() {
963 let host: url::Host<&str> = url::Host::Domain("example.com");
964 assert!(!is_private_host(&host));
965 }
966
967 #[test]
968 fn is_private_host_localhost_is_true() {
969 let host: url::Host<&str> = url::Host::Domain("localhost");
970 assert!(is_private_host(&host));
971 }
972
973 #[test]
974 fn is_private_host_ipv6_unspecified_is_true() {
975 let host = url::Host::Ipv6(std::net::Ipv6Addr::UNSPECIFIED);
976 assert!(is_private_host(&host));
977 }
978
979 #[test]
980 fn is_private_host_public_ipv6_is_false() {
981 let host = url::Host::Ipv6("2001:db8::1".parse().unwrap());
982 assert!(!is_private_host(&host));
983 }
984
985 async fn mock_server_executor() -> (WebScrapeExecutor, wiremock::MockServer) {
996 let server = wiremock::MockServer::start().await;
997 let executor = WebScrapeExecutor {
998 timeout: Duration::from_secs(5),
999 max_body_bytes: 1_048_576,
1000 audit_logger: None,
1001 };
1002 (executor, server)
1003 }
1004
1005 fn server_host_and_addr(server: &wiremock::MockServer) -> (String, Vec<std::net::SocketAddr>) {
1007 let uri = server.uri();
1008 let url = Url::parse(&uri).unwrap();
1009 let host = url.host_str().unwrap_or("127.0.0.1").to_owned();
1010 let port = url.port().unwrap_or(80);
1011 let addr: std::net::SocketAddr = format!("{host}:{port}").parse().unwrap();
1012 (host, vec![addr])
1013 }
1014
1015 async fn follow_redirects_raw(
1019 executor: &WebScrapeExecutor,
1020 start_url: &str,
1021 host: &str,
1022 addrs: &[std::net::SocketAddr],
1023 ) -> Result<String, ToolError> {
1024 const MAX_REDIRECTS: usize = 3;
1025 let mut current_url = start_url.to_owned();
1026 let mut current_host = host.to_owned();
1027 let mut current_addrs = addrs.to_vec();
1028
1029 for hop in 0..=MAX_REDIRECTS {
1030 let client = executor.build_client(¤t_host, ¤t_addrs);
1031 let resp = client
1032 .get(¤t_url)
1033 .send()
1034 .await
1035 .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())))?;
1036
1037 let status = resp.status();
1038
1039 if status.is_redirection() {
1040 if hop == MAX_REDIRECTS {
1041 return Err(ToolError::Execution(std::io::Error::other(
1042 "too many redirects",
1043 )));
1044 }
1045
1046 let location = resp
1047 .headers()
1048 .get(reqwest::header::LOCATION)
1049 .and_then(|v| v.to_str().ok())
1050 .ok_or_else(|| {
1051 ToolError::Execution(std::io::Error::other("redirect with no Location"))
1052 })?;
1053
1054 let base = Url::parse(¤t_url)
1055 .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())))?;
1056 let next_url = base
1057 .join(location)
1058 .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())))?;
1059
1060 current_url = next_url.to_string();
1062 let _ = &mut current_host;
1064 let _ = &mut current_addrs;
1065 continue;
1066 }
1067
1068 if !status.is_success() {
1069 return Err(ToolError::Execution(std::io::Error::other(format!(
1070 "HTTP {status}",
1071 ))));
1072 }
1073
1074 let bytes = resp
1075 .bytes()
1076 .await
1077 .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())))?;
1078
1079 if bytes.len() > executor.max_body_bytes {
1080 return Err(ToolError::Execution(std::io::Error::other(format!(
1081 "response too large: {} bytes (max: {})",
1082 bytes.len(),
1083 executor.max_body_bytes,
1084 ))));
1085 }
1086
1087 return String::from_utf8(bytes.to_vec())
1088 .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())));
1089 }
1090
1091 Err(ToolError::Execution(std::io::Error::other(
1092 "too many redirects",
1093 )))
1094 }
1095
1096 #[tokio::test]
1097 async fn fetch_html_success_returns_body() {
1098 use wiremock::matchers::{method, path};
1099 use wiremock::{Mock, ResponseTemplate};
1100
1101 let (executor, server) = mock_server_executor().await;
1102 Mock::given(method("GET"))
1103 .and(path("/page"))
1104 .respond_with(ResponseTemplate::new(200).set_body_string("<h1>OK</h1>"))
1105 .mount(&server)
1106 .await;
1107
1108 let (host, addrs) = server_host_and_addr(&server);
1109 let url = format!("{}/page", server.uri());
1110 let result = executor.fetch_html(&url, &host, &addrs).await;
1111 assert!(result.is_ok(), "expected Ok, got: {result:?}");
1112 assert_eq!(result.unwrap(), "<h1>OK</h1>");
1113 }
1114
1115 #[tokio::test]
1116 async fn fetch_html_non_2xx_returns_error() {
1117 use wiremock::matchers::{method, path};
1118 use wiremock::{Mock, ResponseTemplate};
1119
1120 let (executor, server) = mock_server_executor().await;
1121 Mock::given(method("GET"))
1122 .and(path("/forbidden"))
1123 .respond_with(ResponseTemplate::new(403))
1124 .mount(&server)
1125 .await;
1126
1127 let (host, addrs) = server_host_and_addr(&server);
1128 let url = format!("{}/forbidden", server.uri());
1129 let result = executor.fetch_html(&url, &host, &addrs).await;
1130 assert!(result.is_err());
1131 let msg = result.unwrap_err().to_string();
1132 assert!(msg.contains("403"), "expected 403 in error: {msg}");
1133 }
1134
1135 #[tokio::test]
1136 async fn fetch_html_404_returns_error() {
1137 use wiremock::matchers::{method, path};
1138 use wiremock::{Mock, ResponseTemplate};
1139
1140 let (executor, server) = mock_server_executor().await;
1141 Mock::given(method("GET"))
1142 .and(path("/missing"))
1143 .respond_with(ResponseTemplate::new(404))
1144 .mount(&server)
1145 .await;
1146
1147 let (host, addrs) = server_host_and_addr(&server);
1148 let url = format!("{}/missing", server.uri());
1149 let result = executor.fetch_html(&url, &host, &addrs).await;
1150 assert!(result.is_err());
1151 let msg = result.unwrap_err().to_string();
1152 assert!(msg.contains("404"), "expected 404 in error: {msg}");
1153 }
1154
1155 #[tokio::test]
1156 async fn fetch_html_redirect_no_location_returns_error() {
1157 use wiremock::matchers::{method, path};
1158 use wiremock::{Mock, ResponseTemplate};
1159
1160 let (executor, server) = mock_server_executor().await;
1161 Mock::given(method("GET"))
1163 .and(path("/redirect-no-loc"))
1164 .respond_with(ResponseTemplate::new(302))
1165 .mount(&server)
1166 .await;
1167
1168 let (host, addrs) = server_host_and_addr(&server);
1169 let url = format!("{}/redirect-no-loc", server.uri());
1170 let result = executor.fetch_html(&url, &host, &addrs).await;
1171 assert!(result.is_err());
1172 let msg = result.unwrap_err().to_string();
1173 assert!(
1174 msg.contains("Location") || msg.contains("location"),
1175 "expected Location-related error: {msg}"
1176 );
1177 }
1178
1179 #[tokio::test]
1180 async fn fetch_html_single_redirect_followed() {
1181 use wiremock::matchers::{method, path};
1182 use wiremock::{Mock, ResponseTemplate};
1183
1184 let (executor, server) = mock_server_executor().await;
1185 let final_url = format!("{}/final", server.uri());
1186
1187 Mock::given(method("GET"))
1188 .and(path("/start"))
1189 .respond_with(ResponseTemplate::new(302).insert_header("location", final_url.as_str()))
1190 .mount(&server)
1191 .await;
1192
1193 Mock::given(method("GET"))
1194 .and(path("/final"))
1195 .respond_with(ResponseTemplate::new(200).set_body_string("<p>final</p>"))
1196 .mount(&server)
1197 .await;
1198
1199 let (host, addrs) = server_host_and_addr(&server);
1200 let url = format!("{}/start", server.uri());
1201 let result = follow_redirects_raw(&executor, &url, &host, &addrs).await;
1202 assert!(result.is_ok(), "single redirect should succeed: {result:?}");
1203 assert_eq!(result.unwrap(), "<p>final</p>");
1204 }
1205
1206 #[tokio::test]
1207 async fn fetch_html_three_redirects_allowed() {
1208 use wiremock::matchers::{method, path};
1209 use wiremock::{Mock, ResponseTemplate};
1210
1211 let (executor, server) = mock_server_executor().await;
1212 let hop2 = format!("{}/hop2", server.uri());
1213 let hop3 = format!("{}/hop3", server.uri());
1214 let final_dest = format!("{}/done", server.uri());
1215
1216 Mock::given(method("GET"))
1217 .and(path("/hop1"))
1218 .respond_with(ResponseTemplate::new(301).insert_header("location", hop2.as_str()))
1219 .mount(&server)
1220 .await;
1221 Mock::given(method("GET"))
1222 .and(path("/hop2"))
1223 .respond_with(ResponseTemplate::new(301).insert_header("location", hop3.as_str()))
1224 .mount(&server)
1225 .await;
1226 Mock::given(method("GET"))
1227 .and(path("/hop3"))
1228 .respond_with(ResponseTemplate::new(301).insert_header("location", final_dest.as_str()))
1229 .mount(&server)
1230 .await;
1231 Mock::given(method("GET"))
1232 .and(path("/done"))
1233 .respond_with(ResponseTemplate::new(200).set_body_string("<p>done</p>"))
1234 .mount(&server)
1235 .await;
1236
1237 let (host, addrs) = server_host_and_addr(&server);
1238 let url = format!("{}/hop1", server.uri());
1239 let result = follow_redirects_raw(&executor, &url, &host, &addrs).await;
1240 assert!(result.is_ok(), "3 redirects should succeed: {result:?}");
1241 assert_eq!(result.unwrap(), "<p>done</p>");
1242 }
1243
1244 #[tokio::test]
1245 async fn fetch_html_four_redirects_rejected() {
1246 use wiremock::matchers::{method, path};
1247 use wiremock::{Mock, ResponseTemplate};
1248
1249 let (executor, server) = mock_server_executor().await;
1250 let hop2 = format!("{}/r2", server.uri());
1251 let hop3 = format!("{}/r3", server.uri());
1252 let hop4 = format!("{}/r4", server.uri());
1253 let hop5 = format!("{}/r5", server.uri());
1254
1255 for (from, to) in [
1256 ("/r1", &hop2),
1257 ("/r2", &hop3),
1258 ("/r3", &hop4),
1259 ("/r4", &hop5),
1260 ] {
1261 Mock::given(method("GET"))
1262 .and(path(from))
1263 .respond_with(ResponseTemplate::new(301).insert_header("location", to.as_str()))
1264 .mount(&server)
1265 .await;
1266 }
1267
1268 let (host, addrs) = server_host_and_addr(&server);
1269 let url = format!("{}/r1", server.uri());
1270 let result = follow_redirects_raw(&executor, &url, &host, &addrs).await;
1271 assert!(result.is_err(), "4 redirects should be rejected");
1272 let msg = result.unwrap_err().to_string();
1273 assert!(
1274 msg.contains("redirect"),
1275 "expected redirect-related error: {msg}"
1276 );
1277 }
1278
1279 #[tokio::test]
1280 async fn fetch_html_body_too_large_returns_error() {
1281 use wiremock::matchers::{method, path};
1282 use wiremock::{Mock, ResponseTemplate};
1283
1284 let small_limit_executor = WebScrapeExecutor {
1285 timeout: Duration::from_secs(5),
1286 max_body_bytes: 10,
1287 audit_logger: None,
1288 };
1289 let server = wiremock::MockServer::start().await;
1290 Mock::given(method("GET"))
1291 .and(path("/big"))
1292 .respond_with(
1293 ResponseTemplate::new(200)
1294 .set_body_string("this body is definitely longer than ten bytes"),
1295 )
1296 .mount(&server)
1297 .await;
1298
1299 let (host, addrs) = server_host_and_addr(&server);
1300 let url = format!("{}/big", server.uri());
1301 let result = small_limit_executor.fetch_html(&url, &host, &addrs).await;
1302 assert!(result.is_err());
1303 let msg = result.unwrap_err().to_string();
1304 assert!(msg.contains("too large"), "expected too-large error: {msg}");
1305 }
1306
1307 #[test]
1308 fn extract_scrape_blocks_empty_block_content() {
1309 let text = "```scrape\n\n```";
1310 let blocks = extract_scrape_blocks(text);
1311 assert_eq!(blocks.len(), 1);
1312 assert!(blocks[0].is_empty());
1313 }
1314
1315 #[test]
1316 fn extract_scrape_blocks_whitespace_only() {
1317 let text = "```scrape\n \n```";
1318 let blocks = extract_scrape_blocks(text);
1319 assert_eq!(blocks.len(), 1);
1320 }
1321
1322 #[test]
1323 fn parse_and_extract_multiple_selectors() {
1324 let html = "<div><h1>Title</h1><p>Para</p></div>";
1325 let result = parse_and_extract(html, "h1, p", &ExtractMode::Text, 10).unwrap();
1326 assert!(result.contains("Title"));
1327 assert!(result.contains("Para"));
1328 }
1329
1330 #[test]
1331 fn webscrape_executor_new_with_custom_config() {
1332 let config = ScrapeConfig {
1333 timeout: 60,
1334 max_body_bytes: 512,
1335 };
1336 let executor = WebScrapeExecutor::new(&config);
1337 assert_eq!(executor.max_body_bytes, 512);
1338 }
1339
1340 #[test]
1341 fn webscrape_executor_debug() {
1342 let config = ScrapeConfig::default();
1343 let executor = WebScrapeExecutor::new(&config);
1344 let dbg = format!("{executor:?}");
1345 assert!(dbg.contains("WebScrapeExecutor"));
1346 }
1347
1348 #[test]
1349 fn extract_mode_attr_empty_name() {
1350 let mode = ExtractMode::parse("attr:");
1351 assert!(matches!(mode, ExtractMode::Attr(ref s) if s.is_empty()));
1352 }
1353
1354 #[test]
1355 fn default_extract_returns_text() {
1356 assert_eq!(default_extract(), "text");
1357 }
1358
1359 #[test]
1360 fn scrape_instruction_debug() {
1361 let json = r#"{"url":"https://example.com","select":"h1"}"#;
1362 let instr: ScrapeInstruction = serde_json::from_str(json).unwrap();
1363 let dbg = format!("{instr:?}");
1364 assert!(dbg.contains("ScrapeInstruction"));
1365 }
1366
1367 #[test]
1368 fn extract_mode_debug() {
1369 let mode = ExtractMode::Text;
1370 let dbg = format!("{mode:?}");
1371 assert!(dbg.contains("Text"));
1372 }
1373
1374 #[test]
1379 fn max_redirects_constant_is_three() {
1380 const MAX_REDIRECTS: usize = 3;
1384 assert_eq!(MAX_REDIRECTS, 3, "fetch_html allows exactly 3 redirects");
1385 }
1386
1387 #[test]
1390 fn redirect_no_location_error_message() {
1391 let err = std::io::Error::other("redirect with no Location");
1392 assert!(err.to_string().contains("redirect with no Location"));
1393 }
1394
1395 #[test]
1397 fn too_many_redirects_error_message() {
1398 let err = std::io::Error::other("too many redirects");
1399 assert!(err.to_string().contains("too many redirects"));
1400 }
1401
1402 #[test]
1404 fn non_2xx_status_error_format() {
1405 let status = reqwest::StatusCode::FORBIDDEN;
1406 let msg = format!("HTTP {status}");
1407 assert!(msg.contains("403"));
1408 }
1409
1410 #[test]
1412 fn not_found_status_error_format() {
1413 let status = reqwest::StatusCode::NOT_FOUND;
1414 let msg = format!("HTTP {status}");
1415 assert!(msg.contains("404"));
1416 }
1417
1418 #[test]
1420 fn relative_redirect_same_host_path() {
1421 let base = Url::parse("https://example.com/current").unwrap();
1422 let resolved = base.join("/other").unwrap();
1423 assert_eq!(resolved.as_str(), "https://example.com/other");
1424 }
1425
1426 #[test]
1428 fn relative_redirect_relative_path() {
1429 let base = Url::parse("https://example.com/a/b").unwrap();
1430 let resolved = base.join("c").unwrap();
1431 assert_eq!(resolved.as_str(), "https://example.com/a/c");
1432 }
1433
1434 #[test]
1436 fn absolute_redirect_overrides_base() {
1437 let base = Url::parse("https://example.com/page").unwrap();
1438 let resolved = base.join("https://other.com/target").unwrap();
1439 assert_eq!(resolved.as_str(), "https://other.com/target");
1440 }
1441
1442 #[test]
1444 fn redirect_http_downgrade_rejected() {
1445 let location = "http://example.com/page";
1446 let base = Url::parse("https://example.com/start").unwrap();
1447 let next = base.join(location).unwrap();
1448 let err = validate_url(next.as_str()).unwrap_err();
1449 assert!(matches!(err, ToolError::Blocked { .. }));
1450 }
1451
1452 #[test]
1454 fn redirect_location_private_ip_blocked() {
1455 let location = "https://192.168.100.1/admin";
1456 let base = Url::parse("https://example.com/start").unwrap();
1457 let next = base.join(location).unwrap();
1458 let err = validate_url(next.as_str()).unwrap_err();
1459 assert!(matches!(err, ToolError::Blocked { .. }));
1460 let ToolError::Blocked { command: cmd } = err else {
1461 panic!("expected Blocked");
1462 };
1463 assert!(
1464 cmd.contains("private") || cmd.contains("scheme"),
1465 "error message should describe the block reason: {cmd}"
1466 );
1467 }
1468
1469 #[test]
1471 fn redirect_location_internal_domain_blocked() {
1472 let location = "https://metadata.internal/latest/meta-data/";
1473 let base = Url::parse("https://example.com/start").unwrap();
1474 let next = base.join(location).unwrap();
1475 let err = validate_url(next.as_str()).unwrap_err();
1476 assert!(matches!(err, ToolError::Blocked { .. }));
1477 }
1478
1479 #[test]
1481 fn redirect_chain_three_hops_all_public() {
1482 let hops = [
1483 "https://redirect1.example.com/hop1",
1484 "https://redirect2.example.com/hop2",
1485 "https://destination.example.com/final",
1486 ];
1487 for hop in hops {
1488 assert!(validate_url(hop).is_ok(), "expected ok for {hop}");
1489 }
1490 }
1491
1492 #[test]
1497 fn redirect_to_private_ip_rejected_by_validate_url() {
1498 let private_targets = [
1500 "https://127.0.0.1/secret",
1501 "https://10.0.0.1/internal",
1502 "https://192.168.1.1/admin",
1503 "https://172.16.0.1/data",
1504 "https://[::1]/path",
1505 "https://[fe80::1]/path",
1506 "https://localhost/path",
1507 "https://service.internal/api",
1508 ];
1509 for target in private_targets {
1510 let result = validate_url(target);
1511 assert!(result.is_err(), "expected error for {target}");
1512 assert!(
1513 matches!(result.unwrap_err(), ToolError::Blocked { .. }),
1514 "expected Blocked for {target}"
1515 );
1516 }
1517 }
1518
1519 #[test]
1521 fn redirect_relative_url_resolves_correctly() {
1522 let base = Url::parse("https://example.com/page").unwrap();
1523 let relative = "/other";
1524 let resolved = base.join(relative).unwrap();
1525 assert_eq!(resolved.as_str(), "https://example.com/other");
1526 }
1527
1528 #[test]
1530 fn redirect_to_http_rejected() {
1531 let err = validate_url("http://example.com/page").unwrap_err();
1532 assert!(matches!(err, ToolError::Blocked { .. }));
1533 }
1534
1535 #[test]
1536 fn ipv4_mapped_ipv6_link_local_blocked() {
1537 let err = validate_url("https://[::ffff:169.254.0.1]/path").unwrap_err();
1538 assert!(matches!(err, ToolError::Blocked { .. }));
1539 }
1540
1541 #[test]
1542 fn ipv4_mapped_ipv6_public_allowed() {
1543 assert!(validate_url("https://[::ffff:93.184.216.34]/path").is_ok());
1544 }
1545
1546 #[tokio::test]
1549 async fn fetch_http_scheme_blocked() {
1550 let config = ScrapeConfig::default();
1551 let executor = WebScrapeExecutor::new(&config);
1552 let call = crate::executor::ToolCall {
1553 tool_id: "fetch".to_owned(),
1554 params: {
1555 let mut m = serde_json::Map::new();
1556 m.insert("url".to_owned(), serde_json::json!("http://example.com"));
1557 m
1558 },
1559 };
1560 let result = executor.execute_tool_call(&call).await;
1561 assert!(matches!(result, Err(ToolError::Blocked { .. })));
1562 }
1563
1564 #[tokio::test]
1565 async fn fetch_private_ip_blocked() {
1566 let config = ScrapeConfig::default();
1567 let executor = WebScrapeExecutor::new(&config);
1568 let call = crate::executor::ToolCall {
1569 tool_id: "fetch".to_owned(),
1570 params: {
1571 let mut m = serde_json::Map::new();
1572 m.insert(
1573 "url".to_owned(),
1574 serde_json::json!("https://192.168.1.1/secret"),
1575 );
1576 m
1577 },
1578 };
1579 let result = executor.execute_tool_call(&call).await;
1580 assert!(matches!(result, Err(ToolError::Blocked { .. })));
1581 }
1582
1583 #[tokio::test]
1584 async fn fetch_localhost_blocked() {
1585 let config = ScrapeConfig::default();
1586 let executor = WebScrapeExecutor::new(&config);
1587 let call = crate::executor::ToolCall {
1588 tool_id: "fetch".to_owned(),
1589 params: {
1590 let mut m = serde_json::Map::new();
1591 m.insert(
1592 "url".to_owned(),
1593 serde_json::json!("https://localhost/page"),
1594 );
1595 m
1596 },
1597 };
1598 let result = executor.execute_tool_call(&call).await;
1599 assert!(matches!(result, Err(ToolError::Blocked { .. })));
1600 }
1601
1602 #[tokio::test]
1603 async fn fetch_unknown_tool_returns_none() {
1604 let config = ScrapeConfig::default();
1605 let executor = WebScrapeExecutor::new(&config);
1606 let call = crate::executor::ToolCall {
1607 tool_id: "unknown_tool".to_owned(),
1608 params: serde_json::Map::new(),
1609 };
1610 let result = executor.execute_tool_call(&call).await;
1611 assert!(result.unwrap().is_none());
1612 }
1613
1614 #[tokio::test]
1615 async fn fetch_returns_body_via_mock() {
1616 use wiremock::matchers::{method, path};
1617 use wiremock::{Mock, ResponseTemplate};
1618
1619 let (executor, server) = mock_server_executor().await;
1620 Mock::given(method("GET"))
1621 .and(path("/content"))
1622 .respond_with(ResponseTemplate::new(200).set_body_string("plain text content"))
1623 .mount(&server)
1624 .await;
1625
1626 let (host, addrs) = server_host_and_addr(&server);
1627 let url = format!("{}/content", server.uri());
1628 let result = executor.fetch_html(&url, &host, &addrs).await;
1629 assert!(result.is_ok());
1630 assert_eq!(result.unwrap(), "plain text content");
1631 }
1632
1633 #[test]
1634 fn tool_definitions_returns_web_scrape_and_fetch() {
1635 let config = ScrapeConfig::default();
1636 let executor = WebScrapeExecutor::new(&config);
1637 let defs = executor.tool_definitions();
1638 assert_eq!(defs.len(), 2);
1639 assert_eq!(defs[0].id, "web_scrape");
1640 assert_eq!(
1641 defs[0].invocation,
1642 crate::registry::InvocationHint::FencedBlock("scrape")
1643 );
1644 assert_eq!(defs[1].id, "fetch");
1645 assert_eq!(
1646 defs[1].invocation,
1647 crate::registry::InvocationHint::ToolCall
1648 );
1649 }
1650
1651 #[test]
1652 fn tool_definitions_schema_has_all_params() {
1653 let config = ScrapeConfig::default();
1654 let executor = WebScrapeExecutor::new(&config);
1655 let defs = executor.tool_definitions();
1656 let obj = defs[0].schema.as_object().unwrap();
1657 let props = obj["properties"].as_object().unwrap();
1658 assert!(props.contains_key("url"));
1659 assert!(props.contains_key("select"));
1660 assert!(props.contains_key("extract"));
1661 assert!(props.contains_key("limit"));
1662 let req = obj["required"].as_array().unwrap();
1663 assert!(req.iter().any(|v| v.as_str() == Some("url")));
1664 assert!(req.iter().any(|v| v.as_str() == Some("select")));
1665 assert!(!req.iter().any(|v| v.as_str() == Some("extract")));
1666 }
1667
1668 #[test]
1671 fn subdomain_localhost_blocked() {
1672 let host: url::Host<&str> = url::Host::Domain("foo.localhost");
1673 assert!(is_private_host(&host));
1674 }
1675
1676 #[test]
1677 fn internal_tld_blocked() {
1678 let host: url::Host<&str> = url::Host::Domain("service.internal");
1679 assert!(is_private_host(&host));
1680 }
1681
1682 #[test]
1683 fn local_tld_blocked() {
1684 let host: url::Host<&str> = url::Host::Domain("printer.local");
1685 assert!(is_private_host(&host));
1686 }
1687
1688 #[test]
1689 fn public_domain_not_blocked() {
1690 let host: url::Host<&str> = url::Host::Domain("example.com");
1691 assert!(!is_private_host(&host));
1692 }
1693
1694 #[tokio::test]
1697 async fn resolve_loopback_rejected() {
1698 let url = url::Url::parse("https://127.0.0.1/path").unwrap();
1700 let result = resolve_and_validate(&url).await;
1702 assert!(
1703 result.is_err(),
1704 "loopback IP must be rejected by resolve_and_validate"
1705 );
1706 let err = result.unwrap_err();
1707 assert!(matches!(err, crate::executor::ToolError::Blocked { .. }));
1708 }
1709
1710 #[tokio::test]
1711 async fn resolve_private_10_rejected() {
1712 let url = url::Url::parse("https://10.0.0.1/path").unwrap();
1713 let result = resolve_and_validate(&url).await;
1714 assert!(result.is_err());
1715 assert!(matches!(
1716 result.unwrap_err(),
1717 crate::executor::ToolError::Blocked { .. }
1718 ));
1719 }
1720
1721 #[tokio::test]
1722 async fn resolve_private_192_rejected() {
1723 let url = url::Url::parse("https://192.168.1.1/path").unwrap();
1724 let result = resolve_and_validate(&url).await;
1725 assert!(result.is_err());
1726 assert!(matches!(
1727 result.unwrap_err(),
1728 crate::executor::ToolError::Blocked { .. }
1729 ));
1730 }
1731
1732 #[tokio::test]
1733 async fn resolve_ipv6_loopback_rejected() {
1734 let url = url::Url::parse("https://[::1]/path").unwrap();
1735 let result = resolve_and_validate(&url).await;
1736 assert!(result.is_err());
1737 assert!(matches!(
1738 result.unwrap_err(),
1739 crate::executor::ToolError::Blocked { .. }
1740 ));
1741 }
1742
1743 #[tokio::test]
1744 async fn resolve_no_host_returns_ok() {
1745 let url = url::Url::parse("https://example.com/path").unwrap();
1747 let url_no_host = url::Url::parse("data:text/plain,hello").unwrap();
1749 let result = resolve_and_validate(&url_no_host).await;
1751 assert!(result.is_ok());
1752 let (host, addrs) = result.unwrap();
1753 assert!(host.is_empty());
1754 assert!(addrs.is_empty());
1755 drop(url);
1756 drop(url_no_host);
1757 }
1758
1759 async fn make_file_audit_logger(
1763 dir: &tempfile::TempDir,
1764 ) -> (
1765 std::sync::Arc<crate::audit::AuditLogger>,
1766 std::path::PathBuf,
1767 ) {
1768 use crate::audit::AuditLogger;
1769 use crate::config::AuditConfig;
1770 let path = dir.path().join("audit.log");
1771 let config = AuditConfig {
1772 enabled: true,
1773 destination: path.display().to_string(),
1774 };
1775 let logger = std::sync::Arc::new(AuditLogger::from_config(&config).await.unwrap());
1776 (logger, path)
1777 }
1778
1779 #[tokio::test]
1780 async fn with_audit_sets_logger() {
1781 let config = ScrapeConfig::default();
1782 let executor = WebScrapeExecutor::new(&config);
1783 assert!(executor.audit_logger.is_none());
1784
1785 let dir = tempfile::tempdir().unwrap();
1786 let (logger, _path) = make_file_audit_logger(&dir).await;
1787 let executor = executor.with_audit(logger);
1788 assert!(executor.audit_logger.is_some());
1789 }
1790
1791 #[test]
1792 fn tool_error_to_audit_result_blocked_maps_correctly() {
1793 let err = ToolError::Blocked {
1794 command: "scheme not allowed: http".into(),
1795 };
1796 let result = tool_error_to_audit_result(&err);
1797 assert!(
1798 matches!(result, AuditResult::Blocked { reason } if reason == "scheme not allowed: http")
1799 );
1800 }
1801
1802 #[test]
1803 fn tool_error_to_audit_result_timeout_maps_correctly() {
1804 let err = ToolError::Timeout { timeout_secs: 15 };
1805 let result = tool_error_to_audit_result(&err);
1806 assert!(matches!(result, AuditResult::Timeout));
1807 }
1808
1809 #[test]
1810 fn tool_error_to_audit_result_execution_error_maps_correctly() {
1811 let err = ToolError::Execution(std::io::Error::other("connection refused"));
1812 let result = tool_error_to_audit_result(&err);
1813 assert!(
1814 matches!(result, AuditResult::Error { message } if message.contains("connection refused"))
1815 );
1816 }
1817
1818 #[tokio::test]
1819 async fn fetch_audit_blocked_url_logged() {
1820 let dir = tempfile::tempdir().unwrap();
1821 let (logger, log_path) = make_file_audit_logger(&dir).await;
1822
1823 let config = ScrapeConfig::default();
1824 let executor = WebScrapeExecutor::new(&config).with_audit(logger);
1825
1826 let call = crate::executor::ToolCall {
1827 tool_id: "fetch".to_owned(),
1828 params: {
1829 let mut m = serde_json::Map::new();
1830 m.insert("url".to_owned(), serde_json::json!("http://example.com"));
1831 m
1832 },
1833 };
1834 let result = executor.execute_tool_call(&call).await;
1835 assert!(matches!(result, Err(ToolError::Blocked { .. })));
1836
1837 let content = tokio::fs::read_to_string(&log_path).await.unwrap();
1838 assert!(
1839 content.contains("\"tool\":\"fetch\""),
1840 "expected tool=fetch in audit: {content}"
1841 );
1842 assert!(
1843 content.contains("\"type\":\"blocked\""),
1844 "expected type=blocked in audit: {content}"
1845 );
1846 assert!(
1847 content.contains("http://example.com"),
1848 "expected URL in audit command field: {content}"
1849 );
1850 }
1851
1852 #[tokio::test]
1853 async fn log_audit_success_writes_to_file() {
1854 let dir = tempfile::tempdir().unwrap();
1855 let (logger, log_path) = make_file_audit_logger(&dir).await;
1856
1857 let config = ScrapeConfig::default();
1858 let executor = WebScrapeExecutor::new(&config).with_audit(logger);
1859
1860 executor
1861 .log_audit(
1862 "fetch",
1863 "https://example.com/page",
1864 AuditResult::Success,
1865 42,
1866 None,
1867 )
1868 .await;
1869
1870 let content = tokio::fs::read_to_string(&log_path).await.unwrap();
1871 assert!(
1872 content.contains("\"tool\":\"fetch\""),
1873 "expected tool=fetch in audit: {content}"
1874 );
1875 assert!(
1876 content.contains("\"type\":\"success\""),
1877 "expected type=success in audit: {content}"
1878 );
1879 assert!(
1880 content.contains("\"command\":\"https://example.com/page\""),
1881 "expected command URL in audit: {content}"
1882 );
1883 assert!(
1884 content.contains("\"duration_ms\":42"),
1885 "expected duration_ms in audit: {content}"
1886 );
1887 }
1888
1889 #[tokio::test]
1890 async fn log_audit_blocked_writes_to_file() {
1891 let dir = tempfile::tempdir().unwrap();
1892 let (logger, log_path) = make_file_audit_logger(&dir).await;
1893
1894 let config = ScrapeConfig::default();
1895 let executor = WebScrapeExecutor::new(&config).with_audit(logger);
1896
1897 executor
1898 .log_audit(
1899 "web_scrape",
1900 "http://evil.com/page",
1901 AuditResult::Blocked {
1902 reason: "scheme not allowed: http".into(),
1903 },
1904 0,
1905 None,
1906 )
1907 .await;
1908
1909 let content = tokio::fs::read_to_string(&log_path).await.unwrap();
1910 assert!(
1911 content.contains("\"tool\":\"web_scrape\""),
1912 "expected tool=web_scrape in audit: {content}"
1913 );
1914 assert!(
1915 content.contains("\"type\":\"blocked\""),
1916 "expected type=blocked in audit: {content}"
1917 );
1918 assert!(
1919 content.contains("scheme not allowed"),
1920 "expected block reason in audit: {content}"
1921 );
1922 }
1923
1924 #[tokio::test]
1925 async fn web_scrape_audit_blocked_url_logged() {
1926 let dir = tempfile::tempdir().unwrap();
1927 let (logger, log_path) = make_file_audit_logger(&dir).await;
1928
1929 let config = ScrapeConfig::default();
1930 let executor = WebScrapeExecutor::new(&config).with_audit(logger);
1931
1932 let call = crate::executor::ToolCall {
1933 tool_id: "web_scrape".to_owned(),
1934 params: {
1935 let mut m = serde_json::Map::new();
1936 m.insert("url".to_owned(), serde_json::json!("http://example.com"));
1937 m.insert("select".to_owned(), serde_json::json!("h1"));
1938 m
1939 },
1940 };
1941 let result = executor.execute_tool_call(&call).await;
1942 assert!(matches!(result, Err(ToolError::Blocked { .. })));
1943
1944 let content = tokio::fs::read_to_string(&log_path).await.unwrap();
1945 assert!(
1946 content.contains("\"tool\":\"web_scrape\""),
1947 "expected tool=web_scrape in audit: {content}"
1948 );
1949 assert!(
1950 content.contains("\"type\":\"blocked\""),
1951 "expected type=blocked in audit: {content}"
1952 );
1953 }
1954
1955 #[tokio::test]
1956 async fn no_audit_logger_does_not_panic_on_blocked_fetch() {
1957 let config = ScrapeConfig::default();
1958 let executor = WebScrapeExecutor::new(&config);
1959 assert!(executor.audit_logger.is_none());
1960
1961 let call = crate::executor::ToolCall {
1962 tool_id: "fetch".to_owned(),
1963 params: {
1964 let mut m = serde_json::Map::new();
1965 m.insert("url".to_owned(), serde_json::json!("http://example.com"));
1966 m
1967 },
1968 };
1969 let result = executor.execute_tool_call(&call).await;
1971 assert!(matches!(result, Err(ToolError::Blocked { .. })));
1972 }
1973
1974 #[tokio::test]
1976 async fn fetch_execute_tool_call_end_to_end() {
1977 use wiremock::matchers::{method, path};
1978 use wiremock::{Mock, ResponseTemplate};
1979
1980 let (executor, server) = mock_server_executor().await;
1981 Mock::given(method("GET"))
1982 .and(path("/e2e"))
1983 .respond_with(ResponseTemplate::new(200).set_body_string("<h1>end-to-end</h1>"))
1984 .mount(&server)
1985 .await;
1986
1987 let (host, addrs) = server_host_and_addr(&server);
1988 let result = executor
1990 .fetch_html(&format!("{}/e2e", server.uri()), &host, &addrs)
1991 .await;
1992 assert!(result.is_ok());
1993 assert!(result.unwrap().contains("end-to-end"));
1994 }
1995}