1use std::net::{IpAddr, SocketAddr};
20use std::sync::Arc;
21use std::sync::atomic::{AtomicU64, Ordering};
22use std::time::{Duration, Instant};
23
24use schemars::JsonSchema;
25use serde::Deserialize;
26use url::Url;
27
28use zeph_common::ToolName;
29
30use crate::audit::{AuditEntry, AuditLogger, AuditResult, EgressEvent, chrono_now};
31use crate::config::{EgressConfig, ScrapeConfig};
32use crate::executor::{
33 ClaimSource, ToolCall, ToolError, ToolExecutor, ToolOutput, deserialize_params,
34};
35use crate::net::is_private_ip;
36
37fn redact_url_for_log(url: &str) -> String {
41 let Ok(mut parsed) = Url::parse(url) else {
42 return url.to_owned();
43 };
44 let _ = parsed.set_username("");
46 let _ = parsed.set_password(None);
47 let sensitive = [
49 "token", "key", "secret", "password", "auth", "sig", "api_key", "apikey",
50 ];
51 let filtered: Vec<(String, String)> = parsed
52 .query_pairs()
53 .filter(|(k, _)| {
54 let lower = k.to_lowercase();
55 !sensitive.iter().any(|s| lower.contains(s))
56 })
57 .map(|(k, v)| (k.into_owned(), v.into_owned()))
58 .collect();
59 if filtered.is_empty() {
60 parsed.set_query(None);
61 } else {
62 let q: String = filtered
63 .iter()
64 .map(|(k, v)| format!("{k}={v}"))
65 .collect::<Vec<_>>()
66 .join("&");
67 parsed.set_query(Some(&q));
68 }
69 parsed.to_string()
70}
71
72#[derive(Debug, Deserialize, JsonSchema)]
73struct FetchParams {
74 url: String,
76}
77
78#[derive(Debug, Deserialize, JsonSchema)]
79struct ScrapeInstruction {
80 url: String,
82 select: String,
84 #[serde(default = "default_extract")]
86 extract: String,
87 limit: Option<usize>,
89}
90
91fn default_extract() -> String {
92 "text".into()
93}
94
95#[derive(Debug)]
96enum ExtractMode {
97 Text,
98 Html,
99 Attr(String),
100}
101
102impl ExtractMode {
103 fn parse(s: &str) -> Self {
104 match s {
105 "text" => Self::Text,
106 "html" => Self::Html,
107 attr if attr.starts_with("attr:") => {
108 Self::Attr(attr.strip_prefix("attr:").unwrap_or(attr).to_owned())
109 }
110 _ => Self::Text,
111 }
112 }
113}
114
115#[derive(Debug)]
156pub struct WebScrapeExecutor {
157 timeout: Duration,
158 max_body_bytes: usize,
159 allowed_domains: Vec<String>,
160 denied_domains: Vec<String>,
161 audit_logger: Option<Arc<AuditLogger>>,
162 egress_config: EgressConfig,
163 egress_tx: Option<tokio::sync::mpsc::Sender<EgressEvent>>,
164 egress_dropped: Arc<AtomicU64>,
165}
166
167impl WebScrapeExecutor {
168 #[must_use]
172 pub fn new(config: &ScrapeConfig) -> Self {
173 Self {
174 timeout: Duration::from_secs(config.timeout),
175 max_body_bytes: config.max_body_bytes,
176 allowed_domains: config.allowed_domains.clone(),
177 denied_domains: config.denied_domains.clone(),
178 audit_logger: None,
179 egress_config: EgressConfig::default(),
180 egress_tx: None,
181 egress_dropped: Arc::new(AtomicU64::new(0)),
182 }
183 }
184
185 #[must_use]
187 pub fn with_audit(mut self, logger: Arc<AuditLogger>) -> Self {
188 self.audit_logger = Some(logger);
189 self
190 }
191
192 #[must_use]
194 pub fn with_egress_config(mut self, config: EgressConfig) -> Self {
195 self.egress_config = config;
196 self
197 }
198
199 #[must_use]
204 pub fn with_egress_tx(
205 mut self,
206 tx: tokio::sync::mpsc::Sender<EgressEvent>,
207 dropped: Arc<AtomicU64>,
208 ) -> Self {
209 self.egress_tx = Some(tx);
210 self.egress_dropped = dropped;
211 self
212 }
213
214 #[must_use]
216 pub fn egress_dropped(&self) -> Arc<AtomicU64> {
217 Arc::clone(&self.egress_dropped)
218 }
219
220 fn build_client(&self, host: &str, addrs: &[SocketAddr]) -> reqwest::Client {
221 let mut builder = reqwest::Client::builder()
222 .timeout(self.timeout)
223 .redirect(reqwest::redirect::Policy::none());
224 builder = builder.resolve_to_addrs(host, addrs);
225 builder.build().unwrap_or_default()
226 }
227}
228
229impl ToolExecutor for WebScrapeExecutor {
230 fn tool_definitions(&self) -> Vec<crate::registry::ToolDef> {
231 use crate::registry::{InvocationHint, ToolDef};
232 vec![
233 ToolDef {
234 id: "web_scrape".into(),
235 description: "Extract structured data from a web page using CSS selectors.\n\nONLY call this tool when the user has explicitly provided a URL in their message, or when a prior tool call returned a URL to retrieve. NEVER construct, guess, or infer a URL from entity names, brand knowledge, or domain patterns.\n\nParameters: url (string, required) - HTTPS URL; select (string, required) - CSS selector; extract (string, optional) - \"text\", \"html\", or \"attr:<name>\"; limit (integer, optional) - max results\nReturns: extracted text/HTML/attribute values, one per line\nErrors: InvalidParams if URL is not HTTPS or selector is empty; Timeout after configured seconds; connection/DNS failures".into(),
236 schema: schemars::schema_for!(ScrapeInstruction),
237 invocation: InvocationHint::FencedBlock("scrape"),
238 output_schema: None,
239 },
240 ToolDef {
241 id: "fetch".into(),
242 description: "Fetch a URL and return the response body as plain text.\n\nONLY call this tool when the user has explicitly provided a URL in their message, or when a prior tool call returned a URL to retrieve. NEVER construct, guess, or infer a URL from entity names, brand knowledge, or domain patterns. If no URL is present in the conversation, do not call this tool.\n\nParameters: url (string, required) - HTTPS URL to fetch\nReturns: response body as UTF-8 text, truncated if exceeding max body size\nErrors: InvalidParams if URL is not HTTPS; Timeout; SSRF-blocked private IPs; connection failures".into(),
243 schema: schemars::schema_for!(FetchParams),
244 invocation: InvocationHint::ToolCall,
245 output_schema: None,
246 },
247 ]
248 }
249
250 async fn execute(&self, response: &str) -> Result<Option<ToolOutput>, ToolError> {
251 let blocks = extract_scrape_blocks(response);
252 if blocks.is_empty() {
253 return Ok(None);
254 }
255
256 let mut outputs = Vec::with_capacity(blocks.len());
257 #[allow(clippy::cast_possible_truncation)]
258 let blocks_executed = blocks.len() as u32;
259
260 for block in &blocks {
261 let instruction: ScrapeInstruction = serde_json::from_str(block).map_err(|e| {
262 ToolError::Execution(std::io::Error::new(
263 std::io::ErrorKind::InvalidData,
264 e.to_string(),
265 ))
266 })?;
267 let correlation_id = EgressEvent::new_correlation_id();
268 let start = Instant::now();
269 let scrape_result = self
270 .scrape_instruction(&instruction, &correlation_id, None)
271 .await;
272 #[allow(clippy::cast_possible_truncation)]
273 let duration_ms = start.elapsed().as_millis() as u64;
274 match scrape_result {
275 Ok(output) => {
276 self.log_audit(
277 "web_scrape",
278 &instruction.url,
279 AuditResult::Success,
280 duration_ms,
281 None,
282 None,
283 Some(correlation_id),
284 )
285 .await;
286 outputs.push(output);
287 }
288 Err(e) => {
289 let audit_result = tool_error_to_audit_result(&e);
290 self.log_audit(
291 "web_scrape",
292 &instruction.url,
293 audit_result,
294 duration_ms,
295 Some(&e),
296 None,
297 Some(correlation_id),
298 )
299 .await;
300 return Err(e);
301 }
302 }
303 }
304
305 Ok(Some(ToolOutput {
306 tool_name: ToolName::new("web-scrape"),
307 summary: outputs.join("\n\n"),
308 blocks_executed,
309 filter_stats: None,
310 diff: None,
311 streamed: false,
312 terminal_id: None,
313 locations: None,
314 raw_response: None,
315 claim_source: Some(ClaimSource::WebScrape),
316 }))
317 }
318
319 #[cfg_attr(
320 feature = "profiling",
321 tracing::instrument(name = "tool.web_scrape", skip_all)
322 )]
323 async fn execute_tool_call(&self, call: &ToolCall) -> Result<Option<ToolOutput>, ToolError> {
324 match call.tool_id.as_str() {
325 "web_scrape" => {
326 let instruction: ScrapeInstruction = deserialize_params(&call.params)?;
327 let correlation_id = EgressEvent::new_correlation_id();
328 let start = Instant::now();
329 let result = self
330 .scrape_instruction(&instruction, &correlation_id, call.caller_id.clone())
331 .await;
332 #[allow(clippy::cast_possible_truncation)]
333 let duration_ms = start.elapsed().as_millis() as u64;
334 self.run_with_audit(
335 "web_scrape",
336 "web-scrape",
337 &instruction.url,
338 call.caller_id.clone(),
339 correlation_id,
340 duration_ms,
341 result,
342 )
343 .await
344 }
345 "fetch" => {
346 let p: FetchParams = deserialize_params(&call.params)?;
347 let correlation_id = EgressEvent::new_correlation_id();
348 let start = Instant::now();
349 let result = self
350 .handle_fetch(&p, &correlation_id, call.caller_id.clone())
351 .await;
352 #[allow(clippy::cast_possible_truncation)]
353 let duration_ms = start.elapsed().as_millis() as u64;
354 self.run_with_audit(
355 "fetch",
356 "fetch",
357 &p.url,
358 call.caller_id.clone(),
359 correlation_id,
360 duration_ms,
361 result,
362 )
363 .await
364 }
365 _ => Ok(None),
366 }
367 }
368
369 fn is_tool_retryable(&self, tool_id: &str) -> bool {
370 matches!(tool_id, "web_scrape" | "fetch")
371 }
372}
373
374fn tool_error_to_audit_result(e: &ToolError) -> AuditResult {
375 match e {
376 ToolError::Blocked { command } => AuditResult::Blocked {
377 reason: command.clone(),
378 },
379 ToolError::Timeout { .. } => AuditResult::Timeout,
380 _ => AuditResult::Error {
381 message: e.to_string(),
382 },
383 }
384}
385
386impl WebScrapeExecutor {
387 #[allow(clippy::too_many_arguments)]
388 async fn run_with_audit(
389 &self,
390 audit_tool_name: &str,
391 public_tool_name: &str,
392 audit_command: &str,
393 caller_id: Option<String>,
394 correlation_id: String,
395 duration_ms: u64,
396 result: Result<String, ToolError>,
397 ) -> Result<Option<ToolOutput>, ToolError> {
398 match result {
399 Ok(output) => {
400 self.log_audit(
401 audit_tool_name,
402 audit_command,
403 AuditResult::Success,
404 duration_ms,
405 None,
406 caller_id,
407 Some(correlation_id),
408 )
409 .await;
410 Ok(Some(ToolOutput {
411 tool_name: ToolName::new(public_tool_name),
412 summary: output,
413 blocks_executed: 1,
414 filter_stats: None,
415 diff: None,
416 streamed: false,
417 terminal_id: None,
418 locations: None,
419 raw_response: None,
420 claim_source: Some(ClaimSource::WebScrape),
421 }))
422 }
423 Err(e) => {
424 let audit_result = tool_error_to_audit_result(&e);
425 self.log_audit(
426 audit_tool_name,
427 audit_command,
428 audit_result,
429 duration_ms,
430 Some(&e),
431 caller_id,
432 Some(correlation_id),
433 )
434 .await;
435 Err(e)
436 }
437 }
438 }
439
440 #[allow(clippy::too_many_arguments)] async fn log_audit(
442 &self,
443 tool: &str,
444 command: &str,
445 result: AuditResult,
446 duration_ms: u64,
447 error: Option<&ToolError>,
448 caller_id: Option<String>,
449 correlation_id: Option<String>,
450 ) {
451 if let Some(ref logger) = self.audit_logger {
452 let (error_category, error_domain, error_phase) =
453 error.map_or((None, None, None), |e| {
454 let cat = e.category();
455 (
456 Some(cat.label().to_owned()),
457 Some(cat.domain().label().to_owned()),
458 Some(cat.phase().label().to_owned()),
459 )
460 });
461 let entry = AuditEntry {
462 timestamp: chrono_now(),
463 tool: tool.into(),
464 command: command.into(),
465 result,
466 duration_ms,
467 error_category,
468 error_domain,
469 error_phase,
470 claim_source: Some(ClaimSource::WebScrape),
471 mcp_server_id: None,
472 injection_flagged: false,
473 embedding_anomalous: false,
474 cross_boundary_mcp_to_acp: false,
475 adversarial_policy_decision: None,
476 exit_code: None,
477 truncated: false,
478 caller_id,
479 policy_match: None,
480 correlation_id,
481 vigil_risk: None,
482 execution_env: None,
483 resolved_cwd: None,
484 scope_at_definition: None,
485 scope_at_dispatch: None,
486 };
487 logger.log(&entry).await;
488 }
489 }
490
491 fn send_egress_event(&self, event: EgressEvent) {
492 if let Some(ref tx) = self.egress_tx {
493 match tx.try_send(event) {
494 Ok(()) => {}
495 Err(tokio::sync::mpsc::error::TrySendError::Full(_)) => {
496 self.egress_dropped.fetch_add(1, Ordering::Relaxed);
497 }
498 Err(tokio::sync::mpsc::error::TrySendError::Closed(_)) => {
499 tracing::debug!("egress channel closed; executor continuing without telemetry");
500 }
501 }
502 }
503 }
504
505 async fn log_egress_event(&self, event: &EgressEvent) {
506 if let Some(ref logger) = self.audit_logger {
507 logger.log_egress(event).await;
508 }
509 self.send_egress_event(event.clone());
510 }
511
512 async fn handle_fetch(
513 &self,
514 params: &FetchParams,
515 correlation_id: &str,
516 caller_id: Option<String>,
517 ) -> Result<String, ToolError> {
518 let parsed = validate_url(¶ms.url);
519 let host_str = parsed
520 .as_ref()
521 .map(|u| u.host_str().unwrap_or("").to_owned())
522 .unwrap_or_default();
523
524 if let Err(ref _e) = parsed {
525 if self.egress_config.enabled && self.egress_config.log_blocked {
526 let event = Self::make_blocked_event(
527 "fetch",
528 ¶ms.url,
529 &host_str,
530 correlation_id,
531 caller_id.clone(),
532 "scheme",
533 );
534 self.log_egress_event(&event).await;
535 }
536 return Err(parsed.unwrap_err());
537 }
538 let parsed = parsed.unwrap();
539
540 if let Err(e) = check_domain_policy(
541 parsed.host_str().unwrap_or(""),
542 &self.allowed_domains,
543 &self.denied_domains,
544 ) {
545 if self.egress_config.enabled && self.egress_config.log_blocked {
546 let event = Self::make_blocked_event(
547 "fetch",
548 ¶ms.url,
549 parsed.host_str().unwrap_or(""),
550 correlation_id,
551 caller_id.clone(),
552 "blocklist",
553 );
554 self.log_egress_event(&event).await;
555 }
556 return Err(e);
557 }
558
559 let (host, addrs) = match resolve_and_validate(&parsed).await {
560 Ok(v) => v,
561 Err(e) => {
562 if self.egress_config.enabled && self.egress_config.log_blocked {
563 let event = Self::make_blocked_event(
564 "fetch",
565 ¶ms.url,
566 parsed.host_str().unwrap_or(""),
567 correlation_id,
568 caller_id.clone(),
569 "ssrf",
570 );
571 self.log_egress_event(&event).await;
572 }
573 return Err(e);
574 }
575 };
576
577 self.fetch_html(
578 ¶ms.url,
579 &host,
580 &addrs,
581 "fetch",
582 correlation_id,
583 caller_id,
584 )
585 .await
586 }
587
588 async fn scrape_instruction(
589 &self,
590 instruction: &ScrapeInstruction,
591 correlation_id: &str,
592 caller_id: Option<String>,
593 ) -> Result<String, ToolError> {
594 let parsed = validate_url(&instruction.url);
595 let host_str = parsed
596 .as_ref()
597 .map(|u| u.host_str().unwrap_or("").to_owned())
598 .unwrap_or_default();
599
600 if let Err(ref _e) = parsed {
601 if self.egress_config.enabled && self.egress_config.log_blocked {
602 let event = Self::make_blocked_event(
603 "web_scrape",
604 &instruction.url,
605 &host_str,
606 correlation_id,
607 caller_id.clone(),
608 "scheme",
609 );
610 self.log_egress_event(&event).await;
611 }
612 return Err(parsed.unwrap_err());
613 }
614 let parsed = parsed.unwrap();
615
616 if let Err(e) = check_domain_policy(
617 parsed.host_str().unwrap_or(""),
618 &self.allowed_domains,
619 &self.denied_domains,
620 ) {
621 if self.egress_config.enabled && self.egress_config.log_blocked {
622 let event = Self::make_blocked_event(
623 "web_scrape",
624 &instruction.url,
625 parsed.host_str().unwrap_or(""),
626 correlation_id,
627 caller_id.clone(),
628 "blocklist",
629 );
630 self.log_egress_event(&event).await;
631 }
632 return Err(e);
633 }
634
635 let (host, addrs) = match resolve_and_validate(&parsed).await {
636 Ok(v) => v,
637 Err(e) => {
638 if self.egress_config.enabled && self.egress_config.log_blocked {
639 let event = Self::make_blocked_event(
640 "web_scrape",
641 &instruction.url,
642 parsed.host_str().unwrap_or(""),
643 correlation_id,
644 caller_id.clone(),
645 "ssrf",
646 );
647 self.log_egress_event(&event).await;
648 }
649 return Err(e);
650 }
651 };
652
653 let html = self
654 .fetch_html(
655 &instruction.url,
656 &host,
657 &addrs,
658 "web_scrape",
659 correlation_id,
660 caller_id,
661 )
662 .await?;
663 let selector = instruction.select.clone();
664 let extract = ExtractMode::parse(&instruction.extract);
665 let limit = instruction.limit.unwrap_or(10);
666 tokio::task::spawn_blocking(move || parse_and_extract(&html, &selector, &extract, limit))
667 .await
668 .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())))?
669 }
670
671 fn make_blocked_event(
672 tool: &str,
673 url: &str,
674 host: &str,
675 correlation_id: &str,
676 caller_id: Option<String>,
677 block_reason: &'static str,
678 ) -> EgressEvent {
679 EgressEvent {
680 timestamp: chrono_now(),
681 kind: "egress",
682 correlation_id: correlation_id.to_owned(),
683 tool: tool.into(),
684 url: redact_url_for_log(url),
685 host: host.to_owned(),
686 method: "GET".to_owned(),
687 status: None,
688 duration_ms: 0,
689 response_bytes: 0,
690 blocked: true,
691 block_reason: Some(block_reason),
692 caller_id,
693 hop: 0,
694 }
695 }
696
697 #[allow(clippy::too_many_lines, clippy::too_many_arguments)]
708 async fn fetch_html(
709 &self,
710 url: &str,
711 host: &str,
712 addrs: &[SocketAddr],
713 tool: &str,
714 correlation_id: &str,
715 caller_id: Option<String>,
716 ) -> Result<String, ToolError> {
717 const MAX_REDIRECTS: usize = 3;
718
719 let mut current_url = url.to_owned();
720 let mut current_host = host.to_owned();
721 let mut current_addrs = addrs.to_vec();
722
723 for hop in 0..=MAX_REDIRECTS {
724 let hop_start = Instant::now();
725 let client = self.build_client(¤t_host, ¤t_addrs);
727 let resp = client
728 .get(¤t_url)
729 .send()
730 .await
731 .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())));
732
733 let resp = match resp {
734 Ok(r) => r,
735 Err(e) => {
736 if self.egress_config.enabled {
737 #[allow(clippy::cast_possible_truncation)]
738 let duration_ms = hop_start.elapsed().as_millis() as u64;
739 let event = EgressEvent {
740 timestamp: chrono_now(),
741 kind: "egress",
742 correlation_id: correlation_id.to_owned(),
743 tool: tool.into(),
744 url: redact_url_for_log(¤t_url),
745 host: current_host.clone(),
746 method: "GET".to_owned(),
747 status: None,
748 duration_ms,
749 response_bytes: 0,
750 blocked: false,
751 block_reason: None,
752 caller_id: caller_id.clone(),
753 #[allow(clippy::cast_possible_truncation)]
754 hop: hop as u8,
755 };
756 self.log_egress_event(&event).await;
757 }
758 return Err(e);
759 }
760 };
761
762 let status = resp.status();
763
764 if status.is_redirection() {
765 if hop == MAX_REDIRECTS {
766 return Err(ToolError::Execution(std::io::Error::other(
767 "too many redirects",
768 )));
769 }
770
771 let location = resp
772 .headers()
773 .get(reqwest::header::LOCATION)
774 .and_then(|v| v.to_str().ok())
775 .ok_or_else(|| {
776 ToolError::Execution(std::io::Error::other("redirect with no Location"))
777 })?;
778
779 let base = Url::parse(¤t_url)
781 .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())))?;
782 let next_url = base
783 .join(location)
784 .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())))?;
785
786 let validated = validate_url(next_url.as_str());
787 if let Err(ref _e) = validated {
788 if self.egress_config.enabled && self.egress_config.log_blocked {
789 #[allow(clippy::cast_possible_truncation)]
790 let duration_ms = hop_start.elapsed().as_millis() as u64;
791 let next_host = next_url.host_str().unwrap_or("").to_owned();
792 let event = EgressEvent {
793 timestamp: chrono_now(),
794 kind: "egress",
795 correlation_id: correlation_id.to_owned(),
796 tool: tool.into(),
797 url: redact_url_for_log(next_url.as_str()),
798 host: next_host,
799 method: "GET".to_owned(),
800 status: None,
801 duration_ms,
802 response_bytes: 0,
803 blocked: true,
804 block_reason: Some("ssrf"),
805 caller_id: caller_id.clone(),
806 #[allow(clippy::cast_possible_truncation)]
807 hop: (hop + 1) as u8,
808 };
809 self.log_egress_event(&event).await;
810 }
811 return Err(validated.unwrap_err());
812 }
813 let validated = validated.unwrap();
814 let resolve_result = resolve_and_validate(&validated).await;
815 if let Err(ref _e) = resolve_result {
816 if self.egress_config.enabled && self.egress_config.log_blocked {
817 #[allow(clippy::cast_possible_truncation)]
818 let duration_ms = hop_start.elapsed().as_millis() as u64;
819 let next_host = next_url.host_str().unwrap_or("").to_owned();
820 let event = EgressEvent {
821 timestamp: chrono_now(),
822 kind: "egress",
823 correlation_id: correlation_id.to_owned(),
824 tool: tool.into(),
825 url: redact_url_for_log(next_url.as_str()),
826 host: next_host,
827 method: "GET".to_owned(),
828 status: None,
829 duration_ms,
830 response_bytes: 0,
831 blocked: true,
832 block_reason: Some("ssrf"),
833 caller_id: caller_id.clone(),
834 #[allow(clippy::cast_possible_truncation)]
835 hop: (hop + 1) as u8,
836 };
837 self.log_egress_event(&event).await;
838 }
839 return Err(resolve_result.unwrap_err());
840 }
841 let (next_host, next_addrs) = resolve_result.unwrap();
842
843 current_url = next_url.to_string();
844 current_host = next_host;
845 current_addrs = next_addrs;
846 continue;
847 }
848
849 if !status.is_success() {
850 if self.egress_config.enabled {
851 #[allow(clippy::cast_possible_truncation)]
852 let duration_ms = hop_start.elapsed().as_millis() as u64;
853 let event = EgressEvent {
854 timestamp: chrono_now(),
855 kind: "egress",
856 correlation_id: correlation_id.to_owned(),
857 tool: tool.into(),
858 url: current_url.clone(),
859 host: current_host.clone(),
860 method: "GET".to_owned(),
861 status: Some(status.as_u16()),
862 duration_ms,
863 response_bytes: 0,
864 blocked: false,
865 block_reason: None,
866 caller_id: caller_id.clone(),
867 #[allow(clippy::cast_possible_truncation)]
868 hop: hop as u8,
869 };
870 self.log_egress_event(&event).await;
871 }
872 return Err(ToolError::Http {
873 status: status.as_u16(),
874 message: status.canonical_reason().unwrap_or("unknown").to_owned(),
875 });
876 }
877
878 let bytes = resp
879 .bytes()
880 .await
881 .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())))?;
882
883 if bytes.len() > self.max_body_bytes {
884 if self.egress_config.enabled {
885 #[allow(clippy::cast_possible_truncation)]
886 let duration_ms = hop_start.elapsed().as_millis() as u64;
887 let event = EgressEvent {
888 timestamp: chrono_now(),
889 kind: "egress",
890 correlation_id: correlation_id.to_owned(),
891 tool: tool.into(),
892 url: current_url.clone(),
893 host: current_host.clone(),
894 method: "GET".to_owned(),
895 status: Some(status.as_u16()),
896 duration_ms,
897 response_bytes: bytes.len(),
898 blocked: false,
899 block_reason: None,
900 caller_id: caller_id.clone(),
901 #[allow(clippy::cast_possible_truncation)]
902 hop: hop as u8,
903 };
904 self.log_egress_event(&event).await;
905 }
906 return Err(ToolError::Execution(std::io::Error::other(format!(
907 "response too large: {} bytes (max: {})",
908 bytes.len(),
909 self.max_body_bytes,
910 ))));
911 }
912
913 if self.egress_config.enabled {
915 #[allow(clippy::cast_possible_truncation)]
916 let duration_ms = hop_start.elapsed().as_millis() as u64;
917 let response_bytes = if self.egress_config.log_response_bytes {
918 bytes.len()
919 } else {
920 0
921 };
922 let event = EgressEvent {
923 timestamp: chrono_now(),
924 kind: "egress",
925 correlation_id: correlation_id.to_owned(),
926 tool: tool.into(),
927 url: current_url.clone(),
928 host: current_host.clone(),
929 method: "GET".to_owned(),
930 status: Some(status.as_u16()),
931 duration_ms,
932 response_bytes,
933 blocked: false,
934 block_reason: None,
935 caller_id: caller_id.clone(),
936 #[allow(clippy::cast_possible_truncation)]
937 hop: hop as u8,
938 };
939 self.log_egress_event(&event).await;
940 }
941
942 return String::from_utf8(bytes.to_vec())
943 .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())));
944 }
945
946 Err(ToolError::Execution(std::io::Error::other(
947 "too many redirects",
948 )))
949 }
950}
951
952fn extract_scrape_blocks(text: &str) -> Vec<&str> {
953 crate::executor::extract_fenced_blocks(text, "scrape")
954}
955
956fn check_domain_policy(
968 host: &str,
969 allowed_domains: &[String],
970 denied_domains: &[String],
971) -> Result<(), ToolError> {
972 if denied_domains.iter().any(|p| domain_matches(p, host)) {
973 return Err(ToolError::Blocked {
974 command: format!("domain blocked by denylist: {host}"),
975 });
976 }
977 if !allowed_domains.is_empty() {
978 let is_ip = host.parse::<std::net::IpAddr>().is_ok()
980 || (host.starts_with('[') && host.ends_with(']'));
981 if is_ip {
982 return Err(ToolError::Blocked {
983 command: format!(
984 "bare IP address not allowed when domain allowlist is active: {host}"
985 ),
986 });
987 }
988 if !allowed_domains.iter().any(|p| domain_matches(p, host)) {
989 return Err(ToolError::Blocked {
990 command: format!("domain not in allowlist: {host}"),
991 });
992 }
993 }
994 Ok(())
995}
996
997use crate::domain_match::domain_matches;
999
1000fn validate_url(raw: &str) -> Result<Url, ToolError> {
1001 let parsed = Url::parse(raw).map_err(|_| ToolError::Blocked {
1002 command: format!("invalid URL: {raw}"),
1003 })?;
1004
1005 if parsed.scheme() != "https" {
1006 return Err(ToolError::Blocked {
1007 command: format!("scheme not allowed: {}", parsed.scheme()),
1008 });
1009 }
1010
1011 if let Some(host) = parsed.host()
1012 && is_private_host(&host)
1013 {
1014 return Err(ToolError::Blocked {
1015 command: format!(
1016 "private/local host blocked: {}",
1017 parsed.host_str().unwrap_or("")
1018 ),
1019 });
1020 }
1021
1022 Ok(parsed)
1023}
1024
1025fn is_private_host(host: &url::Host<&str>) -> bool {
1026 match host {
1027 url::Host::Domain(d) => {
1028 #[allow(clippy::case_sensitive_file_extension_comparisons)]
1031 {
1032 *d == "localhost"
1033 || d.ends_with(".localhost")
1034 || d.ends_with(".internal")
1035 || d.ends_with(".local")
1036 }
1037 }
1038 url::Host::Ipv4(v4) => is_private_ip(IpAddr::V4(*v4)),
1039 url::Host::Ipv6(v6) => is_private_ip(IpAddr::V6(*v6)),
1040 }
1041}
1042
1043async fn resolve_and_validate(url: &Url) -> Result<(String, Vec<SocketAddr>), ToolError> {
1049 let Some(host) = url.host_str() else {
1050 return Ok((String::new(), vec![]));
1051 };
1052 let port = url.port_or_known_default().unwrap_or(443);
1053 let addrs: Vec<SocketAddr> = tokio::net::lookup_host(format!("{host}:{port}"))
1054 .await
1055 .map_err(|e| ToolError::Blocked {
1056 command: format!("DNS resolution failed: {e}"),
1057 })?
1058 .collect();
1059 for addr in &addrs {
1060 if is_private_ip(addr.ip()) {
1061 return Err(ToolError::Blocked {
1062 command: format!("SSRF protection: private IP {} for host {host}", addr.ip()),
1063 });
1064 }
1065 }
1066 Ok((host.to_owned(), addrs))
1067}
1068
1069fn parse_and_extract(
1070 html: &str,
1071 selector: &str,
1072 extract: &ExtractMode,
1073 limit: usize,
1074) -> Result<String, ToolError> {
1075 let soup = scrape_core::Soup::parse(html);
1076
1077 let tags = soup.find_all(selector).map_err(|e| {
1078 ToolError::Execution(std::io::Error::new(
1079 std::io::ErrorKind::InvalidData,
1080 format!("invalid selector: {e}"),
1081 ))
1082 })?;
1083
1084 let mut results = Vec::new();
1085
1086 for tag in tags.into_iter().take(limit) {
1087 let value = match extract {
1088 ExtractMode::Text => tag.text(),
1089 ExtractMode::Html => tag.inner_html(),
1090 ExtractMode::Attr(name) => tag.get(name).unwrap_or_default().to_owned(),
1091 };
1092 if !value.trim().is_empty() {
1093 results.push(value.trim().to_owned());
1094 }
1095 }
1096
1097 if results.is_empty() {
1098 Ok(format!("No results for selector: {selector}"))
1099 } else {
1100 Ok(results.join("\n"))
1101 }
1102}
1103
1104#[cfg(test)]
1105mod tests {
1106 use super::*;
1107
1108 #[test]
1111 fn extract_single_block() {
1112 let text =
1113 "Here:\n```scrape\n{\"url\":\"https://example.com\",\"select\":\"h1\"}\n```\nDone.";
1114 let blocks = extract_scrape_blocks(text);
1115 assert_eq!(blocks.len(), 1);
1116 assert!(blocks[0].contains("example.com"));
1117 }
1118
1119 #[test]
1120 fn extract_multiple_blocks() {
1121 let text = "```scrape\n{\"url\":\"https://a.com\",\"select\":\"h1\"}\n```\ntext\n```scrape\n{\"url\":\"https://b.com\",\"select\":\"p\"}\n```";
1122 let blocks = extract_scrape_blocks(text);
1123 assert_eq!(blocks.len(), 2);
1124 }
1125
1126 #[test]
1127 fn no_blocks_returns_empty() {
1128 let blocks = extract_scrape_blocks("plain text, no code blocks");
1129 assert!(blocks.is_empty());
1130 }
1131
1132 #[test]
1133 fn unclosed_block_ignored() {
1134 let blocks = extract_scrape_blocks("```scrape\n{\"url\":\"https://x.com\"}");
1135 assert!(blocks.is_empty());
1136 }
1137
1138 #[test]
1139 fn non_scrape_block_ignored() {
1140 let text =
1141 "```bash\necho hi\n```\n```scrape\n{\"url\":\"https://x.com\",\"select\":\"h1\"}\n```";
1142 let blocks = extract_scrape_blocks(text);
1143 assert_eq!(blocks.len(), 1);
1144 assert!(blocks[0].contains("x.com"));
1145 }
1146
1147 #[test]
1148 fn multiline_json_block() {
1149 let text =
1150 "```scrape\n{\n \"url\": \"https://example.com\",\n \"select\": \"h1\"\n}\n```";
1151 let blocks = extract_scrape_blocks(text);
1152 assert_eq!(blocks.len(), 1);
1153 let instr: ScrapeInstruction = serde_json::from_str(blocks[0]).unwrap();
1154 assert_eq!(instr.url, "https://example.com");
1155 }
1156
1157 #[test]
1160 fn parse_valid_instruction() {
1161 let json = r#"{"url":"https://example.com","select":"h1","extract":"text","limit":5}"#;
1162 let instr: ScrapeInstruction = serde_json::from_str(json).unwrap();
1163 assert_eq!(instr.url, "https://example.com");
1164 assert_eq!(instr.select, "h1");
1165 assert_eq!(instr.extract, "text");
1166 assert_eq!(instr.limit, Some(5));
1167 }
1168
1169 #[test]
1170 fn parse_minimal_instruction() {
1171 let json = r#"{"url":"https://example.com","select":"p"}"#;
1172 let instr: ScrapeInstruction = serde_json::from_str(json).unwrap();
1173 assert_eq!(instr.extract, "text");
1174 assert!(instr.limit.is_none());
1175 }
1176
1177 #[test]
1178 fn parse_attr_extract() {
1179 let json = r#"{"url":"https://example.com","select":"a","extract":"attr:href"}"#;
1180 let instr: ScrapeInstruction = serde_json::from_str(json).unwrap();
1181 assert_eq!(instr.extract, "attr:href");
1182 }
1183
1184 #[test]
1185 fn parse_invalid_json_errors() {
1186 let result = serde_json::from_str::<ScrapeInstruction>("not json");
1187 assert!(result.is_err());
1188 }
1189
1190 #[test]
1193 fn extract_mode_text() {
1194 assert!(matches!(ExtractMode::parse("text"), ExtractMode::Text));
1195 }
1196
1197 #[test]
1198 fn extract_mode_html() {
1199 assert!(matches!(ExtractMode::parse("html"), ExtractMode::Html));
1200 }
1201
1202 #[test]
1203 fn extract_mode_attr() {
1204 let mode = ExtractMode::parse("attr:href");
1205 assert!(matches!(mode, ExtractMode::Attr(ref s) if s == "href"));
1206 }
1207
1208 #[test]
1209 fn extract_mode_unknown_defaults_to_text() {
1210 assert!(matches!(ExtractMode::parse("unknown"), ExtractMode::Text));
1211 }
1212
1213 #[test]
1216 fn valid_https_url() {
1217 assert!(validate_url("https://example.com").is_ok());
1218 }
1219
1220 #[test]
1221 fn http_rejected() {
1222 let err = validate_url("http://example.com").unwrap_err();
1223 assert!(matches!(err, ToolError::Blocked { .. }));
1224 }
1225
1226 #[test]
1227 fn ftp_rejected() {
1228 let err = validate_url("ftp://files.example.com").unwrap_err();
1229 assert!(matches!(err, ToolError::Blocked { .. }));
1230 }
1231
1232 #[test]
1233 fn file_rejected() {
1234 let err = validate_url("file:///etc/passwd").unwrap_err();
1235 assert!(matches!(err, ToolError::Blocked { .. }));
1236 }
1237
1238 #[test]
1239 fn invalid_url_rejected() {
1240 let err = validate_url("not a url").unwrap_err();
1241 assert!(matches!(err, ToolError::Blocked { .. }));
1242 }
1243
1244 #[test]
1245 fn localhost_blocked() {
1246 let err = validate_url("https://localhost/path").unwrap_err();
1247 assert!(matches!(err, ToolError::Blocked { .. }));
1248 }
1249
1250 #[test]
1251 fn loopback_ip_blocked() {
1252 let err = validate_url("https://127.0.0.1/path").unwrap_err();
1253 assert!(matches!(err, ToolError::Blocked { .. }));
1254 }
1255
1256 #[test]
1257 fn private_10_blocked() {
1258 let err = validate_url("https://10.0.0.1/api").unwrap_err();
1259 assert!(matches!(err, ToolError::Blocked { .. }));
1260 }
1261
1262 #[test]
1263 fn private_172_blocked() {
1264 let err = validate_url("https://172.16.0.1/api").unwrap_err();
1265 assert!(matches!(err, ToolError::Blocked { .. }));
1266 }
1267
1268 #[test]
1269 fn private_192_blocked() {
1270 let err = validate_url("https://192.168.1.1/api").unwrap_err();
1271 assert!(matches!(err, ToolError::Blocked { .. }));
1272 }
1273
1274 #[test]
1275 fn ipv6_loopback_blocked() {
1276 let err = validate_url("https://[::1]/path").unwrap_err();
1277 assert!(matches!(err, ToolError::Blocked { .. }));
1278 }
1279
1280 #[test]
1281 fn public_ip_allowed() {
1282 assert!(validate_url("https://93.184.216.34/page").is_ok());
1283 }
1284
1285 #[test]
1288 fn extract_text_from_html() {
1289 let html = "<html><body><h1>Hello World</h1><p>Content</p></body></html>";
1290 let result = parse_and_extract(html, "h1", &ExtractMode::Text, 10).unwrap();
1291 assert_eq!(result, "Hello World");
1292 }
1293
1294 #[test]
1295 fn extract_multiple_elements() {
1296 let html = "<ul><li>A</li><li>B</li><li>C</li></ul>";
1297 let result = parse_and_extract(html, "li", &ExtractMode::Text, 10).unwrap();
1298 assert_eq!(result, "A\nB\nC");
1299 }
1300
1301 #[test]
1302 fn extract_with_limit() {
1303 let html = "<ul><li>A</li><li>B</li><li>C</li></ul>";
1304 let result = parse_and_extract(html, "li", &ExtractMode::Text, 2).unwrap();
1305 assert_eq!(result, "A\nB");
1306 }
1307
1308 #[test]
1309 fn extract_attr_href() {
1310 let html = r#"<a href="https://example.com">Link</a>"#;
1311 let result =
1312 parse_and_extract(html, "a", &ExtractMode::Attr("href".to_owned()), 10).unwrap();
1313 assert_eq!(result, "https://example.com");
1314 }
1315
1316 #[test]
1317 fn extract_inner_html() {
1318 let html = "<div><span>inner</span></div>";
1319 let result = parse_and_extract(html, "div", &ExtractMode::Html, 10).unwrap();
1320 assert!(result.contains("<span>inner</span>"));
1321 }
1322
1323 #[test]
1324 fn no_matches_returns_message() {
1325 let html = "<html><body><p>text</p></body></html>";
1326 let result = parse_and_extract(html, "h1", &ExtractMode::Text, 10).unwrap();
1327 assert!(result.starts_with("No results for selector:"));
1328 }
1329
1330 #[test]
1331 fn empty_text_skipped() {
1332 let html = "<ul><li> </li><li>A</li></ul>";
1333 let result = parse_and_extract(html, "li", &ExtractMode::Text, 10).unwrap();
1334 assert_eq!(result, "A");
1335 }
1336
1337 #[test]
1338 fn invalid_selector_errors() {
1339 let html = "<html><body></body></html>";
1340 let result = parse_and_extract(html, "[[[invalid", &ExtractMode::Text, 10);
1341 assert!(result.is_err());
1342 }
1343
1344 #[test]
1345 fn empty_html_returns_no_results() {
1346 let result = parse_and_extract("", "h1", &ExtractMode::Text, 10).unwrap();
1347 assert!(result.starts_with("No results for selector:"));
1348 }
1349
1350 #[test]
1351 fn nested_selector() {
1352 let html = "<div><span>inner</span></div><span>outer</span>";
1353 let result = parse_and_extract(html, "div > span", &ExtractMode::Text, 10).unwrap();
1354 assert_eq!(result, "inner");
1355 }
1356
1357 #[test]
1358 fn attr_missing_returns_empty() {
1359 let html = r"<a>No href</a>";
1360 let result =
1361 parse_and_extract(html, "a", &ExtractMode::Attr("href".to_owned()), 10).unwrap();
1362 assert!(result.starts_with("No results for selector:"));
1363 }
1364
1365 #[test]
1366 fn extract_html_mode() {
1367 let html = "<div><b>bold</b> text</div>";
1368 let result = parse_and_extract(html, "div", &ExtractMode::Html, 10).unwrap();
1369 assert!(result.contains("<b>bold</b>"));
1370 }
1371
1372 #[test]
1373 fn limit_zero_returns_no_results() {
1374 let html = "<ul><li>A</li><li>B</li></ul>";
1375 let result = parse_and_extract(html, "li", &ExtractMode::Text, 0).unwrap();
1376 assert!(result.starts_with("No results for selector:"));
1377 }
1378
1379 #[test]
1382 fn url_with_port_allowed() {
1383 assert!(validate_url("https://example.com:8443/path").is_ok());
1384 }
1385
1386 #[test]
1387 fn link_local_ip_blocked() {
1388 let err = validate_url("https://169.254.1.1/path").unwrap_err();
1389 assert!(matches!(err, ToolError::Blocked { .. }));
1390 }
1391
1392 #[test]
1393 fn url_no_scheme_rejected() {
1394 let err = validate_url("example.com/path").unwrap_err();
1395 assert!(matches!(err, ToolError::Blocked { .. }));
1396 }
1397
1398 #[test]
1399 fn unspecified_ipv4_blocked() {
1400 let err = validate_url("https://0.0.0.0/path").unwrap_err();
1401 assert!(matches!(err, ToolError::Blocked { .. }));
1402 }
1403
1404 #[test]
1405 fn broadcast_ipv4_blocked() {
1406 let err = validate_url("https://255.255.255.255/path").unwrap_err();
1407 assert!(matches!(err, ToolError::Blocked { .. }));
1408 }
1409
1410 #[test]
1411 fn ipv6_link_local_blocked() {
1412 let err = validate_url("https://[fe80::1]/path").unwrap_err();
1413 assert!(matches!(err, ToolError::Blocked { .. }));
1414 }
1415
1416 #[test]
1417 fn ipv6_unique_local_blocked() {
1418 let err = validate_url("https://[fd12::1]/path").unwrap_err();
1419 assert!(matches!(err, ToolError::Blocked { .. }));
1420 }
1421
1422 #[test]
1423 fn ipv4_mapped_ipv6_loopback_blocked() {
1424 let err = validate_url("https://[::ffff:127.0.0.1]/path").unwrap_err();
1425 assert!(matches!(err, ToolError::Blocked { .. }));
1426 }
1427
1428 #[test]
1429 fn ipv4_mapped_ipv6_private_blocked() {
1430 let err = validate_url("https://[::ffff:10.0.0.1]/path").unwrap_err();
1431 assert!(matches!(err, ToolError::Blocked { .. }));
1432 }
1433
1434 #[tokio::test]
1437 async fn executor_no_blocks_returns_none() {
1438 let config = ScrapeConfig::default();
1439 let executor = WebScrapeExecutor::new(&config);
1440 let result = executor.execute("plain text").await;
1441 assert!(result.unwrap().is_none());
1442 }
1443
1444 #[tokio::test]
1445 async fn executor_invalid_json_errors() {
1446 let config = ScrapeConfig::default();
1447 let executor = WebScrapeExecutor::new(&config);
1448 let response = "```scrape\nnot json\n```";
1449 let result = executor.execute(response).await;
1450 assert!(matches!(result, Err(ToolError::Execution(_))));
1451 }
1452
1453 #[tokio::test]
1454 async fn executor_blocked_url_errors() {
1455 let config = ScrapeConfig::default();
1456 let executor = WebScrapeExecutor::new(&config);
1457 let response = "```scrape\n{\"url\":\"http://example.com\",\"select\":\"h1\"}\n```";
1458 let result = executor.execute(response).await;
1459 assert!(matches!(result, Err(ToolError::Blocked { .. })));
1460 }
1461
1462 #[tokio::test]
1463 async fn executor_private_ip_blocked() {
1464 let config = ScrapeConfig::default();
1465 let executor = WebScrapeExecutor::new(&config);
1466 let response = "```scrape\n{\"url\":\"https://192.168.1.1/api\",\"select\":\"h1\"}\n```";
1467 let result = executor.execute(response).await;
1468 assert!(matches!(result, Err(ToolError::Blocked { .. })));
1469 }
1470
1471 #[tokio::test]
1472 async fn executor_unreachable_host_returns_error() {
1473 let config = ScrapeConfig {
1474 timeout: 1,
1475 max_body_bytes: 1_048_576,
1476 ..Default::default()
1477 };
1478 let executor = WebScrapeExecutor::new(&config);
1479 let response = "```scrape\n{\"url\":\"https://192.0.2.1:1/page\",\"select\":\"h1\"}\n```";
1480 let result = executor.execute(response).await;
1481 assert!(matches!(result, Err(ToolError::Execution(_))));
1482 }
1483
1484 #[tokio::test]
1485 async fn executor_localhost_url_blocked() {
1486 let config = ScrapeConfig::default();
1487 let executor = WebScrapeExecutor::new(&config);
1488 let response = "```scrape\n{\"url\":\"https://localhost:9999/api\",\"select\":\"h1\"}\n```";
1489 let result = executor.execute(response).await;
1490 assert!(matches!(result, Err(ToolError::Blocked { .. })));
1491 }
1492
1493 #[tokio::test]
1494 async fn executor_empty_text_returns_none() {
1495 let config = ScrapeConfig::default();
1496 let executor = WebScrapeExecutor::new(&config);
1497 let result = executor.execute("").await;
1498 assert!(result.unwrap().is_none());
1499 }
1500
1501 #[tokio::test]
1502 async fn executor_multiple_blocks_first_blocked() {
1503 let config = ScrapeConfig::default();
1504 let executor = WebScrapeExecutor::new(&config);
1505 let response = "```scrape\n{\"url\":\"http://evil.com\",\"select\":\"h1\"}\n```\n\
1506 ```scrape\n{\"url\":\"https://ok.com\",\"select\":\"h1\"}\n```";
1507 let result = executor.execute(response).await;
1508 assert!(result.is_err());
1509 }
1510
1511 #[test]
1512 fn validate_url_empty_string() {
1513 let err = validate_url("").unwrap_err();
1514 assert!(matches!(err, ToolError::Blocked { .. }));
1515 }
1516
1517 #[test]
1518 fn validate_url_javascript_scheme_blocked() {
1519 let err = validate_url("javascript:alert(1)").unwrap_err();
1520 assert!(matches!(err, ToolError::Blocked { .. }));
1521 }
1522
1523 #[test]
1524 fn validate_url_data_scheme_blocked() {
1525 let err = validate_url("data:text/html,<h1>hi</h1>").unwrap_err();
1526 assert!(matches!(err, ToolError::Blocked { .. }));
1527 }
1528
1529 #[test]
1530 fn is_private_host_public_domain_is_false() {
1531 let host: url::Host<&str> = url::Host::Domain("example.com");
1532 assert!(!is_private_host(&host));
1533 }
1534
1535 #[test]
1536 fn is_private_host_localhost_is_true() {
1537 let host: url::Host<&str> = url::Host::Domain("localhost");
1538 assert!(is_private_host(&host));
1539 }
1540
1541 #[test]
1542 fn is_private_host_ipv6_unspecified_is_true() {
1543 let host = url::Host::Ipv6(std::net::Ipv6Addr::UNSPECIFIED);
1544 assert!(is_private_host(&host));
1545 }
1546
1547 #[test]
1548 fn is_private_host_public_ipv6_is_false() {
1549 let host = url::Host::Ipv6("2001:db8::1".parse().unwrap());
1550 assert!(!is_private_host(&host));
1551 }
1552
1553 async fn mock_server_executor() -> (WebScrapeExecutor, wiremock::MockServer) {
1564 let server = wiremock::MockServer::start().await;
1565 let executor = WebScrapeExecutor {
1566 timeout: Duration::from_secs(5),
1567 max_body_bytes: 1_048_576,
1568 allowed_domains: vec![],
1569 denied_domains: vec![],
1570 audit_logger: None,
1571 egress_config: EgressConfig::default(),
1572 egress_tx: None,
1573 egress_dropped: Arc::new(AtomicU64::new(0)),
1574 };
1575 (executor, server)
1576 }
1577
1578 fn server_host_and_addr(server: &wiremock::MockServer) -> (String, Vec<std::net::SocketAddr>) {
1580 let uri = server.uri();
1581 let url = Url::parse(&uri).unwrap();
1582 let host = url.host_str().unwrap_or("127.0.0.1").to_owned();
1583 let port = url.port().unwrap_or(80);
1584 let addr: std::net::SocketAddr = format!("{host}:{port}").parse().unwrap();
1585 (host, vec![addr])
1586 }
1587
1588 async fn follow_redirects_raw(
1592 executor: &WebScrapeExecutor,
1593 start_url: &str,
1594 host: &str,
1595 addrs: &[std::net::SocketAddr],
1596 ) -> Result<String, ToolError> {
1597 const MAX_REDIRECTS: usize = 3;
1598 let mut current_url = start_url.to_owned();
1599 let mut current_host = host.to_owned();
1600 let mut current_addrs = addrs.to_vec();
1601
1602 for hop in 0..=MAX_REDIRECTS {
1603 let client = executor.build_client(¤t_host, ¤t_addrs);
1604 let resp = client
1605 .get(¤t_url)
1606 .send()
1607 .await
1608 .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())))?;
1609
1610 let status = resp.status();
1611
1612 if status.is_redirection() {
1613 if hop == MAX_REDIRECTS {
1614 return Err(ToolError::Execution(std::io::Error::other(
1615 "too many redirects",
1616 )));
1617 }
1618
1619 let location = resp
1620 .headers()
1621 .get(reqwest::header::LOCATION)
1622 .and_then(|v| v.to_str().ok())
1623 .ok_or_else(|| {
1624 ToolError::Execution(std::io::Error::other("redirect with no Location"))
1625 })?;
1626
1627 let base = Url::parse(¤t_url)
1628 .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())))?;
1629 let next_url = base
1630 .join(location)
1631 .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())))?;
1632
1633 current_url = next_url.to_string();
1635 let _ = &mut current_host;
1637 let _ = &mut current_addrs;
1638 continue;
1639 }
1640
1641 if !status.is_success() {
1642 return Err(ToolError::Execution(std::io::Error::other(format!(
1643 "HTTP {status}",
1644 ))));
1645 }
1646
1647 let bytes = resp
1648 .bytes()
1649 .await
1650 .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())))?;
1651
1652 if bytes.len() > executor.max_body_bytes {
1653 return Err(ToolError::Execution(std::io::Error::other(format!(
1654 "response too large: {} bytes (max: {})",
1655 bytes.len(),
1656 executor.max_body_bytes,
1657 ))));
1658 }
1659
1660 return String::from_utf8(bytes.to_vec())
1661 .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())));
1662 }
1663
1664 Err(ToolError::Execution(std::io::Error::other(
1665 "too many redirects",
1666 )))
1667 }
1668
1669 #[tokio::test]
1670 async fn fetch_html_success_returns_body() {
1671 use wiremock::matchers::{method, path};
1672 use wiremock::{Mock, ResponseTemplate};
1673
1674 let (executor, server) = mock_server_executor().await;
1675 Mock::given(method("GET"))
1676 .and(path("/page"))
1677 .respond_with(ResponseTemplate::new(200).set_body_string("<h1>OK</h1>"))
1678 .mount(&server)
1679 .await;
1680
1681 let (host, addrs) = server_host_and_addr(&server);
1682 let url = format!("{}/page", server.uri());
1683 let result = executor
1684 .fetch_html(&url, &host, &addrs, "fetch", "test-cid", None)
1685 .await;
1686 assert!(result.is_ok(), "expected Ok, got: {result:?}");
1687 assert_eq!(result.unwrap(), "<h1>OK</h1>");
1688 }
1689
1690 #[tokio::test]
1691 async fn fetch_html_non_2xx_returns_error() {
1692 use wiremock::matchers::{method, path};
1693 use wiremock::{Mock, ResponseTemplate};
1694
1695 let (executor, server) = mock_server_executor().await;
1696 Mock::given(method("GET"))
1697 .and(path("/forbidden"))
1698 .respond_with(ResponseTemplate::new(403))
1699 .mount(&server)
1700 .await;
1701
1702 let (host, addrs) = server_host_and_addr(&server);
1703 let url = format!("{}/forbidden", server.uri());
1704 let result = executor
1705 .fetch_html(&url, &host, &addrs, "fetch", "test-cid", None)
1706 .await;
1707 assert!(result.is_err());
1708 let msg = result.unwrap_err().to_string();
1709 assert!(msg.contains("403"), "expected 403 in error: {msg}");
1710 }
1711
1712 #[tokio::test]
1713 async fn fetch_html_404_returns_error() {
1714 use wiremock::matchers::{method, path};
1715 use wiremock::{Mock, ResponseTemplate};
1716
1717 let (executor, server) = mock_server_executor().await;
1718 Mock::given(method("GET"))
1719 .and(path("/missing"))
1720 .respond_with(ResponseTemplate::new(404))
1721 .mount(&server)
1722 .await;
1723
1724 let (host, addrs) = server_host_and_addr(&server);
1725 let url = format!("{}/missing", server.uri());
1726 let result = executor
1727 .fetch_html(&url, &host, &addrs, "fetch", "test-cid", None)
1728 .await;
1729 assert!(result.is_err());
1730 let msg = result.unwrap_err().to_string();
1731 assert!(msg.contains("404"), "expected 404 in error: {msg}");
1732 }
1733
1734 #[tokio::test]
1735 async fn fetch_html_redirect_no_location_returns_error() {
1736 use wiremock::matchers::{method, path};
1737 use wiremock::{Mock, ResponseTemplate};
1738
1739 let (executor, server) = mock_server_executor().await;
1740 Mock::given(method("GET"))
1742 .and(path("/redirect-no-loc"))
1743 .respond_with(ResponseTemplate::new(302))
1744 .mount(&server)
1745 .await;
1746
1747 let (host, addrs) = server_host_and_addr(&server);
1748 let url = format!("{}/redirect-no-loc", server.uri());
1749 let result = executor
1750 .fetch_html(&url, &host, &addrs, "fetch", "test-cid", None)
1751 .await;
1752 assert!(result.is_err());
1753 let msg = result.unwrap_err().to_string();
1754 assert!(
1755 msg.contains("Location") || msg.contains("location"),
1756 "expected Location-related error: {msg}"
1757 );
1758 }
1759
1760 #[tokio::test]
1761 async fn fetch_html_single_redirect_followed() {
1762 use wiremock::matchers::{method, path};
1763 use wiremock::{Mock, ResponseTemplate};
1764
1765 let (executor, server) = mock_server_executor().await;
1766 let final_url = format!("{}/final", server.uri());
1767
1768 Mock::given(method("GET"))
1769 .and(path("/start"))
1770 .respond_with(ResponseTemplate::new(302).insert_header("location", final_url.as_str()))
1771 .mount(&server)
1772 .await;
1773
1774 Mock::given(method("GET"))
1775 .and(path("/final"))
1776 .respond_with(ResponseTemplate::new(200).set_body_string("<p>final</p>"))
1777 .mount(&server)
1778 .await;
1779
1780 let (host, addrs) = server_host_and_addr(&server);
1781 let url = format!("{}/start", server.uri());
1782 let result = follow_redirects_raw(&executor, &url, &host, &addrs).await;
1783 assert!(result.is_ok(), "single redirect should succeed: {result:?}");
1784 assert_eq!(result.unwrap(), "<p>final</p>");
1785 }
1786
1787 #[tokio::test]
1788 async fn fetch_html_three_redirects_allowed() {
1789 use wiremock::matchers::{method, path};
1790 use wiremock::{Mock, ResponseTemplate};
1791
1792 let (executor, server) = mock_server_executor().await;
1793 let hop2 = format!("{}/hop2", server.uri());
1794 let hop3 = format!("{}/hop3", server.uri());
1795 let final_dest = format!("{}/done", server.uri());
1796
1797 Mock::given(method("GET"))
1798 .and(path("/hop1"))
1799 .respond_with(ResponseTemplate::new(301).insert_header("location", hop2.as_str()))
1800 .mount(&server)
1801 .await;
1802 Mock::given(method("GET"))
1803 .and(path("/hop2"))
1804 .respond_with(ResponseTemplate::new(301).insert_header("location", hop3.as_str()))
1805 .mount(&server)
1806 .await;
1807 Mock::given(method("GET"))
1808 .and(path("/hop3"))
1809 .respond_with(ResponseTemplate::new(301).insert_header("location", final_dest.as_str()))
1810 .mount(&server)
1811 .await;
1812 Mock::given(method("GET"))
1813 .and(path("/done"))
1814 .respond_with(ResponseTemplate::new(200).set_body_string("<p>done</p>"))
1815 .mount(&server)
1816 .await;
1817
1818 let (host, addrs) = server_host_and_addr(&server);
1819 let url = format!("{}/hop1", server.uri());
1820 let result = follow_redirects_raw(&executor, &url, &host, &addrs).await;
1821 assert!(result.is_ok(), "3 redirects should succeed: {result:?}");
1822 assert_eq!(result.unwrap(), "<p>done</p>");
1823 }
1824
1825 #[tokio::test]
1826 async fn fetch_html_four_redirects_rejected() {
1827 use wiremock::matchers::{method, path};
1828 use wiremock::{Mock, ResponseTemplate};
1829
1830 let (executor, server) = mock_server_executor().await;
1831 let hop2 = format!("{}/r2", server.uri());
1832 let hop3 = format!("{}/r3", server.uri());
1833 let hop4 = format!("{}/r4", server.uri());
1834 let hop5 = format!("{}/r5", server.uri());
1835
1836 for (from, to) in [
1837 ("/r1", &hop2),
1838 ("/r2", &hop3),
1839 ("/r3", &hop4),
1840 ("/r4", &hop5),
1841 ] {
1842 Mock::given(method("GET"))
1843 .and(path(from))
1844 .respond_with(ResponseTemplate::new(301).insert_header("location", to.as_str()))
1845 .mount(&server)
1846 .await;
1847 }
1848
1849 let (host, addrs) = server_host_and_addr(&server);
1850 let url = format!("{}/r1", server.uri());
1851 let result = follow_redirects_raw(&executor, &url, &host, &addrs).await;
1852 assert!(result.is_err(), "4 redirects should be rejected");
1853 let msg = result.unwrap_err().to_string();
1854 assert!(
1855 msg.contains("redirect"),
1856 "expected redirect-related error: {msg}"
1857 );
1858 }
1859
1860 #[tokio::test]
1861 async fn fetch_html_body_too_large_returns_error() {
1862 use wiremock::matchers::{method, path};
1863 use wiremock::{Mock, ResponseTemplate};
1864
1865 let small_limit_executor = WebScrapeExecutor {
1866 timeout: Duration::from_secs(5),
1867 max_body_bytes: 10,
1868 allowed_domains: vec![],
1869 denied_domains: vec![],
1870 audit_logger: None,
1871 egress_config: EgressConfig::default(),
1872 egress_tx: None,
1873 egress_dropped: Arc::new(AtomicU64::new(0)),
1874 };
1875 let server = wiremock::MockServer::start().await;
1876 Mock::given(method("GET"))
1877 .and(path("/big"))
1878 .respond_with(
1879 ResponseTemplate::new(200)
1880 .set_body_string("this body is definitely longer than ten bytes"),
1881 )
1882 .mount(&server)
1883 .await;
1884
1885 let (host, addrs) = server_host_and_addr(&server);
1886 let url = format!("{}/big", server.uri());
1887 let result = small_limit_executor
1888 .fetch_html(&url, &host, &addrs, "fetch", "test-cid", None)
1889 .await;
1890 assert!(result.is_err());
1891 let msg = result.unwrap_err().to_string();
1892 assert!(msg.contains("too large"), "expected too-large error: {msg}");
1893 }
1894
1895 #[test]
1896 fn extract_scrape_blocks_empty_block_content() {
1897 let text = "```scrape\n\n```";
1898 let blocks = extract_scrape_blocks(text);
1899 assert_eq!(blocks.len(), 1);
1900 assert!(blocks[0].is_empty());
1901 }
1902
1903 #[test]
1904 fn extract_scrape_blocks_whitespace_only() {
1905 let text = "```scrape\n \n```";
1906 let blocks = extract_scrape_blocks(text);
1907 assert_eq!(blocks.len(), 1);
1908 }
1909
1910 #[test]
1911 fn parse_and_extract_multiple_selectors() {
1912 let html = "<div><h1>Title</h1><p>Para</p></div>";
1913 let result = parse_and_extract(html, "h1, p", &ExtractMode::Text, 10).unwrap();
1914 assert!(result.contains("Title"));
1915 assert!(result.contains("Para"));
1916 }
1917
1918 #[test]
1919 fn webscrape_executor_new_with_custom_config() {
1920 let config = ScrapeConfig {
1921 timeout: 60,
1922 max_body_bytes: 512,
1923 ..Default::default()
1924 };
1925 let executor = WebScrapeExecutor::new(&config);
1926 assert_eq!(executor.max_body_bytes, 512);
1927 }
1928
1929 #[test]
1930 fn webscrape_executor_debug() {
1931 let config = ScrapeConfig::default();
1932 let executor = WebScrapeExecutor::new(&config);
1933 let dbg = format!("{executor:?}");
1934 assert!(dbg.contains("WebScrapeExecutor"));
1935 }
1936
1937 #[test]
1938 fn extract_mode_attr_empty_name() {
1939 let mode = ExtractMode::parse("attr:");
1940 assert!(matches!(mode, ExtractMode::Attr(ref s) if s.is_empty()));
1941 }
1942
1943 #[test]
1944 fn default_extract_returns_text() {
1945 assert_eq!(default_extract(), "text");
1946 }
1947
1948 #[test]
1949 fn scrape_instruction_debug() {
1950 let json = r#"{"url":"https://example.com","select":"h1"}"#;
1951 let instr: ScrapeInstruction = serde_json::from_str(json).unwrap();
1952 let dbg = format!("{instr:?}");
1953 assert!(dbg.contains("ScrapeInstruction"));
1954 }
1955
1956 #[test]
1957 fn extract_mode_debug() {
1958 let mode = ExtractMode::Text;
1959 let dbg = format!("{mode:?}");
1960 assert!(dbg.contains("Text"));
1961 }
1962
1963 #[test]
1968 fn max_redirects_constant_is_three() {
1969 const MAX_REDIRECTS: usize = 3;
1973 assert_eq!(MAX_REDIRECTS, 3, "fetch_html allows exactly 3 redirects");
1974 }
1975
1976 #[test]
1979 fn redirect_no_location_error_message() {
1980 let err = std::io::Error::other("redirect with no Location");
1981 assert!(err.to_string().contains("redirect with no Location"));
1982 }
1983
1984 #[test]
1986 fn too_many_redirects_error_message() {
1987 let err = std::io::Error::other("too many redirects");
1988 assert!(err.to_string().contains("too many redirects"));
1989 }
1990
1991 #[test]
1993 fn non_2xx_status_error_format() {
1994 let status = reqwest::StatusCode::FORBIDDEN;
1995 let msg = format!("HTTP {status}");
1996 assert!(msg.contains("403"));
1997 }
1998
1999 #[test]
2001 fn not_found_status_error_format() {
2002 let status = reqwest::StatusCode::NOT_FOUND;
2003 let msg = format!("HTTP {status}");
2004 assert!(msg.contains("404"));
2005 }
2006
2007 #[test]
2009 fn relative_redirect_same_host_path() {
2010 let base = Url::parse("https://example.com/current").unwrap();
2011 let resolved = base.join("/other").unwrap();
2012 assert_eq!(resolved.as_str(), "https://example.com/other");
2013 }
2014
2015 #[test]
2017 fn relative_redirect_relative_path() {
2018 let base = Url::parse("https://example.com/a/b").unwrap();
2019 let resolved = base.join("c").unwrap();
2020 assert_eq!(resolved.as_str(), "https://example.com/a/c");
2021 }
2022
2023 #[test]
2025 fn absolute_redirect_overrides_base() {
2026 let base = Url::parse("https://example.com/page").unwrap();
2027 let resolved = base.join("https://other.com/target").unwrap();
2028 assert_eq!(resolved.as_str(), "https://other.com/target");
2029 }
2030
2031 #[test]
2033 fn redirect_http_downgrade_rejected() {
2034 let location = "http://example.com/page";
2035 let base = Url::parse("https://example.com/start").unwrap();
2036 let next = base.join(location).unwrap();
2037 let err = validate_url(next.as_str()).unwrap_err();
2038 assert!(matches!(err, ToolError::Blocked { .. }));
2039 }
2040
2041 #[test]
2043 fn redirect_location_private_ip_blocked() {
2044 let location = "https://192.168.100.1/admin";
2045 let base = Url::parse("https://example.com/start").unwrap();
2046 let next = base.join(location).unwrap();
2047 let err = validate_url(next.as_str()).unwrap_err();
2048 assert!(matches!(err, ToolError::Blocked { .. }));
2049 let ToolError::Blocked { command: cmd } = err else {
2050 panic!("expected Blocked");
2051 };
2052 assert!(
2053 cmd.contains("private") || cmd.contains("scheme"),
2054 "error message should describe the block reason: {cmd}"
2055 );
2056 }
2057
2058 #[test]
2060 fn redirect_location_internal_domain_blocked() {
2061 let location = "https://metadata.internal/latest/meta-data/";
2062 let base = Url::parse("https://example.com/start").unwrap();
2063 let next = base.join(location).unwrap();
2064 let err = validate_url(next.as_str()).unwrap_err();
2065 assert!(matches!(err, ToolError::Blocked { .. }));
2066 }
2067
2068 #[test]
2070 fn redirect_chain_three_hops_all_public() {
2071 let hops = [
2072 "https://redirect1.example.com/hop1",
2073 "https://redirect2.example.com/hop2",
2074 "https://destination.example.com/final",
2075 ];
2076 for hop in hops {
2077 assert!(validate_url(hop).is_ok(), "expected ok for {hop}");
2078 }
2079 }
2080
2081 #[test]
2086 fn redirect_to_private_ip_rejected_by_validate_url() {
2087 let private_targets = [
2089 "https://127.0.0.1/secret",
2090 "https://10.0.0.1/internal",
2091 "https://192.168.1.1/admin",
2092 "https://172.16.0.1/data",
2093 "https://[::1]/path",
2094 "https://[fe80::1]/path",
2095 "https://localhost/path",
2096 "https://service.internal/api",
2097 ];
2098 for target in private_targets {
2099 let result = validate_url(target);
2100 assert!(result.is_err(), "expected error for {target}");
2101 assert!(
2102 matches!(result.unwrap_err(), ToolError::Blocked { .. }),
2103 "expected Blocked for {target}"
2104 );
2105 }
2106 }
2107
2108 #[test]
2110 fn redirect_relative_url_resolves_correctly() {
2111 let base = Url::parse("https://example.com/page").unwrap();
2112 let relative = "/other";
2113 let resolved = base.join(relative).unwrap();
2114 assert_eq!(resolved.as_str(), "https://example.com/other");
2115 }
2116
2117 #[test]
2119 fn redirect_to_http_rejected() {
2120 let err = validate_url("http://example.com/page").unwrap_err();
2121 assert!(matches!(err, ToolError::Blocked { .. }));
2122 }
2123
2124 #[test]
2125 fn ipv4_mapped_ipv6_link_local_blocked() {
2126 let err = validate_url("https://[::ffff:169.254.0.1]/path").unwrap_err();
2127 assert!(matches!(err, ToolError::Blocked { .. }));
2128 }
2129
2130 #[test]
2131 fn ipv4_mapped_ipv6_public_allowed() {
2132 assert!(validate_url("https://[::ffff:93.184.216.34]/path").is_ok());
2133 }
2134
2135 #[tokio::test]
2138 async fn fetch_http_scheme_blocked() {
2139 let config = ScrapeConfig::default();
2140 let executor = WebScrapeExecutor::new(&config);
2141 let call = crate::executor::ToolCall {
2142 tool_id: ToolName::new("fetch"),
2143 params: {
2144 let mut m = serde_json::Map::new();
2145 m.insert("url".to_owned(), serde_json::json!("http://example.com"));
2146 m
2147 },
2148 caller_id: None,
2149 context: None,
2150
2151 tool_call_id: String::new(),
2152 };
2153 let result = executor.execute_tool_call(&call).await;
2154 assert!(matches!(result, Err(ToolError::Blocked { .. })));
2155 }
2156
2157 #[tokio::test]
2158 async fn fetch_private_ip_blocked() {
2159 let config = ScrapeConfig::default();
2160 let executor = WebScrapeExecutor::new(&config);
2161 let call = crate::executor::ToolCall {
2162 tool_id: ToolName::new("fetch"),
2163 params: {
2164 let mut m = serde_json::Map::new();
2165 m.insert(
2166 "url".to_owned(),
2167 serde_json::json!("https://192.168.1.1/secret"),
2168 );
2169 m
2170 },
2171 caller_id: None,
2172 context: None,
2173
2174 tool_call_id: String::new(),
2175 };
2176 let result = executor.execute_tool_call(&call).await;
2177 assert!(matches!(result, Err(ToolError::Blocked { .. })));
2178 }
2179
2180 #[tokio::test]
2181 async fn fetch_localhost_blocked() {
2182 let config = ScrapeConfig::default();
2183 let executor = WebScrapeExecutor::new(&config);
2184 let call = crate::executor::ToolCall {
2185 tool_id: ToolName::new("fetch"),
2186 params: {
2187 let mut m = serde_json::Map::new();
2188 m.insert(
2189 "url".to_owned(),
2190 serde_json::json!("https://localhost/page"),
2191 );
2192 m
2193 },
2194 caller_id: None,
2195 context: None,
2196
2197 tool_call_id: String::new(),
2198 };
2199 let result = executor.execute_tool_call(&call).await;
2200 assert!(matches!(result, Err(ToolError::Blocked { .. })));
2201 }
2202
2203 #[tokio::test]
2204 async fn fetch_unknown_tool_returns_none() {
2205 let config = ScrapeConfig::default();
2206 let executor = WebScrapeExecutor::new(&config);
2207 let call = crate::executor::ToolCall {
2208 tool_id: ToolName::new("unknown_tool"),
2209 params: serde_json::Map::new(),
2210 caller_id: None,
2211 context: None,
2212
2213 tool_call_id: String::new(),
2214 };
2215 let result = executor.execute_tool_call(&call).await;
2216 assert!(result.unwrap().is_none());
2217 }
2218
2219 #[tokio::test]
2220 async fn fetch_returns_body_via_mock() {
2221 use wiremock::matchers::{method, path};
2222 use wiremock::{Mock, ResponseTemplate};
2223
2224 let (executor, server) = mock_server_executor().await;
2225 Mock::given(method("GET"))
2226 .and(path("/content"))
2227 .respond_with(ResponseTemplate::new(200).set_body_string("plain text content"))
2228 .mount(&server)
2229 .await;
2230
2231 let (host, addrs) = server_host_and_addr(&server);
2232 let url = format!("{}/content", server.uri());
2233 let result = executor
2234 .fetch_html(&url, &host, &addrs, "fetch", "test-cid", None)
2235 .await;
2236 assert!(result.is_ok());
2237 assert_eq!(result.unwrap(), "plain text content");
2238 }
2239
2240 #[test]
2241 fn tool_definitions_returns_web_scrape_and_fetch() {
2242 let config = ScrapeConfig::default();
2243 let executor = WebScrapeExecutor::new(&config);
2244 let defs = executor.tool_definitions();
2245 assert_eq!(defs.len(), 2);
2246 assert_eq!(defs[0].id, "web_scrape");
2247 assert_eq!(
2248 defs[0].invocation,
2249 crate::registry::InvocationHint::FencedBlock("scrape")
2250 );
2251 assert_eq!(defs[1].id, "fetch");
2252 assert_eq!(
2253 defs[1].invocation,
2254 crate::registry::InvocationHint::ToolCall
2255 );
2256 }
2257
2258 #[test]
2259 fn tool_definitions_schema_has_all_params() {
2260 let config = ScrapeConfig::default();
2261 let executor = WebScrapeExecutor::new(&config);
2262 let defs = executor.tool_definitions();
2263 let obj = defs[0].schema.as_object().unwrap();
2264 let props = obj["properties"].as_object().unwrap();
2265 assert!(props.contains_key("url"));
2266 assert!(props.contains_key("select"));
2267 assert!(props.contains_key("extract"));
2268 assert!(props.contains_key("limit"));
2269 let req = obj["required"].as_array().unwrap();
2270 assert!(req.iter().any(|v| v.as_str() == Some("url")));
2271 assert!(req.iter().any(|v| v.as_str() == Some("select")));
2272 assert!(!req.iter().any(|v| v.as_str() == Some("extract")));
2273 }
2274
2275 #[test]
2278 fn subdomain_localhost_blocked() {
2279 let host: url::Host<&str> = url::Host::Domain("foo.localhost");
2280 assert!(is_private_host(&host));
2281 }
2282
2283 #[test]
2284 fn internal_tld_blocked() {
2285 let host: url::Host<&str> = url::Host::Domain("service.internal");
2286 assert!(is_private_host(&host));
2287 }
2288
2289 #[test]
2290 fn local_tld_blocked() {
2291 let host: url::Host<&str> = url::Host::Domain("printer.local");
2292 assert!(is_private_host(&host));
2293 }
2294
2295 #[test]
2296 fn public_domain_not_blocked() {
2297 let host: url::Host<&str> = url::Host::Domain("example.com");
2298 assert!(!is_private_host(&host));
2299 }
2300
2301 #[tokio::test]
2304 async fn resolve_loopback_rejected() {
2305 let url = url::Url::parse("https://127.0.0.1/path").unwrap();
2307 let result = resolve_and_validate(&url).await;
2309 assert!(
2310 result.is_err(),
2311 "loopback IP must be rejected by resolve_and_validate"
2312 );
2313 let err = result.unwrap_err();
2314 assert!(matches!(err, crate::executor::ToolError::Blocked { .. }));
2315 }
2316
2317 #[tokio::test]
2318 async fn resolve_private_10_rejected() {
2319 let url = url::Url::parse("https://10.0.0.1/path").unwrap();
2320 let result = resolve_and_validate(&url).await;
2321 assert!(result.is_err());
2322 assert!(matches!(
2323 result.unwrap_err(),
2324 crate::executor::ToolError::Blocked { .. }
2325 ));
2326 }
2327
2328 #[tokio::test]
2329 async fn resolve_private_192_rejected() {
2330 let url = url::Url::parse("https://192.168.1.1/path").unwrap();
2331 let result = resolve_and_validate(&url).await;
2332 assert!(result.is_err());
2333 assert!(matches!(
2334 result.unwrap_err(),
2335 crate::executor::ToolError::Blocked { .. }
2336 ));
2337 }
2338
2339 #[tokio::test]
2340 async fn resolve_ipv6_loopback_rejected() {
2341 let url = url::Url::parse("https://[::1]/path").unwrap();
2342 let result = resolve_and_validate(&url).await;
2343 assert!(result.is_err());
2344 assert!(matches!(
2345 result.unwrap_err(),
2346 crate::executor::ToolError::Blocked { .. }
2347 ));
2348 }
2349
2350 #[tokio::test]
2351 async fn resolve_no_host_returns_ok() {
2352 let url = url::Url::parse("https://example.com/path").unwrap();
2354 let url_no_host = url::Url::parse("data:text/plain,hello").unwrap();
2356 let result = resolve_and_validate(&url_no_host).await;
2358 assert!(result.is_ok());
2359 let (host, addrs) = result.unwrap();
2360 assert!(host.is_empty());
2361 assert!(addrs.is_empty());
2362 drop(url);
2363 drop(url_no_host);
2364 }
2365
2366 async fn make_file_audit_logger(
2370 dir: &tempfile::TempDir,
2371 ) -> (
2372 std::sync::Arc<crate::audit::AuditLogger>,
2373 std::path::PathBuf,
2374 ) {
2375 use crate::audit::AuditLogger;
2376 use crate::config::AuditConfig;
2377 let path = dir.path().join("audit.log");
2378 let config = AuditConfig {
2379 enabled: true,
2380 destination: path.display().to_string(),
2381 ..Default::default()
2382 };
2383 let logger = std::sync::Arc::new(AuditLogger::from_config(&config, false).await.unwrap());
2384 (logger, path)
2385 }
2386
2387 #[tokio::test]
2388 async fn with_audit_sets_logger() {
2389 let config = ScrapeConfig::default();
2390 let executor = WebScrapeExecutor::new(&config);
2391 assert!(executor.audit_logger.is_none());
2392
2393 let dir = tempfile::tempdir().unwrap();
2394 let (logger, _path) = make_file_audit_logger(&dir).await;
2395 let executor = executor.with_audit(logger);
2396 assert!(executor.audit_logger.is_some());
2397 }
2398
2399 #[test]
2400 fn tool_error_to_audit_result_blocked_maps_correctly() {
2401 let err = ToolError::Blocked {
2402 command: "scheme not allowed: http".into(),
2403 };
2404 let result = tool_error_to_audit_result(&err);
2405 assert!(
2406 matches!(result, AuditResult::Blocked { reason } if reason == "scheme not allowed: http")
2407 );
2408 }
2409
2410 #[test]
2411 fn tool_error_to_audit_result_timeout_maps_correctly() {
2412 let err = ToolError::Timeout { timeout_secs: 15 };
2413 let result = tool_error_to_audit_result(&err);
2414 assert!(matches!(result, AuditResult::Timeout));
2415 }
2416
2417 #[test]
2418 fn tool_error_to_audit_result_execution_error_maps_correctly() {
2419 let err = ToolError::Execution(std::io::Error::other("connection refused"));
2420 let result = tool_error_to_audit_result(&err);
2421 assert!(
2422 matches!(result, AuditResult::Error { message } if message.contains("connection refused"))
2423 );
2424 }
2425
2426 #[tokio::test]
2427 async fn fetch_audit_blocked_url_logged() {
2428 let dir = tempfile::tempdir().unwrap();
2429 let (logger, log_path) = make_file_audit_logger(&dir).await;
2430
2431 let config = ScrapeConfig::default();
2432 let executor = WebScrapeExecutor::new(&config).with_audit(logger);
2433
2434 let call = crate::executor::ToolCall {
2435 tool_id: ToolName::new("fetch"),
2436 params: {
2437 let mut m = serde_json::Map::new();
2438 m.insert("url".to_owned(), serde_json::json!("http://example.com"));
2439 m
2440 },
2441 caller_id: None,
2442 context: None,
2443
2444 tool_call_id: String::new(),
2445 };
2446 let result = executor.execute_tool_call(&call).await;
2447 assert!(matches!(result, Err(ToolError::Blocked { .. })));
2448
2449 let content = tokio::fs::read_to_string(&log_path).await.unwrap();
2450 assert!(
2451 content.contains("\"tool\":\"fetch\""),
2452 "expected tool=fetch in audit: {content}"
2453 );
2454 assert!(
2455 content.contains("\"type\":\"blocked\""),
2456 "expected type=blocked in audit: {content}"
2457 );
2458 assert!(
2459 content.contains("http://example.com"),
2460 "expected URL in audit command field: {content}"
2461 );
2462 }
2463
2464 #[tokio::test]
2465 async fn log_audit_success_writes_to_file() {
2466 let dir = tempfile::tempdir().unwrap();
2467 let (logger, log_path) = make_file_audit_logger(&dir).await;
2468
2469 let config = ScrapeConfig::default();
2470 let executor = WebScrapeExecutor::new(&config).with_audit(logger);
2471
2472 executor
2473 .log_audit(
2474 "fetch",
2475 "https://example.com/page",
2476 AuditResult::Success,
2477 42,
2478 None,
2479 None,
2480 None,
2481 )
2482 .await;
2483
2484 let content = tokio::fs::read_to_string(&log_path).await.unwrap();
2485 assert!(
2486 content.contains("\"tool\":\"fetch\""),
2487 "expected tool=fetch in audit: {content}"
2488 );
2489 assert!(
2490 content.contains("\"type\":\"success\""),
2491 "expected type=success in audit: {content}"
2492 );
2493 assert!(
2494 content.contains("\"command\":\"https://example.com/page\""),
2495 "expected command URL in audit: {content}"
2496 );
2497 assert!(
2498 content.contains("\"duration_ms\":42"),
2499 "expected duration_ms in audit: {content}"
2500 );
2501 }
2502
2503 #[tokio::test]
2504 async fn log_audit_blocked_writes_to_file() {
2505 let dir = tempfile::tempdir().unwrap();
2506 let (logger, log_path) = make_file_audit_logger(&dir).await;
2507
2508 let config = ScrapeConfig::default();
2509 let executor = WebScrapeExecutor::new(&config).with_audit(logger);
2510
2511 executor
2512 .log_audit(
2513 "web_scrape",
2514 "http://evil.com/page",
2515 AuditResult::Blocked {
2516 reason: "scheme not allowed: http".into(),
2517 },
2518 0,
2519 None,
2520 None,
2521 None,
2522 )
2523 .await;
2524
2525 let content = tokio::fs::read_to_string(&log_path).await.unwrap();
2526 assert!(
2527 content.contains("\"tool\":\"web_scrape\""),
2528 "expected tool=web_scrape in audit: {content}"
2529 );
2530 assert!(
2531 content.contains("\"type\":\"blocked\""),
2532 "expected type=blocked in audit: {content}"
2533 );
2534 assert!(
2535 content.contains("scheme not allowed"),
2536 "expected block reason in audit: {content}"
2537 );
2538 }
2539
2540 #[tokio::test]
2541 async fn web_scrape_audit_blocked_url_logged() {
2542 let dir = tempfile::tempdir().unwrap();
2543 let (logger, log_path) = make_file_audit_logger(&dir).await;
2544
2545 let config = ScrapeConfig::default();
2546 let executor = WebScrapeExecutor::new(&config).with_audit(logger);
2547
2548 let call = crate::executor::ToolCall {
2549 tool_id: ToolName::new("web_scrape"),
2550 params: {
2551 let mut m = serde_json::Map::new();
2552 m.insert("url".to_owned(), serde_json::json!("http://example.com"));
2553 m.insert("select".to_owned(), serde_json::json!("h1"));
2554 m
2555 },
2556 caller_id: None,
2557 context: None,
2558
2559 tool_call_id: String::new(),
2560 };
2561 let result = executor.execute_tool_call(&call).await;
2562 assert!(matches!(result, Err(ToolError::Blocked { .. })));
2563
2564 let content = tokio::fs::read_to_string(&log_path).await.unwrap();
2565 assert!(
2566 content.contains("\"tool\":\"web_scrape\""),
2567 "expected tool=web_scrape in audit: {content}"
2568 );
2569 assert!(
2570 content.contains("\"type\":\"blocked\""),
2571 "expected type=blocked in audit: {content}"
2572 );
2573 }
2574
2575 #[tokio::test]
2576 async fn no_audit_logger_does_not_panic_on_blocked_fetch() {
2577 let config = ScrapeConfig::default();
2578 let executor = WebScrapeExecutor::new(&config);
2579 assert!(executor.audit_logger.is_none());
2580
2581 let call = crate::executor::ToolCall {
2582 tool_id: ToolName::new("fetch"),
2583 params: {
2584 let mut m = serde_json::Map::new();
2585 m.insert("url".to_owned(), serde_json::json!("http://example.com"));
2586 m
2587 },
2588 caller_id: None,
2589 context: None,
2590
2591 tool_call_id: String::new(),
2592 };
2593 let result = executor.execute_tool_call(&call).await;
2595 assert!(matches!(result, Err(ToolError::Blocked { .. })));
2596 }
2597
2598 #[tokio::test]
2600 async fn fetch_execute_tool_call_end_to_end() {
2601 use wiremock::matchers::{method, path};
2602 use wiremock::{Mock, ResponseTemplate};
2603
2604 let (executor, server) = mock_server_executor().await;
2605 Mock::given(method("GET"))
2606 .and(path("/e2e"))
2607 .respond_with(ResponseTemplate::new(200).set_body_string("<h1>end-to-end</h1>"))
2608 .mount(&server)
2609 .await;
2610
2611 let (host, addrs) = server_host_and_addr(&server);
2612 let result = executor
2614 .fetch_html(
2615 &format!("{}/e2e", server.uri()),
2616 &host,
2617 &addrs,
2618 "fetch",
2619 "test-cid",
2620 None,
2621 )
2622 .await;
2623 assert!(result.is_ok());
2624 assert!(result.unwrap().contains("end-to-end"));
2625 }
2626
2627 #[test]
2630 fn domain_matches_exact() {
2631 assert!(domain_matches("example.com", "example.com"));
2632 assert!(!domain_matches("example.com", "other.com"));
2633 assert!(!domain_matches("example.com", "sub.example.com"));
2634 }
2635
2636 #[test]
2637 fn domain_matches_wildcard_single_subdomain() {
2638 assert!(domain_matches("*.example.com", "sub.example.com"));
2639 assert!(!domain_matches("*.example.com", "example.com"));
2640 assert!(!domain_matches("*.example.com", "sub.sub.example.com"));
2641 }
2642
2643 #[test]
2644 fn domain_matches_wildcard_does_not_match_empty_label() {
2645 assert!(!domain_matches("*.example.com", ".example.com"));
2647 }
2648
2649 #[test]
2650 fn domain_matches_multi_wildcard_treated_as_exact() {
2651 assert!(!domain_matches("*.*.example.com", "a.b.example.com"));
2653 }
2654
2655 #[test]
2658 fn check_domain_policy_empty_lists_allow_all() {
2659 assert!(check_domain_policy("example.com", &[], &[]).is_ok());
2660 assert!(check_domain_policy("evil.com", &[], &[]).is_ok());
2661 }
2662
2663 #[test]
2664 fn check_domain_policy_denylist_blocks() {
2665 let denied = vec!["evil.com".to_string()];
2666 let err = check_domain_policy("evil.com", &[], &denied).unwrap_err();
2667 assert!(matches!(err, ToolError::Blocked { .. }));
2668 }
2669
2670 #[test]
2671 fn check_domain_policy_denylist_does_not_block_other_domains() {
2672 let denied = vec!["evil.com".to_string()];
2673 assert!(check_domain_policy("good.com", &[], &denied).is_ok());
2674 }
2675
2676 #[test]
2677 fn check_domain_policy_allowlist_permits_matching() {
2678 let allowed = vec!["docs.rs".to_string(), "*.rust-lang.org".to_string()];
2679 assert!(check_domain_policy("docs.rs", &allowed, &[]).is_ok());
2680 assert!(check_domain_policy("blog.rust-lang.org", &allowed, &[]).is_ok());
2681 }
2682
2683 #[test]
2684 fn check_domain_policy_allowlist_blocks_unknown() {
2685 let allowed = vec!["docs.rs".to_string()];
2686 let err = check_domain_policy("other.com", &allowed, &[]).unwrap_err();
2687 assert!(matches!(err, ToolError::Blocked { .. }));
2688 }
2689
2690 #[test]
2691 fn check_domain_policy_deny_overrides_allow() {
2692 let allowed = vec!["example.com".to_string()];
2693 let denied = vec!["example.com".to_string()];
2694 let err = check_domain_policy("example.com", &allowed, &denied).unwrap_err();
2695 assert!(matches!(err, ToolError::Blocked { .. }));
2696 }
2697
2698 #[test]
2699 fn check_domain_policy_wildcard_in_denylist() {
2700 let denied = vec!["*.evil.com".to_string()];
2701 let err = check_domain_policy("sub.evil.com", &[], &denied).unwrap_err();
2702 assert!(matches!(err, ToolError::Blocked { .. }));
2703 assert!(check_domain_policy("evil.com", &[], &denied).is_ok());
2705 }
2706}