1use std::net::{IpAddr, SocketAddr};
5use std::sync::Arc;
6use std::time::{Duration, Instant};
7
8use schemars::JsonSchema;
9use serde::Deserialize;
10use url::Url;
11
12use crate::audit::{AuditEntry, AuditLogger, AuditResult, chrono_now};
13use crate::config::ScrapeConfig;
14use crate::executor::{ToolCall, ToolError, ToolExecutor, ToolOutput, deserialize_params};
15
16#[derive(Debug, Deserialize, JsonSchema)]
17struct FetchParams {
18 url: String,
20}
21
22#[derive(Debug, Deserialize, JsonSchema)]
23struct ScrapeInstruction {
24 url: String,
26 select: String,
28 #[serde(default = "default_extract")]
30 extract: String,
31 limit: Option<usize>,
33}
34
35fn default_extract() -> String {
36 "text".into()
37}
38
39#[derive(Debug)]
40enum ExtractMode {
41 Text,
42 Html,
43 Attr(String),
44}
45
46impl ExtractMode {
47 fn parse(s: &str) -> Self {
48 match s {
49 "text" => Self::Text,
50 "html" => Self::Html,
51 attr if attr.starts_with("attr:") => {
52 Self::Attr(attr.strip_prefix("attr:").unwrap_or(attr).to_owned())
53 }
54 _ => Self::Text,
55 }
56 }
57}
58
59#[derive(Debug)]
64pub struct WebScrapeExecutor {
65 timeout: Duration,
66 max_body_bytes: usize,
67 audit_logger: Option<Arc<AuditLogger>>,
68}
69
70impl WebScrapeExecutor {
71 #[must_use]
72 pub fn new(config: &ScrapeConfig) -> Self {
73 Self {
74 timeout: Duration::from_secs(config.timeout),
75 max_body_bytes: config.max_body_bytes,
76 audit_logger: None,
77 }
78 }
79
80 #[must_use]
81 pub fn with_audit(mut self, logger: Arc<AuditLogger>) -> Self {
82 self.audit_logger = Some(logger);
83 self
84 }
85
86 fn build_client(&self, host: &str, addrs: &[SocketAddr]) -> reqwest::Client {
87 let mut builder = reqwest::Client::builder()
88 .timeout(self.timeout)
89 .redirect(reqwest::redirect::Policy::none());
90 builder = builder.resolve_to_addrs(host, addrs);
91 builder.build().unwrap_or_default()
92 }
93}
94
95impl ToolExecutor for WebScrapeExecutor {
96 fn tool_definitions(&self) -> Vec<crate::registry::ToolDef> {
97 use crate::registry::{InvocationHint, ToolDef};
98 vec![
99 ToolDef {
100 id: "web_scrape".into(),
101 description: "Extract structured data from a web page using CSS selectors.\n\nParameters: url (string, required) - HTTPS URL; select (string, required) - CSS selector; extract (string, optional) - \"text\", \"html\", or \"attr:<name>\"; limit (integer, optional) - max results\nReturns: extracted text/HTML/attribute values, one per line\nErrors: InvalidParams if URL is not HTTPS or selector is empty; Timeout after configured seconds; connection/DNS failures\nExample: {\"url\": \"https://example.com\", \"select\": \"h1\", \"extract\": \"text\"}".into(),
102 schema: schemars::schema_for!(ScrapeInstruction),
103 invocation: InvocationHint::FencedBlock("scrape"),
104 },
105 ToolDef {
106 id: "fetch".into(),
107 description: "Fetch a URL and return the response body as plain text.\n\nParameters: url (string, required) - HTTPS URL to fetch\nReturns: response body as UTF-8 text, truncated if exceeding max body size\nErrors: InvalidParams if URL is not HTTPS; Timeout; SSRF-blocked private IPs; connection failures\nExample: {\"url\": \"https://api.example.com/data.json\"}".into(),
108 schema: schemars::schema_for!(FetchParams),
109 invocation: InvocationHint::ToolCall,
110 },
111 ]
112 }
113
114 async fn execute(&self, response: &str) -> Result<Option<ToolOutput>, ToolError> {
115 let blocks = extract_scrape_blocks(response);
116 if blocks.is_empty() {
117 return Ok(None);
118 }
119
120 let mut outputs = Vec::with_capacity(blocks.len());
121 #[allow(clippy::cast_possible_truncation)]
122 let blocks_executed = blocks.len() as u32;
123
124 for block in &blocks {
125 let instruction: ScrapeInstruction = serde_json::from_str(block).map_err(|e| {
126 ToolError::Execution(std::io::Error::new(
127 std::io::ErrorKind::InvalidData,
128 e.to_string(),
129 ))
130 })?;
131 let start = Instant::now();
132 let scrape_result = self.scrape_instruction(&instruction).await;
133 #[allow(clippy::cast_possible_truncation)]
134 let duration_ms = start.elapsed().as_millis() as u64;
135 match scrape_result {
136 Ok(output) => {
137 self.log_audit(
138 "web_scrape",
139 &instruction.url,
140 AuditResult::Success,
141 duration_ms,
142 )
143 .await;
144 outputs.push(output);
145 }
146 Err(e) => {
147 let audit_result = tool_error_to_audit_result(&e);
148 self.log_audit("web_scrape", &instruction.url, audit_result, duration_ms)
149 .await;
150 return Err(e);
151 }
152 }
153 }
154
155 Ok(Some(ToolOutput {
156 tool_name: "web-scrape".to_owned(),
157 summary: outputs.join("\n\n"),
158 blocks_executed,
159 filter_stats: None,
160 diff: None,
161 streamed: false,
162 terminal_id: None,
163 locations: None,
164 raw_response: None,
165 }))
166 }
167
168 async fn execute_tool_call(&self, call: &ToolCall) -> Result<Option<ToolOutput>, ToolError> {
169 match call.tool_id.as_str() {
170 "web_scrape" => {
171 let instruction: ScrapeInstruction = deserialize_params(&call.params)?;
172 let start = Instant::now();
173 let result = self.scrape_instruction(&instruction).await;
174 #[allow(clippy::cast_possible_truncation)]
175 let duration_ms = start.elapsed().as_millis() as u64;
176 match result {
177 Ok(output) => {
178 self.log_audit(
179 "web_scrape",
180 &instruction.url,
181 AuditResult::Success,
182 duration_ms,
183 )
184 .await;
185 Ok(Some(ToolOutput {
186 tool_name: "web-scrape".to_owned(),
187 summary: output,
188 blocks_executed: 1,
189 filter_stats: None,
190 diff: None,
191 streamed: false,
192 terminal_id: None,
193 locations: None,
194 raw_response: None,
195 }))
196 }
197 Err(e) => {
198 let audit_result = tool_error_to_audit_result(&e);
199 self.log_audit("web_scrape", &instruction.url, audit_result, duration_ms)
200 .await;
201 Err(e)
202 }
203 }
204 }
205 "fetch" => {
206 let p: FetchParams = deserialize_params(&call.params)?;
207 let start = Instant::now();
208 let result = self.handle_fetch(&p).await;
209 #[allow(clippy::cast_possible_truncation)]
210 let duration_ms = start.elapsed().as_millis() as u64;
211 match result {
212 Ok(output) => {
213 self.log_audit("fetch", &p.url, AuditResult::Success, duration_ms)
214 .await;
215 Ok(Some(ToolOutput {
216 tool_name: "fetch".to_owned(),
217 summary: output,
218 blocks_executed: 1,
219 filter_stats: None,
220 diff: None,
221 streamed: false,
222 terminal_id: None,
223 locations: None,
224 raw_response: None,
225 }))
226 }
227 Err(e) => {
228 let audit_result = tool_error_to_audit_result(&e);
229 self.log_audit("fetch", &p.url, audit_result, duration_ms)
230 .await;
231 Err(e)
232 }
233 }
234 }
235 _ => Ok(None),
236 }
237 }
238}
239
240fn tool_error_to_audit_result(e: &ToolError) -> AuditResult {
241 match e {
242 ToolError::Blocked { command } => AuditResult::Blocked {
243 reason: command.clone(),
244 },
245 ToolError::Timeout { .. } => AuditResult::Timeout,
246 _ => AuditResult::Error {
247 message: e.to_string(),
248 },
249 }
250}
251
252impl WebScrapeExecutor {
253 async fn log_audit(&self, tool: &str, command: &str, result: AuditResult, duration_ms: u64) {
254 if let Some(ref logger) = self.audit_logger {
255 let entry = AuditEntry {
256 timestamp: chrono_now(),
257 tool: tool.into(),
258 command: command.into(),
259 result,
260 duration_ms,
261 };
262 logger.log(&entry).await;
263 }
264 }
265
266 async fn handle_fetch(&self, params: &FetchParams) -> Result<String, ToolError> {
267 let parsed = validate_url(¶ms.url)?;
268 let (host, addrs) = resolve_and_validate(&parsed).await?;
269 self.fetch_html(¶ms.url, &host, &addrs).await
270 }
271
272 async fn scrape_instruction(
273 &self,
274 instruction: &ScrapeInstruction,
275 ) -> Result<String, ToolError> {
276 let parsed = validate_url(&instruction.url)?;
277 let (host, addrs) = resolve_and_validate(&parsed).await?;
278 let html = self.fetch_html(&instruction.url, &host, &addrs).await?;
279 let selector = instruction.select.clone();
280 let extract = ExtractMode::parse(&instruction.extract);
281 let limit = instruction.limit.unwrap_or(10);
282 tokio::task::spawn_blocking(move || parse_and_extract(&html, &selector, &extract, limit))
283 .await
284 .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())))?
285 }
286
287 async fn fetch_html(
297 &self,
298 url: &str,
299 host: &str,
300 addrs: &[SocketAddr],
301 ) -> Result<String, ToolError> {
302 const MAX_REDIRECTS: usize = 3;
303
304 let mut current_url = url.to_owned();
305 let mut current_host = host.to_owned();
306 let mut current_addrs = addrs.to_vec();
307
308 for hop in 0..=MAX_REDIRECTS {
309 let client = self.build_client(¤t_host, ¤t_addrs);
311 let resp = client
312 .get(¤t_url)
313 .send()
314 .await
315 .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())))?;
316
317 let status = resp.status();
318
319 if status.is_redirection() {
320 if hop == MAX_REDIRECTS {
321 return Err(ToolError::Execution(std::io::Error::other(
322 "too many redirects",
323 )));
324 }
325
326 let location = resp
327 .headers()
328 .get(reqwest::header::LOCATION)
329 .and_then(|v| v.to_str().ok())
330 .ok_or_else(|| {
331 ToolError::Execution(std::io::Error::other("redirect with no Location"))
332 })?;
333
334 let base = Url::parse(¤t_url)
336 .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())))?;
337 let next_url = base
338 .join(location)
339 .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())))?;
340
341 let validated = validate_url(next_url.as_str())?;
342 let (next_host, next_addrs) = resolve_and_validate(&validated).await?;
343
344 current_url = next_url.to_string();
345 current_host = next_host;
346 current_addrs = next_addrs;
347 continue;
348 }
349
350 if !status.is_success() {
351 return Err(ToolError::Execution(std::io::Error::other(format!(
352 "HTTP {status}",
353 ))));
354 }
355
356 let bytes = resp
357 .bytes()
358 .await
359 .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())))?;
360
361 if bytes.len() > self.max_body_bytes {
362 return Err(ToolError::Execution(std::io::Error::other(format!(
363 "response too large: {} bytes (max: {})",
364 bytes.len(),
365 self.max_body_bytes,
366 ))));
367 }
368
369 return String::from_utf8(bytes.to_vec())
370 .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())));
371 }
372
373 Err(ToolError::Execution(std::io::Error::other(
374 "too many redirects",
375 )))
376 }
377}
378
379fn extract_scrape_blocks(text: &str) -> Vec<&str> {
380 crate::executor::extract_fenced_blocks(text, "scrape")
381}
382
383fn validate_url(raw: &str) -> Result<Url, ToolError> {
384 let parsed = Url::parse(raw).map_err(|_| ToolError::Blocked {
385 command: format!("invalid URL: {raw}"),
386 })?;
387
388 if parsed.scheme() != "https" {
389 return Err(ToolError::Blocked {
390 command: format!("scheme not allowed: {}", parsed.scheme()),
391 });
392 }
393
394 if let Some(host) = parsed.host()
395 && is_private_host(&host)
396 {
397 return Err(ToolError::Blocked {
398 command: format!(
399 "private/local host blocked: {}",
400 parsed.host_str().unwrap_or("")
401 ),
402 });
403 }
404
405 Ok(parsed)
406}
407
408pub(crate) fn is_private_ip(ip: IpAddr) -> bool {
409 match ip {
410 IpAddr::V4(v4) => {
411 v4.is_loopback()
412 || v4.is_private()
413 || v4.is_link_local()
414 || v4.is_unspecified()
415 || v4.is_broadcast()
416 }
417 IpAddr::V6(v6) => {
418 if v6.is_loopback() || v6.is_unspecified() {
419 return true;
420 }
421 let seg = v6.segments();
422 if seg[0] & 0xffc0 == 0xfe80 {
424 return true;
425 }
426 if seg[0] & 0xfe00 == 0xfc00 {
428 return true;
429 }
430 if seg[0..6] == [0, 0, 0, 0, 0, 0xffff] {
432 let v4 = v6
433 .to_ipv4_mapped()
434 .unwrap_or(std::net::Ipv4Addr::UNSPECIFIED);
435 return v4.is_loopback()
436 || v4.is_private()
437 || v4.is_link_local()
438 || v4.is_unspecified()
439 || v4.is_broadcast();
440 }
441 false
442 }
443 }
444}
445
446fn is_private_host(host: &url::Host<&str>) -> bool {
447 match host {
448 url::Host::Domain(d) => {
449 #[allow(clippy::case_sensitive_file_extension_comparisons)]
452 {
453 *d == "localhost"
454 || d.ends_with(".localhost")
455 || d.ends_with(".internal")
456 || d.ends_with(".local")
457 }
458 }
459 url::Host::Ipv4(v4) => is_private_ip(IpAddr::V4(*v4)),
460 url::Host::Ipv6(v6) => is_private_ip(IpAddr::V6(*v6)),
461 }
462}
463
464async fn resolve_and_validate(url: &Url) -> Result<(String, Vec<SocketAddr>), ToolError> {
470 let Some(host) = url.host_str() else {
471 return Ok((String::new(), vec![]));
472 };
473 let port = url.port_or_known_default().unwrap_or(443);
474 let addrs: Vec<SocketAddr> = tokio::net::lookup_host(format!("{host}:{port}"))
475 .await
476 .map_err(|e| ToolError::Blocked {
477 command: format!("DNS resolution failed: {e}"),
478 })?
479 .collect();
480 for addr in &addrs {
481 if is_private_ip(addr.ip()) {
482 return Err(ToolError::Blocked {
483 command: format!("SSRF protection: private IP {} for host {host}", addr.ip()),
484 });
485 }
486 }
487 Ok((host.to_owned(), addrs))
488}
489
490fn parse_and_extract(
491 html: &str,
492 selector: &str,
493 extract: &ExtractMode,
494 limit: usize,
495) -> Result<String, ToolError> {
496 let soup = scrape_core::Soup::parse(html);
497
498 let tags = soup.find_all(selector).map_err(|e| {
499 ToolError::Execution(std::io::Error::new(
500 std::io::ErrorKind::InvalidData,
501 format!("invalid selector: {e}"),
502 ))
503 })?;
504
505 let mut results = Vec::new();
506
507 for tag in tags.into_iter().take(limit) {
508 let value = match extract {
509 ExtractMode::Text => tag.text(),
510 ExtractMode::Html => tag.inner_html(),
511 ExtractMode::Attr(name) => tag.get(name).unwrap_or_default().to_owned(),
512 };
513 if !value.trim().is_empty() {
514 results.push(value.trim().to_owned());
515 }
516 }
517
518 if results.is_empty() {
519 Ok(format!("No results for selector: {selector}"))
520 } else {
521 Ok(results.join("\n"))
522 }
523}
524
525#[cfg(test)]
526mod tests {
527 use super::*;
528
529 #[test]
532 fn extract_single_block() {
533 let text =
534 "Here:\n```scrape\n{\"url\":\"https://example.com\",\"select\":\"h1\"}\n```\nDone.";
535 let blocks = extract_scrape_blocks(text);
536 assert_eq!(blocks.len(), 1);
537 assert!(blocks[0].contains("example.com"));
538 }
539
540 #[test]
541 fn extract_multiple_blocks() {
542 let text = "```scrape\n{\"url\":\"https://a.com\",\"select\":\"h1\"}\n```\ntext\n```scrape\n{\"url\":\"https://b.com\",\"select\":\"p\"}\n```";
543 let blocks = extract_scrape_blocks(text);
544 assert_eq!(blocks.len(), 2);
545 }
546
547 #[test]
548 fn no_blocks_returns_empty() {
549 let blocks = extract_scrape_blocks("plain text, no code blocks");
550 assert!(blocks.is_empty());
551 }
552
553 #[test]
554 fn unclosed_block_ignored() {
555 let blocks = extract_scrape_blocks("```scrape\n{\"url\":\"https://x.com\"}");
556 assert!(blocks.is_empty());
557 }
558
559 #[test]
560 fn non_scrape_block_ignored() {
561 let text =
562 "```bash\necho hi\n```\n```scrape\n{\"url\":\"https://x.com\",\"select\":\"h1\"}\n```";
563 let blocks = extract_scrape_blocks(text);
564 assert_eq!(blocks.len(), 1);
565 assert!(blocks[0].contains("x.com"));
566 }
567
568 #[test]
569 fn multiline_json_block() {
570 let text =
571 "```scrape\n{\n \"url\": \"https://example.com\",\n \"select\": \"h1\"\n}\n```";
572 let blocks = extract_scrape_blocks(text);
573 assert_eq!(blocks.len(), 1);
574 let instr: ScrapeInstruction = serde_json::from_str(blocks[0]).unwrap();
575 assert_eq!(instr.url, "https://example.com");
576 }
577
578 #[test]
581 fn parse_valid_instruction() {
582 let json = r#"{"url":"https://example.com","select":"h1","extract":"text","limit":5}"#;
583 let instr: ScrapeInstruction = serde_json::from_str(json).unwrap();
584 assert_eq!(instr.url, "https://example.com");
585 assert_eq!(instr.select, "h1");
586 assert_eq!(instr.extract, "text");
587 assert_eq!(instr.limit, Some(5));
588 }
589
590 #[test]
591 fn parse_minimal_instruction() {
592 let json = r#"{"url":"https://example.com","select":"p"}"#;
593 let instr: ScrapeInstruction = serde_json::from_str(json).unwrap();
594 assert_eq!(instr.extract, "text");
595 assert!(instr.limit.is_none());
596 }
597
598 #[test]
599 fn parse_attr_extract() {
600 let json = r#"{"url":"https://example.com","select":"a","extract":"attr:href"}"#;
601 let instr: ScrapeInstruction = serde_json::from_str(json).unwrap();
602 assert_eq!(instr.extract, "attr:href");
603 }
604
605 #[test]
606 fn parse_invalid_json_errors() {
607 let result = serde_json::from_str::<ScrapeInstruction>("not json");
608 assert!(result.is_err());
609 }
610
611 #[test]
614 fn extract_mode_text() {
615 assert!(matches!(ExtractMode::parse("text"), ExtractMode::Text));
616 }
617
618 #[test]
619 fn extract_mode_html() {
620 assert!(matches!(ExtractMode::parse("html"), ExtractMode::Html));
621 }
622
623 #[test]
624 fn extract_mode_attr() {
625 let mode = ExtractMode::parse("attr:href");
626 assert!(matches!(mode, ExtractMode::Attr(ref s) if s == "href"));
627 }
628
629 #[test]
630 fn extract_mode_unknown_defaults_to_text() {
631 assert!(matches!(ExtractMode::parse("unknown"), ExtractMode::Text));
632 }
633
634 #[test]
637 fn valid_https_url() {
638 assert!(validate_url("https://example.com").is_ok());
639 }
640
641 #[test]
642 fn http_rejected() {
643 let err = validate_url("http://example.com").unwrap_err();
644 assert!(matches!(err, ToolError::Blocked { .. }));
645 }
646
647 #[test]
648 fn ftp_rejected() {
649 let err = validate_url("ftp://files.example.com").unwrap_err();
650 assert!(matches!(err, ToolError::Blocked { .. }));
651 }
652
653 #[test]
654 fn file_rejected() {
655 let err = validate_url("file:///etc/passwd").unwrap_err();
656 assert!(matches!(err, ToolError::Blocked { .. }));
657 }
658
659 #[test]
660 fn invalid_url_rejected() {
661 let err = validate_url("not a url").unwrap_err();
662 assert!(matches!(err, ToolError::Blocked { .. }));
663 }
664
665 #[test]
666 fn localhost_blocked() {
667 let err = validate_url("https://localhost/path").unwrap_err();
668 assert!(matches!(err, ToolError::Blocked { .. }));
669 }
670
671 #[test]
672 fn loopback_ip_blocked() {
673 let err = validate_url("https://127.0.0.1/path").unwrap_err();
674 assert!(matches!(err, ToolError::Blocked { .. }));
675 }
676
677 #[test]
678 fn private_10_blocked() {
679 let err = validate_url("https://10.0.0.1/api").unwrap_err();
680 assert!(matches!(err, ToolError::Blocked { .. }));
681 }
682
683 #[test]
684 fn private_172_blocked() {
685 let err = validate_url("https://172.16.0.1/api").unwrap_err();
686 assert!(matches!(err, ToolError::Blocked { .. }));
687 }
688
689 #[test]
690 fn private_192_blocked() {
691 let err = validate_url("https://192.168.1.1/api").unwrap_err();
692 assert!(matches!(err, ToolError::Blocked { .. }));
693 }
694
695 #[test]
696 fn ipv6_loopback_blocked() {
697 let err = validate_url("https://[::1]/path").unwrap_err();
698 assert!(matches!(err, ToolError::Blocked { .. }));
699 }
700
701 #[test]
702 fn public_ip_allowed() {
703 assert!(validate_url("https://93.184.216.34/page").is_ok());
704 }
705
706 #[test]
709 fn extract_text_from_html() {
710 let html = "<html><body><h1>Hello World</h1><p>Content</p></body></html>";
711 let result = parse_and_extract(html, "h1", &ExtractMode::Text, 10).unwrap();
712 assert_eq!(result, "Hello World");
713 }
714
715 #[test]
716 fn extract_multiple_elements() {
717 let html = "<ul><li>A</li><li>B</li><li>C</li></ul>";
718 let result = parse_and_extract(html, "li", &ExtractMode::Text, 10).unwrap();
719 assert_eq!(result, "A\nB\nC");
720 }
721
722 #[test]
723 fn extract_with_limit() {
724 let html = "<ul><li>A</li><li>B</li><li>C</li></ul>";
725 let result = parse_and_extract(html, "li", &ExtractMode::Text, 2).unwrap();
726 assert_eq!(result, "A\nB");
727 }
728
729 #[test]
730 fn extract_attr_href() {
731 let html = r#"<a href="https://example.com">Link</a>"#;
732 let result =
733 parse_and_extract(html, "a", &ExtractMode::Attr("href".to_owned()), 10).unwrap();
734 assert_eq!(result, "https://example.com");
735 }
736
737 #[test]
738 fn extract_inner_html() {
739 let html = "<div><span>inner</span></div>";
740 let result = parse_and_extract(html, "div", &ExtractMode::Html, 10).unwrap();
741 assert!(result.contains("<span>inner</span>"));
742 }
743
744 #[test]
745 fn no_matches_returns_message() {
746 let html = "<html><body><p>text</p></body></html>";
747 let result = parse_and_extract(html, "h1", &ExtractMode::Text, 10).unwrap();
748 assert!(result.starts_with("No results for selector:"));
749 }
750
751 #[test]
752 fn empty_text_skipped() {
753 let html = "<ul><li> </li><li>A</li></ul>";
754 let result = parse_and_extract(html, "li", &ExtractMode::Text, 10).unwrap();
755 assert_eq!(result, "A");
756 }
757
758 #[test]
759 fn invalid_selector_errors() {
760 let html = "<html><body></body></html>";
761 let result = parse_and_extract(html, "[[[invalid", &ExtractMode::Text, 10);
762 assert!(result.is_err());
763 }
764
765 #[test]
766 fn empty_html_returns_no_results() {
767 let result = parse_and_extract("", "h1", &ExtractMode::Text, 10).unwrap();
768 assert!(result.starts_with("No results for selector:"));
769 }
770
771 #[test]
772 fn nested_selector() {
773 let html = "<div><span>inner</span></div><span>outer</span>";
774 let result = parse_and_extract(html, "div > span", &ExtractMode::Text, 10).unwrap();
775 assert_eq!(result, "inner");
776 }
777
778 #[test]
779 fn attr_missing_returns_empty() {
780 let html = r#"<a>No href</a>"#;
781 let result =
782 parse_and_extract(html, "a", &ExtractMode::Attr("href".to_owned()), 10).unwrap();
783 assert!(result.starts_with("No results for selector:"));
784 }
785
786 #[test]
787 fn extract_html_mode() {
788 let html = "<div><b>bold</b> text</div>";
789 let result = parse_and_extract(html, "div", &ExtractMode::Html, 10).unwrap();
790 assert!(result.contains("<b>bold</b>"));
791 }
792
793 #[test]
794 fn limit_zero_returns_no_results() {
795 let html = "<ul><li>A</li><li>B</li></ul>";
796 let result = parse_and_extract(html, "li", &ExtractMode::Text, 0).unwrap();
797 assert!(result.starts_with("No results for selector:"));
798 }
799
800 #[test]
803 fn url_with_port_allowed() {
804 assert!(validate_url("https://example.com:8443/path").is_ok());
805 }
806
807 #[test]
808 fn link_local_ip_blocked() {
809 let err = validate_url("https://169.254.1.1/path").unwrap_err();
810 assert!(matches!(err, ToolError::Blocked { .. }));
811 }
812
813 #[test]
814 fn url_no_scheme_rejected() {
815 let err = validate_url("example.com/path").unwrap_err();
816 assert!(matches!(err, ToolError::Blocked { .. }));
817 }
818
819 #[test]
820 fn unspecified_ipv4_blocked() {
821 let err = validate_url("https://0.0.0.0/path").unwrap_err();
822 assert!(matches!(err, ToolError::Blocked { .. }));
823 }
824
825 #[test]
826 fn broadcast_ipv4_blocked() {
827 let err = validate_url("https://255.255.255.255/path").unwrap_err();
828 assert!(matches!(err, ToolError::Blocked { .. }));
829 }
830
831 #[test]
832 fn ipv6_link_local_blocked() {
833 let err = validate_url("https://[fe80::1]/path").unwrap_err();
834 assert!(matches!(err, ToolError::Blocked { .. }));
835 }
836
837 #[test]
838 fn ipv6_unique_local_blocked() {
839 let err = validate_url("https://[fd12::1]/path").unwrap_err();
840 assert!(matches!(err, ToolError::Blocked { .. }));
841 }
842
843 #[test]
844 fn ipv4_mapped_ipv6_loopback_blocked() {
845 let err = validate_url("https://[::ffff:127.0.0.1]/path").unwrap_err();
846 assert!(matches!(err, ToolError::Blocked { .. }));
847 }
848
849 #[test]
850 fn ipv4_mapped_ipv6_private_blocked() {
851 let err = validate_url("https://[::ffff:10.0.0.1]/path").unwrap_err();
852 assert!(matches!(err, ToolError::Blocked { .. }));
853 }
854
855 #[tokio::test]
858 async fn executor_no_blocks_returns_none() {
859 let config = ScrapeConfig::default();
860 let executor = WebScrapeExecutor::new(&config);
861 let result = executor.execute("plain text").await;
862 assert!(result.unwrap().is_none());
863 }
864
865 #[tokio::test]
866 async fn executor_invalid_json_errors() {
867 let config = ScrapeConfig::default();
868 let executor = WebScrapeExecutor::new(&config);
869 let response = "```scrape\nnot json\n```";
870 let result = executor.execute(response).await;
871 assert!(matches!(result, Err(ToolError::Execution(_))));
872 }
873
874 #[tokio::test]
875 async fn executor_blocked_url_errors() {
876 let config = ScrapeConfig::default();
877 let executor = WebScrapeExecutor::new(&config);
878 let response = "```scrape\n{\"url\":\"http://example.com\",\"select\":\"h1\"}\n```";
879 let result = executor.execute(response).await;
880 assert!(matches!(result, Err(ToolError::Blocked { .. })));
881 }
882
883 #[tokio::test]
884 async fn executor_private_ip_blocked() {
885 let config = ScrapeConfig::default();
886 let executor = WebScrapeExecutor::new(&config);
887 let response = "```scrape\n{\"url\":\"https://192.168.1.1/api\",\"select\":\"h1\"}\n```";
888 let result = executor.execute(response).await;
889 assert!(matches!(result, Err(ToolError::Blocked { .. })));
890 }
891
892 #[tokio::test]
893 async fn executor_unreachable_host_returns_error() {
894 let config = ScrapeConfig {
895 timeout: 1,
896 max_body_bytes: 1_048_576,
897 };
898 let executor = WebScrapeExecutor::new(&config);
899 let response = "```scrape\n{\"url\":\"https://192.0.2.1:1/page\",\"select\":\"h1\"}\n```";
900 let result = executor.execute(response).await;
901 assert!(matches!(result, Err(ToolError::Execution(_))));
902 }
903
904 #[tokio::test]
905 async fn executor_localhost_url_blocked() {
906 let config = ScrapeConfig::default();
907 let executor = WebScrapeExecutor::new(&config);
908 let response = "```scrape\n{\"url\":\"https://localhost:9999/api\",\"select\":\"h1\"}\n```";
909 let result = executor.execute(response).await;
910 assert!(matches!(result, Err(ToolError::Blocked { .. })));
911 }
912
913 #[tokio::test]
914 async fn executor_empty_text_returns_none() {
915 let config = ScrapeConfig::default();
916 let executor = WebScrapeExecutor::new(&config);
917 let result = executor.execute("").await;
918 assert!(result.unwrap().is_none());
919 }
920
921 #[tokio::test]
922 async fn executor_multiple_blocks_first_blocked() {
923 let config = ScrapeConfig::default();
924 let executor = WebScrapeExecutor::new(&config);
925 let response = "```scrape\n{\"url\":\"http://evil.com\",\"select\":\"h1\"}\n```\n\
926 ```scrape\n{\"url\":\"https://ok.com\",\"select\":\"h1\"}\n```";
927 let result = executor.execute(response).await;
928 assert!(result.is_err());
929 }
930
931 #[test]
932 fn validate_url_empty_string() {
933 let err = validate_url("").unwrap_err();
934 assert!(matches!(err, ToolError::Blocked { .. }));
935 }
936
937 #[test]
938 fn validate_url_javascript_scheme_blocked() {
939 let err = validate_url("javascript:alert(1)").unwrap_err();
940 assert!(matches!(err, ToolError::Blocked { .. }));
941 }
942
943 #[test]
944 fn validate_url_data_scheme_blocked() {
945 let err = validate_url("data:text/html,<h1>hi</h1>").unwrap_err();
946 assert!(matches!(err, ToolError::Blocked { .. }));
947 }
948
949 #[test]
950 fn is_private_host_public_domain_is_false() {
951 let host: url::Host<&str> = url::Host::Domain("example.com");
952 assert!(!is_private_host(&host));
953 }
954
955 #[test]
956 fn is_private_host_localhost_is_true() {
957 let host: url::Host<&str> = url::Host::Domain("localhost");
958 assert!(is_private_host(&host));
959 }
960
961 #[test]
962 fn is_private_host_ipv6_unspecified_is_true() {
963 let host = url::Host::Ipv6(std::net::Ipv6Addr::UNSPECIFIED);
964 assert!(is_private_host(&host));
965 }
966
967 #[test]
968 fn is_private_host_public_ipv6_is_false() {
969 let host = url::Host::Ipv6("2001:db8::1".parse().unwrap());
970 assert!(!is_private_host(&host));
971 }
972
973 async fn mock_server_executor() -> (WebScrapeExecutor, wiremock::MockServer) {
984 let server = wiremock::MockServer::start().await;
985 let executor = WebScrapeExecutor {
986 timeout: Duration::from_secs(5),
987 max_body_bytes: 1_048_576,
988 audit_logger: None,
989 };
990 (executor, server)
991 }
992
993 fn server_host_and_addr(server: &wiremock::MockServer) -> (String, Vec<std::net::SocketAddr>) {
995 let uri = server.uri();
996 let url = Url::parse(&uri).unwrap();
997 let host = url.host_str().unwrap_or("127.0.0.1").to_owned();
998 let port = url.port().unwrap_or(80);
999 let addr: std::net::SocketAddr = format!("{host}:{port}").parse().unwrap();
1000 (host, vec![addr])
1001 }
1002
1003 async fn follow_redirects_raw(
1007 executor: &WebScrapeExecutor,
1008 start_url: &str,
1009 host: &str,
1010 addrs: &[std::net::SocketAddr],
1011 ) -> Result<String, ToolError> {
1012 const MAX_REDIRECTS: usize = 3;
1013 let mut current_url = start_url.to_owned();
1014 let mut current_host = host.to_owned();
1015 let mut current_addrs = addrs.to_vec();
1016
1017 for hop in 0..=MAX_REDIRECTS {
1018 let client = executor.build_client(¤t_host, ¤t_addrs);
1019 let resp = client
1020 .get(¤t_url)
1021 .send()
1022 .await
1023 .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())))?;
1024
1025 let status = resp.status();
1026
1027 if status.is_redirection() {
1028 if hop == MAX_REDIRECTS {
1029 return Err(ToolError::Execution(std::io::Error::other(
1030 "too many redirects",
1031 )));
1032 }
1033
1034 let location = resp
1035 .headers()
1036 .get(reqwest::header::LOCATION)
1037 .and_then(|v| v.to_str().ok())
1038 .ok_or_else(|| {
1039 ToolError::Execution(std::io::Error::other("redirect with no Location"))
1040 })?;
1041
1042 let base = Url::parse(¤t_url)
1043 .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())))?;
1044 let next_url = base
1045 .join(location)
1046 .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())))?;
1047
1048 current_url = next_url.to_string();
1050 let _ = &mut current_host;
1052 let _ = &mut current_addrs;
1053 continue;
1054 }
1055
1056 if !status.is_success() {
1057 return Err(ToolError::Execution(std::io::Error::other(format!(
1058 "HTTP {status}",
1059 ))));
1060 }
1061
1062 let bytes = resp
1063 .bytes()
1064 .await
1065 .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())))?;
1066
1067 if bytes.len() > executor.max_body_bytes {
1068 return Err(ToolError::Execution(std::io::Error::other(format!(
1069 "response too large: {} bytes (max: {})",
1070 bytes.len(),
1071 executor.max_body_bytes,
1072 ))));
1073 }
1074
1075 return String::from_utf8(bytes.to_vec())
1076 .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())));
1077 }
1078
1079 Err(ToolError::Execution(std::io::Error::other(
1080 "too many redirects",
1081 )))
1082 }
1083
1084 #[tokio::test]
1085 async fn fetch_html_success_returns_body() {
1086 use wiremock::matchers::{method, path};
1087 use wiremock::{Mock, ResponseTemplate};
1088
1089 let (executor, server) = mock_server_executor().await;
1090 Mock::given(method("GET"))
1091 .and(path("/page"))
1092 .respond_with(ResponseTemplate::new(200).set_body_string("<h1>OK</h1>"))
1093 .mount(&server)
1094 .await;
1095
1096 let (host, addrs) = server_host_and_addr(&server);
1097 let url = format!("{}/page", server.uri());
1098 let result = executor.fetch_html(&url, &host, &addrs).await;
1099 assert!(result.is_ok(), "expected Ok, got: {result:?}");
1100 assert_eq!(result.unwrap(), "<h1>OK</h1>");
1101 }
1102
1103 #[tokio::test]
1104 async fn fetch_html_non_2xx_returns_error() {
1105 use wiremock::matchers::{method, path};
1106 use wiremock::{Mock, ResponseTemplate};
1107
1108 let (executor, server) = mock_server_executor().await;
1109 Mock::given(method("GET"))
1110 .and(path("/forbidden"))
1111 .respond_with(ResponseTemplate::new(403))
1112 .mount(&server)
1113 .await;
1114
1115 let (host, addrs) = server_host_and_addr(&server);
1116 let url = format!("{}/forbidden", server.uri());
1117 let result = executor.fetch_html(&url, &host, &addrs).await;
1118 assert!(result.is_err());
1119 let msg = result.unwrap_err().to_string();
1120 assert!(msg.contains("403"), "expected 403 in error: {msg}");
1121 }
1122
1123 #[tokio::test]
1124 async fn fetch_html_404_returns_error() {
1125 use wiremock::matchers::{method, path};
1126 use wiremock::{Mock, ResponseTemplate};
1127
1128 let (executor, server) = mock_server_executor().await;
1129 Mock::given(method("GET"))
1130 .and(path("/missing"))
1131 .respond_with(ResponseTemplate::new(404))
1132 .mount(&server)
1133 .await;
1134
1135 let (host, addrs) = server_host_and_addr(&server);
1136 let url = format!("{}/missing", server.uri());
1137 let result = executor.fetch_html(&url, &host, &addrs).await;
1138 assert!(result.is_err());
1139 let msg = result.unwrap_err().to_string();
1140 assert!(msg.contains("404"), "expected 404 in error: {msg}");
1141 }
1142
1143 #[tokio::test]
1144 async fn fetch_html_redirect_no_location_returns_error() {
1145 use wiremock::matchers::{method, path};
1146 use wiremock::{Mock, ResponseTemplate};
1147
1148 let (executor, server) = mock_server_executor().await;
1149 Mock::given(method("GET"))
1151 .and(path("/redirect-no-loc"))
1152 .respond_with(ResponseTemplate::new(302))
1153 .mount(&server)
1154 .await;
1155
1156 let (host, addrs) = server_host_and_addr(&server);
1157 let url = format!("{}/redirect-no-loc", server.uri());
1158 let result = executor.fetch_html(&url, &host, &addrs).await;
1159 assert!(result.is_err());
1160 let msg = result.unwrap_err().to_string();
1161 assert!(
1162 msg.contains("Location") || msg.contains("location"),
1163 "expected Location-related error: {msg}"
1164 );
1165 }
1166
1167 #[tokio::test]
1168 async fn fetch_html_single_redirect_followed() {
1169 use wiremock::matchers::{method, path};
1170 use wiremock::{Mock, ResponseTemplate};
1171
1172 let (executor, server) = mock_server_executor().await;
1173 let final_url = format!("{}/final", server.uri());
1174
1175 Mock::given(method("GET"))
1176 .and(path("/start"))
1177 .respond_with(ResponseTemplate::new(302).insert_header("location", final_url.as_str()))
1178 .mount(&server)
1179 .await;
1180
1181 Mock::given(method("GET"))
1182 .and(path("/final"))
1183 .respond_with(ResponseTemplate::new(200).set_body_string("<p>final</p>"))
1184 .mount(&server)
1185 .await;
1186
1187 let (host, addrs) = server_host_and_addr(&server);
1188 let url = format!("{}/start", server.uri());
1189 let result = follow_redirects_raw(&executor, &url, &host, &addrs).await;
1190 assert!(result.is_ok(), "single redirect should succeed: {result:?}");
1191 assert_eq!(result.unwrap(), "<p>final</p>");
1192 }
1193
1194 #[tokio::test]
1195 async fn fetch_html_three_redirects_allowed() {
1196 use wiremock::matchers::{method, path};
1197 use wiremock::{Mock, ResponseTemplate};
1198
1199 let (executor, server) = mock_server_executor().await;
1200 let hop2 = format!("{}/hop2", server.uri());
1201 let hop3 = format!("{}/hop3", server.uri());
1202 let final_dest = format!("{}/done", server.uri());
1203
1204 Mock::given(method("GET"))
1205 .and(path("/hop1"))
1206 .respond_with(ResponseTemplate::new(301).insert_header("location", hop2.as_str()))
1207 .mount(&server)
1208 .await;
1209 Mock::given(method("GET"))
1210 .and(path("/hop2"))
1211 .respond_with(ResponseTemplate::new(301).insert_header("location", hop3.as_str()))
1212 .mount(&server)
1213 .await;
1214 Mock::given(method("GET"))
1215 .and(path("/hop3"))
1216 .respond_with(ResponseTemplate::new(301).insert_header("location", final_dest.as_str()))
1217 .mount(&server)
1218 .await;
1219 Mock::given(method("GET"))
1220 .and(path("/done"))
1221 .respond_with(ResponseTemplate::new(200).set_body_string("<p>done</p>"))
1222 .mount(&server)
1223 .await;
1224
1225 let (host, addrs) = server_host_and_addr(&server);
1226 let url = format!("{}/hop1", server.uri());
1227 let result = follow_redirects_raw(&executor, &url, &host, &addrs).await;
1228 assert!(result.is_ok(), "3 redirects should succeed: {result:?}");
1229 assert_eq!(result.unwrap(), "<p>done</p>");
1230 }
1231
1232 #[tokio::test]
1233 async fn fetch_html_four_redirects_rejected() {
1234 use wiremock::matchers::{method, path};
1235 use wiremock::{Mock, ResponseTemplate};
1236
1237 let (executor, server) = mock_server_executor().await;
1238 let hop2 = format!("{}/r2", server.uri());
1239 let hop3 = format!("{}/r3", server.uri());
1240 let hop4 = format!("{}/r4", server.uri());
1241 let hop5 = format!("{}/r5", server.uri());
1242
1243 for (from, to) in [
1244 ("/r1", &hop2),
1245 ("/r2", &hop3),
1246 ("/r3", &hop4),
1247 ("/r4", &hop5),
1248 ] {
1249 Mock::given(method("GET"))
1250 .and(path(from))
1251 .respond_with(ResponseTemplate::new(301).insert_header("location", to.as_str()))
1252 .mount(&server)
1253 .await;
1254 }
1255
1256 let (host, addrs) = server_host_and_addr(&server);
1257 let url = format!("{}/r1", server.uri());
1258 let result = follow_redirects_raw(&executor, &url, &host, &addrs).await;
1259 assert!(result.is_err(), "4 redirects should be rejected");
1260 let msg = result.unwrap_err().to_string();
1261 assert!(
1262 msg.contains("redirect"),
1263 "expected redirect-related error: {msg}"
1264 );
1265 }
1266
1267 #[tokio::test]
1268 async fn fetch_html_body_too_large_returns_error() {
1269 use wiremock::matchers::{method, path};
1270 use wiremock::{Mock, ResponseTemplate};
1271
1272 let small_limit_executor = WebScrapeExecutor {
1273 timeout: Duration::from_secs(5),
1274 max_body_bytes: 10,
1275 audit_logger: None,
1276 };
1277 let server = wiremock::MockServer::start().await;
1278 Mock::given(method("GET"))
1279 .and(path("/big"))
1280 .respond_with(
1281 ResponseTemplate::new(200)
1282 .set_body_string("this body is definitely longer than ten bytes"),
1283 )
1284 .mount(&server)
1285 .await;
1286
1287 let (host, addrs) = server_host_and_addr(&server);
1288 let url = format!("{}/big", server.uri());
1289 let result = small_limit_executor.fetch_html(&url, &host, &addrs).await;
1290 assert!(result.is_err());
1291 let msg = result.unwrap_err().to_string();
1292 assert!(msg.contains("too large"), "expected too-large error: {msg}");
1293 }
1294
1295 #[test]
1296 fn extract_scrape_blocks_empty_block_content() {
1297 let text = "```scrape\n\n```";
1298 let blocks = extract_scrape_blocks(text);
1299 assert_eq!(blocks.len(), 1);
1300 assert!(blocks[0].is_empty());
1301 }
1302
1303 #[test]
1304 fn extract_scrape_blocks_whitespace_only() {
1305 let text = "```scrape\n \n```";
1306 let blocks = extract_scrape_blocks(text);
1307 assert_eq!(blocks.len(), 1);
1308 }
1309
1310 #[test]
1311 fn parse_and_extract_multiple_selectors() {
1312 let html = "<div><h1>Title</h1><p>Para</p></div>";
1313 let result = parse_and_extract(html, "h1, p", &ExtractMode::Text, 10).unwrap();
1314 assert!(result.contains("Title"));
1315 assert!(result.contains("Para"));
1316 }
1317
1318 #[test]
1319 fn webscrape_executor_new_with_custom_config() {
1320 let config = ScrapeConfig {
1321 timeout: 60,
1322 max_body_bytes: 512,
1323 };
1324 let executor = WebScrapeExecutor::new(&config);
1325 assert_eq!(executor.max_body_bytes, 512);
1326 }
1327
1328 #[test]
1329 fn webscrape_executor_debug() {
1330 let config = ScrapeConfig::default();
1331 let executor = WebScrapeExecutor::new(&config);
1332 let dbg = format!("{executor:?}");
1333 assert!(dbg.contains("WebScrapeExecutor"));
1334 }
1335
1336 #[test]
1337 fn extract_mode_attr_empty_name() {
1338 let mode = ExtractMode::parse("attr:");
1339 assert!(matches!(mode, ExtractMode::Attr(ref s) if s.is_empty()));
1340 }
1341
1342 #[test]
1343 fn default_extract_returns_text() {
1344 assert_eq!(default_extract(), "text");
1345 }
1346
1347 #[test]
1348 fn scrape_instruction_debug() {
1349 let json = r#"{"url":"https://example.com","select":"h1"}"#;
1350 let instr: ScrapeInstruction = serde_json::from_str(json).unwrap();
1351 let dbg = format!("{instr:?}");
1352 assert!(dbg.contains("ScrapeInstruction"));
1353 }
1354
1355 #[test]
1356 fn extract_mode_debug() {
1357 let mode = ExtractMode::Text;
1358 let dbg = format!("{mode:?}");
1359 assert!(dbg.contains("Text"));
1360 }
1361
1362 #[test]
1367 fn max_redirects_constant_is_three() {
1368 const MAX_REDIRECTS: usize = 3;
1372 assert_eq!(MAX_REDIRECTS, 3, "fetch_html allows exactly 3 redirects");
1373 }
1374
1375 #[test]
1378 fn redirect_no_location_error_message() {
1379 let err = std::io::Error::other("redirect with no Location");
1380 assert!(err.to_string().contains("redirect with no Location"));
1381 }
1382
1383 #[test]
1385 fn too_many_redirects_error_message() {
1386 let err = std::io::Error::other("too many redirects");
1387 assert!(err.to_string().contains("too many redirects"));
1388 }
1389
1390 #[test]
1392 fn non_2xx_status_error_format() {
1393 let status = reqwest::StatusCode::FORBIDDEN;
1394 let msg = format!("HTTP {status}");
1395 assert!(msg.contains("403"));
1396 }
1397
1398 #[test]
1400 fn not_found_status_error_format() {
1401 let status = reqwest::StatusCode::NOT_FOUND;
1402 let msg = format!("HTTP {status}");
1403 assert!(msg.contains("404"));
1404 }
1405
1406 #[test]
1408 fn relative_redirect_same_host_path() {
1409 let base = Url::parse("https://example.com/current").unwrap();
1410 let resolved = base.join("/other").unwrap();
1411 assert_eq!(resolved.as_str(), "https://example.com/other");
1412 }
1413
1414 #[test]
1416 fn relative_redirect_relative_path() {
1417 let base = Url::parse("https://example.com/a/b").unwrap();
1418 let resolved = base.join("c").unwrap();
1419 assert_eq!(resolved.as_str(), "https://example.com/a/c");
1420 }
1421
1422 #[test]
1424 fn absolute_redirect_overrides_base() {
1425 let base = Url::parse("https://example.com/page").unwrap();
1426 let resolved = base.join("https://other.com/target").unwrap();
1427 assert_eq!(resolved.as_str(), "https://other.com/target");
1428 }
1429
1430 #[test]
1432 fn redirect_http_downgrade_rejected() {
1433 let location = "http://example.com/page";
1434 let base = Url::parse("https://example.com/start").unwrap();
1435 let next = base.join(location).unwrap();
1436 let err = validate_url(next.as_str()).unwrap_err();
1437 assert!(matches!(err, ToolError::Blocked { .. }));
1438 }
1439
1440 #[test]
1442 fn redirect_location_private_ip_blocked() {
1443 let location = "https://192.168.100.1/admin";
1444 let base = Url::parse("https://example.com/start").unwrap();
1445 let next = base.join(location).unwrap();
1446 let err = validate_url(next.as_str()).unwrap_err();
1447 assert!(matches!(err, ToolError::Blocked { .. }));
1448 let cmd = match err {
1449 ToolError::Blocked { command } => command,
1450 _ => panic!("expected Blocked"),
1451 };
1452 assert!(
1453 cmd.contains("private") || cmd.contains("scheme"),
1454 "error message should describe the block reason: {cmd}"
1455 );
1456 }
1457
1458 #[test]
1460 fn redirect_location_internal_domain_blocked() {
1461 let location = "https://metadata.internal/latest/meta-data/";
1462 let base = Url::parse("https://example.com/start").unwrap();
1463 let next = base.join(location).unwrap();
1464 let err = validate_url(next.as_str()).unwrap_err();
1465 assert!(matches!(err, ToolError::Blocked { .. }));
1466 }
1467
1468 #[test]
1470 fn redirect_chain_three_hops_all_public() {
1471 let hops = [
1472 "https://redirect1.example.com/hop1",
1473 "https://redirect2.example.com/hop2",
1474 "https://destination.example.com/final",
1475 ];
1476 for hop in hops {
1477 assert!(validate_url(hop).is_ok(), "expected ok for {hop}");
1478 }
1479 }
1480
1481 #[test]
1486 fn redirect_to_private_ip_rejected_by_validate_url() {
1487 let private_targets = [
1489 "https://127.0.0.1/secret",
1490 "https://10.0.0.1/internal",
1491 "https://192.168.1.1/admin",
1492 "https://172.16.0.1/data",
1493 "https://[::1]/path",
1494 "https://[fe80::1]/path",
1495 "https://localhost/path",
1496 "https://service.internal/api",
1497 ];
1498 for target in private_targets {
1499 let result = validate_url(target);
1500 assert!(result.is_err(), "expected error for {target}");
1501 assert!(
1502 matches!(result.unwrap_err(), ToolError::Blocked { .. }),
1503 "expected Blocked for {target}"
1504 );
1505 }
1506 }
1507
1508 #[test]
1510 fn redirect_relative_url_resolves_correctly() {
1511 let base = Url::parse("https://example.com/page").unwrap();
1512 let relative = "/other";
1513 let resolved = base.join(relative).unwrap();
1514 assert_eq!(resolved.as_str(), "https://example.com/other");
1515 }
1516
1517 #[test]
1519 fn redirect_to_http_rejected() {
1520 let err = validate_url("http://example.com/page").unwrap_err();
1521 assert!(matches!(err, ToolError::Blocked { .. }));
1522 }
1523
1524 #[test]
1525 fn ipv4_mapped_ipv6_link_local_blocked() {
1526 let err = validate_url("https://[::ffff:169.254.0.1]/path").unwrap_err();
1527 assert!(matches!(err, ToolError::Blocked { .. }));
1528 }
1529
1530 #[test]
1531 fn ipv4_mapped_ipv6_public_allowed() {
1532 assert!(validate_url("https://[::ffff:93.184.216.34]/path").is_ok());
1533 }
1534
1535 #[tokio::test]
1538 async fn fetch_http_scheme_blocked() {
1539 let config = ScrapeConfig::default();
1540 let executor = WebScrapeExecutor::new(&config);
1541 let call = crate::executor::ToolCall {
1542 tool_id: "fetch".to_owned(),
1543 params: {
1544 let mut m = serde_json::Map::new();
1545 m.insert("url".to_owned(), serde_json::json!("http://example.com"));
1546 m
1547 },
1548 };
1549 let result = executor.execute_tool_call(&call).await;
1550 assert!(matches!(result, Err(ToolError::Blocked { .. })));
1551 }
1552
1553 #[tokio::test]
1554 async fn fetch_private_ip_blocked() {
1555 let config = ScrapeConfig::default();
1556 let executor = WebScrapeExecutor::new(&config);
1557 let call = crate::executor::ToolCall {
1558 tool_id: "fetch".to_owned(),
1559 params: {
1560 let mut m = serde_json::Map::new();
1561 m.insert(
1562 "url".to_owned(),
1563 serde_json::json!("https://192.168.1.1/secret"),
1564 );
1565 m
1566 },
1567 };
1568 let result = executor.execute_tool_call(&call).await;
1569 assert!(matches!(result, Err(ToolError::Blocked { .. })));
1570 }
1571
1572 #[tokio::test]
1573 async fn fetch_localhost_blocked() {
1574 let config = ScrapeConfig::default();
1575 let executor = WebScrapeExecutor::new(&config);
1576 let call = crate::executor::ToolCall {
1577 tool_id: "fetch".to_owned(),
1578 params: {
1579 let mut m = serde_json::Map::new();
1580 m.insert(
1581 "url".to_owned(),
1582 serde_json::json!("https://localhost/page"),
1583 );
1584 m
1585 },
1586 };
1587 let result = executor.execute_tool_call(&call).await;
1588 assert!(matches!(result, Err(ToolError::Blocked { .. })));
1589 }
1590
1591 #[tokio::test]
1592 async fn fetch_unknown_tool_returns_none() {
1593 let config = ScrapeConfig::default();
1594 let executor = WebScrapeExecutor::new(&config);
1595 let call = crate::executor::ToolCall {
1596 tool_id: "unknown_tool".to_owned(),
1597 params: serde_json::Map::new(),
1598 };
1599 let result = executor.execute_tool_call(&call).await;
1600 assert!(result.unwrap().is_none());
1601 }
1602
1603 #[tokio::test]
1604 async fn fetch_returns_body_via_mock() {
1605 use wiremock::matchers::{method, path};
1606 use wiremock::{Mock, ResponseTemplate};
1607
1608 let (executor, server) = mock_server_executor().await;
1609 Mock::given(method("GET"))
1610 .and(path("/content"))
1611 .respond_with(ResponseTemplate::new(200).set_body_string("plain text content"))
1612 .mount(&server)
1613 .await;
1614
1615 let (host, addrs) = server_host_and_addr(&server);
1616 let url = format!("{}/content", server.uri());
1617 let result = executor.fetch_html(&url, &host, &addrs).await;
1618 assert!(result.is_ok());
1619 assert_eq!(result.unwrap(), "plain text content");
1620 }
1621
1622 #[test]
1623 fn tool_definitions_returns_web_scrape_and_fetch() {
1624 let config = ScrapeConfig::default();
1625 let executor = WebScrapeExecutor::new(&config);
1626 let defs = executor.tool_definitions();
1627 assert_eq!(defs.len(), 2);
1628 assert_eq!(defs[0].id, "web_scrape");
1629 assert_eq!(
1630 defs[0].invocation,
1631 crate::registry::InvocationHint::FencedBlock("scrape")
1632 );
1633 assert_eq!(defs[1].id, "fetch");
1634 assert_eq!(
1635 defs[1].invocation,
1636 crate::registry::InvocationHint::ToolCall
1637 );
1638 }
1639
1640 #[test]
1641 fn tool_definitions_schema_has_all_params() {
1642 let config = ScrapeConfig::default();
1643 let executor = WebScrapeExecutor::new(&config);
1644 let defs = executor.tool_definitions();
1645 let obj = defs[0].schema.as_object().unwrap();
1646 let props = obj["properties"].as_object().unwrap();
1647 assert!(props.contains_key("url"));
1648 assert!(props.contains_key("select"));
1649 assert!(props.contains_key("extract"));
1650 assert!(props.contains_key("limit"));
1651 let req = obj["required"].as_array().unwrap();
1652 assert!(req.iter().any(|v| v.as_str() == Some("url")));
1653 assert!(req.iter().any(|v| v.as_str() == Some("select")));
1654 assert!(!req.iter().any(|v| v.as_str() == Some("extract")));
1655 }
1656
1657 #[test]
1660 fn subdomain_localhost_blocked() {
1661 let host: url::Host<&str> = url::Host::Domain("foo.localhost");
1662 assert!(is_private_host(&host));
1663 }
1664
1665 #[test]
1666 fn internal_tld_blocked() {
1667 let host: url::Host<&str> = url::Host::Domain("service.internal");
1668 assert!(is_private_host(&host));
1669 }
1670
1671 #[test]
1672 fn local_tld_blocked() {
1673 let host: url::Host<&str> = url::Host::Domain("printer.local");
1674 assert!(is_private_host(&host));
1675 }
1676
1677 #[test]
1678 fn public_domain_not_blocked() {
1679 let host: url::Host<&str> = url::Host::Domain("example.com");
1680 assert!(!is_private_host(&host));
1681 }
1682
1683 #[tokio::test]
1686 async fn resolve_loopback_rejected() {
1687 let url = url::Url::parse("https://127.0.0.1/path").unwrap();
1689 let result = resolve_and_validate(&url).await;
1691 assert!(
1692 result.is_err(),
1693 "loopback IP must be rejected by resolve_and_validate"
1694 );
1695 let err = result.unwrap_err();
1696 assert!(matches!(err, crate::executor::ToolError::Blocked { .. }));
1697 }
1698
1699 #[tokio::test]
1700 async fn resolve_private_10_rejected() {
1701 let url = url::Url::parse("https://10.0.0.1/path").unwrap();
1702 let result = resolve_and_validate(&url).await;
1703 assert!(result.is_err());
1704 assert!(matches!(
1705 result.unwrap_err(),
1706 crate::executor::ToolError::Blocked { .. }
1707 ));
1708 }
1709
1710 #[tokio::test]
1711 async fn resolve_private_192_rejected() {
1712 let url = url::Url::parse("https://192.168.1.1/path").unwrap();
1713 let result = resolve_and_validate(&url).await;
1714 assert!(result.is_err());
1715 assert!(matches!(
1716 result.unwrap_err(),
1717 crate::executor::ToolError::Blocked { .. }
1718 ));
1719 }
1720
1721 #[tokio::test]
1722 async fn resolve_ipv6_loopback_rejected() {
1723 let url = url::Url::parse("https://[::1]/path").unwrap();
1724 let result = resolve_and_validate(&url).await;
1725 assert!(result.is_err());
1726 assert!(matches!(
1727 result.unwrap_err(),
1728 crate::executor::ToolError::Blocked { .. }
1729 ));
1730 }
1731
1732 #[tokio::test]
1733 async fn resolve_no_host_returns_ok() {
1734 let url = url::Url::parse("https://example.com/path").unwrap();
1736 let url_no_host = url::Url::parse("data:text/plain,hello").unwrap();
1738 let result = resolve_and_validate(&url_no_host).await;
1740 assert!(result.is_ok());
1741 let (host, addrs) = result.unwrap();
1742 assert!(host.is_empty());
1743 assert!(addrs.is_empty());
1744 drop(url);
1745 drop(url_no_host);
1746 }
1747
1748 async fn make_file_audit_logger(
1752 dir: &tempfile::TempDir,
1753 ) -> (
1754 std::sync::Arc<crate::audit::AuditLogger>,
1755 std::path::PathBuf,
1756 ) {
1757 use crate::audit::AuditLogger;
1758 use crate::config::AuditConfig;
1759 let path = dir.path().join("audit.log");
1760 let config = AuditConfig {
1761 enabled: true,
1762 destination: path.display().to_string(),
1763 };
1764 let logger = std::sync::Arc::new(AuditLogger::from_config(&config).await.unwrap());
1765 (logger, path)
1766 }
1767
1768 #[tokio::test]
1769 async fn with_audit_sets_logger() {
1770 let config = ScrapeConfig::default();
1771 let executor = WebScrapeExecutor::new(&config);
1772 assert!(executor.audit_logger.is_none());
1773
1774 let dir = tempfile::tempdir().unwrap();
1775 let (logger, _path) = make_file_audit_logger(&dir).await;
1776 let executor = executor.with_audit(logger);
1777 assert!(executor.audit_logger.is_some());
1778 }
1779
1780 #[test]
1781 fn tool_error_to_audit_result_blocked_maps_correctly() {
1782 let err = ToolError::Blocked {
1783 command: "scheme not allowed: http".into(),
1784 };
1785 let result = tool_error_to_audit_result(&err);
1786 assert!(
1787 matches!(result, AuditResult::Blocked { reason } if reason == "scheme not allowed: http")
1788 );
1789 }
1790
1791 #[test]
1792 fn tool_error_to_audit_result_timeout_maps_correctly() {
1793 let err = ToolError::Timeout { timeout_secs: 15 };
1794 let result = tool_error_to_audit_result(&err);
1795 assert!(matches!(result, AuditResult::Timeout));
1796 }
1797
1798 #[test]
1799 fn tool_error_to_audit_result_execution_error_maps_correctly() {
1800 let err = ToolError::Execution(std::io::Error::other("connection refused"));
1801 let result = tool_error_to_audit_result(&err);
1802 assert!(
1803 matches!(result, AuditResult::Error { message } if message.contains("connection refused"))
1804 );
1805 }
1806
1807 #[tokio::test]
1808 async fn fetch_audit_blocked_url_logged() {
1809 let dir = tempfile::tempdir().unwrap();
1810 let (logger, log_path) = make_file_audit_logger(&dir).await;
1811
1812 let config = ScrapeConfig::default();
1813 let executor = WebScrapeExecutor::new(&config).with_audit(logger);
1814
1815 let call = crate::executor::ToolCall {
1816 tool_id: "fetch".to_owned(),
1817 params: {
1818 let mut m = serde_json::Map::new();
1819 m.insert("url".to_owned(), serde_json::json!("http://example.com"));
1820 m
1821 },
1822 };
1823 let result = executor.execute_tool_call(&call).await;
1824 assert!(matches!(result, Err(ToolError::Blocked { .. })));
1825
1826 let content = tokio::fs::read_to_string(&log_path).await.unwrap();
1827 assert!(
1828 content.contains("\"tool\":\"fetch\""),
1829 "expected tool=fetch in audit: {content}"
1830 );
1831 assert!(
1832 content.contains("\"type\":\"blocked\""),
1833 "expected type=blocked in audit: {content}"
1834 );
1835 assert!(
1836 content.contains("http://example.com"),
1837 "expected URL in audit command field: {content}"
1838 );
1839 }
1840
1841 #[tokio::test]
1842 async fn log_audit_success_writes_to_file() {
1843 let dir = tempfile::tempdir().unwrap();
1844 let (logger, log_path) = make_file_audit_logger(&dir).await;
1845
1846 let config = ScrapeConfig::default();
1847 let executor = WebScrapeExecutor::new(&config).with_audit(logger);
1848
1849 executor
1850 .log_audit(
1851 "fetch",
1852 "https://example.com/page",
1853 AuditResult::Success,
1854 42,
1855 )
1856 .await;
1857
1858 let content = tokio::fs::read_to_string(&log_path).await.unwrap();
1859 assert!(
1860 content.contains("\"tool\":\"fetch\""),
1861 "expected tool=fetch in audit: {content}"
1862 );
1863 assert!(
1864 content.contains("\"type\":\"success\""),
1865 "expected type=success in audit: {content}"
1866 );
1867 assert!(
1868 content.contains("\"command\":\"https://example.com/page\""),
1869 "expected command URL in audit: {content}"
1870 );
1871 assert!(
1872 content.contains("\"duration_ms\":42"),
1873 "expected duration_ms in audit: {content}"
1874 );
1875 }
1876
1877 #[tokio::test]
1878 async fn log_audit_blocked_writes_to_file() {
1879 let dir = tempfile::tempdir().unwrap();
1880 let (logger, log_path) = make_file_audit_logger(&dir).await;
1881
1882 let config = ScrapeConfig::default();
1883 let executor = WebScrapeExecutor::new(&config).with_audit(logger);
1884
1885 executor
1886 .log_audit(
1887 "web_scrape",
1888 "http://evil.com/page",
1889 AuditResult::Blocked {
1890 reason: "scheme not allowed: http".into(),
1891 },
1892 0,
1893 )
1894 .await;
1895
1896 let content = tokio::fs::read_to_string(&log_path).await.unwrap();
1897 assert!(
1898 content.contains("\"tool\":\"web_scrape\""),
1899 "expected tool=web_scrape in audit: {content}"
1900 );
1901 assert!(
1902 content.contains("\"type\":\"blocked\""),
1903 "expected type=blocked in audit: {content}"
1904 );
1905 assert!(
1906 content.contains("scheme not allowed"),
1907 "expected block reason in audit: {content}"
1908 );
1909 }
1910
1911 #[tokio::test]
1912 async fn web_scrape_audit_blocked_url_logged() {
1913 let dir = tempfile::tempdir().unwrap();
1914 let (logger, log_path) = make_file_audit_logger(&dir).await;
1915
1916 let config = ScrapeConfig::default();
1917 let executor = WebScrapeExecutor::new(&config).with_audit(logger);
1918
1919 let call = crate::executor::ToolCall {
1920 tool_id: "web_scrape".to_owned(),
1921 params: {
1922 let mut m = serde_json::Map::new();
1923 m.insert("url".to_owned(), serde_json::json!("http://example.com"));
1924 m.insert("select".to_owned(), serde_json::json!("h1"));
1925 m
1926 },
1927 };
1928 let result = executor.execute_tool_call(&call).await;
1929 assert!(matches!(result, Err(ToolError::Blocked { .. })));
1930
1931 let content = tokio::fs::read_to_string(&log_path).await.unwrap();
1932 assert!(
1933 content.contains("\"tool\":\"web_scrape\""),
1934 "expected tool=web_scrape in audit: {content}"
1935 );
1936 assert!(
1937 content.contains("\"type\":\"blocked\""),
1938 "expected type=blocked in audit: {content}"
1939 );
1940 }
1941
1942 #[tokio::test]
1943 async fn no_audit_logger_does_not_panic_on_blocked_fetch() {
1944 let config = ScrapeConfig::default();
1945 let executor = WebScrapeExecutor::new(&config);
1946 assert!(executor.audit_logger.is_none());
1947
1948 let call = crate::executor::ToolCall {
1949 tool_id: "fetch".to_owned(),
1950 params: {
1951 let mut m = serde_json::Map::new();
1952 m.insert("url".to_owned(), serde_json::json!("http://example.com"));
1953 m
1954 },
1955 };
1956 let result = executor.execute_tool_call(&call).await;
1958 assert!(matches!(result, Err(ToolError::Blocked { .. })));
1959 }
1960
1961 #[tokio::test]
1963 async fn fetch_execute_tool_call_end_to_end() {
1964 use wiremock::matchers::{method, path};
1965 use wiremock::{Mock, ResponseTemplate};
1966
1967 let (executor, server) = mock_server_executor().await;
1968 Mock::given(method("GET"))
1969 .and(path("/e2e"))
1970 .respond_with(ResponseTemplate::new(200).set_body_string("<h1>end-to-end</h1>"))
1971 .mount(&server)
1972 .await;
1973
1974 let (host, addrs) = server_host_and_addr(&server);
1975 let result = executor
1977 .fetch_html(&format!("{}/e2e", server.uri()), &host, &addrs)
1978 .await;
1979 assert!(result.is_ok());
1980 assert!(result.unwrap().contains("end-to-end"));
1981 }
1982}