1use std::net::{IpAddr, SocketAddr};
20use std::sync::Arc;
21use std::sync::atomic::{AtomicU64, Ordering};
22use std::time::{Duration, Instant};
23
24use schemars::JsonSchema;
25use serde::Deserialize;
26use url::Url;
27
28use zeph_common::ToolName;
29
30use crate::audit::{AuditEntry, AuditLogger, AuditResult, EgressEvent, chrono_now};
31use crate::config::{EgressConfig, ScrapeConfig};
32use crate::executor::{
33 ClaimSource, ToolCall, ToolError, ToolExecutor, ToolOutput, deserialize_params,
34};
35use crate::net::is_private_ip;
36
37fn redact_url_for_log(url: &str) -> String {
41 let Ok(mut parsed) = Url::parse(url) else {
42 return url.to_owned();
43 };
44 let _ = parsed.set_username("");
46 let _ = parsed.set_password(None);
47 let sensitive = [
49 "token", "key", "secret", "password", "auth", "sig", "api_key", "apikey",
50 ];
51 let filtered: Vec<(String, String)> = parsed
52 .query_pairs()
53 .filter(|(k, _)| {
54 let lower = k.to_lowercase();
55 !sensitive.iter().any(|s| lower.contains(s))
56 })
57 .map(|(k, v)| (k.into_owned(), v.into_owned()))
58 .collect();
59 if filtered.is_empty() {
60 parsed.set_query(None);
61 } else {
62 let q: String = filtered
63 .iter()
64 .map(|(k, v)| format!("{k}={v}"))
65 .collect::<Vec<_>>()
66 .join("&");
67 parsed.set_query(Some(&q));
68 }
69 parsed.to_string()
70}
71
72#[derive(Debug, Deserialize, JsonSchema)]
73struct FetchParams {
74 url: String,
76}
77
78#[derive(Debug, Deserialize, JsonSchema)]
79struct ScrapeInstruction {
80 url: String,
82 select: String,
84 #[serde(default = "default_extract")]
86 extract: String,
87 limit: Option<usize>,
89}
90
91fn default_extract() -> String {
92 "text".into()
93}
94
95#[derive(Debug)]
96enum ExtractMode {
97 Text,
98 Html,
99 Attr(String),
100}
101
102impl ExtractMode {
103 fn parse(s: &str) -> Self {
104 match s {
105 "text" => Self::Text,
106 "html" => Self::Html,
107 attr if attr.starts_with("attr:") => {
108 Self::Attr(attr.strip_prefix("attr:").unwrap_or(attr).to_owned())
109 }
110 _ => Self::Text,
111 }
112 }
113}
114
115#[derive(Debug)]
155pub struct WebScrapeExecutor {
156 timeout: Duration,
157 max_body_bytes: usize,
158 allowed_domains: Vec<String>,
159 denied_domains: Vec<String>,
160 audit_logger: Option<Arc<AuditLogger>>,
161 egress_config: EgressConfig,
162 egress_tx: Option<tokio::sync::mpsc::Sender<EgressEvent>>,
163 egress_dropped: Arc<AtomicU64>,
164}
165
166impl WebScrapeExecutor {
167 #[must_use]
171 pub fn new(config: &ScrapeConfig) -> Self {
172 Self {
173 timeout: Duration::from_secs(config.timeout),
174 max_body_bytes: config.max_body_bytes,
175 allowed_domains: config.allowed_domains.clone(),
176 denied_domains: config.denied_domains.clone(),
177 audit_logger: None,
178 egress_config: EgressConfig::default(),
179 egress_tx: None,
180 egress_dropped: Arc::new(AtomicU64::new(0)),
181 }
182 }
183
184 #[must_use]
186 pub fn with_audit(mut self, logger: Arc<AuditLogger>) -> Self {
187 self.audit_logger = Some(logger);
188 self
189 }
190
191 #[must_use]
193 pub fn with_egress_config(mut self, config: EgressConfig) -> Self {
194 self.egress_config = config;
195 self
196 }
197
198 #[must_use]
203 pub fn with_egress_tx(
204 mut self,
205 tx: tokio::sync::mpsc::Sender<EgressEvent>,
206 dropped: Arc<AtomicU64>,
207 ) -> Self {
208 self.egress_tx = Some(tx);
209 self.egress_dropped = dropped;
210 self
211 }
212
213 #[must_use]
215 pub fn egress_dropped(&self) -> Arc<AtomicU64> {
216 Arc::clone(&self.egress_dropped)
217 }
218
219 fn build_client(&self, host: &str, addrs: &[SocketAddr]) -> reqwest::Client {
220 let mut builder = reqwest::Client::builder()
221 .timeout(self.timeout)
222 .redirect(reqwest::redirect::Policy::none());
223 builder = builder.resolve_to_addrs(host, addrs);
224 builder.build().unwrap_or_default()
225 }
226}
227
228impl ToolExecutor for WebScrapeExecutor {
229 fn tool_definitions(&self) -> Vec<crate::registry::ToolDef> {
230 use crate::registry::{InvocationHint, ToolDef};
231 vec![
232 ToolDef {
233 id: "web_scrape".into(),
234 description: "Extract structured data from a web page using CSS selectors.\n\nONLY call this tool when the user has explicitly provided a URL in their message, or when a prior tool call returned a URL to retrieve. NEVER construct, guess, or infer a URL from entity names, brand knowledge, or domain patterns.\n\nParameters: url (string, required) - HTTPS URL; select (string, required) - CSS selector; extract (string, optional) - \"text\", \"html\", or \"attr:<name>\"; limit (integer, optional) - max results\nReturns: extracted text/HTML/attribute values, one per line\nErrors: InvalidParams if URL is not HTTPS or selector is empty; Timeout after configured seconds; connection/DNS failures".into(),
235 schema: schemars::schema_for!(ScrapeInstruction),
236 invocation: InvocationHint::FencedBlock("scrape"),
237 output_schema: None,
238 },
239 ToolDef {
240 id: "fetch".into(),
241 description: "Fetch a URL and return the response body as plain text.\n\nONLY call this tool when the user has explicitly provided a URL in their message, or when a prior tool call returned a URL to retrieve. NEVER construct, guess, or infer a URL from entity names, brand knowledge, or domain patterns. If no URL is present in the conversation, do not call this tool.\n\nParameters: url (string, required) - HTTPS URL to fetch\nReturns: response body as UTF-8 text, truncated if exceeding max body size\nErrors: InvalidParams if URL is not HTTPS; Timeout; SSRF-blocked private IPs; connection failures".into(),
242 schema: schemars::schema_for!(FetchParams),
243 invocation: InvocationHint::ToolCall,
244 output_schema: None,
245 },
246 ]
247 }
248
249 async fn execute(&self, response: &str) -> Result<Option<ToolOutput>, ToolError> {
250 let blocks = extract_scrape_blocks(response);
251 if blocks.is_empty() {
252 return Ok(None);
253 }
254
255 let mut outputs = Vec::with_capacity(blocks.len());
256 #[allow(clippy::cast_possible_truncation)]
257 let blocks_executed = blocks.len() as u32;
258
259 for block in &blocks {
260 let instruction: ScrapeInstruction = serde_json::from_str(block).map_err(|e| {
261 ToolError::Execution(std::io::Error::new(
262 std::io::ErrorKind::InvalidData,
263 e.to_string(),
264 ))
265 })?;
266 let correlation_id = EgressEvent::new_correlation_id();
267 let start = Instant::now();
268 let scrape_result = self
269 .scrape_instruction(&instruction, &correlation_id, None)
270 .await;
271 #[allow(clippy::cast_possible_truncation)]
272 let duration_ms = start.elapsed().as_millis() as u64;
273 match scrape_result {
274 Ok(output) => {
275 self.log_audit(
276 "web_scrape",
277 &instruction.url,
278 AuditResult::Success,
279 duration_ms,
280 None,
281 None,
282 Some(correlation_id),
283 )
284 .await;
285 outputs.push(output);
286 }
287 Err(e) => {
288 let audit_result = tool_error_to_audit_result(&e);
289 self.log_audit(
290 "web_scrape",
291 &instruction.url,
292 audit_result,
293 duration_ms,
294 Some(&e),
295 None,
296 Some(correlation_id),
297 )
298 .await;
299 return Err(e);
300 }
301 }
302 }
303
304 Ok(Some(ToolOutput {
305 tool_name: ToolName::new("web-scrape"),
306 summary: outputs.join("\n\n"),
307 blocks_executed,
308 filter_stats: None,
309 diff: None,
310 streamed: false,
311 terminal_id: None,
312 locations: None,
313 raw_response: None,
314 claim_source: Some(ClaimSource::WebScrape),
315 }))
316 }
317
318 #[cfg_attr(
319 feature = "profiling",
320 tracing::instrument(name = "tool.web_scrape", skip_all)
321 )]
322 async fn execute_tool_call(&self, call: &ToolCall) -> Result<Option<ToolOutput>, ToolError> {
323 match call.tool_id.as_str() {
324 "web_scrape" => {
325 let instruction: ScrapeInstruction = deserialize_params(&call.params)?;
326 let correlation_id = EgressEvent::new_correlation_id();
327 let start = Instant::now();
328 let result = self
329 .scrape_instruction(&instruction, &correlation_id, call.caller_id.clone())
330 .await;
331 #[allow(clippy::cast_possible_truncation)]
332 let duration_ms = start.elapsed().as_millis() as u64;
333 self.run_with_audit(
334 "web_scrape",
335 "web-scrape",
336 &instruction.url,
337 call.caller_id.clone(),
338 correlation_id,
339 duration_ms,
340 result,
341 )
342 .await
343 }
344 "fetch" => {
345 let p: FetchParams = deserialize_params(&call.params)?;
346 let correlation_id = EgressEvent::new_correlation_id();
347 let start = Instant::now();
348 let result = self
349 .handle_fetch(&p, &correlation_id, call.caller_id.clone())
350 .await;
351 #[allow(clippy::cast_possible_truncation)]
352 let duration_ms = start.elapsed().as_millis() as u64;
353 self.run_with_audit(
354 "fetch",
355 "fetch",
356 &p.url,
357 call.caller_id.clone(),
358 correlation_id,
359 duration_ms,
360 result,
361 )
362 .await
363 }
364 _ => Ok(None),
365 }
366 }
367
368 fn is_tool_retryable(&self, tool_id: &str) -> bool {
369 matches!(tool_id, "web_scrape" | "fetch")
370 }
371}
372
373fn tool_error_to_audit_result(e: &ToolError) -> AuditResult {
374 match e {
375 ToolError::Blocked { command } => AuditResult::Blocked {
376 reason: command.clone(),
377 },
378 ToolError::Timeout { .. } => AuditResult::Timeout,
379 _ => AuditResult::Error {
380 message: e.to_string(),
381 },
382 }
383}
384
385impl WebScrapeExecutor {
386 #[allow(clippy::too_many_arguments)]
387 async fn run_with_audit(
388 &self,
389 audit_tool_name: &str,
390 public_tool_name: &str,
391 audit_command: &str,
392 caller_id: Option<String>,
393 correlation_id: String,
394 duration_ms: u64,
395 result: Result<String, ToolError>,
396 ) -> Result<Option<ToolOutput>, ToolError> {
397 match result {
398 Ok(output) => {
399 self.log_audit(
400 audit_tool_name,
401 audit_command,
402 AuditResult::Success,
403 duration_ms,
404 None,
405 caller_id,
406 Some(correlation_id),
407 )
408 .await;
409 Ok(Some(ToolOutput {
410 tool_name: ToolName::new(public_tool_name),
411 summary: output,
412 blocks_executed: 1,
413 filter_stats: None,
414 diff: None,
415 streamed: false,
416 terminal_id: None,
417 locations: None,
418 raw_response: None,
419 claim_source: Some(ClaimSource::WebScrape),
420 }))
421 }
422 Err(e) => {
423 let audit_result = tool_error_to_audit_result(&e);
424 self.log_audit(
425 audit_tool_name,
426 audit_command,
427 audit_result,
428 duration_ms,
429 Some(&e),
430 caller_id,
431 Some(correlation_id),
432 )
433 .await;
434 Err(e)
435 }
436 }
437 }
438
439 #[allow(clippy::too_many_arguments)] async fn log_audit(
441 &self,
442 tool: &str,
443 command: &str,
444 result: AuditResult,
445 duration_ms: u64,
446 error: Option<&ToolError>,
447 caller_id: Option<String>,
448 correlation_id: Option<String>,
449 ) {
450 if let Some(ref logger) = self.audit_logger {
451 let (error_category, error_domain, error_phase) =
452 error.map_or((None, None, None), |e| {
453 let cat = e.category();
454 (
455 Some(cat.label().to_owned()),
456 Some(cat.domain().label().to_owned()),
457 Some(cat.phase().label().to_owned()),
458 )
459 });
460 let entry = AuditEntry {
461 timestamp: chrono_now(),
462 tool: tool.into(),
463 command: command.into(),
464 result,
465 duration_ms,
466 error_category,
467 error_domain,
468 error_phase,
469 claim_source: Some(ClaimSource::WebScrape),
470 mcp_server_id: None,
471 injection_flagged: false,
472 embedding_anomalous: false,
473 cross_boundary_mcp_to_acp: false,
474 adversarial_policy_decision: None,
475 exit_code: None,
476 truncated: false,
477 caller_id,
478 policy_match: None,
479 correlation_id,
480 vigil_risk: None,
481 execution_env: None,
482 resolved_cwd: None,
483 scope_at_definition: None,
484 scope_at_dispatch: None,
485 };
486 logger.log(&entry).await;
487 }
488 }
489
490 fn send_egress_event(&self, event: EgressEvent) {
491 if let Some(ref tx) = self.egress_tx {
492 match tx.try_send(event) {
493 Ok(()) => {}
494 Err(tokio::sync::mpsc::error::TrySendError::Full(_)) => {
495 self.egress_dropped.fetch_add(1, Ordering::Relaxed);
496 }
497 Err(tokio::sync::mpsc::error::TrySendError::Closed(_)) => {
498 tracing::debug!("egress channel closed; executor continuing without telemetry");
499 }
500 }
501 }
502 }
503
504 async fn log_egress_event(&self, event: &EgressEvent) {
505 if let Some(ref logger) = self.audit_logger {
506 logger.log_egress(event).await;
507 }
508 self.send_egress_event(event.clone());
509 }
510
511 async fn handle_fetch(
512 &self,
513 params: &FetchParams,
514 correlation_id: &str,
515 caller_id: Option<String>,
516 ) -> Result<String, ToolError> {
517 let parsed = validate_url(¶ms.url);
518 let host_str = parsed
519 .as_ref()
520 .map(|u| u.host_str().unwrap_or("").to_owned())
521 .unwrap_or_default();
522
523 if let Err(ref _e) = parsed {
524 if self.egress_config.enabled && self.egress_config.log_blocked {
525 let event = Self::make_blocked_event(
526 "fetch",
527 ¶ms.url,
528 &host_str,
529 correlation_id,
530 caller_id.clone(),
531 "scheme",
532 );
533 self.log_egress_event(&event).await;
534 }
535 return Err(parsed.unwrap_err());
536 }
537 let parsed = parsed.unwrap();
538
539 if let Err(e) = check_domain_policy(
540 parsed.host_str().unwrap_or(""),
541 &self.allowed_domains,
542 &self.denied_domains,
543 ) {
544 if self.egress_config.enabled && self.egress_config.log_blocked {
545 let event = Self::make_blocked_event(
546 "fetch",
547 ¶ms.url,
548 parsed.host_str().unwrap_or(""),
549 correlation_id,
550 caller_id.clone(),
551 "blocklist",
552 );
553 self.log_egress_event(&event).await;
554 }
555 return Err(e);
556 }
557
558 let (host, addrs) = match resolve_and_validate(&parsed).await {
559 Ok(v) => v,
560 Err(e) => {
561 if self.egress_config.enabled && self.egress_config.log_blocked {
562 let event = Self::make_blocked_event(
563 "fetch",
564 ¶ms.url,
565 parsed.host_str().unwrap_or(""),
566 correlation_id,
567 caller_id.clone(),
568 "ssrf",
569 );
570 self.log_egress_event(&event).await;
571 }
572 return Err(e);
573 }
574 };
575
576 self.fetch_html(
577 ¶ms.url,
578 &host,
579 &addrs,
580 "fetch",
581 correlation_id,
582 caller_id,
583 )
584 .await
585 }
586
587 async fn scrape_instruction(
588 &self,
589 instruction: &ScrapeInstruction,
590 correlation_id: &str,
591 caller_id: Option<String>,
592 ) -> Result<String, ToolError> {
593 let parsed = validate_url(&instruction.url);
594 let host_str = parsed
595 .as_ref()
596 .map(|u| u.host_str().unwrap_or("").to_owned())
597 .unwrap_or_default();
598
599 if let Err(ref _e) = parsed {
600 if self.egress_config.enabled && self.egress_config.log_blocked {
601 let event = Self::make_blocked_event(
602 "web_scrape",
603 &instruction.url,
604 &host_str,
605 correlation_id,
606 caller_id.clone(),
607 "scheme",
608 );
609 self.log_egress_event(&event).await;
610 }
611 return Err(parsed.unwrap_err());
612 }
613 let parsed = parsed.unwrap();
614
615 if let Err(e) = check_domain_policy(
616 parsed.host_str().unwrap_or(""),
617 &self.allowed_domains,
618 &self.denied_domains,
619 ) {
620 if self.egress_config.enabled && self.egress_config.log_blocked {
621 let event = Self::make_blocked_event(
622 "web_scrape",
623 &instruction.url,
624 parsed.host_str().unwrap_or(""),
625 correlation_id,
626 caller_id.clone(),
627 "blocklist",
628 );
629 self.log_egress_event(&event).await;
630 }
631 return Err(e);
632 }
633
634 let (host, addrs) = match resolve_and_validate(&parsed).await {
635 Ok(v) => v,
636 Err(e) => {
637 if self.egress_config.enabled && self.egress_config.log_blocked {
638 let event = Self::make_blocked_event(
639 "web_scrape",
640 &instruction.url,
641 parsed.host_str().unwrap_or(""),
642 correlation_id,
643 caller_id.clone(),
644 "ssrf",
645 );
646 self.log_egress_event(&event).await;
647 }
648 return Err(e);
649 }
650 };
651
652 let html = self
653 .fetch_html(
654 &instruction.url,
655 &host,
656 &addrs,
657 "web_scrape",
658 correlation_id,
659 caller_id,
660 )
661 .await?;
662 let selector = instruction.select.clone();
663 let extract = ExtractMode::parse(&instruction.extract);
664 let limit = instruction.limit.unwrap_or(10);
665 tokio::task::spawn_blocking(move || parse_and_extract(&html, &selector, &extract, limit))
666 .await
667 .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())))?
668 }
669
670 fn make_blocked_event(
671 tool: &str,
672 url: &str,
673 host: &str,
674 correlation_id: &str,
675 caller_id: Option<String>,
676 block_reason: &'static str,
677 ) -> EgressEvent {
678 EgressEvent {
679 timestamp: chrono_now(),
680 kind: "egress",
681 correlation_id: correlation_id.to_owned(),
682 tool: tool.into(),
683 url: redact_url_for_log(url),
684 host: host.to_owned(),
685 method: "GET".to_owned(),
686 status: None,
687 duration_ms: 0,
688 response_bytes: 0,
689 blocked: true,
690 block_reason: Some(block_reason),
691 caller_id,
692 hop: 0,
693 }
694 }
695
696 #[allow(clippy::too_many_lines, clippy::too_many_arguments)]
707 async fn fetch_html(
708 &self,
709 url: &str,
710 host: &str,
711 addrs: &[SocketAddr],
712 tool: &str,
713 correlation_id: &str,
714 caller_id: Option<String>,
715 ) -> Result<String, ToolError> {
716 const MAX_REDIRECTS: usize = 3;
717
718 let mut current_url = url.to_owned();
719 let mut current_host = host.to_owned();
720 let mut current_addrs = addrs.to_vec();
721
722 for hop in 0..=MAX_REDIRECTS {
723 let hop_start = Instant::now();
724 let client = self.build_client(¤t_host, ¤t_addrs);
726 let resp = client
727 .get(¤t_url)
728 .send()
729 .await
730 .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())));
731
732 let resp = match resp {
733 Ok(r) => r,
734 Err(e) => {
735 if self.egress_config.enabled {
736 #[allow(clippy::cast_possible_truncation)]
737 let duration_ms = hop_start.elapsed().as_millis() as u64;
738 let event = EgressEvent {
739 timestamp: chrono_now(),
740 kind: "egress",
741 correlation_id: correlation_id.to_owned(),
742 tool: tool.into(),
743 url: redact_url_for_log(¤t_url),
744 host: current_host.clone(),
745 method: "GET".to_owned(),
746 status: None,
747 duration_ms,
748 response_bytes: 0,
749 blocked: false,
750 block_reason: None,
751 caller_id: caller_id.clone(),
752 #[allow(clippy::cast_possible_truncation)]
753 hop: hop as u8,
754 };
755 self.log_egress_event(&event).await;
756 }
757 return Err(e);
758 }
759 };
760
761 let status = resp.status();
762
763 if status.is_redirection() {
764 if hop == MAX_REDIRECTS {
765 return Err(ToolError::Execution(std::io::Error::other(
766 "too many redirects",
767 )));
768 }
769
770 let location = resp
771 .headers()
772 .get(reqwest::header::LOCATION)
773 .and_then(|v| v.to_str().ok())
774 .ok_or_else(|| {
775 ToolError::Execution(std::io::Error::other("redirect with no Location"))
776 })?;
777
778 let base = Url::parse(¤t_url)
780 .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())))?;
781 let next_url = base
782 .join(location)
783 .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())))?;
784
785 let validated = validate_url(next_url.as_str());
786 if let Err(ref _e) = validated {
787 if self.egress_config.enabled && self.egress_config.log_blocked {
788 #[allow(clippy::cast_possible_truncation)]
789 let duration_ms = hop_start.elapsed().as_millis() as u64;
790 let next_host = next_url.host_str().unwrap_or("").to_owned();
791 let event = EgressEvent {
792 timestamp: chrono_now(),
793 kind: "egress",
794 correlation_id: correlation_id.to_owned(),
795 tool: tool.into(),
796 url: redact_url_for_log(next_url.as_str()),
797 host: next_host,
798 method: "GET".to_owned(),
799 status: None,
800 duration_ms,
801 response_bytes: 0,
802 blocked: true,
803 block_reason: Some("ssrf"),
804 caller_id: caller_id.clone(),
805 #[allow(clippy::cast_possible_truncation)]
806 hop: (hop + 1) as u8,
807 };
808 self.log_egress_event(&event).await;
809 }
810 return Err(validated.unwrap_err());
811 }
812 let validated = validated.unwrap();
813 let resolve_result = resolve_and_validate(&validated).await;
814 if let Err(ref _e) = resolve_result {
815 if self.egress_config.enabled && self.egress_config.log_blocked {
816 #[allow(clippy::cast_possible_truncation)]
817 let duration_ms = hop_start.elapsed().as_millis() as u64;
818 let next_host = next_url.host_str().unwrap_or("").to_owned();
819 let event = EgressEvent {
820 timestamp: chrono_now(),
821 kind: "egress",
822 correlation_id: correlation_id.to_owned(),
823 tool: tool.into(),
824 url: redact_url_for_log(next_url.as_str()),
825 host: next_host,
826 method: "GET".to_owned(),
827 status: None,
828 duration_ms,
829 response_bytes: 0,
830 blocked: true,
831 block_reason: Some("ssrf"),
832 caller_id: caller_id.clone(),
833 #[allow(clippy::cast_possible_truncation)]
834 hop: (hop + 1) as u8,
835 };
836 self.log_egress_event(&event).await;
837 }
838 return Err(resolve_result.unwrap_err());
839 }
840 let (next_host, next_addrs) = resolve_result.unwrap();
841
842 current_url = next_url.to_string();
843 current_host = next_host;
844 current_addrs = next_addrs;
845 continue;
846 }
847
848 if !status.is_success() {
849 if self.egress_config.enabled {
850 #[allow(clippy::cast_possible_truncation)]
851 let duration_ms = hop_start.elapsed().as_millis() as u64;
852 let event = EgressEvent {
853 timestamp: chrono_now(),
854 kind: "egress",
855 correlation_id: correlation_id.to_owned(),
856 tool: tool.into(),
857 url: current_url.clone(),
858 host: current_host.clone(),
859 method: "GET".to_owned(),
860 status: Some(status.as_u16()),
861 duration_ms,
862 response_bytes: 0,
863 blocked: false,
864 block_reason: None,
865 caller_id: caller_id.clone(),
866 #[allow(clippy::cast_possible_truncation)]
867 hop: hop as u8,
868 };
869 self.log_egress_event(&event).await;
870 }
871 return Err(ToolError::Http {
872 status: status.as_u16(),
873 message: status.canonical_reason().unwrap_or("unknown").to_owned(),
874 });
875 }
876
877 let bytes = resp
878 .bytes()
879 .await
880 .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())))?;
881
882 if bytes.len() > self.max_body_bytes {
883 if self.egress_config.enabled {
884 #[allow(clippy::cast_possible_truncation)]
885 let duration_ms = hop_start.elapsed().as_millis() as u64;
886 let event = EgressEvent {
887 timestamp: chrono_now(),
888 kind: "egress",
889 correlation_id: correlation_id.to_owned(),
890 tool: tool.into(),
891 url: current_url.clone(),
892 host: current_host.clone(),
893 method: "GET".to_owned(),
894 status: Some(status.as_u16()),
895 duration_ms,
896 response_bytes: bytes.len(),
897 blocked: false,
898 block_reason: None,
899 caller_id: caller_id.clone(),
900 #[allow(clippy::cast_possible_truncation)]
901 hop: hop as u8,
902 };
903 self.log_egress_event(&event).await;
904 }
905 return Err(ToolError::Execution(std::io::Error::other(format!(
906 "response too large: {} bytes (max: {})",
907 bytes.len(),
908 self.max_body_bytes,
909 ))));
910 }
911
912 if self.egress_config.enabled {
914 #[allow(clippy::cast_possible_truncation)]
915 let duration_ms = hop_start.elapsed().as_millis() as u64;
916 let response_bytes = if self.egress_config.log_response_bytes {
917 bytes.len()
918 } else {
919 0
920 };
921 let event = EgressEvent {
922 timestamp: chrono_now(),
923 kind: "egress",
924 correlation_id: correlation_id.to_owned(),
925 tool: tool.into(),
926 url: current_url.clone(),
927 host: current_host.clone(),
928 method: "GET".to_owned(),
929 status: Some(status.as_u16()),
930 duration_ms,
931 response_bytes,
932 blocked: false,
933 block_reason: None,
934 caller_id: caller_id.clone(),
935 #[allow(clippy::cast_possible_truncation)]
936 hop: hop as u8,
937 };
938 self.log_egress_event(&event).await;
939 }
940
941 return String::from_utf8(bytes.to_vec())
942 .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())));
943 }
944
945 Err(ToolError::Execution(std::io::Error::other(
946 "too many redirects",
947 )))
948 }
949}
950
951fn extract_scrape_blocks(text: &str) -> Vec<&str> {
952 crate::executor::extract_fenced_blocks(text, "scrape")
953}
954
955fn check_domain_policy(
967 host: &str,
968 allowed_domains: &[String],
969 denied_domains: &[String],
970) -> Result<(), ToolError> {
971 if denied_domains.iter().any(|p| domain_matches(p, host)) {
972 return Err(ToolError::Blocked {
973 command: format!("domain blocked by denylist: {host}"),
974 });
975 }
976 if !allowed_domains.is_empty() {
977 let is_ip = host.parse::<std::net::IpAddr>().is_ok()
979 || (host.starts_with('[') && host.ends_with(']'));
980 if is_ip {
981 return Err(ToolError::Blocked {
982 command: format!(
983 "bare IP address not allowed when domain allowlist is active: {host}"
984 ),
985 });
986 }
987 if !allowed_domains.iter().any(|p| domain_matches(p, host)) {
988 return Err(ToolError::Blocked {
989 command: format!("domain not in allowlist: {host}"),
990 });
991 }
992 }
993 Ok(())
994}
995
996use crate::domain_match::domain_matches;
998
999fn validate_url(raw: &str) -> Result<Url, ToolError> {
1000 let parsed = Url::parse(raw).map_err(|_| ToolError::Blocked {
1001 command: format!("invalid URL: {raw}"),
1002 })?;
1003
1004 if parsed.scheme() != "https" {
1005 return Err(ToolError::Blocked {
1006 command: format!("scheme not allowed: {}", parsed.scheme()),
1007 });
1008 }
1009
1010 if let Some(host) = parsed.host()
1011 && is_private_host(&host)
1012 {
1013 return Err(ToolError::Blocked {
1014 command: format!(
1015 "private/local host blocked: {}",
1016 parsed.host_str().unwrap_or("")
1017 ),
1018 });
1019 }
1020
1021 Ok(parsed)
1022}
1023
1024fn is_private_host(host: &url::Host<&str>) -> bool {
1025 match host {
1026 url::Host::Domain(d) => {
1027 #[allow(clippy::case_sensitive_file_extension_comparisons)]
1030 {
1031 *d == "localhost"
1032 || d.ends_with(".localhost")
1033 || d.ends_with(".internal")
1034 || d.ends_with(".local")
1035 }
1036 }
1037 url::Host::Ipv4(v4) => is_private_ip(IpAddr::V4(*v4)),
1038 url::Host::Ipv6(v6) => is_private_ip(IpAddr::V6(*v6)),
1039 }
1040}
1041
1042async fn resolve_and_validate(url: &Url) -> Result<(String, Vec<SocketAddr>), ToolError> {
1048 let Some(host) = url.host_str() else {
1049 return Ok((String::new(), vec![]));
1050 };
1051 let port = url.port_or_known_default().unwrap_or(443);
1052 let addrs: Vec<SocketAddr> = tokio::net::lookup_host(format!("{host}:{port}"))
1053 .await
1054 .map_err(|e| ToolError::Blocked {
1055 command: format!("DNS resolution failed: {e}"),
1056 })?
1057 .collect();
1058 for addr in &addrs {
1059 if is_private_ip(addr.ip()) {
1060 return Err(ToolError::Blocked {
1061 command: format!("SSRF protection: private IP {} for host {host}", addr.ip()),
1062 });
1063 }
1064 }
1065 Ok((host.to_owned(), addrs))
1066}
1067
1068fn parse_and_extract(
1069 html: &str,
1070 selector: &str,
1071 extract: &ExtractMode,
1072 limit: usize,
1073) -> Result<String, ToolError> {
1074 let soup = scrape_core::Soup::parse(html);
1075
1076 let tags = soup.find_all(selector).map_err(|e| {
1077 ToolError::Execution(std::io::Error::new(
1078 std::io::ErrorKind::InvalidData,
1079 format!("invalid selector: {e}"),
1080 ))
1081 })?;
1082
1083 let mut results = Vec::new();
1084
1085 for tag in tags.into_iter().take(limit) {
1086 let value = match extract {
1087 ExtractMode::Text => tag.text(),
1088 ExtractMode::Html => tag.inner_html(),
1089 ExtractMode::Attr(name) => tag.get(name).unwrap_or_default().to_owned(),
1090 };
1091 if !value.trim().is_empty() {
1092 results.push(value.trim().to_owned());
1093 }
1094 }
1095
1096 if results.is_empty() {
1097 Ok(format!("No results for selector: {selector}"))
1098 } else {
1099 Ok(results.join("\n"))
1100 }
1101}
1102
1103#[cfg(test)]
1104mod tests {
1105 use super::*;
1106
1107 #[test]
1110 fn extract_single_block() {
1111 let text =
1112 "Here:\n```scrape\n{\"url\":\"https://example.com\",\"select\":\"h1\"}\n```\nDone.";
1113 let blocks = extract_scrape_blocks(text);
1114 assert_eq!(blocks.len(), 1);
1115 assert!(blocks[0].contains("example.com"));
1116 }
1117
1118 #[test]
1119 fn extract_multiple_blocks() {
1120 let text = "```scrape\n{\"url\":\"https://a.com\",\"select\":\"h1\"}\n```\ntext\n```scrape\n{\"url\":\"https://b.com\",\"select\":\"p\"}\n```";
1121 let blocks = extract_scrape_blocks(text);
1122 assert_eq!(blocks.len(), 2);
1123 }
1124
1125 #[test]
1126 fn no_blocks_returns_empty() {
1127 let blocks = extract_scrape_blocks("plain text, no code blocks");
1128 assert!(blocks.is_empty());
1129 }
1130
1131 #[test]
1132 fn unclosed_block_ignored() {
1133 let blocks = extract_scrape_blocks("```scrape\n{\"url\":\"https://x.com\"}");
1134 assert!(blocks.is_empty());
1135 }
1136
1137 #[test]
1138 fn non_scrape_block_ignored() {
1139 let text =
1140 "```bash\necho hi\n```\n```scrape\n{\"url\":\"https://x.com\",\"select\":\"h1\"}\n```";
1141 let blocks = extract_scrape_blocks(text);
1142 assert_eq!(blocks.len(), 1);
1143 assert!(blocks[0].contains("x.com"));
1144 }
1145
1146 #[test]
1147 fn multiline_json_block() {
1148 let text =
1149 "```scrape\n{\n \"url\": \"https://example.com\",\n \"select\": \"h1\"\n}\n```";
1150 let blocks = extract_scrape_blocks(text);
1151 assert_eq!(blocks.len(), 1);
1152 let instr: ScrapeInstruction = serde_json::from_str(blocks[0]).unwrap();
1153 assert_eq!(instr.url, "https://example.com");
1154 }
1155
1156 #[test]
1159 fn parse_valid_instruction() {
1160 let json = r#"{"url":"https://example.com","select":"h1","extract":"text","limit":5}"#;
1161 let instr: ScrapeInstruction = serde_json::from_str(json).unwrap();
1162 assert_eq!(instr.url, "https://example.com");
1163 assert_eq!(instr.select, "h1");
1164 assert_eq!(instr.extract, "text");
1165 assert_eq!(instr.limit, Some(5));
1166 }
1167
1168 #[test]
1169 fn parse_minimal_instruction() {
1170 let json = r#"{"url":"https://example.com","select":"p"}"#;
1171 let instr: ScrapeInstruction = serde_json::from_str(json).unwrap();
1172 assert_eq!(instr.extract, "text");
1173 assert!(instr.limit.is_none());
1174 }
1175
1176 #[test]
1177 fn parse_attr_extract() {
1178 let json = r#"{"url":"https://example.com","select":"a","extract":"attr:href"}"#;
1179 let instr: ScrapeInstruction = serde_json::from_str(json).unwrap();
1180 assert_eq!(instr.extract, "attr:href");
1181 }
1182
1183 #[test]
1184 fn parse_invalid_json_errors() {
1185 let result = serde_json::from_str::<ScrapeInstruction>("not json");
1186 assert!(result.is_err());
1187 }
1188
1189 #[test]
1192 fn extract_mode_text() {
1193 assert!(matches!(ExtractMode::parse("text"), ExtractMode::Text));
1194 }
1195
1196 #[test]
1197 fn extract_mode_html() {
1198 assert!(matches!(ExtractMode::parse("html"), ExtractMode::Html));
1199 }
1200
1201 #[test]
1202 fn extract_mode_attr() {
1203 let mode = ExtractMode::parse("attr:href");
1204 assert!(matches!(mode, ExtractMode::Attr(ref s) if s == "href"));
1205 }
1206
1207 #[test]
1208 fn extract_mode_unknown_defaults_to_text() {
1209 assert!(matches!(ExtractMode::parse("unknown"), ExtractMode::Text));
1210 }
1211
1212 #[test]
1215 fn valid_https_url() {
1216 assert!(validate_url("https://example.com").is_ok());
1217 }
1218
1219 #[test]
1220 fn http_rejected() {
1221 let err = validate_url("http://example.com").unwrap_err();
1222 assert!(matches!(err, ToolError::Blocked { .. }));
1223 }
1224
1225 #[test]
1226 fn ftp_rejected() {
1227 let err = validate_url("ftp://files.example.com").unwrap_err();
1228 assert!(matches!(err, ToolError::Blocked { .. }));
1229 }
1230
1231 #[test]
1232 fn file_rejected() {
1233 let err = validate_url("file:///etc/passwd").unwrap_err();
1234 assert!(matches!(err, ToolError::Blocked { .. }));
1235 }
1236
1237 #[test]
1238 fn invalid_url_rejected() {
1239 let err = validate_url("not a url").unwrap_err();
1240 assert!(matches!(err, ToolError::Blocked { .. }));
1241 }
1242
1243 #[test]
1244 fn localhost_blocked() {
1245 let err = validate_url("https://localhost/path").unwrap_err();
1246 assert!(matches!(err, ToolError::Blocked { .. }));
1247 }
1248
1249 #[test]
1250 fn loopback_ip_blocked() {
1251 let err = validate_url("https://127.0.0.1/path").unwrap_err();
1252 assert!(matches!(err, ToolError::Blocked { .. }));
1253 }
1254
1255 #[test]
1256 fn private_10_blocked() {
1257 let err = validate_url("https://10.0.0.1/api").unwrap_err();
1258 assert!(matches!(err, ToolError::Blocked { .. }));
1259 }
1260
1261 #[test]
1262 fn private_172_blocked() {
1263 let err = validate_url("https://172.16.0.1/api").unwrap_err();
1264 assert!(matches!(err, ToolError::Blocked { .. }));
1265 }
1266
1267 #[test]
1268 fn private_192_blocked() {
1269 let err = validate_url("https://192.168.1.1/api").unwrap_err();
1270 assert!(matches!(err, ToolError::Blocked { .. }));
1271 }
1272
1273 #[test]
1274 fn ipv6_loopback_blocked() {
1275 let err = validate_url("https://[::1]/path").unwrap_err();
1276 assert!(matches!(err, ToolError::Blocked { .. }));
1277 }
1278
1279 #[test]
1280 fn public_ip_allowed() {
1281 assert!(validate_url("https://93.184.216.34/page").is_ok());
1282 }
1283
1284 #[test]
1287 fn extract_text_from_html() {
1288 let html = "<html><body><h1>Hello World</h1><p>Content</p></body></html>";
1289 let result = parse_and_extract(html, "h1", &ExtractMode::Text, 10).unwrap();
1290 assert_eq!(result, "Hello World");
1291 }
1292
1293 #[test]
1294 fn extract_multiple_elements() {
1295 let html = "<ul><li>A</li><li>B</li><li>C</li></ul>";
1296 let result = parse_and_extract(html, "li", &ExtractMode::Text, 10).unwrap();
1297 assert_eq!(result, "A\nB\nC");
1298 }
1299
1300 #[test]
1301 fn extract_with_limit() {
1302 let html = "<ul><li>A</li><li>B</li><li>C</li></ul>";
1303 let result = parse_and_extract(html, "li", &ExtractMode::Text, 2).unwrap();
1304 assert_eq!(result, "A\nB");
1305 }
1306
1307 #[test]
1308 fn extract_attr_href() {
1309 let html = r#"<a href="https://example.com">Link</a>"#;
1310 let result =
1311 parse_and_extract(html, "a", &ExtractMode::Attr("href".to_owned()), 10).unwrap();
1312 assert_eq!(result, "https://example.com");
1313 }
1314
1315 #[test]
1316 fn extract_inner_html() {
1317 let html = "<div><span>inner</span></div>";
1318 let result = parse_and_extract(html, "div", &ExtractMode::Html, 10).unwrap();
1319 assert!(result.contains("<span>inner</span>"));
1320 }
1321
1322 #[test]
1323 fn no_matches_returns_message() {
1324 let html = "<html><body><p>text</p></body></html>";
1325 let result = parse_and_extract(html, "h1", &ExtractMode::Text, 10).unwrap();
1326 assert!(result.starts_with("No results for selector:"));
1327 }
1328
1329 #[test]
1330 fn empty_text_skipped() {
1331 let html = "<ul><li> </li><li>A</li></ul>";
1332 let result = parse_and_extract(html, "li", &ExtractMode::Text, 10).unwrap();
1333 assert_eq!(result, "A");
1334 }
1335
1336 #[test]
1337 fn invalid_selector_errors() {
1338 let html = "<html><body></body></html>";
1339 let result = parse_and_extract(html, "[[[invalid", &ExtractMode::Text, 10);
1340 assert!(result.is_err());
1341 }
1342
1343 #[test]
1344 fn empty_html_returns_no_results() {
1345 let result = parse_and_extract("", "h1", &ExtractMode::Text, 10).unwrap();
1346 assert!(result.starts_with("No results for selector:"));
1347 }
1348
1349 #[test]
1350 fn nested_selector() {
1351 let html = "<div><span>inner</span></div><span>outer</span>";
1352 let result = parse_and_extract(html, "div > span", &ExtractMode::Text, 10).unwrap();
1353 assert_eq!(result, "inner");
1354 }
1355
1356 #[test]
1357 fn attr_missing_returns_empty() {
1358 let html = r"<a>No href</a>";
1359 let result =
1360 parse_and_extract(html, "a", &ExtractMode::Attr("href".to_owned()), 10).unwrap();
1361 assert!(result.starts_with("No results for selector:"));
1362 }
1363
1364 #[test]
1365 fn extract_html_mode() {
1366 let html = "<div><b>bold</b> text</div>";
1367 let result = parse_and_extract(html, "div", &ExtractMode::Html, 10).unwrap();
1368 assert!(result.contains("<b>bold</b>"));
1369 }
1370
1371 #[test]
1372 fn limit_zero_returns_no_results() {
1373 let html = "<ul><li>A</li><li>B</li></ul>";
1374 let result = parse_and_extract(html, "li", &ExtractMode::Text, 0).unwrap();
1375 assert!(result.starts_with("No results for selector:"));
1376 }
1377
1378 #[test]
1381 fn url_with_port_allowed() {
1382 assert!(validate_url("https://example.com:8443/path").is_ok());
1383 }
1384
1385 #[test]
1386 fn link_local_ip_blocked() {
1387 let err = validate_url("https://169.254.1.1/path").unwrap_err();
1388 assert!(matches!(err, ToolError::Blocked { .. }));
1389 }
1390
1391 #[test]
1392 fn url_no_scheme_rejected() {
1393 let err = validate_url("example.com/path").unwrap_err();
1394 assert!(matches!(err, ToolError::Blocked { .. }));
1395 }
1396
1397 #[test]
1398 fn unspecified_ipv4_blocked() {
1399 let err = validate_url("https://0.0.0.0/path").unwrap_err();
1400 assert!(matches!(err, ToolError::Blocked { .. }));
1401 }
1402
1403 #[test]
1404 fn broadcast_ipv4_blocked() {
1405 let err = validate_url("https://255.255.255.255/path").unwrap_err();
1406 assert!(matches!(err, ToolError::Blocked { .. }));
1407 }
1408
1409 #[test]
1410 fn ipv6_link_local_blocked() {
1411 let err = validate_url("https://[fe80::1]/path").unwrap_err();
1412 assert!(matches!(err, ToolError::Blocked { .. }));
1413 }
1414
1415 #[test]
1416 fn ipv6_unique_local_blocked() {
1417 let err = validate_url("https://[fd12::1]/path").unwrap_err();
1418 assert!(matches!(err, ToolError::Blocked { .. }));
1419 }
1420
1421 #[test]
1422 fn ipv4_mapped_ipv6_loopback_blocked() {
1423 let err = validate_url("https://[::ffff:127.0.0.1]/path").unwrap_err();
1424 assert!(matches!(err, ToolError::Blocked { .. }));
1425 }
1426
1427 #[test]
1428 fn ipv4_mapped_ipv6_private_blocked() {
1429 let err = validate_url("https://[::ffff:10.0.0.1]/path").unwrap_err();
1430 assert!(matches!(err, ToolError::Blocked { .. }));
1431 }
1432
1433 #[tokio::test]
1436 async fn executor_no_blocks_returns_none() {
1437 let config = ScrapeConfig::default();
1438 let executor = WebScrapeExecutor::new(&config);
1439 let result = executor.execute("plain text").await;
1440 assert!(result.unwrap().is_none());
1441 }
1442
1443 #[tokio::test]
1444 async fn executor_invalid_json_errors() {
1445 let config = ScrapeConfig::default();
1446 let executor = WebScrapeExecutor::new(&config);
1447 let response = "```scrape\nnot json\n```";
1448 let result = executor.execute(response).await;
1449 assert!(matches!(result, Err(ToolError::Execution(_))));
1450 }
1451
1452 #[tokio::test]
1453 async fn executor_blocked_url_errors() {
1454 let config = ScrapeConfig::default();
1455 let executor = WebScrapeExecutor::new(&config);
1456 let response = "```scrape\n{\"url\":\"http://example.com\",\"select\":\"h1\"}\n```";
1457 let result = executor.execute(response).await;
1458 assert!(matches!(result, Err(ToolError::Blocked { .. })));
1459 }
1460
1461 #[tokio::test]
1462 async fn executor_private_ip_blocked() {
1463 let config = ScrapeConfig::default();
1464 let executor = WebScrapeExecutor::new(&config);
1465 let response = "```scrape\n{\"url\":\"https://192.168.1.1/api\",\"select\":\"h1\"}\n```";
1466 let result = executor.execute(response).await;
1467 assert!(matches!(result, Err(ToolError::Blocked { .. })));
1468 }
1469
1470 #[tokio::test]
1471 async fn executor_unreachable_host_returns_error() {
1472 let config = ScrapeConfig {
1473 timeout: 1,
1474 max_body_bytes: 1_048_576,
1475 ..Default::default()
1476 };
1477 let executor = WebScrapeExecutor::new(&config);
1478 let response = "```scrape\n{\"url\":\"https://192.0.2.1:1/page\",\"select\":\"h1\"}\n```";
1479 let result = executor.execute(response).await;
1480 assert!(matches!(result, Err(ToolError::Execution(_))));
1481 }
1482
1483 #[tokio::test]
1484 async fn executor_localhost_url_blocked() {
1485 let config = ScrapeConfig::default();
1486 let executor = WebScrapeExecutor::new(&config);
1487 let response = "```scrape\n{\"url\":\"https://localhost:9999/api\",\"select\":\"h1\"}\n```";
1488 let result = executor.execute(response).await;
1489 assert!(matches!(result, Err(ToolError::Blocked { .. })));
1490 }
1491
1492 #[tokio::test]
1493 async fn executor_empty_text_returns_none() {
1494 let config = ScrapeConfig::default();
1495 let executor = WebScrapeExecutor::new(&config);
1496 let result = executor.execute("").await;
1497 assert!(result.unwrap().is_none());
1498 }
1499
1500 #[tokio::test]
1501 async fn executor_multiple_blocks_first_blocked() {
1502 let config = ScrapeConfig::default();
1503 let executor = WebScrapeExecutor::new(&config);
1504 let response = "```scrape\n{\"url\":\"http://evil.com\",\"select\":\"h1\"}\n```\n\
1505 ```scrape\n{\"url\":\"https://ok.com\",\"select\":\"h1\"}\n```";
1506 let result = executor.execute(response).await;
1507 assert!(result.is_err());
1508 }
1509
1510 #[test]
1511 fn validate_url_empty_string() {
1512 let err = validate_url("").unwrap_err();
1513 assert!(matches!(err, ToolError::Blocked { .. }));
1514 }
1515
1516 #[test]
1517 fn validate_url_javascript_scheme_blocked() {
1518 let err = validate_url("javascript:alert(1)").unwrap_err();
1519 assert!(matches!(err, ToolError::Blocked { .. }));
1520 }
1521
1522 #[test]
1523 fn validate_url_data_scheme_blocked() {
1524 let err = validate_url("data:text/html,<h1>hi</h1>").unwrap_err();
1525 assert!(matches!(err, ToolError::Blocked { .. }));
1526 }
1527
1528 #[test]
1529 fn is_private_host_public_domain_is_false() {
1530 let host: url::Host<&str> = url::Host::Domain("example.com");
1531 assert!(!is_private_host(&host));
1532 }
1533
1534 #[test]
1535 fn is_private_host_localhost_is_true() {
1536 let host: url::Host<&str> = url::Host::Domain("localhost");
1537 assert!(is_private_host(&host));
1538 }
1539
1540 #[test]
1541 fn is_private_host_ipv6_unspecified_is_true() {
1542 let host = url::Host::Ipv6(std::net::Ipv6Addr::UNSPECIFIED);
1543 assert!(is_private_host(&host));
1544 }
1545
1546 #[test]
1547 fn is_private_host_public_ipv6_is_false() {
1548 let host = url::Host::Ipv6("2001:db8::1".parse().unwrap());
1549 assert!(!is_private_host(&host));
1550 }
1551
1552 async fn mock_server_executor() -> (WebScrapeExecutor, wiremock::MockServer) {
1563 let server = wiremock::MockServer::start().await;
1564 let executor = WebScrapeExecutor {
1565 timeout: Duration::from_secs(5),
1566 max_body_bytes: 1_048_576,
1567 allowed_domains: vec![],
1568 denied_domains: vec![],
1569 audit_logger: None,
1570 egress_config: EgressConfig::default(),
1571 egress_tx: None,
1572 egress_dropped: Arc::new(AtomicU64::new(0)),
1573 };
1574 (executor, server)
1575 }
1576
1577 fn server_host_and_addr(server: &wiremock::MockServer) -> (String, Vec<std::net::SocketAddr>) {
1579 let uri = server.uri();
1580 let url = Url::parse(&uri).unwrap();
1581 let host = url.host_str().unwrap_or("127.0.0.1").to_owned();
1582 let port = url.port().unwrap_or(80);
1583 let addr: std::net::SocketAddr = format!("{host}:{port}").parse().unwrap();
1584 (host, vec![addr])
1585 }
1586
1587 async fn follow_redirects_raw(
1591 executor: &WebScrapeExecutor,
1592 start_url: &str,
1593 host: &str,
1594 addrs: &[std::net::SocketAddr],
1595 ) -> Result<String, ToolError> {
1596 const MAX_REDIRECTS: usize = 3;
1597 let mut current_url = start_url.to_owned();
1598 let mut current_host = host.to_owned();
1599 let mut current_addrs = addrs.to_vec();
1600
1601 for hop in 0..=MAX_REDIRECTS {
1602 let client = executor.build_client(¤t_host, ¤t_addrs);
1603 let resp = client
1604 .get(¤t_url)
1605 .send()
1606 .await
1607 .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())))?;
1608
1609 let status = resp.status();
1610
1611 if status.is_redirection() {
1612 if hop == MAX_REDIRECTS {
1613 return Err(ToolError::Execution(std::io::Error::other(
1614 "too many redirects",
1615 )));
1616 }
1617
1618 let location = resp
1619 .headers()
1620 .get(reqwest::header::LOCATION)
1621 .and_then(|v| v.to_str().ok())
1622 .ok_or_else(|| {
1623 ToolError::Execution(std::io::Error::other("redirect with no Location"))
1624 })?;
1625
1626 let base = Url::parse(¤t_url)
1627 .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())))?;
1628 let next_url = base
1629 .join(location)
1630 .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())))?;
1631
1632 current_url = next_url.to_string();
1634 let _ = &mut current_host;
1636 let _ = &mut current_addrs;
1637 continue;
1638 }
1639
1640 if !status.is_success() {
1641 return Err(ToolError::Execution(std::io::Error::other(format!(
1642 "HTTP {status}",
1643 ))));
1644 }
1645
1646 let bytes = resp
1647 .bytes()
1648 .await
1649 .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())))?;
1650
1651 if bytes.len() > executor.max_body_bytes {
1652 return Err(ToolError::Execution(std::io::Error::other(format!(
1653 "response too large: {} bytes (max: {})",
1654 bytes.len(),
1655 executor.max_body_bytes,
1656 ))));
1657 }
1658
1659 return String::from_utf8(bytes.to_vec())
1660 .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())));
1661 }
1662
1663 Err(ToolError::Execution(std::io::Error::other(
1664 "too many redirects",
1665 )))
1666 }
1667
1668 #[tokio::test]
1669 async fn fetch_html_success_returns_body() {
1670 use wiremock::matchers::{method, path};
1671 use wiremock::{Mock, ResponseTemplate};
1672
1673 let (executor, server) = mock_server_executor().await;
1674 Mock::given(method("GET"))
1675 .and(path("/page"))
1676 .respond_with(ResponseTemplate::new(200).set_body_string("<h1>OK</h1>"))
1677 .mount(&server)
1678 .await;
1679
1680 let (host, addrs) = server_host_and_addr(&server);
1681 let url = format!("{}/page", server.uri());
1682 let result = executor
1683 .fetch_html(&url, &host, &addrs, "fetch", "test-cid", None)
1684 .await;
1685 assert!(result.is_ok(), "expected Ok, got: {result:?}");
1686 assert_eq!(result.unwrap(), "<h1>OK</h1>");
1687 }
1688
1689 #[tokio::test]
1690 async fn fetch_html_non_2xx_returns_error() {
1691 use wiremock::matchers::{method, path};
1692 use wiremock::{Mock, ResponseTemplate};
1693
1694 let (executor, server) = mock_server_executor().await;
1695 Mock::given(method("GET"))
1696 .and(path("/forbidden"))
1697 .respond_with(ResponseTemplate::new(403))
1698 .mount(&server)
1699 .await;
1700
1701 let (host, addrs) = server_host_and_addr(&server);
1702 let url = format!("{}/forbidden", server.uri());
1703 let result = executor
1704 .fetch_html(&url, &host, &addrs, "fetch", "test-cid", None)
1705 .await;
1706 assert!(result.is_err());
1707 let msg = result.unwrap_err().to_string();
1708 assert!(msg.contains("403"), "expected 403 in error: {msg}");
1709 }
1710
1711 #[tokio::test]
1712 async fn fetch_html_404_returns_error() {
1713 use wiremock::matchers::{method, path};
1714 use wiremock::{Mock, ResponseTemplate};
1715
1716 let (executor, server) = mock_server_executor().await;
1717 Mock::given(method("GET"))
1718 .and(path("/missing"))
1719 .respond_with(ResponseTemplate::new(404))
1720 .mount(&server)
1721 .await;
1722
1723 let (host, addrs) = server_host_and_addr(&server);
1724 let url = format!("{}/missing", server.uri());
1725 let result = executor
1726 .fetch_html(&url, &host, &addrs, "fetch", "test-cid", None)
1727 .await;
1728 assert!(result.is_err());
1729 let msg = result.unwrap_err().to_string();
1730 assert!(msg.contains("404"), "expected 404 in error: {msg}");
1731 }
1732
1733 #[tokio::test]
1734 async fn fetch_html_redirect_no_location_returns_error() {
1735 use wiremock::matchers::{method, path};
1736 use wiremock::{Mock, ResponseTemplate};
1737
1738 let (executor, server) = mock_server_executor().await;
1739 Mock::given(method("GET"))
1741 .and(path("/redirect-no-loc"))
1742 .respond_with(ResponseTemplate::new(302))
1743 .mount(&server)
1744 .await;
1745
1746 let (host, addrs) = server_host_and_addr(&server);
1747 let url = format!("{}/redirect-no-loc", server.uri());
1748 let result = executor
1749 .fetch_html(&url, &host, &addrs, "fetch", "test-cid", None)
1750 .await;
1751 assert!(result.is_err());
1752 let msg = result.unwrap_err().to_string();
1753 assert!(
1754 msg.contains("Location") || msg.contains("location"),
1755 "expected Location-related error: {msg}"
1756 );
1757 }
1758
1759 #[tokio::test]
1760 async fn fetch_html_single_redirect_followed() {
1761 use wiremock::matchers::{method, path};
1762 use wiremock::{Mock, ResponseTemplate};
1763
1764 let (executor, server) = mock_server_executor().await;
1765 let final_url = format!("{}/final", server.uri());
1766
1767 Mock::given(method("GET"))
1768 .and(path("/start"))
1769 .respond_with(ResponseTemplate::new(302).insert_header("location", final_url.as_str()))
1770 .mount(&server)
1771 .await;
1772
1773 Mock::given(method("GET"))
1774 .and(path("/final"))
1775 .respond_with(ResponseTemplate::new(200).set_body_string("<p>final</p>"))
1776 .mount(&server)
1777 .await;
1778
1779 let (host, addrs) = server_host_and_addr(&server);
1780 let url = format!("{}/start", server.uri());
1781 let result = follow_redirects_raw(&executor, &url, &host, &addrs).await;
1782 assert!(result.is_ok(), "single redirect should succeed: {result:?}");
1783 assert_eq!(result.unwrap(), "<p>final</p>");
1784 }
1785
1786 #[tokio::test]
1787 async fn fetch_html_three_redirects_allowed() {
1788 use wiremock::matchers::{method, path};
1789 use wiremock::{Mock, ResponseTemplate};
1790
1791 let (executor, server) = mock_server_executor().await;
1792 let hop2 = format!("{}/hop2", server.uri());
1793 let hop3 = format!("{}/hop3", server.uri());
1794 let final_dest = format!("{}/done", server.uri());
1795
1796 Mock::given(method("GET"))
1797 .and(path("/hop1"))
1798 .respond_with(ResponseTemplate::new(301).insert_header("location", hop2.as_str()))
1799 .mount(&server)
1800 .await;
1801 Mock::given(method("GET"))
1802 .and(path("/hop2"))
1803 .respond_with(ResponseTemplate::new(301).insert_header("location", hop3.as_str()))
1804 .mount(&server)
1805 .await;
1806 Mock::given(method("GET"))
1807 .and(path("/hop3"))
1808 .respond_with(ResponseTemplate::new(301).insert_header("location", final_dest.as_str()))
1809 .mount(&server)
1810 .await;
1811 Mock::given(method("GET"))
1812 .and(path("/done"))
1813 .respond_with(ResponseTemplate::new(200).set_body_string("<p>done</p>"))
1814 .mount(&server)
1815 .await;
1816
1817 let (host, addrs) = server_host_and_addr(&server);
1818 let url = format!("{}/hop1", server.uri());
1819 let result = follow_redirects_raw(&executor, &url, &host, &addrs).await;
1820 assert!(result.is_ok(), "3 redirects should succeed: {result:?}");
1821 assert_eq!(result.unwrap(), "<p>done</p>");
1822 }
1823
1824 #[tokio::test]
1825 async fn fetch_html_four_redirects_rejected() {
1826 use wiremock::matchers::{method, path};
1827 use wiremock::{Mock, ResponseTemplate};
1828
1829 let (executor, server) = mock_server_executor().await;
1830 let hop2 = format!("{}/r2", server.uri());
1831 let hop3 = format!("{}/r3", server.uri());
1832 let hop4 = format!("{}/r4", server.uri());
1833 let hop5 = format!("{}/r5", server.uri());
1834
1835 for (from, to) in [
1836 ("/r1", &hop2),
1837 ("/r2", &hop3),
1838 ("/r3", &hop4),
1839 ("/r4", &hop5),
1840 ] {
1841 Mock::given(method("GET"))
1842 .and(path(from))
1843 .respond_with(ResponseTemplate::new(301).insert_header("location", to.as_str()))
1844 .mount(&server)
1845 .await;
1846 }
1847
1848 let (host, addrs) = server_host_and_addr(&server);
1849 let url = format!("{}/r1", server.uri());
1850 let result = follow_redirects_raw(&executor, &url, &host, &addrs).await;
1851 assert!(result.is_err(), "4 redirects should be rejected");
1852 let msg = result.unwrap_err().to_string();
1853 assert!(
1854 msg.contains("redirect"),
1855 "expected redirect-related error: {msg}"
1856 );
1857 }
1858
1859 #[tokio::test]
1860 async fn fetch_html_body_too_large_returns_error() {
1861 use wiremock::matchers::{method, path};
1862 use wiremock::{Mock, ResponseTemplate};
1863
1864 let small_limit_executor = WebScrapeExecutor {
1865 timeout: Duration::from_secs(5),
1866 max_body_bytes: 10,
1867 allowed_domains: vec![],
1868 denied_domains: vec![],
1869 audit_logger: None,
1870 egress_config: EgressConfig::default(),
1871 egress_tx: None,
1872 egress_dropped: Arc::new(AtomicU64::new(0)),
1873 };
1874 let server = wiremock::MockServer::start().await;
1875 Mock::given(method("GET"))
1876 .and(path("/big"))
1877 .respond_with(
1878 ResponseTemplate::new(200)
1879 .set_body_string("this body is definitely longer than ten bytes"),
1880 )
1881 .mount(&server)
1882 .await;
1883
1884 let (host, addrs) = server_host_and_addr(&server);
1885 let url = format!("{}/big", server.uri());
1886 let result = small_limit_executor
1887 .fetch_html(&url, &host, &addrs, "fetch", "test-cid", None)
1888 .await;
1889 assert!(result.is_err());
1890 let msg = result.unwrap_err().to_string();
1891 assert!(msg.contains("too large"), "expected too-large error: {msg}");
1892 }
1893
1894 #[test]
1895 fn extract_scrape_blocks_empty_block_content() {
1896 let text = "```scrape\n\n```";
1897 let blocks = extract_scrape_blocks(text);
1898 assert_eq!(blocks.len(), 1);
1899 assert!(blocks[0].is_empty());
1900 }
1901
1902 #[test]
1903 fn extract_scrape_blocks_whitespace_only() {
1904 let text = "```scrape\n \n```";
1905 let blocks = extract_scrape_blocks(text);
1906 assert_eq!(blocks.len(), 1);
1907 }
1908
1909 #[test]
1910 fn parse_and_extract_multiple_selectors() {
1911 let html = "<div><h1>Title</h1><p>Para</p></div>";
1912 let result = parse_and_extract(html, "h1, p", &ExtractMode::Text, 10).unwrap();
1913 assert!(result.contains("Title"));
1914 assert!(result.contains("Para"));
1915 }
1916
1917 #[test]
1918 fn webscrape_executor_new_with_custom_config() {
1919 let config = ScrapeConfig {
1920 timeout: 60,
1921 max_body_bytes: 512,
1922 ..Default::default()
1923 };
1924 let executor = WebScrapeExecutor::new(&config);
1925 assert_eq!(executor.max_body_bytes, 512);
1926 }
1927
1928 #[test]
1929 fn webscrape_executor_debug() {
1930 let config = ScrapeConfig::default();
1931 let executor = WebScrapeExecutor::new(&config);
1932 let dbg = format!("{executor:?}");
1933 assert!(dbg.contains("WebScrapeExecutor"));
1934 }
1935
1936 #[test]
1937 fn extract_mode_attr_empty_name() {
1938 let mode = ExtractMode::parse("attr:");
1939 assert!(matches!(mode, ExtractMode::Attr(ref s) if s.is_empty()));
1940 }
1941
1942 #[test]
1943 fn default_extract_returns_text() {
1944 assert_eq!(default_extract(), "text");
1945 }
1946
1947 #[test]
1948 fn scrape_instruction_debug() {
1949 let json = r#"{"url":"https://example.com","select":"h1"}"#;
1950 let instr: ScrapeInstruction = serde_json::from_str(json).unwrap();
1951 let dbg = format!("{instr:?}");
1952 assert!(dbg.contains("ScrapeInstruction"));
1953 }
1954
1955 #[test]
1956 fn extract_mode_debug() {
1957 let mode = ExtractMode::Text;
1958 let dbg = format!("{mode:?}");
1959 assert!(dbg.contains("Text"));
1960 }
1961
1962 #[test]
1967 fn max_redirects_constant_is_three() {
1968 const MAX_REDIRECTS: usize = 3;
1972 assert_eq!(MAX_REDIRECTS, 3, "fetch_html allows exactly 3 redirects");
1973 }
1974
1975 #[test]
1978 fn redirect_no_location_error_message() {
1979 let err = std::io::Error::other("redirect with no Location");
1980 assert!(err.to_string().contains("redirect with no Location"));
1981 }
1982
1983 #[test]
1985 fn too_many_redirects_error_message() {
1986 let err = std::io::Error::other("too many redirects");
1987 assert!(err.to_string().contains("too many redirects"));
1988 }
1989
1990 #[test]
1992 fn non_2xx_status_error_format() {
1993 let status = reqwest::StatusCode::FORBIDDEN;
1994 let msg = format!("HTTP {status}");
1995 assert!(msg.contains("403"));
1996 }
1997
1998 #[test]
2000 fn not_found_status_error_format() {
2001 let status = reqwest::StatusCode::NOT_FOUND;
2002 let msg = format!("HTTP {status}");
2003 assert!(msg.contains("404"));
2004 }
2005
2006 #[test]
2008 fn relative_redirect_same_host_path() {
2009 let base = Url::parse("https://example.com/current").unwrap();
2010 let resolved = base.join("/other").unwrap();
2011 assert_eq!(resolved.as_str(), "https://example.com/other");
2012 }
2013
2014 #[test]
2016 fn relative_redirect_relative_path() {
2017 let base = Url::parse("https://example.com/a/b").unwrap();
2018 let resolved = base.join("c").unwrap();
2019 assert_eq!(resolved.as_str(), "https://example.com/a/c");
2020 }
2021
2022 #[test]
2024 fn absolute_redirect_overrides_base() {
2025 let base = Url::parse("https://example.com/page").unwrap();
2026 let resolved = base.join("https://other.com/target").unwrap();
2027 assert_eq!(resolved.as_str(), "https://other.com/target");
2028 }
2029
2030 #[test]
2032 fn redirect_http_downgrade_rejected() {
2033 let location = "http://example.com/page";
2034 let base = Url::parse("https://example.com/start").unwrap();
2035 let next = base.join(location).unwrap();
2036 let err = validate_url(next.as_str()).unwrap_err();
2037 assert!(matches!(err, ToolError::Blocked { .. }));
2038 }
2039
2040 #[test]
2042 fn redirect_location_private_ip_blocked() {
2043 let location = "https://192.168.100.1/admin";
2044 let base = Url::parse("https://example.com/start").unwrap();
2045 let next = base.join(location).unwrap();
2046 let err = validate_url(next.as_str()).unwrap_err();
2047 assert!(matches!(err, ToolError::Blocked { .. }));
2048 let ToolError::Blocked { command: cmd } = err else {
2049 panic!("expected Blocked");
2050 };
2051 assert!(
2052 cmd.contains("private") || cmd.contains("scheme"),
2053 "error message should describe the block reason: {cmd}"
2054 );
2055 }
2056
2057 #[test]
2059 fn redirect_location_internal_domain_blocked() {
2060 let location = "https://metadata.internal/latest/meta-data/";
2061 let base = Url::parse("https://example.com/start").unwrap();
2062 let next = base.join(location).unwrap();
2063 let err = validate_url(next.as_str()).unwrap_err();
2064 assert!(matches!(err, ToolError::Blocked { .. }));
2065 }
2066
2067 #[test]
2069 fn redirect_chain_three_hops_all_public() {
2070 let hops = [
2071 "https://redirect1.example.com/hop1",
2072 "https://redirect2.example.com/hop2",
2073 "https://destination.example.com/final",
2074 ];
2075 for hop in hops {
2076 assert!(validate_url(hop).is_ok(), "expected ok for {hop}");
2077 }
2078 }
2079
2080 #[test]
2085 fn redirect_to_private_ip_rejected_by_validate_url() {
2086 let private_targets = [
2088 "https://127.0.0.1/secret",
2089 "https://10.0.0.1/internal",
2090 "https://192.168.1.1/admin",
2091 "https://172.16.0.1/data",
2092 "https://[::1]/path",
2093 "https://[fe80::1]/path",
2094 "https://localhost/path",
2095 "https://service.internal/api",
2096 ];
2097 for target in private_targets {
2098 let result = validate_url(target);
2099 assert!(result.is_err(), "expected error for {target}");
2100 assert!(
2101 matches!(result.unwrap_err(), ToolError::Blocked { .. }),
2102 "expected Blocked for {target}"
2103 );
2104 }
2105 }
2106
2107 #[test]
2109 fn redirect_relative_url_resolves_correctly() {
2110 let base = Url::parse("https://example.com/page").unwrap();
2111 let relative = "/other";
2112 let resolved = base.join(relative).unwrap();
2113 assert_eq!(resolved.as_str(), "https://example.com/other");
2114 }
2115
2116 #[test]
2118 fn redirect_to_http_rejected() {
2119 let err = validate_url("http://example.com/page").unwrap_err();
2120 assert!(matches!(err, ToolError::Blocked { .. }));
2121 }
2122
2123 #[test]
2124 fn ipv4_mapped_ipv6_link_local_blocked() {
2125 let err = validate_url("https://[::ffff:169.254.0.1]/path").unwrap_err();
2126 assert!(matches!(err, ToolError::Blocked { .. }));
2127 }
2128
2129 #[test]
2130 fn ipv4_mapped_ipv6_public_allowed() {
2131 assert!(validate_url("https://[::ffff:93.184.216.34]/path").is_ok());
2132 }
2133
2134 #[tokio::test]
2137 async fn fetch_http_scheme_blocked() {
2138 let config = ScrapeConfig::default();
2139 let executor = WebScrapeExecutor::new(&config);
2140 let call = crate::executor::ToolCall {
2141 tool_id: ToolName::new("fetch"),
2142 params: {
2143 let mut m = serde_json::Map::new();
2144 m.insert("url".to_owned(), serde_json::json!("http://example.com"));
2145 m
2146 },
2147 caller_id: None,
2148 context: None,
2149 };
2150 let result = executor.execute_tool_call(&call).await;
2151 assert!(matches!(result, Err(ToolError::Blocked { .. })));
2152 }
2153
2154 #[tokio::test]
2155 async fn fetch_private_ip_blocked() {
2156 let config = ScrapeConfig::default();
2157 let executor = WebScrapeExecutor::new(&config);
2158 let call = crate::executor::ToolCall {
2159 tool_id: ToolName::new("fetch"),
2160 params: {
2161 let mut m = serde_json::Map::new();
2162 m.insert(
2163 "url".to_owned(),
2164 serde_json::json!("https://192.168.1.1/secret"),
2165 );
2166 m
2167 },
2168 caller_id: None,
2169 context: None,
2170 };
2171 let result = executor.execute_tool_call(&call).await;
2172 assert!(matches!(result, Err(ToolError::Blocked { .. })));
2173 }
2174
2175 #[tokio::test]
2176 async fn fetch_localhost_blocked() {
2177 let config = ScrapeConfig::default();
2178 let executor = WebScrapeExecutor::new(&config);
2179 let call = crate::executor::ToolCall {
2180 tool_id: ToolName::new("fetch"),
2181 params: {
2182 let mut m = serde_json::Map::new();
2183 m.insert(
2184 "url".to_owned(),
2185 serde_json::json!("https://localhost/page"),
2186 );
2187 m
2188 },
2189 caller_id: None,
2190 context: None,
2191 };
2192 let result = executor.execute_tool_call(&call).await;
2193 assert!(matches!(result, Err(ToolError::Blocked { .. })));
2194 }
2195
2196 #[tokio::test]
2197 async fn fetch_unknown_tool_returns_none() {
2198 let config = ScrapeConfig::default();
2199 let executor = WebScrapeExecutor::new(&config);
2200 let call = crate::executor::ToolCall {
2201 tool_id: ToolName::new("unknown_tool"),
2202 params: serde_json::Map::new(),
2203 caller_id: None,
2204 context: None,
2205 };
2206 let result = executor.execute_tool_call(&call).await;
2207 assert!(result.unwrap().is_none());
2208 }
2209
2210 #[tokio::test]
2211 async fn fetch_returns_body_via_mock() {
2212 use wiremock::matchers::{method, path};
2213 use wiremock::{Mock, ResponseTemplate};
2214
2215 let (executor, server) = mock_server_executor().await;
2216 Mock::given(method("GET"))
2217 .and(path("/content"))
2218 .respond_with(ResponseTemplate::new(200).set_body_string("plain text content"))
2219 .mount(&server)
2220 .await;
2221
2222 let (host, addrs) = server_host_and_addr(&server);
2223 let url = format!("{}/content", server.uri());
2224 let result = executor
2225 .fetch_html(&url, &host, &addrs, "fetch", "test-cid", None)
2226 .await;
2227 assert!(result.is_ok());
2228 assert_eq!(result.unwrap(), "plain text content");
2229 }
2230
2231 #[test]
2232 fn tool_definitions_returns_web_scrape_and_fetch() {
2233 let config = ScrapeConfig::default();
2234 let executor = WebScrapeExecutor::new(&config);
2235 let defs = executor.tool_definitions();
2236 assert_eq!(defs.len(), 2);
2237 assert_eq!(defs[0].id, "web_scrape");
2238 assert_eq!(
2239 defs[0].invocation,
2240 crate::registry::InvocationHint::FencedBlock("scrape")
2241 );
2242 assert_eq!(defs[1].id, "fetch");
2243 assert_eq!(
2244 defs[1].invocation,
2245 crate::registry::InvocationHint::ToolCall
2246 );
2247 }
2248
2249 #[test]
2250 fn tool_definitions_schema_has_all_params() {
2251 let config = ScrapeConfig::default();
2252 let executor = WebScrapeExecutor::new(&config);
2253 let defs = executor.tool_definitions();
2254 let obj = defs[0].schema.as_object().unwrap();
2255 let props = obj["properties"].as_object().unwrap();
2256 assert!(props.contains_key("url"));
2257 assert!(props.contains_key("select"));
2258 assert!(props.contains_key("extract"));
2259 assert!(props.contains_key("limit"));
2260 let req = obj["required"].as_array().unwrap();
2261 assert!(req.iter().any(|v| v.as_str() == Some("url")));
2262 assert!(req.iter().any(|v| v.as_str() == Some("select")));
2263 assert!(!req.iter().any(|v| v.as_str() == Some("extract")));
2264 }
2265
2266 #[test]
2269 fn subdomain_localhost_blocked() {
2270 let host: url::Host<&str> = url::Host::Domain("foo.localhost");
2271 assert!(is_private_host(&host));
2272 }
2273
2274 #[test]
2275 fn internal_tld_blocked() {
2276 let host: url::Host<&str> = url::Host::Domain("service.internal");
2277 assert!(is_private_host(&host));
2278 }
2279
2280 #[test]
2281 fn local_tld_blocked() {
2282 let host: url::Host<&str> = url::Host::Domain("printer.local");
2283 assert!(is_private_host(&host));
2284 }
2285
2286 #[test]
2287 fn public_domain_not_blocked() {
2288 let host: url::Host<&str> = url::Host::Domain("example.com");
2289 assert!(!is_private_host(&host));
2290 }
2291
2292 #[tokio::test]
2295 async fn resolve_loopback_rejected() {
2296 let url = url::Url::parse("https://127.0.0.1/path").unwrap();
2298 let result = resolve_and_validate(&url).await;
2300 assert!(
2301 result.is_err(),
2302 "loopback IP must be rejected by resolve_and_validate"
2303 );
2304 let err = result.unwrap_err();
2305 assert!(matches!(err, crate::executor::ToolError::Blocked { .. }));
2306 }
2307
2308 #[tokio::test]
2309 async fn resolve_private_10_rejected() {
2310 let url = url::Url::parse("https://10.0.0.1/path").unwrap();
2311 let result = resolve_and_validate(&url).await;
2312 assert!(result.is_err());
2313 assert!(matches!(
2314 result.unwrap_err(),
2315 crate::executor::ToolError::Blocked { .. }
2316 ));
2317 }
2318
2319 #[tokio::test]
2320 async fn resolve_private_192_rejected() {
2321 let url = url::Url::parse("https://192.168.1.1/path").unwrap();
2322 let result = resolve_and_validate(&url).await;
2323 assert!(result.is_err());
2324 assert!(matches!(
2325 result.unwrap_err(),
2326 crate::executor::ToolError::Blocked { .. }
2327 ));
2328 }
2329
2330 #[tokio::test]
2331 async fn resolve_ipv6_loopback_rejected() {
2332 let url = url::Url::parse("https://[::1]/path").unwrap();
2333 let result = resolve_and_validate(&url).await;
2334 assert!(result.is_err());
2335 assert!(matches!(
2336 result.unwrap_err(),
2337 crate::executor::ToolError::Blocked { .. }
2338 ));
2339 }
2340
2341 #[tokio::test]
2342 async fn resolve_no_host_returns_ok() {
2343 let url = url::Url::parse("https://example.com/path").unwrap();
2345 let url_no_host = url::Url::parse("data:text/plain,hello").unwrap();
2347 let result = resolve_and_validate(&url_no_host).await;
2349 assert!(result.is_ok());
2350 let (host, addrs) = result.unwrap();
2351 assert!(host.is_empty());
2352 assert!(addrs.is_empty());
2353 drop(url);
2354 drop(url_no_host);
2355 }
2356
2357 async fn make_file_audit_logger(
2361 dir: &tempfile::TempDir,
2362 ) -> (
2363 std::sync::Arc<crate::audit::AuditLogger>,
2364 std::path::PathBuf,
2365 ) {
2366 use crate::audit::AuditLogger;
2367 use crate::config::AuditConfig;
2368 let path = dir.path().join("audit.log");
2369 let config = AuditConfig {
2370 enabled: true,
2371 destination: path.display().to_string(),
2372 ..Default::default()
2373 };
2374 let logger = std::sync::Arc::new(AuditLogger::from_config(&config, false).await.unwrap());
2375 (logger, path)
2376 }
2377
2378 #[tokio::test]
2379 async fn with_audit_sets_logger() {
2380 let config = ScrapeConfig::default();
2381 let executor = WebScrapeExecutor::new(&config);
2382 assert!(executor.audit_logger.is_none());
2383
2384 let dir = tempfile::tempdir().unwrap();
2385 let (logger, _path) = make_file_audit_logger(&dir).await;
2386 let executor = executor.with_audit(logger);
2387 assert!(executor.audit_logger.is_some());
2388 }
2389
2390 #[test]
2391 fn tool_error_to_audit_result_blocked_maps_correctly() {
2392 let err = ToolError::Blocked {
2393 command: "scheme not allowed: http".into(),
2394 };
2395 let result = tool_error_to_audit_result(&err);
2396 assert!(
2397 matches!(result, AuditResult::Blocked { reason } if reason == "scheme not allowed: http")
2398 );
2399 }
2400
2401 #[test]
2402 fn tool_error_to_audit_result_timeout_maps_correctly() {
2403 let err = ToolError::Timeout { timeout_secs: 15 };
2404 let result = tool_error_to_audit_result(&err);
2405 assert!(matches!(result, AuditResult::Timeout));
2406 }
2407
2408 #[test]
2409 fn tool_error_to_audit_result_execution_error_maps_correctly() {
2410 let err = ToolError::Execution(std::io::Error::other("connection refused"));
2411 let result = tool_error_to_audit_result(&err);
2412 assert!(
2413 matches!(result, AuditResult::Error { message } if message.contains("connection refused"))
2414 );
2415 }
2416
2417 #[tokio::test]
2418 async fn fetch_audit_blocked_url_logged() {
2419 let dir = tempfile::tempdir().unwrap();
2420 let (logger, log_path) = make_file_audit_logger(&dir).await;
2421
2422 let config = ScrapeConfig::default();
2423 let executor = WebScrapeExecutor::new(&config).with_audit(logger);
2424
2425 let call = crate::executor::ToolCall {
2426 tool_id: ToolName::new("fetch"),
2427 params: {
2428 let mut m = serde_json::Map::new();
2429 m.insert("url".to_owned(), serde_json::json!("http://example.com"));
2430 m
2431 },
2432 caller_id: None,
2433 context: None,
2434 };
2435 let result = executor.execute_tool_call(&call).await;
2436 assert!(matches!(result, Err(ToolError::Blocked { .. })));
2437
2438 let content = tokio::fs::read_to_string(&log_path).await.unwrap();
2439 assert!(
2440 content.contains("\"tool\":\"fetch\""),
2441 "expected tool=fetch in audit: {content}"
2442 );
2443 assert!(
2444 content.contains("\"type\":\"blocked\""),
2445 "expected type=blocked in audit: {content}"
2446 );
2447 assert!(
2448 content.contains("http://example.com"),
2449 "expected URL in audit command field: {content}"
2450 );
2451 }
2452
2453 #[tokio::test]
2454 async fn log_audit_success_writes_to_file() {
2455 let dir = tempfile::tempdir().unwrap();
2456 let (logger, log_path) = make_file_audit_logger(&dir).await;
2457
2458 let config = ScrapeConfig::default();
2459 let executor = WebScrapeExecutor::new(&config).with_audit(logger);
2460
2461 executor
2462 .log_audit(
2463 "fetch",
2464 "https://example.com/page",
2465 AuditResult::Success,
2466 42,
2467 None,
2468 None,
2469 None,
2470 )
2471 .await;
2472
2473 let content = tokio::fs::read_to_string(&log_path).await.unwrap();
2474 assert!(
2475 content.contains("\"tool\":\"fetch\""),
2476 "expected tool=fetch in audit: {content}"
2477 );
2478 assert!(
2479 content.contains("\"type\":\"success\""),
2480 "expected type=success in audit: {content}"
2481 );
2482 assert!(
2483 content.contains("\"command\":\"https://example.com/page\""),
2484 "expected command URL in audit: {content}"
2485 );
2486 assert!(
2487 content.contains("\"duration_ms\":42"),
2488 "expected duration_ms in audit: {content}"
2489 );
2490 }
2491
2492 #[tokio::test]
2493 async fn log_audit_blocked_writes_to_file() {
2494 let dir = tempfile::tempdir().unwrap();
2495 let (logger, log_path) = make_file_audit_logger(&dir).await;
2496
2497 let config = ScrapeConfig::default();
2498 let executor = WebScrapeExecutor::new(&config).with_audit(logger);
2499
2500 executor
2501 .log_audit(
2502 "web_scrape",
2503 "http://evil.com/page",
2504 AuditResult::Blocked {
2505 reason: "scheme not allowed: http".into(),
2506 },
2507 0,
2508 None,
2509 None,
2510 None,
2511 )
2512 .await;
2513
2514 let content = tokio::fs::read_to_string(&log_path).await.unwrap();
2515 assert!(
2516 content.contains("\"tool\":\"web_scrape\""),
2517 "expected tool=web_scrape in audit: {content}"
2518 );
2519 assert!(
2520 content.contains("\"type\":\"blocked\""),
2521 "expected type=blocked in audit: {content}"
2522 );
2523 assert!(
2524 content.contains("scheme not allowed"),
2525 "expected block reason in audit: {content}"
2526 );
2527 }
2528
2529 #[tokio::test]
2530 async fn web_scrape_audit_blocked_url_logged() {
2531 let dir = tempfile::tempdir().unwrap();
2532 let (logger, log_path) = make_file_audit_logger(&dir).await;
2533
2534 let config = ScrapeConfig::default();
2535 let executor = WebScrapeExecutor::new(&config).with_audit(logger);
2536
2537 let call = crate::executor::ToolCall {
2538 tool_id: ToolName::new("web_scrape"),
2539 params: {
2540 let mut m = serde_json::Map::new();
2541 m.insert("url".to_owned(), serde_json::json!("http://example.com"));
2542 m.insert("select".to_owned(), serde_json::json!("h1"));
2543 m
2544 },
2545 caller_id: None,
2546 context: None,
2547 };
2548 let result = executor.execute_tool_call(&call).await;
2549 assert!(matches!(result, Err(ToolError::Blocked { .. })));
2550
2551 let content = tokio::fs::read_to_string(&log_path).await.unwrap();
2552 assert!(
2553 content.contains("\"tool\":\"web_scrape\""),
2554 "expected tool=web_scrape in audit: {content}"
2555 );
2556 assert!(
2557 content.contains("\"type\":\"blocked\""),
2558 "expected type=blocked in audit: {content}"
2559 );
2560 }
2561
2562 #[tokio::test]
2563 async fn no_audit_logger_does_not_panic_on_blocked_fetch() {
2564 let config = ScrapeConfig::default();
2565 let executor = WebScrapeExecutor::new(&config);
2566 assert!(executor.audit_logger.is_none());
2567
2568 let call = crate::executor::ToolCall {
2569 tool_id: ToolName::new("fetch"),
2570 params: {
2571 let mut m = serde_json::Map::new();
2572 m.insert("url".to_owned(), serde_json::json!("http://example.com"));
2573 m
2574 },
2575 caller_id: None,
2576 context: None,
2577 };
2578 let result = executor.execute_tool_call(&call).await;
2580 assert!(matches!(result, Err(ToolError::Blocked { .. })));
2581 }
2582
2583 #[tokio::test]
2585 async fn fetch_execute_tool_call_end_to_end() {
2586 use wiremock::matchers::{method, path};
2587 use wiremock::{Mock, ResponseTemplate};
2588
2589 let (executor, server) = mock_server_executor().await;
2590 Mock::given(method("GET"))
2591 .and(path("/e2e"))
2592 .respond_with(ResponseTemplate::new(200).set_body_string("<h1>end-to-end</h1>"))
2593 .mount(&server)
2594 .await;
2595
2596 let (host, addrs) = server_host_and_addr(&server);
2597 let result = executor
2599 .fetch_html(
2600 &format!("{}/e2e", server.uri()),
2601 &host,
2602 &addrs,
2603 "fetch",
2604 "test-cid",
2605 None,
2606 )
2607 .await;
2608 assert!(result.is_ok());
2609 assert!(result.unwrap().contains("end-to-end"));
2610 }
2611
2612 #[test]
2615 fn domain_matches_exact() {
2616 assert!(domain_matches("example.com", "example.com"));
2617 assert!(!domain_matches("example.com", "other.com"));
2618 assert!(!domain_matches("example.com", "sub.example.com"));
2619 }
2620
2621 #[test]
2622 fn domain_matches_wildcard_single_subdomain() {
2623 assert!(domain_matches("*.example.com", "sub.example.com"));
2624 assert!(!domain_matches("*.example.com", "example.com"));
2625 assert!(!domain_matches("*.example.com", "sub.sub.example.com"));
2626 }
2627
2628 #[test]
2629 fn domain_matches_wildcard_does_not_match_empty_label() {
2630 assert!(!domain_matches("*.example.com", ".example.com"));
2632 }
2633
2634 #[test]
2635 fn domain_matches_multi_wildcard_treated_as_exact() {
2636 assert!(!domain_matches("*.*.example.com", "a.b.example.com"));
2638 }
2639
2640 #[test]
2643 fn check_domain_policy_empty_lists_allow_all() {
2644 assert!(check_domain_policy("example.com", &[], &[]).is_ok());
2645 assert!(check_domain_policy("evil.com", &[], &[]).is_ok());
2646 }
2647
2648 #[test]
2649 fn check_domain_policy_denylist_blocks() {
2650 let denied = vec!["evil.com".to_string()];
2651 let err = check_domain_policy("evil.com", &[], &denied).unwrap_err();
2652 assert!(matches!(err, ToolError::Blocked { .. }));
2653 }
2654
2655 #[test]
2656 fn check_domain_policy_denylist_does_not_block_other_domains() {
2657 let denied = vec!["evil.com".to_string()];
2658 assert!(check_domain_policy("good.com", &[], &denied).is_ok());
2659 }
2660
2661 #[test]
2662 fn check_domain_policy_allowlist_permits_matching() {
2663 let allowed = vec!["docs.rs".to_string(), "*.rust-lang.org".to_string()];
2664 assert!(check_domain_policy("docs.rs", &allowed, &[]).is_ok());
2665 assert!(check_domain_policy("blog.rust-lang.org", &allowed, &[]).is_ok());
2666 }
2667
2668 #[test]
2669 fn check_domain_policy_allowlist_blocks_unknown() {
2670 let allowed = vec!["docs.rs".to_string()];
2671 let err = check_domain_policy("other.com", &allowed, &[]).unwrap_err();
2672 assert!(matches!(err, ToolError::Blocked { .. }));
2673 }
2674
2675 #[test]
2676 fn check_domain_policy_deny_overrides_allow() {
2677 let allowed = vec!["example.com".to_string()];
2678 let denied = vec!["example.com".to_string()];
2679 let err = check_domain_policy("example.com", &allowed, &denied).unwrap_err();
2680 assert!(matches!(err, ToolError::Blocked { .. }));
2681 }
2682
2683 #[test]
2684 fn check_domain_policy_wildcard_in_denylist() {
2685 let denied = vec!["*.evil.com".to_string()];
2686 let err = check_domain_policy("sub.evil.com", &[], &denied).unwrap_err();
2687 assert!(matches!(err, ToolError::Blocked { .. }));
2688 assert!(check_domain_policy("evil.com", &[], &denied).is_ok());
2690 }
2691}