1use std::net::{IpAddr, SocketAddr};
5use std::time::Duration;
6
7use schemars::JsonSchema;
8use serde::Deserialize;
9use url::Url;
10
11use crate::config::ScrapeConfig;
12use crate::executor::{ToolCall, ToolError, ToolExecutor, ToolOutput, deserialize_params};
13
14#[derive(Debug, Deserialize, JsonSchema)]
15struct FetchParams {
16 url: String,
18}
19
20#[derive(Debug, Deserialize, JsonSchema)]
21struct ScrapeInstruction {
22 url: String,
24 select: String,
26 #[serde(default = "default_extract")]
28 extract: String,
29 limit: Option<usize>,
31}
32
33fn default_extract() -> String {
34 "text".into()
35}
36
37#[derive(Debug)]
38enum ExtractMode {
39 Text,
40 Html,
41 Attr(String),
42}
43
44impl ExtractMode {
45 fn parse(s: &str) -> Self {
46 match s {
47 "text" => Self::Text,
48 "html" => Self::Html,
49 attr if attr.starts_with("attr:") => {
50 Self::Attr(attr.strip_prefix("attr:").unwrap_or(attr).to_owned())
51 }
52 _ => Self::Text,
53 }
54 }
55}
56
57#[derive(Debug)]
62pub struct WebScrapeExecutor {
63 timeout: Duration,
64 max_body_bytes: usize,
65}
66
67impl WebScrapeExecutor {
68 #[must_use]
69 pub fn new(config: &ScrapeConfig) -> Self {
70 Self {
71 timeout: Duration::from_secs(config.timeout),
72 max_body_bytes: config.max_body_bytes,
73 }
74 }
75
76 fn build_client(&self, host: &str, addrs: &[SocketAddr]) -> reqwest::Client {
77 let mut builder = reqwest::Client::builder()
78 .timeout(self.timeout)
79 .redirect(reqwest::redirect::Policy::none());
80 builder = builder.resolve_to_addrs(host, addrs);
81 builder.build().unwrap_or_default()
82 }
83}
84
85impl ToolExecutor for WebScrapeExecutor {
86 fn tool_definitions(&self) -> Vec<crate::registry::ToolDef> {
87 use crate::registry::{InvocationHint, ToolDef};
88 vec![
89 ToolDef {
90 id: "web_scrape".into(),
91 description: "Scrape data from a web page via CSS selectors".into(),
92 schema: schemars::schema_for!(ScrapeInstruction),
93 invocation: InvocationHint::FencedBlock("scrape"),
94 },
95 ToolDef {
96 id: "fetch".into(),
97 description: "Fetch a URL and return content as plain text".into(),
98 schema: schemars::schema_for!(FetchParams),
99 invocation: InvocationHint::ToolCall,
100 },
101 ]
102 }
103
104 async fn execute(&self, response: &str) -> Result<Option<ToolOutput>, ToolError> {
105 let blocks = extract_scrape_blocks(response);
106 if blocks.is_empty() {
107 return Ok(None);
108 }
109
110 let mut outputs = Vec::with_capacity(blocks.len());
111 #[allow(clippy::cast_possible_truncation)]
112 let blocks_executed = blocks.len() as u32;
113
114 for block in &blocks {
115 let instruction: ScrapeInstruction = serde_json::from_str(block).map_err(|e| {
116 ToolError::Execution(std::io::Error::new(
117 std::io::ErrorKind::InvalidData,
118 e.to_string(),
119 ))
120 })?;
121 outputs.push(self.scrape_instruction(&instruction).await?);
122 }
123
124 Ok(Some(ToolOutput {
125 tool_name: "web-scrape".to_owned(),
126 summary: outputs.join("\n\n"),
127 blocks_executed,
128 filter_stats: None,
129 diff: None,
130 streamed: false,
131 terminal_id: None,
132 locations: None,
133 raw_response: None,
134 }))
135 }
136
137 async fn execute_tool_call(&self, call: &ToolCall) -> Result<Option<ToolOutput>, ToolError> {
138 match call.tool_id.as_str() {
139 "web_scrape" => {
140 let instruction: ScrapeInstruction = deserialize_params(&call.params)?;
141 let result = self.scrape_instruction(&instruction).await?;
142 Ok(Some(ToolOutput {
143 tool_name: "web-scrape".to_owned(),
144 summary: result,
145 blocks_executed: 1,
146 filter_stats: None,
147 diff: None,
148 streamed: false,
149 terminal_id: None,
150 locations: None,
151 raw_response: None,
152 }))
153 }
154 "fetch" => {
155 let p: FetchParams = deserialize_params(&call.params)?;
156 let result = self.handle_fetch(&p).await?;
157 Ok(Some(ToolOutput {
158 tool_name: "fetch".to_owned(),
159 summary: result,
160 blocks_executed: 1,
161 filter_stats: None,
162 diff: None,
163 streamed: false,
164 terminal_id: None,
165 locations: None,
166 raw_response: None,
167 }))
168 }
169 _ => Ok(None),
170 }
171 }
172}
173
174impl WebScrapeExecutor {
175 async fn handle_fetch(&self, params: &FetchParams) -> Result<String, ToolError> {
176 let parsed = validate_url(¶ms.url)?;
177 let (host, addrs) = resolve_and_validate(&parsed).await?;
178 self.fetch_html(¶ms.url, &host, &addrs).await
179 }
180
181 async fn scrape_instruction(
182 &self,
183 instruction: &ScrapeInstruction,
184 ) -> Result<String, ToolError> {
185 let parsed = validate_url(&instruction.url)?;
186 let (host, addrs) = resolve_and_validate(&parsed).await?;
187 let html = self.fetch_html(&instruction.url, &host, &addrs).await?;
188 let selector = instruction.select.clone();
189 let extract = ExtractMode::parse(&instruction.extract);
190 let limit = instruction.limit.unwrap_or(10);
191 tokio::task::spawn_blocking(move || parse_and_extract(&html, &selector, &extract, limit))
192 .await
193 .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())))?
194 }
195
196 async fn fetch_html(
206 &self,
207 url: &str,
208 host: &str,
209 addrs: &[SocketAddr],
210 ) -> Result<String, ToolError> {
211 const MAX_REDIRECTS: usize = 3;
212
213 let mut current_url = url.to_owned();
214 let mut current_host = host.to_owned();
215 let mut current_addrs = addrs.to_vec();
216
217 for hop in 0..=MAX_REDIRECTS {
218 let client = self.build_client(¤t_host, ¤t_addrs);
220 let resp = client
221 .get(¤t_url)
222 .send()
223 .await
224 .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())))?;
225
226 let status = resp.status();
227
228 if status.is_redirection() {
229 if hop == MAX_REDIRECTS {
230 return Err(ToolError::Execution(std::io::Error::other(
231 "too many redirects",
232 )));
233 }
234
235 let location = resp
236 .headers()
237 .get(reqwest::header::LOCATION)
238 .and_then(|v| v.to_str().ok())
239 .ok_or_else(|| {
240 ToolError::Execution(std::io::Error::other("redirect with no Location"))
241 })?;
242
243 let base = Url::parse(¤t_url)
245 .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())))?;
246 let next_url = base
247 .join(location)
248 .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())))?;
249
250 let validated = validate_url(next_url.as_str())?;
251 let (next_host, next_addrs) = resolve_and_validate(&validated).await?;
252
253 current_url = next_url.to_string();
254 current_host = next_host;
255 current_addrs = next_addrs;
256 continue;
257 }
258
259 if !status.is_success() {
260 return Err(ToolError::Execution(std::io::Error::other(format!(
261 "HTTP {status}",
262 ))));
263 }
264
265 let bytes = resp
266 .bytes()
267 .await
268 .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())))?;
269
270 if bytes.len() > self.max_body_bytes {
271 return Err(ToolError::Execution(std::io::Error::other(format!(
272 "response too large: {} bytes (max: {})",
273 bytes.len(),
274 self.max_body_bytes,
275 ))));
276 }
277
278 return String::from_utf8(bytes.to_vec())
279 .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())));
280 }
281
282 Err(ToolError::Execution(std::io::Error::other(
283 "too many redirects",
284 )))
285 }
286}
287
288fn extract_scrape_blocks(text: &str) -> Vec<&str> {
289 crate::executor::extract_fenced_blocks(text, "scrape")
290}
291
292fn validate_url(raw: &str) -> Result<Url, ToolError> {
293 let parsed = Url::parse(raw).map_err(|_| ToolError::Blocked {
294 command: format!("invalid URL: {raw}"),
295 })?;
296
297 if parsed.scheme() != "https" {
298 return Err(ToolError::Blocked {
299 command: format!("scheme not allowed: {}", parsed.scheme()),
300 });
301 }
302
303 if let Some(host) = parsed.host()
304 && is_private_host(&host)
305 {
306 return Err(ToolError::Blocked {
307 command: format!(
308 "private/local host blocked: {}",
309 parsed.host_str().unwrap_or("")
310 ),
311 });
312 }
313
314 Ok(parsed)
315}
316
317pub(crate) fn is_private_ip(ip: IpAddr) -> bool {
318 match ip {
319 IpAddr::V4(v4) => {
320 v4.is_loopback()
321 || v4.is_private()
322 || v4.is_link_local()
323 || v4.is_unspecified()
324 || v4.is_broadcast()
325 }
326 IpAddr::V6(v6) => {
327 if v6.is_loopback() || v6.is_unspecified() {
328 return true;
329 }
330 let seg = v6.segments();
331 if seg[0] & 0xffc0 == 0xfe80 {
333 return true;
334 }
335 if seg[0] & 0xfe00 == 0xfc00 {
337 return true;
338 }
339 if seg[0..6] == [0, 0, 0, 0, 0, 0xffff] {
341 let v4 = v6
342 .to_ipv4_mapped()
343 .unwrap_or(std::net::Ipv4Addr::UNSPECIFIED);
344 return v4.is_loopback()
345 || v4.is_private()
346 || v4.is_link_local()
347 || v4.is_unspecified()
348 || v4.is_broadcast();
349 }
350 false
351 }
352 }
353}
354
355fn is_private_host(host: &url::Host<&str>) -> bool {
356 match host {
357 url::Host::Domain(d) => {
358 #[allow(clippy::case_sensitive_file_extension_comparisons)]
361 {
362 *d == "localhost"
363 || d.ends_with(".localhost")
364 || d.ends_with(".internal")
365 || d.ends_with(".local")
366 }
367 }
368 url::Host::Ipv4(v4) => is_private_ip(IpAddr::V4(*v4)),
369 url::Host::Ipv6(v6) => is_private_ip(IpAddr::V6(*v6)),
370 }
371}
372
373async fn resolve_and_validate(url: &Url) -> Result<(String, Vec<SocketAddr>), ToolError> {
379 let Some(host) = url.host_str() else {
380 return Ok((String::new(), vec![]));
381 };
382 let port = url.port_or_known_default().unwrap_or(443);
383 let addrs: Vec<SocketAddr> = tokio::net::lookup_host(format!("{host}:{port}"))
384 .await
385 .map_err(|e| ToolError::Blocked {
386 command: format!("DNS resolution failed: {e}"),
387 })?
388 .collect();
389 for addr in &addrs {
390 if is_private_ip(addr.ip()) {
391 return Err(ToolError::Blocked {
392 command: format!("SSRF protection: private IP {} for host {host}", addr.ip()),
393 });
394 }
395 }
396 Ok((host.to_owned(), addrs))
397}
398
399fn parse_and_extract(
400 html: &str,
401 selector: &str,
402 extract: &ExtractMode,
403 limit: usize,
404) -> Result<String, ToolError> {
405 let soup = scrape_core::Soup::parse(html);
406
407 let tags = soup.find_all(selector).map_err(|e| {
408 ToolError::Execution(std::io::Error::new(
409 std::io::ErrorKind::InvalidData,
410 format!("invalid selector: {e}"),
411 ))
412 })?;
413
414 let mut results = Vec::new();
415
416 for tag in tags.into_iter().take(limit) {
417 let value = match extract {
418 ExtractMode::Text => tag.text(),
419 ExtractMode::Html => tag.inner_html(),
420 ExtractMode::Attr(name) => tag.get(name).unwrap_or_default().to_owned(),
421 };
422 if !value.trim().is_empty() {
423 results.push(value.trim().to_owned());
424 }
425 }
426
427 if results.is_empty() {
428 Ok(format!("No results for selector: {selector}"))
429 } else {
430 Ok(results.join("\n"))
431 }
432}
433
434#[cfg(test)]
435mod tests {
436 use super::*;
437
438 #[test]
441 fn extract_single_block() {
442 let text =
443 "Here:\n```scrape\n{\"url\":\"https://example.com\",\"select\":\"h1\"}\n```\nDone.";
444 let blocks = extract_scrape_blocks(text);
445 assert_eq!(blocks.len(), 1);
446 assert!(blocks[0].contains("example.com"));
447 }
448
449 #[test]
450 fn extract_multiple_blocks() {
451 let text = "```scrape\n{\"url\":\"https://a.com\",\"select\":\"h1\"}\n```\ntext\n```scrape\n{\"url\":\"https://b.com\",\"select\":\"p\"}\n```";
452 let blocks = extract_scrape_blocks(text);
453 assert_eq!(blocks.len(), 2);
454 }
455
456 #[test]
457 fn no_blocks_returns_empty() {
458 let blocks = extract_scrape_blocks("plain text, no code blocks");
459 assert!(blocks.is_empty());
460 }
461
462 #[test]
463 fn unclosed_block_ignored() {
464 let blocks = extract_scrape_blocks("```scrape\n{\"url\":\"https://x.com\"}");
465 assert!(blocks.is_empty());
466 }
467
468 #[test]
469 fn non_scrape_block_ignored() {
470 let text =
471 "```bash\necho hi\n```\n```scrape\n{\"url\":\"https://x.com\",\"select\":\"h1\"}\n```";
472 let blocks = extract_scrape_blocks(text);
473 assert_eq!(blocks.len(), 1);
474 assert!(blocks[0].contains("x.com"));
475 }
476
477 #[test]
478 fn multiline_json_block() {
479 let text =
480 "```scrape\n{\n \"url\": \"https://example.com\",\n \"select\": \"h1\"\n}\n```";
481 let blocks = extract_scrape_blocks(text);
482 assert_eq!(blocks.len(), 1);
483 let instr: ScrapeInstruction = serde_json::from_str(blocks[0]).unwrap();
484 assert_eq!(instr.url, "https://example.com");
485 }
486
487 #[test]
490 fn parse_valid_instruction() {
491 let json = r#"{"url":"https://example.com","select":"h1","extract":"text","limit":5}"#;
492 let instr: ScrapeInstruction = serde_json::from_str(json).unwrap();
493 assert_eq!(instr.url, "https://example.com");
494 assert_eq!(instr.select, "h1");
495 assert_eq!(instr.extract, "text");
496 assert_eq!(instr.limit, Some(5));
497 }
498
499 #[test]
500 fn parse_minimal_instruction() {
501 let json = r#"{"url":"https://example.com","select":"p"}"#;
502 let instr: ScrapeInstruction = serde_json::from_str(json).unwrap();
503 assert_eq!(instr.extract, "text");
504 assert!(instr.limit.is_none());
505 }
506
507 #[test]
508 fn parse_attr_extract() {
509 let json = r#"{"url":"https://example.com","select":"a","extract":"attr:href"}"#;
510 let instr: ScrapeInstruction = serde_json::from_str(json).unwrap();
511 assert_eq!(instr.extract, "attr:href");
512 }
513
514 #[test]
515 fn parse_invalid_json_errors() {
516 let result = serde_json::from_str::<ScrapeInstruction>("not json");
517 assert!(result.is_err());
518 }
519
520 #[test]
523 fn extract_mode_text() {
524 assert!(matches!(ExtractMode::parse("text"), ExtractMode::Text));
525 }
526
527 #[test]
528 fn extract_mode_html() {
529 assert!(matches!(ExtractMode::parse("html"), ExtractMode::Html));
530 }
531
532 #[test]
533 fn extract_mode_attr() {
534 let mode = ExtractMode::parse("attr:href");
535 assert!(matches!(mode, ExtractMode::Attr(ref s) if s == "href"));
536 }
537
538 #[test]
539 fn extract_mode_unknown_defaults_to_text() {
540 assert!(matches!(ExtractMode::parse("unknown"), ExtractMode::Text));
541 }
542
543 #[test]
546 fn valid_https_url() {
547 assert!(validate_url("https://example.com").is_ok());
548 }
549
550 #[test]
551 fn http_rejected() {
552 let err = validate_url("http://example.com").unwrap_err();
553 assert!(matches!(err, ToolError::Blocked { .. }));
554 }
555
556 #[test]
557 fn ftp_rejected() {
558 let err = validate_url("ftp://files.example.com").unwrap_err();
559 assert!(matches!(err, ToolError::Blocked { .. }));
560 }
561
562 #[test]
563 fn file_rejected() {
564 let err = validate_url("file:///etc/passwd").unwrap_err();
565 assert!(matches!(err, ToolError::Blocked { .. }));
566 }
567
568 #[test]
569 fn invalid_url_rejected() {
570 let err = validate_url("not a url").unwrap_err();
571 assert!(matches!(err, ToolError::Blocked { .. }));
572 }
573
574 #[test]
575 fn localhost_blocked() {
576 let err = validate_url("https://localhost/path").unwrap_err();
577 assert!(matches!(err, ToolError::Blocked { .. }));
578 }
579
580 #[test]
581 fn loopback_ip_blocked() {
582 let err = validate_url("https://127.0.0.1/path").unwrap_err();
583 assert!(matches!(err, ToolError::Blocked { .. }));
584 }
585
586 #[test]
587 fn private_10_blocked() {
588 let err = validate_url("https://10.0.0.1/api").unwrap_err();
589 assert!(matches!(err, ToolError::Blocked { .. }));
590 }
591
592 #[test]
593 fn private_172_blocked() {
594 let err = validate_url("https://172.16.0.1/api").unwrap_err();
595 assert!(matches!(err, ToolError::Blocked { .. }));
596 }
597
598 #[test]
599 fn private_192_blocked() {
600 let err = validate_url("https://192.168.1.1/api").unwrap_err();
601 assert!(matches!(err, ToolError::Blocked { .. }));
602 }
603
604 #[test]
605 fn ipv6_loopback_blocked() {
606 let err = validate_url("https://[::1]/path").unwrap_err();
607 assert!(matches!(err, ToolError::Blocked { .. }));
608 }
609
610 #[test]
611 fn public_ip_allowed() {
612 assert!(validate_url("https://93.184.216.34/page").is_ok());
613 }
614
615 #[test]
618 fn extract_text_from_html() {
619 let html = "<html><body><h1>Hello World</h1><p>Content</p></body></html>";
620 let result = parse_and_extract(html, "h1", &ExtractMode::Text, 10).unwrap();
621 assert_eq!(result, "Hello World");
622 }
623
624 #[test]
625 fn extract_multiple_elements() {
626 let html = "<ul><li>A</li><li>B</li><li>C</li></ul>";
627 let result = parse_and_extract(html, "li", &ExtractMode::Text, 10).unwrap();
628 assert_eq!(result, "A\nB\nC");
629 }
630
631 #[test]
632 fn extract_with_limit() {
633 let html = "<ul><li>A</li><li>B</li><li>C</li></ul>";
634 let result = parse_and_extract(html, "li", &ExtractMode::Text, 2).unwrap();
635 assert_eq!(result, "A\nB");
636 }
637
638 #[test]
639 fn extract_attr_href() {
640 let html = r#"<a href="https://example.com">Link</a>"#;
641 let result =
642 parse_and_extract(html, "a", &ExtractMode::Attr("href".to_owned()), 10).unwrap();
643 assert_eq!(result, "https://example.com");
644 }
645
646 #[test]
647 fn extract_inner_html() {
648 let html = "<div><span>inner</span></div>";
649 let result = parse_and_extract(html, "div", &ExtractMode::Html, 10).unwrap();
650 assert!(result.contains("<span>inner</span>"));
651 }
652
653 #[test]
654 fn no_matches_returns_message() {
655 let html = "<html><body><p>text</p></body></html>";
656 let result = parse_and_extract(html, "h1", &ExtractMode::Text, 10).unwrap();
657 assert!(result.starts_with("No results for selector:"));
658 }
659
660 #[test]
661 fn empty_text_skipped() {
662 let html = "<ul><li> </li><li>A</li></ul>";
663 let result = parse_and_extract(html, "li", &ExtractMode::Text, 10).unwrap();
664 assert_eq!(result, "A");
665 }
666
667 #[test]
668 fn invalid_selector_errors() {
669 let html = "<html><body></body></html>";
670 let result = parse_and_extract(html, "[[[invalid", &ExtractMode::Text, 10);
671 assert!(result.is_err());
672 }
673
674 #[test]
675 fn empty_html_returns_no_results() {
676 let result = parse_and_extract("", "h1", &ExtractMode::Text, 10).unwrap();
677 assert!(result.starts_with("No results for selector:"));
678 }
679
680 #[test]
681 fn nested_selector() {
682 let html = "<div><span>inner</span></div><span>outer</span>";
683 let result = parse_and_extract(html, "div > span", &ExtractMode::Text, 10).unwrap();
684 assert_eq!(result, "inner");
685 }
686
687 #[test]
688 fn attr_missing_returns_empty() {
689 let html = r#"<a>No href</a>"#;
690 let result =
691 parse_and_extract(html, "a", &ExtractMode::Attr("href".to_owned()), 10).unwrap();
692 assert!(result.starts_with("No results for selector:"));
693 }
694
695 #[test]
696 fn extract_html_mode() {
697 let html = "<div><b>bold</b> text</div>";
698 let result = parse_and_extract(html, "div", &ExtractMode::Html, 10).unwrap();
699 assert!(result.contains("<b>bold</b>"));
700 }
701
702 #[test]
703 fn limit_zero_returns_no_results() {
704 let html = "<ul><li>A</li><li>B</li></ul>";
705 let result = parse_and_extract(html, "li", &ExtractMode::Text, 0).unwrap();
706 assert!(result.starts_with("No results for selector:"));
707 }
708
709 #[test]
712 fn url_with_port_allowed() {
713 assert!(validate_url("https://example.com:8443/path").is_ok());
714 }
715
716 #[test]
717 fn link_local_ip_blocked() {
718 let err = validate_url("https://169.254.1.1/path").unwrap_err();
719 assert!(matches!(err, ToolError::Blocked { .. }));
720 }
721
722 #[test]
723 fn url_no_scheme_rejected() {
724 let err = validate_url("example.com/path").unwrap_err();
725 assert!(matches!(err, ToolError::Blocked { .. }));
726 }
727
728 #[test]
729 fn unspecified_ipv4_blocked() {
730 let err = validate_url("https://0.0.0.0/path").unwrap_err();
731 assert!(matches!(err, ToolError::Blocked { .. }));
732 }
733
734 #[test]
735 fn broadcast_ipv4_blocked() {
736 let err = validate_url("https://255.255.255.255/path").unwrap_err();
737 assert!(matches!(err, ToolError::Blocked { .. }));
738 }
739
740 #[test]
741 fn ipv6_link_local_blocked() {
742 let err = validate_url("https://[fe80::1]/path").unwrap_err();
743 assert!(matches!(err, ToolError::Blocked { .. }));
744 }
745
746 #[test]
747 fn ipv6_unique_local_blocked() {
748 let err = validate_url("https://[fd12::1]/path").unwrap_err();
749 assert!(matches!(err, ToolError::Blocked { .. }));
750 }
751
752 #[test]
753 fn ipv4_mapped_ipv6_loopback_blocked() {
754 let err = validate_url("https://[::ffff:127.0.0.1]/path").unwrap_err();
755 assert!(matches!(err, ToolError::Blocked { .. }));
756 }
757
758 #[test]
759 fn ipv4_mapped_ipv6_private_blocked() {
760 let err = validate_url("https://[::ffff:10.0.0.1]/path").unwrap_err();
761 assert!(matches!(err, ToolError::Blocked { .. }));
762 }
763
764 #[tokio::test]
767 async fn executor_no_blocks_returns_none() {
768 let config = ScrapeConfig::default();
769 let executor = WebScrapeExecutor::new(&config);
770 let result = executor.execute("plain text").await;
771 assert!(result.unwrap().is_none());
772 }
773
774 #[tokio::test]
775 async fn executor_invalid_json_errors() {
776 let config = ScrapeConfig::default();
777 let executor = WebScrapeExecutor::new(&config);
778 let response = "```scrape\nnot json\n```";
779 let result = executor.execute(response).await;
780 assert!(matches!(result, Err(ToolError::Execution(_))));
781 }
782
783 #[tokio::test]
784 async fn executor_blocked_url_errors() {
785 let config = ScrapeConfig::default();
786 let executor = WebScrapeExecutor::new(&config);
787 let response = "```scrape\n{\"url\":\"http://example.com\",\"select\":\"h1\"}\n```";
788 let result = executor.execute(response).await;
789 assert!(matches!(result, Err(ToolError::Blocked { .. })));
790 }
791
792 #[tokio::test]
793 async fn executor_private_ip_blocked() {
794 let config = ScrapeConfig::default();
795 let executor = WebScrapeExecutor::new(&config);
796 let response = "```scrape\n{\"url\":\"https://192.168.1.1/api\",\"select\":\"h1\"}\n```";
797 let result = executor.execute(response).await;
798 assert!(matches!(result, Err(ToolError::Blocked { .. })));
799 }
800
801 #[tokio::test]
802 async fn executor_unreachable_host_returns_error() {
803 let config = ScrapeConfig {
804 timeout: 1,
805 max_body_bytes: 1_048_576,
806 };
807 let executor = WebScrapeExecutor::new(&config);
808 let response = "```scrape\n{\"url\":\"https://192.0.2.1:1/page\",\"select\":\"h1\"}\n```";
809 let result = executor.execute(response).await;
810 assert!(matches!(result, Err(ToolError::Execution(_))));
811 }
812
813 #[tokio::test]
814 async fn executor_localhost_url_blocked() {
815 let config = ScrapeConfig::default();
816 let executor = WebScrapeExecutor::new(&config);
817 let response = "```scrape\n{\"url\":\"https://localhost:9999/api\",\"select\":\"h1\"}\n```";
818 let result = executor.execute(response).await;
819 assert!(matches!(result, Err(ToolError::Blocked { .. })));
820 }
821
822 #[tokio::test]
823 async fn executor_empty_text_returns_none() {
824 let config = ScrapeConfig::default();
825 let executor = WebScrapeExecutor::new(&config);
826 let result = executor.execute("").await;
827 assert!(result.unwrap().is_none());
828 }
829
830 #[tokio::test]
831 async fn executor_multiple_blocks_first_blocked() {
832 let config = ScrapeConfig::default();
833 let executor = WebScrapeExecutor::new(&config);
834 let response = "```scrape\n{\"url\":\"http://evil.com\",\"select\":\"h1\"}\n```\n\
835 ```scrape\n{\"url\":\"https://ok.com\",\"select\":\"h1\"}\n```";
836 let result = executor.execute(response).await;
837 assert!(result.is_err());
838 }
839
840 #[test]
841 fn validate_url_empty_string() {
842 let err = validate_url("").unwrap_err();
843 assert!(matches!(err, ToolError::Blocked { .. }));
844 }
845
846 #[test]
847 fn validate_url_javascript_scheme_blocked() {
848 let err = validate_url("javascript:alert(1)").unwrap_err();
849 assert!(matches!(err, ToolError::Blocked { .. }));
850 }
851
852 #[test]
853 fn validate_url_data_scheme_blocked() {
854 let err = validate_url("data:text/html,<h1>hi</h1>").unwrap_err();
855 assert!(matches!(err, ToolError::Blocked { .. }));
856 }
857
858 #[test]
859 fn is_private_host_public_domain_is_false() {
860 let host: url::Host<&str> = url::Host::Domain("example.com");
861 assert!(!is_private_host(&host));
862 }
863
864 #[test]
865 fn is_private_host_localhost_is_true() {
866 let host: url::Host<&str> = url::Host::Domain("localhost");
867 assert!(is_private_host(&host));
868 }
869
870 #[test]
871 fn is_private_host_ipv6_unspecified_is_true() {
872 let host = url::Host::Ipv6(std::net::Ipv6Addr::UNSPECIFIED);
873 assert!(is_private_host(&host));
874 }
875
876 #[test]
877 fn is_private_host_public_ipv6_is_false() {
878 let host = url::Host::Ipv6("2001:db8::1".parse().unwrap());
879 assert!(!is_private_host(&host));
880 }
881
882 async fn mock_server_executor() -> (WebScrapeExecutor, wiremock::MockServer) {
893 let server = wiremock::MockServer::start().await;
894 let executor = WebScrapeExecutor {
895 timeout: Duration::from_secs(5),
896 max_body_bytes: 1_048_576,
897 };
898 (executor, server)
899 }
900
901 fn server_host_and_addr(server: &wiremock::MockServer) -> (String, Vec<std::net::SocketAddr>) {
903 let uri = server.uri();
904 let url = Url::parse(&uri).unwrap();
905 let host = url.host_str().unwrap_or("127.0.0.1").to_owned();
906 let port = url.port().unwrap_or(80);
907 let addr: std::net::SocketAddr = format!("{host}:{port}").parse().unwrap();
908 (host, vec![addr])
909 }
910
911 async fn follow_redirects_raw(
915 executor: &WebScrapeExecutor,
916 start_url: &str,
917 host: &str,
918 addrs: &[std::net::SocketAddr],
919 ) -> Result<String, ToolError> {
920 const MAX_REDIRECTS: usize = 3;
921 let mut current_url = start_url.to_owned();
922 let mut current_host = host.to_owned();
923 let mut current_addrs = addrs.to_vec();
924
925 for hop in 0..=MAX_REDIRECTS {
926 let client = executor.build_client(¤t_host, ¤t_addrs);
927 let resp = client
928 .get(¤t_url)
929 .send()
930 .await
931 .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())))?;
932
933 let status = resp.status();
934
935 if status.is_redirection() {
936 if hop == MAX_REDIRECTS {
937 return Err(ToolError::Execution(std::io::Error::other(
938 "too many redirects",
939 )));
940 }
941
942 let location = resp
943 .headers()
944 .get(reqwest::header::LOCATION)
945 .and_then(|v| v.to_str().ok())
946 .ok_or_else(|| {
947 ToolError::Execution(std::io::Error::other("redirect with no Location"))
948 })?;
949
950 let base = Url::parse(¤t_url)
951 .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())))?;
952 let next_url = base
953 .join(location)
954 .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())))?;
955
956 current_url = next_url.to_string();
958 let _ = &mut current_host;
960 let _ = &mut current_addrs;
961 continue;
962 }
963
964 if !status.is_success() {
965 return Err(ToolError::Execution(std::io::Error::other(format!(
966 "HTTP {status}",
967 ))));
968 }
969
970 let bytes = resp
971 .bytes()
972 .await
973 .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())))?;
974
975 if bytes.len() > executor.max_body_bytes {
976 return Err(ToolError::Execution(std::io::Error::other(format!(
977 "response too large: {} bytes (max: {})",
978 bytes.len(),
979 executor.max_body_bytes,
980 ))));
981 }
982
983 return String::from_utf8(bytes.to_vec())
984 .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())));
985 }
986
987 Err(ToolError::Execution(std::io::Error::other(
988 "too many redirects",
989 )))
990 }
991
992 #[tokio::test]
993 async fn fetch_html_success_returns_body() {
994 use wiremock::matchers::{method, path};
995 use wiremock::{Mock, ResponseTemplate};
996
997 let (executor, server) = mock_server_executor().await;
998 Mock::given(method("GET"))
999 .and(path("/page"))
1000 .respond_with(ResponseTemplate::new(200).set_body_string("<h1>OK</h1>"))
1001 .mount(&server)
1002 .await;
1003
1004 let (host, addrs) = server_host_and_addr(&server);
1005 let url = format!("{}/page", server.uri());
1006 let result = executor.fetch_html(&url, &host, &addrs).await;
1007 assert!(result.is_ok(), "expected Ok, got: {result:?}");
1008 assert_eq!(result.unwrap(), "<h1>OK</h1>");
1009 }
1010
1011 #[tokio::test]
1012 async fn fetch_html_non_2xx_returns_error() {
1013 use wiremock::matchers::{method, path};
1014 use wiremock::{Mock, ResponseTemplate};
1015
1016 let (executor, server) = mock_server_executor().await;
1017 Mock::given(method("GET"))
1018 .and(path("/forbidden"))
1019 .respond_with(ResponseTemplate::new(403))
1020 .mount(&server)
1021 .await;
1022
1023 let (host, addrs) = server_host_and_addr(&server);
1024 let url = format!("{}/forbidden", server.uri());
1025 let result = executor.fetch_html(&url, &host, &addrs).await;
1026 assert!(result.is_err());
1027 let msg = result.unwrap_err().to_string();
1028 assert!(msg.contains("403"), "expected 403 in error: {msg}");
1029 }
1030
1031 #[tokio::test]
1032 async fn fetch_html_404_returns_error() {
1033 use wiremock::matchers::{method, path};
1034 use wiremock::{Mock, ResponseTemplate};
1035
1036 let (executor, server) = mock_server_executor().await;
1037 Mock::given(method("GET"))
1038 .and(path("/missing"))
1039 .respond_with(ResponseTemplate::new(404))
1040 .mount(&server)
1041 .await;
1042
1043 let (host, addrs) = server_host_and_addr(&server);
1044 let url = format!("{}/missing", server.uri());
1045 let result = executor.fetch_html(&url, &host, &addrs).await;
1046 assert!(result.is_err());
1047 let msg = result.unwrap_err().to_string();
1048 assert!(msg.contains("404"), "expected 404 in error: {msg}");
1049 }
1050
1051 #[tokio::test]
1052 async fn fetch_html_redirect_no_location_returns_error() {
1053 use wiremock::matchers::{method, path};
1054 use wiremock::{Mock, ResponseTemplate};
1055
1056 let (executor, server) = mock_server_executor().await;
1057 Mock::given(method("GET"))
1059 .and(path("/redirect-no-loc"))
1060 .respond_with(ResponseTemplate::new(302))
1061 .mount(&server)
1062 .await;
1063
1064 let (host, addrs) = server_host_and_addr(&server);
1065 let url = format!("{}/redirect-no-loc", server.uri());
1066 let result = executor.fetch_html(&url, &host, &addrs).await;
1067 assert!(result.is_err());
1068 let msg = result.unwrap_err().to_string();
1069 assert!(
1070 msg.contains("Location") || msg.contains("location"),
1071 "expected Location-related error: {msg}"
1072 );
1073 }
1074
1075 #[tokio::test]
1076 async fn fetch_html_single_redirect_followed() {
1077 use wiremock::matchers::{method, path};
1078 use wiremock::{Mock, ResponseTemplate};
1079
1080 let (executor, server) = mock_server_executor().await;
1081 let final_url = format!("{}/final", server.uri());
1082
1083 Mock::given(method("GET"))
1084 .and(path("/start"))
1085 .respond_with(ResponseTemplate::new(302).insert_header("location", final_url.as_str()))
1086 .mount(&server)
1087 .await;
1088
1089 Mock::given(method("GET"))
1090 .and(path("/final"))
1091 .respond_with(ResponseTemplate::new(200).set_body_string("<p>final</p>"))
1092 .mount(&server)
1093 .await;
1094
1095 let (host, addrs) = server_host_and_addr(&server);
1096 let url = format!("{}/start", server.uri());
1097 let result = follow_redirects_raw(&executor, &url, &host, &addrs).await;
1098 assert!(result.is_ok(), "single redirect should succeed: {result:?}");
1099 assert_eq!(result.unwrap(), "<p>final</p>");
1100 }
1101
1102 #[tokio::test]
1103 async fn fetch_html_three_redirects_allowed() {
1104 use wiremock::matchers::{method, path};
1105 use wiremock::{Mock, ResponseTemplate};
1106
1107 let (executor, server) = mock_server_executor().await;
1108 let hop2 = format!("{}/hop2", server.uri());
1109 let hop3 = format!("{}/hop3", server.uri());
1110 let final_dest = format!("{}/done", server.uri());
1111
1112 Mock::given(method("GET"))
1113 .and(path("/hop1"))
1114 .respond_with(ResponseTemplate::new(301).insert_header("location", hop2.as_str()))
1115 .mount(&server)
1116 .await;
1117 Mock::given(method("GET"))
1118 .and(path("/hop2"))
1119 .respond_with(ResponseTemplate::new(301).insert_header("location", hop3.as_str()))
1120 .mount(&server)
1121 .await;
1122 Mock::given(method("GET"))
1123 .and(path("/hop3"))
1124 .respond_with(ResponseTemplate::new(301).insert_header("location", final_dest.as_str()))
1125 .mount(&server)
1126 .await;
1127 Mock::given(method("GET"))
1128 .and(path("/done"))
1129 .respond_with(ResponseTemplate::new(200).set_body_string("<p>done</p>"))
1130 .mount(&server)
1131 .await;
1132
1133 let (host, addrs) = server_host_and_addr(&server);
1134 let url = format!("{}/hop1", server.uri());
1135 let result = follow_redirects_raw(&executor, &url, &host, &addrs).await;
1136 assert!(result.is_ok(), "3 redirects should succeed: {result:?}");
1137 assert_eq!(result.unwrap(), "<p>done</p>");
1138 }
1139
1140 #[tokio::test]
1141 async fn fetch_html_four_redirects_rejected() {
1142 use wiremock::matchers::{method, path};
1143 use wiremock::{Mock, ResponseTemplate};
1144
1145 let (executor, server) = mock_server_executor().await;
1146 let hop2 = format!("{}/r2", server.uri());
1147 let hop3 = format!("{}/r3", server.uri());
1148 let hop4 = format!("{}/r4", server.uri());
1149 let hop5 = format!("{}/r5", server.uri());
1150
1151 for (from, to) in [
1152 ("/r1", &hop2),
1153 ("/r2", &hop3),
1154 ("/r3", &hop4),
1155 ("/r4", &hop5),
1156 ] {
1157 Mock::given(method("GET"))
1158 .and(path(from))
1159 .respond_with(ResponseTemplate::new(301).insert_header("location", to.as_str()))
1160 .mount(&server)
1161 .await;
1162 }
1163
1164 let (host, addrs) = server_host_and_addr(&server);
1165 let url = format!("{}/r1", server.uri());
1166 let result = follow_redirects_raw(&executor, &url, &host, &addrs).await;
1167 assert!(result.is_err(), "4 redirects should be rejected");
1168 let msg = result.unwrap_err().to_string();
1169 assert!(
1170 msg.contains("redirect"),
1171 "expected redirect-related error: {msg}"
1172 );
1173 }
1174
1175 #[tokio::test]
1176 async fn fetch_html_body_too_large_returns_error() {
1177 use wiremock::matchers::{method, path};
1178 use wiremock::{Mock, ResponseTemplate};
1179
1180 let small_limit_executor = WebScrapeExecutor {
1181 timeout: Duration::from_secs(5),
1182 max_body_bytes: 10,
1183 };
1184 let server = wiremock::MockServer::start().await;
1185 Mock::given(method("GET"))
1186 .and(path("/big"))
1187 .respond_with(
1188 ResponseTemplate::new(200)
1189 .set_body_string("this body is definitely longer than ten bytes"),
1190 )
1191 .mount(&server)
1192 .await;
1193
1194 let (host, addrs) = server_host_and_addr(&server);
1195 let url = format!("{}/big", server.uri());
1196 let result = small_limit_executor.fetch_html(&url, &host, &addrs).await;
1197 assert!(result.is_err());
1198 let msg = result.unwrap_err().to_string();
1199 assert!(msg.contains("too large"), "expected too-large error: {msg}");
1200 }
1201
1202 #[test]
1203 fn extract_scrape_blocks_empty_block_content() {
1204 let text = "```scrape\n\n```";
1205 let blocks = extract_scrape_blocks(text);
1206 assert_eq!(blocks.len(), 1);
1207 assert!(blocks[0].is_empty());
1208 }
1209
1210 #[test]
1211 fn extract_scrape_blocks_whitespace_only() {
1212 let text = "```scrape\n \n```";
1213 let blocks = extract_scrape_blocks(text);
1214 assert_eq!(blocks.len(), 1);
1215 }
1216
1217 #[test]
1218 fn parse_and_extract_multiple_selectors() {
1219 let html = "<div><h1>Title</h1><p>Para</p></div>";
1220 let result = parse_and_extract(html, "h1, p", &ExtractMode::Text, 10).unwrap();
1221 assert!(result.contains("Title"));
1222 assert!(result.contains("Para"));
1223 }
1224
1225 #[test]
1226 fn webscrape_executor_new_with_custom_config() {
1227 let config = ScrapeConfig {
1228 timeout: 60,
1229 max_body_bytes: 512,
1230 };
1231 let executor = WebScrapeExecutor::new(&config);
1232 assert_eq!(executor.max_body_bytes, 512);
1233 }
1234
1235 #[test]
1236 fn webscrape_executor_debug() {
1237 let config = ScrapeConfig::default();
1238 let executor = WebScrapeExecutor::new(&config);
1239 let dbg = format!("{executor:?}");
1240 assert!(dbg.contains("WebScrapeExecutor"));
1241 }
1242
1243 #[test]
1244 fn extract_mode_attr_empty_name() {
1245 let mode = ExtractMode::parse("attr:");
1246 assert!(matches!(mode, ExtractMode::Attr(ref s) if s.is_empty()));
1247 }
1248
1249 #[test]
1250 fn default_extract_returns_text() {
1251 assert_eq!(default_extract(), "text");
1252 }
1253
1254 #[test]
1255 fn scrape_instruction_debug() {
1256 let json = r#"{"url":"https://example.com","select":"h1"}"#;
1257 let instr: ScrapeInstruction = serde_json::from_str(json).unwrap();
1258 let dbg = format!("{instr:?}");
1259 assert!(dbg.contains("ScrapeInstruction"));
1260 }
1261
1262 #[test]
1263 fn extract_mode_debug() {
1264 let mode = ExtractMode::Text;
1265 let dbg = format!("{mode:?}");
1266 assert!(dbg.contains("Text"));
1267 }
1268
1269 #[test]
1274 fn max_redirects_constant_is_three() {
1275 const MAX_REDIRECTS: usize = 3;
1279 assert_eq!(MAX_REDIRECTS, 3, "fetch_html allows exactly 3 redirects");
1280 }
1281
1282 #[test]
1285 fn redirect_no_location_error_message() {
1286 let err = std::io::Error::other("redirect with no Location");
1287 assert!(err.to_string().contains("redirect with no Location"));
1288 }
1289
1290 #[test]
1292 fn too_many_redirects_error_message() {
1293 let err = std::io::Error::other("too many redirects");
1294 assert!(err.to_string().contains("too many redirects"));
1295 }
1296
1297 #[test]
1299 fn non_2xx_status_error_format() {
1300 let status = reqwest::StatusCode::FORBIDDEN;
1301 let msg = format!("HTTP {status}");
1302 assert!(msg.contains("403"));
1303 }
1304
1305 #[test]
1307 fn not_found_status_error_format() {
1308 let status = reqwest::StatusCode::NOT_FOUND;
1309 let msg = format!("HTTP {status}");
1310 assert!(msg.contains("404"));
1311 }
1312
1313 #[test]
1315 fn relative_redirect_same_host_path() {
1316 let base = Url::parse("https://example.com/current").unwrap();
1317 let resolved = base.join("/other").unwrap();
1318 assert_eq!(resolved.as_str(), "https://example.com/other");
1319 }
1320
1321 #[test]
1323 fn relative_redirect_relative_path() {
1324 let base = Url::parse("https://example.com/a/b").unwrap();
1325 let resolved = base.join("c").unwrap();
1326 assert_eq!(resolved.as_str(), "https://example.com/a/c");
1327 }
1328
1329 #[test]
1331 fn absolute_redirect_overrides_base() {
1332 let base = Url::parse("https://example.com/page").unwrap();
1333 let resolved = base.join("https://other.com/target").unwrap();
1334 assert_eq!(resolved.as_str(), "https://other.com/target");
1335 }
1336
1337 #[test]
1339 fn redirect_http_downgrade_rejected() {
1340 let location = "http://example.com/page";
1341 let base = Url::parse("https://example.com/start").unwrap();
1342 let next = base.join(location).unwrap();
1343 let err = validate_url(next.as_str()).unwrap_err();
1344 assert!(matches!(err, ToolError::Blocked { .. }));
1345 }
1346
1347 #[test]
1349 fn redirect_location_private_ip_blocked() {
1350 let location = "https://192.168.100.1/admin";
1351 let base = Url::parse("https://example.com/start").unwrap();
1352 let next = base.join(location).unwrap();
1353 let err = validate_url(next.as_str()).unwrap_err();
1354 assert!(matches!(err, ToolError::Blocked { .. }));
1355 let cmd = match err {
1356 ToolError::Blocked { command } => command,
1357 _ => panic!("expected Blocked"),
1358 };
1359 assert!(
1360 cmd.contains("private") || cmd.contains("scheme"),
1361 "error message should describe the block reason: {cmd}"
1362 );
1363 }
1364
1365 #[test]
1367 fn redirect_location_internal_domain_blocked() {
1368 let location = "https://metadata.internal/latest/meta-data/";
1369 let base = Url::parse("https://example.com/start").unwrap();
1370 let next = base.join(location).unwrap();
1371 let err = validate_url(next.as_str()).unwrap_err();
1372 assert!(matches!(err, ToolError::Blocked { .. }));
1373 }
1374
1375 #[test]
1377 fn redirect_chain_three_hops_all_public() {
1378 let hops = [
1379 "https://redirect1.example.com/hop1",
1380 "https://redirect2.example.com/hop2",
1381 "https://destination.example.com/final",
1382 ];
1383 for hop in hops {
1384 assert!(validate_url(hop).is_ok(), "expected ok for {hop}");
1385 }
1386 }
1387
1388 #[test]
1393 fn redirect_to_private_ip_rejected_by_validate_url() {
1394 let private_targets = [
1396 "https://127.0.0.1/secret",
1397 "https://10.0.0.1/internal",
1398 "https://192.168.1.1/admin",
1399 "https://172.16.0.1/data",
1400 "https://[::1]/path",
1401 "https://[fe80::1]/path",
1402 "https://localhost/path",
1403 "https://service.internal/api",
1404 ];
1405 for target in private_targets {
1406 let result = validate_url(target);
1407 assert!(result.is_err(), "expected error for {target}");
1408 assert!(
1409 matches!(result.unwrap_err(), ToolError::Blocked { .. }),
1410 "expected Blocked for {target}"
1411 );
1412 }
1413 }
1414
1415 #[test]
1417 fn redirect_relative_url_resolves_correctly() {
1418 let base = Url::parse("https://example.com/page").unwrap();
1419 let relative = "/other";
1420 let resolved = base.join(relative).unwrap();
1421 assert_eq!(resolved.as_str(), "https://example.com/other");
1422 }
1423
1424 #[test]
1426 fn redirect_to_http_rejected() {
1427 let err = validate_url("http://example.com/page").unwrap_err();
1428 assert!(matches!(err, ToolError::Blocked { .. }));
1429 }
1430
1431 #[test]
1432 fn ipv4_mapped_ipv6_link_local_blocked() {
1433 let err = validate_url("https://[::ffff:169.254.0.1]/path").unwrap_err();
1434 assert!(matches!(err, ToolError::Blocked { .. }));
1435 }
1436
1437 #[test]
1438 fn ipv4_mapped_ipv6_public_allowed() {
1439 assert!(validate_url("https://[::ffff:93.184.216.34]/path").is_ok());
1440 }
1441
1442 #[tokio::test]
1445 async fn fetch_http_scheme_blocked() {
1446 let config = ScrapeConfig::default();
1447 let executor = WebScrapeExecutor::new(&config);
1448 let call = crate::executor::ToolCall {
1449 tool_id: "fetch".to_owned(),
1450 params: {
1451 let mut m = serde_json::Map::new();
1452 m.insert("url".to_owned(), serde_json::json!("http://example.com"));
1453 m
1454 },
1455 };
1456 let result = executor.execute_tool_call(&call).await;
1457 assert!(matches!(result, Err(ToolError::Blocked { .. })));
1458 }
1459
1460 #[tokio::test]
1461 async fn fetch_private_ip_blocked() {
1462 let config = ScrapeConfig::default();
1463 let executor = WebScrapeExecutor::new(&config);
1464 let call = crate::executor::ToolCall {
1465 tool_id: "fetch".to_owned(),
1466 params: {
1467 let mut m = serde_json::Map::new();
1468 m.insert(
1469 "url".to_owned(),
1470 serde_json::json!("https://192.168.1.1/secret"),
1471 );
1472 m
1473 },
1474 };
1475 let result = executor.execute_tool_call(&call).await;
1476 assert!(matches!(result, Err(ToolError::Blocked { .. })));
1477 }
1478
1479 #[tokio::test]
1480 async fn fetch_localhost_blocked() {
1481 let config = ScrapeConfig::default();
1482 let executor = WebScrapeExecutor::new(&config);
1483 let call = crate::executor::ToolCall {
1484 tool_id: "fetch".to_owned(),
1485 params: {
1486 let mut m = serde_json::Map::new();
1487 m.insert(
1488 "url".to_owned(),
1489 serde_json::json!("https://localhost/page"),
1490 );
1491 m
1492 },
1493 };
1494 let result = executor.execute_tool_call(&call).await;
1495 assert!(matches!(result, Err(ToolError::Blocked { .. })));
1496 }
1497
1498 #[tokio::test]
1499 async fn fetch_unknown_tool_returns_none() {
1500 let config = ScrapeConfig::default();
1501 let executor = WebScrapeExecutor::new(&config);
1502 let call = crate::executor::ToolCall {
1503 tool_id: "unknown_tool".to_owned(),
1504 params: serde_json::Map::new(),
1505 };
1506 let result = executor.execute_tool_call(&call).await;
1507 assert!(result.unwrap().is_none());
1508 }
1509
1510 #[tokio::test]
1511 async fn fetch_returns_body_via_mock() {
1512 use wiremock::matchers::{method, path};
1513 use wiremock::{Mock, ResponseTemplate};
1514
1515 let (executor, server) = mock_server_executor().await;
1516 Mock::given(method("GET"))
1517 .and(path("/content"))
1518 .respond_with(ResponseTemplate::new(200).set_body_string("plain text content"))
1519 .mount(&server)
1520 .await;
1521
1522 let (host, addrs) = server_host_and_addr(&server);
1523 let url = format!("{}/content", server.uri());
1524 let result = executor.fetch_html(&url, &host, &addrs).await;
1525 assert!(result.is_ok());
1526 assert_eq!(result.unwrap(), "plain text content");
1527 }
1528
1529 #[test]
1530 fn tool_definitions_returns_web_scrape_and_fetch() {
1531 let config = ScrapeConfig::default();
1532 let executor = WebScrapeExecutor::new(&config);
1533 let defs = executor.tool_definitions();
1534 assert_eq!(defs.len(), 2);
1535 assert_eq!(defs[0].id, "web_scrape");
1536 assert_eq!(
1537 defs[0].invocation,
1538 crate::registry::InvocationHint::FencedBlock("scrape")
1539 );
1540 assert_eq!(defs[1].id, "fetch");
1541 assert_eq!(
1542 defs[1].invocation,
1543 crate::registry::InvocationHint::ToolCall
1544 );
1545 }
1546
1547 #[test]
1548 fn tool_definitions_schema_has_all_params() {
1549 let config = ScrapeConfig::default();
1550 let executor = WebScrapeExecutor::new(&config);
1551 let defs = executor.tool_definitions();
1552 let obj = defs[0].schema.as_object().unwrap();
1553 let props = obj["properties"].as_object().unwrap();
1554 assert!(props.contains_key("url"));
1555 assert!(props.contains_key("select"));
1556 assert!(props.contains_key("extract"));
1557 assert!(props.contains_key("limit"));
1558 let req = obj["required"].as_array().unwrap();
1559 assert!(req.iter().any(|v| v.as_str() == Some("url")));
1560 assert!(req.iter().any(|v| v.as_str() == Some("select")));
1561 assert!(!req.iter().any(|v| v.as_str() == Some("extract")));
1562 }
1563
1564 #[test]
1567 fn subdomain_localhost_blocked() {
1568 let host: url::Host<&str> = url::Host::Domain("foo.localhost");
1569 assert!(is_private_host(&host));
1570 }
1571
1572 #[test]
1573 fn internal_tld_blocked() {
1574 let host: url::Host<&str> = url::Host::Domain("service.internal");
1575 assert!(is_private_host(&host));
1576 }
1577
1578 #[test]
1579 fn local_tld_blocked() {
1580 let host: url::Host<&str> = url::Host::Domain("printer.local");
1581 assert!(is_private_host(&host));
1582 }
1583
1584 #[test]
1585 fn public_domain_not_blocked() {
1586 let host: url::Host<&str> = url::Host::Domain("example.com");
1587 assert!(!is_private_host(&host));
1588 }
1589
1590 #[tokio::test]
1593 async fn resolve_loopback_rejected() {
1594 let url = url::Url::parse("https://127.0.0.1/path").unwrap();
1596 let result = resolve_and_validate(&url).await;
1598 assert!(
1599 result.is_err(),
1600 "loopback IP must be rejected by resolve_and_validate"
1601 );
1602 let err = result.unwrap_err();
1603 assert!(matches!(err, crate::executor::ToolError::Blocked { .. }));
1604 }
1605
1606 #[tokio::test]
1607 async fn resolve_private_10_rejected() {
1608 let url = url::Url::parse("https://10.0.0.1/path").unwrap();
1609 let result = resolve_and_validate(&url).await;
1610 assert!(result.is_err());
1611 assert!(matches!(
1612 result.unwrap_err(),
1613 crate::executor::ToolError::Blocked { .. }
1614 ));
1615 }
1616
1617 #[tokio::test]
1618 async fn resolve_private_192_rejected() {
1619 let url = url::Url::parse("https://192.168.1.1/path").unwrap();
1620 let result = resolve_and_validate(&url).await;
1621 assert!(result.is_err());
1622 assert!(matches!(
1623 result.unwrap_err(),
1624 crate::executor::ToolError::Blocked { .. }
1625 ));
1626 }
1627
1628 #[tokio::test]
1629 async fn resolve_ipv6_loopback_rejected() {
1630 let url = url::Url::parse("https://[::1]/path").unwrap();
1631 let result = resolve_and_validate(&url).await;
1632 assert!(result.is_err());
1633 assert!(matches!(
1634 result.unwrap_err(),
1635 crate::executor::ToolError::Blocked { .. }
1636 ));
1637 }
1638
1639 #[tokio::test]
1640 async fn resolve_no_host_returns_ok() {
1641 let url = url::Url::parse("https://example.com/path").unwrap();
1643 let url_no_host = url::Url::parse("data:text/plain,hello").unwrap();
1645 let result = resolve_and_validate(&url_no_host).await;
1647 assert!(result.is_ok());
1648 let (host, addrs) = result.unwrap();
1649 assert!(host.is_empty());
1650 assert!(addrs.is_empty());
1651 drop(url);
1652 drop(url_no_host);
1653 }
1654
1655 #[tokio::test]
1657 async fn fetch_execute_tool_call_end_to_end() {
1658 use wiremock::matchers::{method, path};
1659 use wiremock::{Mock, ResponseTemplate};
1660
1661 let (executor, server) = mock_server_executor().await;
1662 Mock::given(method("GET"))
1663 .and(path("/e2e"))
1664 .respond_with(ResponseTemplate::new(200).set_body_string("<h1>end-to-end</h1>"))
1665 .mount(&server)
1666 .await;
1667
1668 let (host, addrs) = server_host_and_addr(&server);
1669 let result = executor
1671 .fetch_html(&format!("{}/e2e", server.uri()), &host, &addrs)
1672 .await;
1673 assert!(result.is_ok());
1674 assert!(result.unwrap().contains("end-to-end"));
1675 }
1676}