1use std::net::{IpAddr, SocketAddr};
20use std::sync::Arc;
21use std::sync::atomic::{AtomicU64, Ordering};
22use std::time::{Duration, Instant};
23
24use schemars::JsonSchema;
25use serde::Deserialize;
26use url::Url;
27
28use zeph_common::ToolName;
29
30use crate::audit::{AuditEntry, AuditLogger, AuditResult, EgressEvent, chrono_now};
31use crate::config::{EgressConfig, ScrapeConfig};
32use crate::executor::{
33 ClaimSource, ToolCall, ToolError, ToolExecutor, ToolOutput, deserialize_params,
34};
35use crate::net::is_private_ip;
36
37fn redact_url_for_log(url: &str) -> String {
41 let Ok(mut parsed) = Url::parse(url) else {
42 return url.to_owned();
43 };
44 let _ = parsed.set_username("");
46 let _ = parsed.set_password(None);
47 let sensitive = [
49 "token", "key", "secret", "password", "auth", "sig", "api_key", "apikey",
50 ];
51 let filtered: Vec<(String, String)> = parsed
52 .query_pairs()
53 .filter(|(k, _)| {
54 let lower = k.to_lowercase();
55 !sensitive.iter().any(|s| lower.contains(s))
56 })
57 .map(|(k, v)| (k.into_owned(), v.into_owned()))
58 .collect();
59 if filtered.is_empty() {
60 parsed.set_query(None);
61 } else {
62 let q: String = filtered
63 .iter()
64 .map(|(k, v)| format!("{k}={v}"))
65 .collect::<Vec<_>>()
66 .join("&");
67 parsed.set_query(Some(&q));
68 }
69 parsed.to_string()
70}
71
72#[derive(Debug, Deserialize, JsonSchema)]
73struct FetchParams {
74 url: String,
76}
77
78#[derive(Debug, Deserialize, JsonSchema)]
79struct ScrapeInstruction {
80 url: String,
82 select: String,
84 #[serde(default = "default_extract")]
86 extract: String,
87 limit: Option<usize>,
89}
90
91fn default_extract() -> String {
92 "text".into()
93}
94
95#[derive(Debug)]
96enum ExtractMode {
97 Text,
98 Html,
99 Attr(String),
100}
101
102impl ExtractMode {
103 fn parse(s: &str) -> Self {
104 match s {
105 "text" => Self::Text,
106 "html" => Self::Html,
107 attr if attr.starts_with("attr:") => {
108 Self::Attr(attr.strip_prefix("attr:").unwrap_or(attr).to_owned())
109 }
110 _ => Self::Text,
111 }
112 }
113}
114
115#[derive(Debug)]
154pub struct WebScrapeExecutor {
155 timeout: Duration,
156 max_body_bytes: usize,
157 allowed_domains: Vec<String>,
158 denied_domains: Vec<String>,
159 audit_logger: Option<Arc<AuditLogger>>,
160 egress_config: EgressConfig,
161 egress_tx: Option<tokio::sync::mpsc::Sender<EgressEvent>>,
162 egress_dropped: Arc<AtomicU64>,
163}
164
165impl WebScrapeExecutor {
166 #[must_use]
170 pub fn new(config: &ScrapeConfig) -> Self {
171 Self {
172 timeout: Duration::from_secs(config.timeout),
173 max_body_bytes: config.max_body_bytes,
174 allowed_domains: config.allowed_domains.clone(),
175 denied_domains: config.denied_domains.clone(),
176 audit_logger: None,
177 egress_config: EgressConfig::default(),
178 egress_tx: None,
179 egress_dropped: Arc::new(AtomicU64::new(0)),
180 }
181 }
182
183 #[must_use]
185 pub fn with_audit(mut self, logger: Arc<AuditLogger>) -> Self {
186 self.audit_logger = Some(logger);
187 self
188 }
189
190 #[must_use]
192 pub fn with_egress_config(mut self, config: EgressConfig) -> Self {
193 self.egress_config = config;
194 self
195 }
196
197 #[must_use]
202 pub fn with_egress_tx(
203 mut self,
204 tx: tokio::sync::mpsc::Sender<EgressEvent>,
205 dropped: Arc<AtomicU64>,
206 ) -> Self {
207 self.egress_tx = Some(tx);
208 self.egress_dropped = dropped;
209 self
210 }
211
212 #[must_use]
214 pub fn egress_dropped(&self) -> Arc<AtomicU64> {
215 Arc::clone(&self.egress_dropped)
216 }
217
218 fn build_client(&self, host: &str, addrs: &[SocketAddr]) -> reqwest::Client {
219 let mut builder = reqwest::Client::builder()
220 .timeout(self.timeout)
221 .redirect(reqwest::redirect::Policy::none());
222 builder = builder.resolve_to_addrs(host, addrs);
223 builder.build().unwrap_or_default()
224 }
225}
226
227impl ToolExecutor for WebScrapeExecutor {
228 fn tool_definitions(&self) -> Vec<crate::registry::ToolDef> {
229 use crate::registry::{InvocationHint, ToolDef};
230 vec![
231 ToolDef {
232 id: "web_scrape".into(),
233 description: "Extract structured data from a web page using CSS selectors.\n\nONLY call this tool when the user has explicitly provided a URL in their message, or when a prior tool call returned a URL to retrieve. NEVER construct, guess, or infer a URL from entity names, brand knowledge, or domain patterns.\n\nParameters: url (string, required) - HTTPS URL; select (string, required) - CSS selector; extract (string, optional) - \"text\", \"html\", or \"attr:<name>\"; limit (integer, optional) - max results\nReturns: extracted text/HTML/attribute values, one per line\nErrors: InvalidParams if URL is not HTTPS or selector is empty; Timeout after configured seconds; connection/DNS failures".into(),
234 schema: schemars::schema_for!(ScrapeInstruction),
235 invocation: InvocationHint::FencedBlock("scrape"),
236 output_schema: None,
237 },
238 ToolDef {
239 id: "fetch".into(),
240 description: "Fetch a URL and return the response body as plain text.\n\nONLY call this tool when the user has explicitly provided a URL in their message, or when a prior tool call returned a URL to retrieve. NEVER construct, guess, or infer a URL from entity names, brand knowledge, or domain patterns. If no URL is present in the conversation, do not call this tool.\n\nParameters: url (string, required) - HTTPS URL to fetch\nReturns: response body as UTF-8 text, truncated if exceeding max body size\nErrors: InvalidParams if URL is not HTTPS; Timeout; SSRF-blocked private IPs; connection failures".into(),
241 schema: schemars::schema_for!(FetchParams),
242 invocation: InvocationHint::ToolCall,
243 output_schema: None,
244 },
245 ]
246 }
247
248 async fn execute(&self, response: &str) -> Result<Option<ToolOutput>, ToolError> {
249 let blocks = extract_scrape_blocks(response);
250 if blocks.is_empty() {
251 return Ok(None);
252 }
253
254 let mut outputs = Vec::with_capacity(blocks.len());
255 #[allow(clippy::cast_possible_truncation)]
256 let blocks_executed = blocks.len() as u32;
257
258 for block in &blocks {
259 let instruction: ScrapeInstruction = serde_json::from_str(block).map_err(|e| {
260 ToolError::Execution(std::io::Error::new(
261 std::io::ErrorKind::InvalidData,
262 e.to_string(),
263 ))
264 })?;
265 let correlation_id = EgressEvent::new_correlation_id();
266 let start = Instant::now();
267 let scrape_result = self
268 .scrape_instruction(&instruction, &correlation_id, None)
269 .await;
270 #[allow(clippy::cast_possible_truncation)]
271 let duration_ms = start.elapsed().as_millis() as u64;
272 match scrape_result {
273 Ok(output) => {
274 self.log_audit(
275 "web_scrape",
276 &instruction.url,
277 AuditResult::Success,
278 duration_ms,
279 None,
280 None,
281 Some(correlation_id),
282 )
283 .await;
284 outputs.push(output);
285 }
286 Err(e) => {
287 let audit_result = tool_error_to_audit_result(&e);
288 self.log_audit(
289 "web_scrape",
290 &instruction.url,
291 audit_result,
292 duration_ms,
293 Some(&e),
294 None,
295 Some(correlation_id),
296 )
297 .await;
298 return Err(e);
299 }
300 }
301 }
302
303 Ok(Some(ToolOutput {
304 tool_name: ToolName::new("web-scrape"),
305 summary: outputs.join("\n\n"),
306 blocks_executed,
307 filter_stats: None,
308 diff: None,
309 streamed: false,
310 terminal_id: None,
311 locations: None,
312 raw_response: None,
313 claim_source: Some(ClaimSource::WebScrape),
314 }))
315 }
316
317 #[cfg_attr(
318 feature = "profiling",
319 tracing::instrument(name = "tool.web_scrape", skip_all)
320 )]
321 async fn execute_tool_call(&self, call: &ToolCall) -> Result<Option<ToolOutput>, ToolError> {
322 match call.tool_id.as_str() {
323 "web_scrape" => {
324 let instruction: ScrapeInstruction = deserialize_params(&call.params)?;
325 let correlation_id = EgressEvent::new_correlation_id();
326 let start = Instant::now();
327 let result = self
328 .scrape_instruction(&instruction, &correlation_id, call.caller_id.clone())
329 .await;
330 #[allow(clippy::cast_possible_truncation)]
331 let duration_ms = start.elapsed().as_millis() as u64;
332 self.run_with_audit(
333 "web_scrape",
334 "web-scrape",
335 &instruction.url,
336 call.caller_id.clone(),
337 correlation_id,
338 duration_ms,
339 result,
340 )
341 .await
342 }
343 "fetch" => {
344 let p: FetchParams = deserialize_params(&call.params)?;
345 let correlation_id = EgressEvent::new_correlation_id();
346 let start = Instant::now();
347 let result = self
348 .handle_fetch(&p, &correlation_id, call.caller_id.clone())
349 .await;
350 #[allow(clippy::cast_possible_truncation)]
351 let duration_ms = start.elapsed().as_millis() as u64;
352 self.run_with_audit(
353 "fetch",
354 "fetch",
355 &p.url,
356 call.caller_id.clone(),
357 correlation_id,
358 duration_ms,
359 result,
360 )
361 .await
362 }
363 _ => Ok(None),
364 }
365 }
366
367 fn is_tool_retryable(&self, tool_id: &str) -> bool {
368 matches!(tool_id, "web_scrape" | "fetch")
369 }
370}
371
372fn tool_error_to_audit_result(e: &ToolError) -> AuditResult {
373 match e {
374 ToolError::Blocked { command } => AuditResult::Blocked {
375 reason: command.clone(),
376 },
377 ToolError::Timeout { .. } => AuditResult::Timeout,
378 _ => AuditResult::Error {
379 message: e.to_string(),
380 },
381 }
382}
383
384impl WebScrapeExecutor {
385 #[allow(clippy::too_many_arguments)]
386 async fn run_with_audit(
387 &self,
388 audit_tool_name: &str,
389 public_tool_name: &str,
390 audit_command: &str,
391 caller_id: Option<String>,
392 correlation_id: String,
393 duration_ms: u64,
394 result: Result<String, ToolError>,
395 ) -> Result<Option<ToolOutput>, ToolError> {
396 match result {
397 Ok(output) => {
398 self.log_audit(
399 audit_tool_name,
400 audit_command,
401 AuditResult::Success,
402 duration_ms,
403 None,
404 caller_id,
405 Some(correlation_id),
406 )
407 .await;
408 Ok(Some(ToolOutput {
409 tool_name: ToolName::new(public_tool_name),
410 summary: output,
411 blocks_executed: 1,
412 filter_stats: None,
413 diff: None,
414 streamed: false,
415 terminal_id: None,
416 locations: None,
417 raw_response: None,
418 claim_source: Some(ClaimSource::WebScrape),
419 }))
420 }
421 Err(e) => {
422 let audit_result = tool_error_to_audit_result(&e);
423 self.log_audit(
424 audit_tool_name,
425 audit_command,
426 audit_result,
427 duration_ms,
428 Some(&e),
429 caller_id,
430 Some(correlation_id),
431 )
432 .await;
433 Err(e)
434 }
435 }
436 }
437
438 #[allow(clippy::too_many_arguments)] async fn log_audit(
440 &self,
441 tool: &str,
442 command: &str,
443 result: AuditResult,
444 duration_ms: u64,
445 error: Option<&ToolError>,
446 caller_id: Option<String>,
447 correlation_id: Option<String>,
448 ) {
449 if let Some(ref logger) = self.audit_logger {
450 let (error_category, error_domain, error_phase) =
451 error.map_or((None, None, None), |e| {
452 let cat = e.category();
453 (
454 Some(cat.label().to_owned()),
455 Some(cat.domain().label().to_owned()),
456 Some(cat.phase().label().to_owned()),
457 )
458 });
459 let entry = AuditEntry {
460 timestamp: chrono_now(),
461 tool: tool.into(),
462 command: command.into(),
463 result,
464 duration_ms,
465 error_category,
466 error_domain,
467 error_phase,
468 claim_source: Some(ClaimSource::WebScrape),
469 mcp_server_id: None,
470 injection_flagged: false,
471 embedding_anomalous: false,
472 cross_boundary_mcp_to_acp: false,
473 adversarial_policy_decision: None,
474 exit_code: None,
475 truncated: false,
476 caller_id,
477 policy_match: None,
478 correlation_id,
479 vigil_risk: None,
480 };
481 logger.log(&entry).await;
482 }
483 }
484
485 fn send_egress_event(&self, event: EgressEvent) {
486 if let Some(ref tx) = self.egress_tx {
487 match tx.try_send(event) {
488 Ok(()) => {}
489 Err(tokio::sync::mpsc::error::TrySendError::Full(_)) => {
490 self.egress_dropped.fetch_add(1, Ordering::Relaxed);
491 }
492 Err(tokio::sync::mpsc::error::TrySendError::Closed(_)) => {
493 tracing::debug!("egress channel closed; executor continuing without telemetry");
494 }
495 }
496 }
497 }
498
499 async fn log_egress_event(&self, event: &EgressEvent) {
500 if let Some(ref logger) = self.audit_logger {
501 logger.log_egress(event).await;
502 }
503 self.send_egress_event(event.clone());
504 }
505
506 async fn handle_fetch(
507 &self,
508 params: &FetchParams,
509 correlation_id: &str,
510 caller_id: Option<String>,
511 ) -> Result<String, ToolError> {
512 let parsed = validate_url(¶ms.url);
513 let host_str = parsed
514 .as_ref()
515 .map(|u| u.host_str().unwrap_or("").to_owned())
516 .unwrap_or_default();
517
518 if let Err(ref _e) = parsed {
519 if self.egress_config.enabled && self.egress_config.log_blocked {
520 let event = Self::make_blocked_event(
521 "fetch",
522 ¶ms.url,
523 &host_str,
524 correlation_id,
525 caller_id.clone(),
526 "scheme",
527 );
528 self.log_egress_event(&event).await;
529 }
530 return Err(parsed.unwrap_err());
531 }
532 let parsed = parsed.unwrap();
533
534 if let Err(e) = check_domain_policy(
535 parsed.host_str().unwrap_or(""),
536 &self.allowed_domains,
537 &self.denied_domains,
538 ) {
539 if self.egress_config.enabled && self.egress_config.log_blocked {
540 let event = Self::make_blocked_event(
541 "fetch",
542 ¶ms.url,
543 parsed.host_str().unwrap_or(""),
544 correlation_id,
545 caller_id.clone(),
546 "blocklist",
547 );
548 self.log_egress_event(&event).await;
549 }
550 return Err(e);
551 }
552
553 let (host, addrs) = match resolve_and_validate(&parsed).await {
554 Ok(v) => v,
555 Err(e) => {
556 if self.egress_config.enabled && self.egress_config.log_blocked {
557 let event = Self::make_blocked_event(
558 "fetch",
559 ¶ms.url,
560 parsed.host_str().unwrap_or(""),
561 correlation_id,
562 caller_id.clone(),
563 "ssrf",
564 );
565 self.log_egress_event(&event).await;
566 }
567 return Err(e);
568 }
569 };
570
571 self.fetch_html(
572 ¶ms.url,
573 &host,
574 &addrs,
575 "fetch",
576 correlation_id,
577 caller_id,
578 )
579 .await
580 }
581
582 async fn scrape_instruction(
583 &self,
584 instruction: &ScrapeInstruction,
585 correlation_id: &str,
586 caller_id: Option<String>,
587 ) -> Result<String, ToolError> {
588 let parsed = validate_url(&instruction.url);
589 let host_str = parsed
590 .as_ref()
591 .map(|u| u.host_str().unwrap_or("").to_owned())
592 .unwrap_or_default();
593
594 if let Err(ref _e) = parsed {
595 if self.egress_config.enabled && self.egress_config.log_blocked {
596 let event = Self::make_blocked_event(
597 "web_scrape",
598 &instruction.url,
599 &host_str,
600 correlation_id,
601 caller_id.clone(),
602 "scheme",
603 );
604 self.log_egress_event(&event).await;
605 }
606 return Err(parsed.unwrap_err());
607 }
608 let parsed = parsed.unwrap();
609
610 if let Err(e) = check_domain_policy(
611 parsed.host_str().unwrap_or(""),
612 &self.allowed_domains,
613 &self.denied_domains,
614 ) {
615 if self.egress_config.enabled && self.egress_config.log_blocked {
616 let event = Self::make_blocked_event(
617 "web_scrape",
618 &instruction.url,
619 parsed.host_str().unwrap_or(""),
620 correlation_id,
621 caller_id.clone(),
622 "blocklist",
623 );
624 self.log_egress_event(&event).await;
625 }
626 return Err(e);
627 }
628
629 let (host, addrs) = match resolve_and_validate(&parsed).await {
630 Ok(v) => v,
631 Err(e) => {
632 if self.egress_config.enabled && self.egress_config.log_blocked {
633 let event = Self::make_blocked_event(
634 "web_scrape",
635 &instruction.url,
636 parsed.host_str().unwrap_or(""),
637 correlation_id,
638 caller_id.clone(),
639 "ssrf",
640 );
641 self.log_egress_event(&event).await;
642 }
643 return Err(e);
644 }
645 };
646
647 let html = self
648 .fetch_html(
649 &instruction.url,
650 &host,
651 &addrs,
652 "web_scrape",
653 correlation_id,
654 caller_id,
655 )
656 .await?;
657 let selector = instruction.select.clone();
658 let extract = ExtractMode::parse(&instruction.extract);
659 let limit = instruction.limit.unwrap_or(10);
660 tokio::task::spawn_blocking(move || parse_and_extract(&html, &selector, &extract, limit))
661 .await
662 .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())))?
663 }
664
665 fn make_blocked_event(
666 tool: &str,
667 url: &str,
668 host: &str,
669 correlation_id: &str,
670 caller_id: Option<String>,
671 block_reason: &'static str,
672 ) -> EgressEvent {
673 EgressEvent {
674 timestamp: chrono_now(),
675 kind: "egress",
676 correlation_id: correlation_id.to_owned(),
677 tool: tool.into(),
678 url: redact_url_for_log(url),
679 host: host.to_owned(),
680 method: "GET".to_owned(),
681 status: None,
682 duration_ms: 0,
683 response_bytes: 0,
684 blocked: true,
685 block_reason: Some(block_reason),
686 caller_id,
687 hop: 0,
688 }
689 }
690
691 #[allow(clippy::too_many_lines, clippy::too_many_arguments)]
702 async fn fetch_html(
703 &self,
704 url: &str,
705 host: &str,
706 addrs: &[SocketAddr],
707 tool: &str,
708 correlation_id: &str,
709 caller_id: Option<String>,
710 ) -> Result<String, ToolError> {
711 const MAX_REDIRECTS: usize = 3;
712
713 let mut current_url = url.to_owned();
714 let mut current_host = host.to_owned();
715 let mut current_addrs = addrs.to_vec();
716
717 for hop in 0..=MAX_REDIRECTS {
718 let hop_start = Instant::now();
719 let client = self.build_client(¤t_host, ¤t_addrs);
721 let resp = client
722 .get(¤t_url)
723 .send()
724 .await
725 .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())));
726
727 let resp = match resp {
728 Ok(r) => r,
729 Err(e) => {
730 if self.egress_config.enabled {
731 #[allow(clippy::cast_possible_truncation)]
732 let duration_ms = hop_start.elapsed().as_millis() as u64;
733 let event = EgressEvent {
734 timestamp: chrono_now(),
735 kind: "egress",
736 correlation_id: correlation_id.to_owned(),
737 tool: tool.into(),
738 url: redact_url_for_log(¤t_url),
739 host: current_host.clone(),
740 method: "GET".to_owned(),
741 status: None,
742 duration_ms,
743 response_bytes: 0,
744 blocked: false,
745 block_reason: None,
746 caller_id: caller_id.clone(),
747 #[allow(clippy::cast_possible_truncation)]
748 hop: hop as u8,
749 };
750 self.log_egress_event(&event).await;
751 }
752 return Err(e);
753 }
754 };
755
756 let status = resp.status();
757
758 if status.is_redirection() {
759 if hop == MAX_REDIRECTS {
760 return Err(ToolError::Execution(std::io::Error::other(
761 "too many redirects",
762 )));
763 }
764
765 let location = resp
766 .headers()
767 .get(reqwest::header::LOCATION)
768 .and_then(|v| v.to_str().ok())
769 .ok_or_else(|| {
770 ToolError::Execution(std::io::Error::other("redirect with no Location"))
771 })?;
772
773 let base = Url::parse(¤t_url)
775 .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())))?;
776 let next_url = base
777 .join(location)
778 .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())))?;
779
780 let validated = validate_url(next_url.as_str());
781 if let Err(ref _e) = validated {
782 if self.egress_config.enabled && self.egress_config.log_blocked {
783 #[allow(clippy::cast_possible_truncation)]
784 let duration_ms = hop_start.elapsed().as_millis() as u64;
785 let next_host = next_url.host_str().unwrap_or("").to_owned();
786 let event = EgressEvent {
787 timestamp: chrono_now(),
788 kind: "egress",
789 correlation_id: correlation_id.to_owned(),
790 tool: tool.into(),
791 url: redact_url_for_log(next_url.as_str()),
792 host: next_host,
793 method: "GET".to_owned(),
794 status: None,
795 duration_ms,
796 response_bytes: 0,
797 blocked: true,
798 block_reason: Some("ssrf"),
799 caller_id: caller_id.clone(),
800 #[allow(clippy::cast_possible_truncation)]
801 hop: (hop + 1) as u8,
802 };
803 self.log_egress_event(&event).await;
804 }
805 return Err(validated.unwrap_err());
806 }
807 let validated = validated.unwrap();
808 let resolve_result = resolve_and_validate(&validated).await;
809 if let Err(ref _e) = resolve_result {
810 if self.egress_config.enabled && self.egress_config.log_blocked {
811 #[allow(clippy::cast_possible_truncation)]
812 let duration_ms = hop_start.elapsed().as_millis() as u64;
813 let next_host = next_url.host_str().unwrap_or("").to_owned();
814 let event = EgressEvent {
815 timestamp: chrono_now(),
816 kind: "egress",
817 correlation_id: correlation_id.to_owned(),
818 tool: tool.into(),
819 url: redact_url_for_log(next_url.as_str()),
820 host: next_host,
821 method: "GET".to_owned(),
822 status: None,
823 duration_ms,
824 response_bytes: 0,
825 blocked: true,
826 block_reason: Some("ssrf"),
827 caller_id: caller_id.clone(),
828 #[allow(clippy::cast_possible_truncation)]
829 hop: (hop + 1) as u8,
830 };
831 self.log_egress_event(&event).await;
832 }
833 return Err(resolve_result.unwrap_err());
834 }
835 let (next_host, next_addrs) = resolve_result.unwrap();
836
837 current_url = next_url.to_string();
838 current_host = next_host;
839 current_addrs = next_addrs;
840 continue;
841 }
842
843 if !status.is_success() {
844 if self.egress_config.enabled {
845 #[allow(clippy::cast_possible_truncation)]
846 let duration_ms = hop_start.elapsed().as_millis() as u64;
847 let event = EgressEvent {
848 timestamp: chrono_now(),
849 kind: "egress",
850 correlation_id: correlation_id.to_owned(),
851 tool: tool.into(),
852 url: current_url.clone(),
853 host: current_host.clone(),
854 method: "GET".to_owned(),
855 status: Some(status.as_u16()),
856 duration_ms,
857 response_bytes: 0,
858 blocked: false,
859 block_reason: None,
860 caller_id: caller_id.clone(),
861 #[allow(clippy::cast_possible_truncation)]
862 hop: hop as u8,
863 };
864 self.log_egress_event(&event).await;
865 }
866 return Err(ToolError::Http {
867 status: status.as_u16(),
868 message: status.canonical_reason().unwrap_or("unknown").to_owned(),
869 });
870 }
871
872 let bytes = resp
873 .bytes()
874 .await
875 .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())))?;
876
877 if bytes.len() > self.max_body_bytes {
878 if self.egress_config.enabled {
879 #[allow(clippy::cast_possible_truncation)]
880 let duration_ms = hop_start.elapsed().as_millis() as u64;
881 let event = EgressEvent {
882 timestamp: chrono_now(),
883 kind: "egress",
884 correlation_id: correlation_id.to_owned(),
885 tool: tool.into(),
886 url: current_url.clone(),
887 host: current_host.clone(),
888 method: "GET".to_owned(),
889 status: Some(status.as_u16()),
890 duration_ms,
891 response_bytes: bytes.len(),
892 blocked: false,
893 block_reason: None,
894 caller_id: caller_id.clone(),
895 #[allow(clippy::cast_possible_truncation)]
896 hop: hop as u8,
897 };
898 self.log_egress_event(&event).await;
899 }
900 return Err(ToolError::Execution(std::io::Error::other(format!(
901 "response too large: {} bytes (max: {})",
902 bytes.len(),
903 self.max_body_bytes,
904 ))));
905 }
906
907 if self.egress_config.enabled {
909 #[allow(clippy::cast_possible_truncation)]
910 let duration_ms = hop_start.elapsed().as_millis() as u64;
911 let response_bytes = if self.egress_config.log_response_bytes {
912 bytes.len()
913 } else {
914 0
915 };
916 let event = EgressEvent {
917 timestamp: chrono_now(),
918 kind: "egress",
919 correlation_id: correlation_id.to_owned(),
920 tool: tool.into(),
921 url: current_url.clone(),
922 host: current_host.clone(),
923 method: "GET".to_owned(),
924 status: Some(status.as_u16()),
925 duration_ms,
926 response_bytes,
927 blocked: false,
928 block_reason: None,
929 caller_id: caller_id.clone(),
930 #[allow(clippy::cast_possible_truncation)]
931 hop: hop as u8,
932 };
933 self.log_egress_event(&event).await;
934 }
935
936 return String::from_utf8(bytes.to_vec())
937 .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())));
938 }
939
940 Err(ToolError::Execution(std::io::Error::other(
941 "too many redirects",
942 )))
943 }
944}
945
946fn extract_scrape_blocks(text: &str) -> Vec<&str> {
947 crate::executor::extract_fenced_blocks(text, "scrape")
948}
949
950fn check_domain_policy(
962 host: &str,
963 allowed_domains: &[String],
964 denied_domains: &[String],
965) -> Result<(), ToolError> {
966 if denied_domains.iter().any(|p| domain_matches(p, host)) {
967 return Err(ToolError::Blocked {
968 command: format!("domain blocked by denylist: {host}"),
969 });
970 }
971 if !allowed_domains.is_empty() {
972 let is_ip = host.parse::<std::net::IpAddr>().is_ok()
974 || (host.starts_with('[') && host.ends_with(']'));
975 if is_ip {
976 return Err(ToolError::Blocked {
977 command: format!(
978 "bare IP address not allowed when domain allowlist is active: {host}"
979 ),
980 });
981 }
982 if !allowed_domains.iter().any(|p| domain_matches(p, host)) {
983 return Err(ToolError::Blocked {
984 command: format!("domain not in allowlist: {host}"),
985 });
986 }
987 }
988 Ok(())
989}
990
991use crate::domain_match::domain_matches;
993
994fn validate_url(raw: &str) -> Result<Url, ToolError> {
995 let parsed = Url::parse(raw).map_err(|_| ToolError::Blocked {
996 command: format!("invalid URL: {raw}"),
997 })?;
998
999 if parsed.scheme() != "https" {
1000 return Err(ToolError::Blocked {
1001 command: format!("scheme not allowed: {}", parsed.scheme()),
1002 });
1003 }
1004
1005 if let Some(host) = parsed.host()
1006 && is_private_host(&host)
1007 {
1008 return Err(ToolError::Blocked {
1009 command: format!(
1010 "private/local host blocked: {}",
1011 parsed.host_str().unwrap_or("")
1012 ),
1013 });
1014 }
1015
1016 Ok(parsed)
1017}
1018
1019fn is_private_host(host: &url::Host<&str>) -> bool {
1020 match host {
1021 url::Host::Domain(d) => {
1022 #[allow(clippy::case_sensitive_file_extension_comparisons)]
1025 {
1026 *d == "localhost"
1027 || d.ends_with(".localhost")
1028 || d.ends_with(".internal")
1029 || d.ends_with(".local")
1030 }
1031 }
1032 url::Host::Ipv4(v4) => is_private_ip(IpAddr::V4(*v4)),
1033 url::Host::Ipv6(v6) => is_private_ip(IpAddr::V6(*v6)),
1034 }
1035}
1036
1037async fn resolve_and_validate(url: &Url) -> Result<(String, Vec<SocketAddr>), ToolError> {
1043 let Some(host) = url.host_str() else {
1044 return Ok((String::new(), vec![]));
1045 };
1046 let port = url.port_or_known_default().unwrap_or(443);
1047 let addrs: Vec<SocketAddr> = tokio::net::lookup_host(format!("{host}:{port}"))
1048 .await
1049 .map_err(|e| ToolError::Blocked {
1050 command: format!("DNS resolution failed: {e}"),
1051 })?
1052 .collect();
1053 for addr in &addrs {
1054 if is_private_ip(addr.ip()) {
1055 return Err(ToolError::Blocked {
1056 command: format!("SSRF protection: private IP {} for host {host}", addr.ip()),
1057 });
1058 }
1059 }
1060 Ok((host.to_owned(), addrs))
1061}
1062
1063fn parse_and_extract(
1064 html: &str,
1065 selector: &str,
1066 extract: &ExtractMode,
1067 limit: usize,
1068) -> Result<String, ToolError> {
1069 let soup = scrape_core::Soup::parse(html);
1070
1071 let tags = soup.find_all(selector).map_err(|e| {
1072 ToolError::Execution(std::io::Error::new(
1073 std::io::ErrorKind::InvalidData,
1074 format!("invalid selector: {e}"),
1075 ))
1076 })?;
1077
1078 let mut results = Vec::new();
1079
1080 for tag in tags.into_iter().take(limit) {
1081 let value = match extract {
1082 ExtractMode::Text => tag.text(),
1083 ExtractMode::Html => tag.inner_html(),
1084 ExtractMode::Attr(name) => tag.get(name).unwrap_or_default().to_owned(),
1085 };
1086 if !value.trim().is_empty() {
1087 results.push(value.trim().to_owned());
1088 }
1089 }
1090
1091 if results.is_empty() {
1092 Ok(format!("No results for selector: {selector}"))
1093 } else {
1094 Ok(results.join("\n"))
1095 }
1096}
1097
1098#[cfg(test)]
1099mod tests {
1100 use super::*;
1101
1102 #[test]
1105 fn extract_single_block() {
1106 let text =
1107 "Here:\n```scrape\n{\"url\":\"https://example.com\",\"select\":\"h1\"}\n```\nDone.";
1108 let blocks = extract_scrape_blocks(text);
1109 assert_eq!(blocks.len(), 1);
1110 assert!(blocks[0].contains("example.com"));
1111 }
1112
1113 #[test]
1114 fn extract_multiple_blocks() {
1115 let text = "```scrape\n{\"url\":\"https://a.com\",\"select\":\"h1\"}\n```\ntext\n```scrape\n{\"url\":\"https://b.com\",\"select\":\"p\"}\n```";
1116 let blocks = extract_scrape_blocks(text);
1117 assert_eq!(blocks.len(), 2);
1118 }
1119
1120 #[test]
1121 fn no_blocks_returns_empty() {
1122 let blocks = extract_scrape_blocks("plain text, no code blocks");
1123 assert!(blocks.is_empty());
1124 }
1125
1126 #[test]
1127 fn unclosed_block_ignored() {
1128 let blocks = extract_scrape_blocks("```scrape\n{\"url\":\"https://x.com\"}");
1129 assert!(blocks.is_empty());
1130 }
1131
1132 #[test]
1133 fn non_scrape_block_ignored() {
1134 let text =
1135 "```bash\necho hi\n```\n```scrape\n{\"url\":\"https://x.com\",\"select\":\"h1\"}\n```";
1136 let blocks = extract_scrape_blocks(text);
1137 assert_eq!(blocks.len(), 1);
1138 assert!(blocks[0].contains("x.com"));
1139 }
1140
1141 #[test]
1142 fn multiline_json_block() {
1143 let text =
1144 "```scrape\n{\n \"url\": \"https://example.com\",\n \"select\": \"h1\"\n}\n```";
1145 let blocks = extract_scrape_blocks(text);
1146 assert_eq!(blocks.len(), 1);
1147 let instr: ScrapeInstruction = serde_json::from_str(blocks[0]).unwrap();
1148 assert_eq!(instr.url, "https://example.com");
1149 }
1150
1151 #[test]
1154 fn parse_valid_instruction() {
1155 let json = r#"{"url":"https://example.com","select":"h1","extract":"text","limit":5}"#;
1156 let instr: ScrapeInstruction = serde_json::from_str(json).unwrap();
1157 assert_eq!(instr.url, "https://example.com");
1158 assert_eq!(instr.select, "h1");
1159 assert_eq!(instr.extract, "text");
1160 assert_eq!(instr.limit, Some(5));
1161 }
1162
1163 #[test]
1164 fn parse_minimal_instruction() {
1165 let json = r#"{"url":"https://example.com","select":"p"}"#;
1166 let instr: ScrapeInstruction = serde_json::from_str(json).unwrap();
1167 assert_eq!(instr.extract, "text");
1168 assert!(instr.limit.is_none());
1169 }
1170
1171 #[test]
1172 fn parse_attr_extract() {
1173 let json = r#"{"url":"https://example.com","select":"a","extract":"attr:href"}"#;
1174 let instr: ScrapeInstruction = serde_json::from_str(json).unwrap();
1175 assert_eq!(instr.extract, "attr:href");
1176 }
1177
1178 #[test]
1179 fn parse_invalid_json_errors() {
1180 let result = serde_json::from_str::<ScrapeInstruction>("not json");
1181 assert!(result.is_err());
1182 }
1183
1184 #[test]
1187 fn extract_mode_text() {
1188 assert!(matches!(ExtractMode::parse("text"), ExtractMode::Text));
1189 }
1190
1191 #[test]
1192 fn extract_mode_html() {
1193 assert!(matches!(ExtractMode::parse("html"), ExtractMode::Html));
1194 }
1195
1196 #[test]
1197 fn extract_mode_attr() {
1198 let mode = ExtractMode::parse("attr:href");
1199 assert!(matches!(mode, ExtractMode::Attr(ref s) if s == "href"));
1200 }
1201
1202 #[test]
1203 fn extract_mode_unknown_defaults_to_text() {
1204 assert!(matches!(ExtractMode::parse("unknown"), ExtractMode::Text));
1205 }
1206
1207 #[test]
1210 fn valid_https_url() {
1211 assert!(validate_url("https://example.com").is_ok());
1212 }
1213
1214 #[test]
1215 fn http_rejected() {
1216 let err = validate_url("http://example.com").unwrap_err();
1217 assert!(matches!(err, ToolError::Blocked { .. }));
1218 }
1219
1220 #[test]
1221 fn ftp_rejected() {
1222 let err = validate_url("ftp://files.example.com").unwrap_err();
1223 assert!(matches!(err, ToolError::Blocked { .. }));
1224 }
1225
1226 #[test]
1227 fn file_rejected() {
1228 let err = validate_url("file:///etc/passwd").unwrap_err();
1229 assert!(matches!(err, ToolError::Blocked { .. }));
1230 }
1231
1232 #[test]
1233 fn invalid_url_rejected() {
1234 let err = validate_url("not a url").unwrap_err();
1235 assert!(matches!(err, ToolError::Blocked { .. }));
1236 }
1237
1238 #[test]
1239 fn localhost_blocked() {
1240 let err = validate_url("https://localhost/path").unwrap_err();
1241 assert!(matches!(err, ToolError::Blocked { .. }));
1242 }
1243
1244 #[test]
1245 fn loopback_ip_blocked() {
1246 let err = validate_url("https://127.0.0.1/path").unwrap_err();
1247 assert!(matches!(err, ToolError::Blocked { .. }));
1248 }
1249
1250 #[test]
1251 fn private_10_blocked() {
1252 let err = validate_url("https://10.0.0.1/api").unwrap_err();
1253 assert!(matches!(err, ToolError::Blocked { .. }));
1254 }
1255
1256 #[test]
1257 fn private_172_blocked() {
1258 let err = validate_url("https://172.16.0.1/api").unwrap_err();
1259 assert!(matches!(err, ToolError::Blocked { .. }));
1260 }
1261
1262 #[test]
1263 fn private_192_blocked() {
1264 let err = validate_url("https://192.168.1.1/api").unwrap_err();
1265 assert!(matches!(err, ToolError::Blocked { .. }));
1266 }
1267
1268 #[test]
1269 fn ipv6_loopback_blocked() {
1270 let err = validate_url("https://[::1]/path").unwrap_err();
1271 assert!(matches!(err, ToolError::Blocked { .. }));
1272 }
1273
1274 #[test]
1275 fn public_ip_allowed() {
1276 assert!(validate_url("https://93.184.216.34/page").is_ok());
1277 }
1278
1279 #[test]
1282 fn extract_text_from_html() {
1283 let html = "<html><body><h1>Hello World</h1><p>Content</p></body></html>";
1284 let result = parse_and_extract(html, "h1", &ExtractMode::Text, 10).unwrap();
1285 assert_eq!(result, "Hello World");
1286 }
1287
1288 #[test]
1289 fn extract_multiple_elements() {
1290 let html = "<ul><li>A</li><li>B</li><li>C</li></ul>";
1291 let result = parse_and_extract(html, "li", &ExtractMode::Text, 10).unwrap();
1292 assert_eq!(result, "A\nB\nC");
1293 }
1294
1295 #[test]
1296 fn extract_with_limit() {
1297 let html = "<ul><li>A</li><li>B</li><li>C</li></ul>";
1298 let result = parse_and_extract(html, "li", &ExtractMode::Text, 2).unwrap();
1299 assert_eq!(result, "A\nB");
1300 }
1301
1302 #[test]
1303 fn extract_attr_href() {
1304 let html = r#"<a href="https://example.com">Link</a>"#;
1305 let result =
1306 parse_and_extract(html, "a", &ExtractMode::Attr("href".to_owned()), 10).unwrap();
1307 assert_eq!(result, "https://example.com");
1308 }
1309
1310 #[test]
1311 fn extract_inner_html() {
1312 let html = "<div><span>inner</span></div>";
1313 let result = parse_and_extract(html, "div", &ExtractMode::Html, 10).unwrap();
1314 assert!(result.contains("<span>inner</span>"));
1315 }
1316
1317 #[test]
1318 fn no_matches_returns_message() {
1319 let html = "<html><body><p>text</p></body></html>";
1320 let result = parse_and_extract(html, "h1", &ExtractMode::Text, 10).unwrap();
1321 assert!(result.starts_with("No results for selector:"));
1322 }
1323
1324 #[test]
1325 fn empty_text_skipped() {
1326 let html = "<ul><li> </li><li>A</li></ul>";
1327 let result = parse_and_extract(html, "li", &ExtractMode::Text, 10).unwrap();
1328 assert_eq!(result, "A");
1329 }
1330
1331 #[test]
1332 fn invalid_selector_errors() {
1333 let html = "<html><body></body></html>";
1334 let result = parse_and_extract(html, "[[[invalid", &ExtractMode::Text, 10);
1335 assert!(result.is_err());
1336 }
1337
1338 #[test]
1339 fn empty_html_returns_no_results() {
1340 let result = parse_and_extract("", "h1", &ExtractMode::Text, 10).unwrap();
1341 assert!(result.starts_with("No results for selector:"));
1342 }
1343
1344 #[test]
1345 fn nested_selector() {
1346 let html = "<div><span>inner</span></div><span>outer</span>";
1347 let result = parse_and_extract(html, "div > span", &ExtractMode::Text, 10).unwrap();
1348 assert_eq!(result, "inner");
1349 }
1350
1351 #[test]
1352 fn attr_missing_returns_empty() {
1353 let html = r"<a>No href</a>";
1354 let result =
1355 parse_and_extract(html, "a", &ExtractMode::Attr("href".to_owned()), 10).unwrap();
1356 assert!(result.starts_with("No results for selector:"));
1357 }
1358
1359 #[test]
1360 fn extract_html_mode() {
1361 let html = "<div><b>bold</b> text</div>";
1362 let result = parse_and_extract(html, "div", &ExtractMode::Html, 10).unwrap();
1363 assert!(result.contains("<b>bold</b>"));
1364 }
1365
1366 #[test]
1367 fn limit_zero_returns_no_results() {
1368 let html = "<ul><li>A</li><li>B</li></ul>";
1369 let result = parse_and_extract(html, "li", &ExtractMode::Text, 0).unwrap();
1370 assert!(result.starts_with("No results for selector:"));
1371 }
1372
1373 #[test]
1376 fn url_with_port_allowed() {
1377 assert!(validate_url("https://example.com:8443/path").is_ok());
1378 }
1379
1380 #[test]
1381 fn link_local_ip_blocked() {
1382 let err = validate_url("https://169.254.1.1/path").unwrap_err();
1383 assert!(matches!(err, ToolError::Blocked { .. }));
1384 }
1385
1386 #[test]
1387 fn url_no_scheme_rejected() {
1388 let err = validate_url("example.com/path").unwrap_err();
1389 assert!(matches!(err, ToolError::Blocked { .. }));
1390 }
1391
1392 #[test]
1393 fn unspecified_ipv4_blocked() {
1394 let err = validate_url("https://0.0.0.0/path").unwrap_err();
1395 assert!(matches!(err, ToolError::Blocked { .. }));
1396 }
1397
1398 #[test]
1399 fn broadcast_ipv4_blocked() {
1400 let err = validate_url("https://255.255.255.255/path").unwrap_err();
1401 assert!(matches!(err, ToolError::Blocked { .. }));
1402 }
1403
1404 #[test]
1405 fn ipv6_link_local_blocked() {
1406 let err = validate_url("https://[fe80::1]/path").unwrap_err();
1407 assert!(matches!(err, ToolError::Blocked { .. }));
1408 }
1409
1410 #[test]
1411 fn ipv6_unique_local_blocked() {
1412 let err = validate_url("https://[fd12::1]/path").unwrap_err();
1413 assert!(matches!(err, ToolError::Blocked { .. }));
1414 }
1415
1416 #[test]
1417 fn ipv4_mapped_ipv6_loopback_blocked() {
1418 let err = validate_url("https://[::ffff:127.0.0.1]/path").unwrap_err();
1419 assert!(matches!(err, ToolError::Blocked { .. }));
1420 }
1421
1422 #[test]
1423 fn ipv4_mapped_ipv6_private_blocked() {
1424 let err = validate_url("https://[::ffff:10.0.0.1]/path").unwrap_err();
1425 assert!(matches!(err, ToolError::Blocked { .. }));
1426 }
1427
1428 #[tokio::test]
1431 async fn executor_no_blocks_returns_none() {
1432 let config = ScrapeConfig::default();
1433 let executor = WebScrapeExecutor::new(&config);
1434 let result = executor.execute("plain text").await;
1435 assert!(result.unwrap().is_none());
1436 }
1437
1438 #[tokio::test]
1439 async fn executor_invalid_json_errors() {
1440 let config = ScrapeConfig::default();
1441 let executor = WebScrapeExecutor::new(&config);
1442 let response = "```scrape\nnot json\n```";
1443 let result = executor.execute(response).await;
1444 assert!(matches!(result, Err(ToolError::Execution(_))));
1445 }
1446
1447 #[tokio::test]
1448 async fn executor_blocked_url_errors() {
1449 let config = ScrapeConfig::default();
1450 let executor = WebScrapeExecutor::new(&config);
1451 let response = "```scrape\n{\"url\":\"http://example.com\",\"select\":\"h1\"}\n```";
1452 let result = executor.execute(response).await;
1453 assert!(matches!(result, Err(ToolError::Blocked { .. })));
1454 }
1455
1456 #[tokio::test]
1457 async fn executor_private_ip_blocked() {
1458 let config = ScrapeConfig::default();
1459 let executor = WebScrapeExecutor::new(&config);
1460 let response = "```scrape\n{\"url\":\"https://192.168.1.1/api\",\"select\":\"h1\"}\n```";
1461 let result = executor.execute(response).await;
1462 assert!(matches!(result, Err(ToolError::Blocked { .. })));
1463 }
1464
1465 #[tokio::test]
1466 async fn executor_unreachable_host_returns_error() {
1467 let config = ScrapeConfig {
1468 timeout: 1,
1469 max_body_bytes: 1_048_576,
1470 ..Default::default()
1471 };
1472 let executor = WebScrapeExecutor::new(&config);
1473 let response = "```scrape\n{\"url\":\"https://192.0.2.1:1/page\",\"select\":\"h1\"}\n```";
1474 let result = executor.execute(response).await;
1475 assert!(matches!(result, Err(ToolError::Execution(_))));
1476 }
1477
1478 #[tokio::test]
1479 async fn executor_localhost_url_blocked() {
1480 let config = ScrapeConfig::default();
1481 let executor = WebScrapeExecutor::new(&config);
1482 let response = "```scrape\n{\"url\":\"https://localhost:9999/api\",\"select\":\"h1\"}\n```";
1483 let result = executor.execute(response).await;
1484 assert!(matches!(result, Err(ToolError::Blocked { .. })));
1485 }
1486
1487 #[tokio::test]
1488 async fn executor_empty_text_returns_none() {
1489 let config = ScrapeConfig::default();
1490 let executor = WebScrapeExecutor::new(&config);
1491 let result = executor.execute("").await;
1492 assert!(result.unwrap().is_none());
1493 }
1494
1495 #[tokio::test]
1496 async fn executor_multiple_blocks_first_blocked() {
1497 let config = ScrapeConfig::default();
1498 let executor = WebScrapeExecutor::new(&config);
1499 let response = "```scrape\n{\"url\":\"http://evil.com\",\"select\":\"h1\"}\n```\n\
1500 ```scrape\n{\"url\":\"https://ok.com\",\"select\":\"h1\"}\n```";
1501 let result = executor.execute(response).await;
1502 assert!(result.is_err());
1503 }
1504
1505 #[test]
1506 fn validate_url_empty_string() {
1507 let err = validate_url("").unwrap_err();
1508 assert!(matches!(err, ToolError::Blocked { .. }));
1509 }
1510
1511 #[test]
1512 fn validate_url_javascript_scheme_blocked() {
1513 let err = validate_url("javascript:alert(1)").unwrap_err();
1514 assert!(matches!(err, ToolError::Blocked { .. }));
1515 }
1516
1517 #[test]
1518 fn validate_url_data_scheme_blocked() {
1519 let err = validate_url("data:text/html,<h1>hi</h1>").unwrap_err();
1520 assert!(matches!(err, ToolError::Blocked { .. }));
1521 }
1522
1523 #[test]
1524 fn is_private_host_public_domain_is_false() {
1525 let host: url::Host<&str> = url::Host::Domain("example.com");
1526 assert!(!is_private_host(&host));
1527 }
1528
1529 #[test]
1530 fn is_private_host_localhost_is_true() {
1531 let host: url::Host<&str> = url::Host::Domain("localhost");
1532 assert!(is_private_host(&host));
1533 }
1534
1535 #[test]
1536 fn is_private_host_ipv6_unspecified_is_true() {
1537 let host = url::Host::Ipv6(std::net::Ipv6Addr::UNSPECIFIED);
1538 assert!(is_private_host(&host));
1539 }
1540
1541 #[test]
1542 fn is_private_host_public_ipv6_is_false() {
1543 let host = url::Host::Ipv6("2001:db8::1".parse().unwrap());
1544 assert!(!is_private_host(&host));
1545 }
1546
1547 async fn mock_server_executor() -> (WebScrapeExecutor, wiremock::MockServer) {
1558 let server = wiremock::MockServer::start().await;
1559 let executor = WebScrapeExecutor {
1560 timeout: Duration::from_secs(5),
1561 max_body_bytes: 1_048_576,
1562 allowed_domains: vec![],
1563 denied_domains: vec![],
1564 audit_logger: None,
1565 egress_config: EgressConfig::default(),
1566 egress_tx: None,
1567 egress_dropped: Arc::new(AtomicU64::new(0)),
1568 };
1569 (executor, server)
1570 }
1571
1572 fn server_host_and_addr(server: &wiremock::MockServer) -> (String, Vec<std::net::SocketAddr>) {
1574 let uri = server.uri();
1575 let url = Url::parse(&uri).unwrap();
1576 let host = url.host_str().unwrap_or("127.0.0.1").to_owned();
1577 let port = url.port().unwrap_or(80);
1578 let addr: std::net::SocketAddr = format!("{host}:{port}").parse().unwrap();
1579 (host, vec![addr])
1580 }
1581
1582 async fn follow_redirects_raw(
1586 executor: &WebScrapeExecutor,
1587 start_url: &str,
1588 host: &str,
1589 addrs: &[std::net::SocketAddr],
1590 ) -> Result<String, ToolError> {
1591 const MAX_REDIRECTS: usize = 3;
1592 let mut current_url = start_url.to_owned();
1593 let mut current_host = host.to_owned();
1594 let mut current_addrs = addrs.to_vec();
1595
1596 for hop in 0..=MAX_REDIRECTS {
1597 let client = executor.build_client(¤t_host, ¤t_addrs);
1598 let resp = client
1599 .get(¤t_url)
1600 .send()
1601 .await
1602 .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())))?;
1603
1604 let status = resp.status();
1605
1606 if status.is_redirection() {
1607 if hop == MAX_REDIRECTS {
1608 return Err(ToolError::Execution(std::io::Error::other(
1609 "too many redirects",
1610 )));
1611 }
1612
1613 let location = resp
1614 .headers()
1615 .get(reqwest::header::LOCATION)
1616 .and_then(|v| v.to_str().ok())
1617 .ok_or_else(|| {
1618 ToolError::Execution(std::io::Error::other("redirect with no Location"))
1619 })?;
1620
1621 let base = Url::parse(¤t_url)
1622 .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())))?;
1623 let next_url = base
1624 .join(location)
1625 .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())))?;
1626
1627 current_url = next_url.to_string();
1629 let _ = &mut current_host;
1631 let _ = &mut current_addrs;
1632 continue;
1633 }
1634
1635 if !status.is_success() {
1636 return Err(ToolError::Execution(std::io::Error::other(format!(
1637 "HTTP {status}",
1638 ))));
1639 }
1640
1641 let bytes = resp
1642 .bytes()
1643 .await
1644 .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())))?;
1645
1646 if bytes.len() > executor.max_body_bytes {
1647 return Err(ToolError::Execution(std::io::Error::other(format!(
1648 "response too large: {} bytes (max: {})",
1649 bytes.len(),
1650 executor.max_body_bytes,
1651 ))));
1652 }
1653
1654 return String::from_utf8(bytes.to_vec())
1655 .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())));
1656 }
1657
1658 Err(ToolError::Execution(std::io::Error::other(
1659 "too many redirects",
1660 )))
1661 }
1662
1663 #[tokio::test]
1664 async fn fetch_html_success_returns_body() {
1665 use wiremock::matchers::{method, path};
1666 use wiremock::{Mock, ResponseTemplate};
1667
1668 let (executor, server) = mock_server_executor().await;
1669 Mock::given(method("GET"))
1670 .and(path("/page"))
1671 .respond_with(ResponseTemplate::new(200).set_body_string("<h1>OK</h1>"))
1672 .mount(&server)
1673 .await;
1674
1675 let (host, addrs) = server_host_and_addr(&server);
1676 let url = format!("{}/page", server.uri());
1677 let result = executor
1678 .fetch_html(&url, &host, &addrs, "fetch", "test-cid", None)
1679 .await;
1680 assert!(result.is_ok(), "expected Ok, got: {result:?}");
1681 assert_eq!(result.unwrap(), "<h1>OK</h1>");
1682 }
1683
1684 #[tokio::test]
1685 async fn fetch_html_non_2xx_returns_error() {
1686 use wiremock::matchers::{method, path};
1687 use wiremock::{Mock, ResponseTemplate};
1688
1689 let (executor, server) = mock_server_executor().await;
1690 Mock::given(method("GET"))
1691 .and(path("/forbidden"))
1692 .respond_with(ResponseTemplate::new(403))
1693 .mount(&server)
1694 .await;
1695
1696 let (host, addrs) = server_host_and_addr(&server);
1697 let url = format!("{}/forbidden", server.uri());
1698 let result = executor
1699 .fetch_html(&url, &host, &addrs, "fetch", "test-cid", None)
1700 .await;
1701 assert!(result.is_err());
1702 let msg = result.unwrap_err().to_string();
1703 assert!(msg.contains("403"), "expected 403 in error: {msg}");
1704 }
1705
1706 #[tokio::test]
1707 async fn fetch_html_404_returns_error() {
1708 use wiremock::matchers::{method, path};
1709 use wiremock::{Mock, ResponseTemplate};
1710
1711 let (executor, server) = mock_server_executor().await;
1712 Mock::given(method("GET"))
1713 .and(path("/missing"))
1714 .respond_with(ResponseTemplate::new(404))
1715 .mount(&server)
1716 .await;
1717
1718 let (host, addrs) = server_host_and_addr(&server);
1719 let url = format!("{}/missing", server.uri());
1720 let result = executor
1721 .fetch_html(&url, &host, &addrs, "fetch", "test-cid", None)
1722 .await;
1723 assert!(result.is_err());
1724 let msg = result.unwrap_err().to_string();
1725 assert!(msg.contains("404"), "expected 404 in error: {msg}");
1726 }
1727
1728 #[tokio::test]
1729 async fn fetch_html_redirect_no_location_returns_error() {
1730 use wiremock::matchers::{method, path};
1731 use wiremock::{Mock, ResponseTemplate};
1732
1733 let (executor, server) = mock_server_executor().await;
1734 Mock::given(method("GET"))
1736 .and(path("/redirect-no-loc"))
1737 .respond_with(ResponseTemplate::new(302))
1738 .mount(&server)
1739 .await;
1740
1741 let (host, addrs) = server_host_and_addr(&server);
1742 let url = format!("{}/redirect-no-loc", server.uri());
1743 let result = executor
1744 .fetch_html(&url, &host, &addrs, "fetch", "test-cid", None)
1745 .await;
1746 assert!(result.is_err());
1747 let msg = result.unwrap_err().to_string();
1748 assert!(
1749 msg.contains("Location") || msg.contains("location"),
1750 "expected Location-related error: {msg}"
1751 );
1752 }
1753
1754 #[tokio::test]
1755 async fn fetch_html_single_redirect_followed() {
1756 use wiremock::matchers::{method, path};
1757 use wiremock::{Mock, ResponseTemplate};
1758
1759 let (executor, server) = mock_server_executor().await;
1760 let final_url = format!("{}/final", server.uri());
1761
1762 Mock::given(method("GET"))
1763 .and(path("/start"))
1764 .respond_with(ResponseTemplate::new(302).insert_header("location", final_url.as_str()))
1765 .mount(&server)
1766 .await;
1767
1768 Mock::given(method("GET"))
1769 .and(path("/final"))
1770 .respond_with(ResponseTemplate::new(200).set_body_string("<p>final</p>"))
1771 .mount(&server)
1772 .await;
1773
1774 let (host, addrs) = server_host_and_addr(&server);
1775 let url = format!("{}/start", server.uri());
1776 let result = follow_redirects_raw(&executor, &url, &host, &addrs).await;
1777 assert!(result.is_ok(), "single redirect should succeed: {result:?}");
1778 assert_eq!(result.unwrap(), "<p>final</p>");
1779 }
1780
1781 #[tokio::test]
1782 async fn fetch_html_three_redirects_allowed() {
1783 use wiremock::matchers::{method, path};
1784 use wiremock::{Mock, ResponseTemplate};
1785
1786 let (executor, server) = mock_server_executor().await;
1787 let hop2 = format!("{}/hop2", server.uri());
1788 let hop3 = format!("{}/hop3", server.uri());
1789 let final_dest = format!("{}/done", server.uri());
1790
1791 Mock::given(method("GET"))
1792 .and(path("/hop1"))
1793 .respond_with(ResponseTemplate::new(301).insert_header("location", hop2.as_str()))
1794 .mount(&server)
1795 .await;
1796 Mock::given(method("GET"))
1797 .and(path("/hop2"))
1798 .respond_with(ResponseTemplate::new(301).insert_header("location", hop3.as_str()))
1799 .mount(&server)
1800 .await;
1801 Mock::given(method("GET"))
1802 .and(path("/hop3"))
1803 .respond_with(ResponseTemplate::new(301).insert_header("location", final_dest.as_str()))
1804 .mount(&server)
1805 .await;
1806 Mock::given(method("GET"))
1807 .and(path("/done"))
1808 .respond_with(ResponseTemplate::new(200).set_body_string("<p>done</p>"))
1809 .mount(&server)
1810 .await;
1811
1812 let (host, addrs) = server_host_and_addr(&server);
1813 let url = format!("{}/hop1", server.uri());
1814 let result = follow_redirects_raw(&executor, &url, &host, &addrs).await;
1815 assert!(result.is_ok(), "3 redirects should succeed: {result:?}");
1816 assert_eq!(result.unwrap(), "<p>done</p>");
1817 }
1818
1819 #[tokio::test]
1820 async fn fetch_html_four_redirects_rejected() {
1821 use wiremock::matchers::{method, path};
1822 use wiremock::{Mock, ResponseTemplate};
1823
1824 let (executor, server) = mock_server_executor().await;
1825 let hop2 = format!("{}/r2", server.uri());
1826 let hop3 = format!("{}/r3", server.uri());
1827 let hop4 = format!("{}/r4", server.uri());
1828 let hop5 = format!("{}/r5", server.uri());
1829
1830 for (from, to) in [
1831 ("/r1", &hop2),
1832 ("/r2", &hop3),
1833 ("/r3", &hop4),
1834 ("/r4", &hop5),
1835 ] {
1836 Mock::given(method("GET"))
1837 .and(path(from))
1838 .respond_with(ResponseTemplate::new(301).insert_header("location", to.as_str()))
1839 .mount(&server)
1840 .await;
1841 }
1842
1843 let (host, addrs) = server_host_and_addr(&server);
1844 let url = format!("{}/r1", server.uri());
1845 let result = follow_redirects_raw(&executor, &url, &host, &addrs).await;
1846 assert!(result.is_err(), "4 redirects should be rejected");
1847 let msg = result.unwrap_err().to_string();
1848 assert!(
1849 msg.contains("redirect"),
1850 "expected redirect-related error: {msg}"
1851 );
1852 }
1853
1854 #[tokio::test]
1855 async fn fetch_html_body_too_large_returns_error() {
1856 use wiremock::matchers::{method, path};
1857 use wiremock::{Mock, ResponseTemplate};
1858
1859 let small_limit_executor = WebScrapeExecutor {
1860 timeout: Duration::from_secs(5),
1861 max_body_bytes: 10,
1862 allowed_domains: vec![],
1863 denied_domains: vec![],
1864 audit_logger: None,
1865 egress_config: EgressConfig::default(),
1866 egress_tx: None,
1867 egress_dropped: Arc::new(AtomicU64::new(0)),
1868 };
1869 let server = wiremock::MockServer::start().await;
1870 Mock::given(method("GET"))
1871 .and(path("/big"))
1872 .respond_with(
1873 ResponseTemplate::new(200)
1874 .set_body_string("this body is definitely longer than ten bytes"),
1875 )
1876 .mount(&server)
1877 .await;
1878
1879 let (host, addrs) = server_host_and_addr(&server);
1880 let url = format!("{}/big", server.uri());
1881 let result = small_limit_executor
1882 .fetch_html(&url, &host, &addrs, "fetch", "test-cid", None)
1883 .await;
1884 assert!(result.is_err());
1885 let msg = result.unwrap_err().to_string();
1886 assert!(msg.contains("too large"), "expected too-large error: {msg}");
1887 }
1888
1889 #[test]
1890 fn extract_scrape_blocks_empty_block_content() {
1891 let text = "```scrape\n\n```";
1892 let blocks = extract_scrape_blocks(text);
1893 assert_eq!(blocks.len(), 1);
1894 assert!(blocks[0].is_empty());
1895 }
1896
1897 #[test]
1898 fn extract_scrape_blocks_whitespace_only() {
1899 let text = "```scrape\n \n```";
1900 let blocks = extract_scrape_blocks(text);
1901 assert_eq!(blocks.len(), 1);
1902 }
1903
1904 #[test]
1905 fn parse_and_extract_multiple_selectors() {
1906 let html = "<div><h1>Title</h1><p>Para</p></div>";
1907 let result = parse_and_extract(html, "h1, p", &ExtractMode::Text, 10).unwrap();
1908 assert!(result.contains("Title"));
1909 assert!(result.contains("Para"));
1910 }
1911
1912 #[test]
1913 fn webscrape_executor_new_with_custom_config() {
1914 let config = ScrapeConfig {
1915 timeout: 60,
1916 max_body_bytes: 512,
1917 ..Default::default()
1918 };
1919 let executor = WebScrapeExecutor::new(&config);
1920 assert_eq!(executor.max_body_bytes, 512);
1921 }
1922
1923 #[test]
1924 fn webscrape_executor_debug() {
1925 let config = ScrapeConfig::default();
1926 let executor = WebScrapeExecutor::new(&config);
1927 let dbg = format!("{executor:?}");
1928 assert!(dbg.contains("WebScrapeExecutor"));
1929 }
1930
1931 #[test]
1932 fn extract_mode_attr_empty_name() {
1933 let mode = ExtractMode::parse("attr:");
1934 assert!(matches!(mode, ExtractMode::Attr(ref s) if s.is_empty()));
1935 }
1936
1937 #[test]
1938 fn default_extract_returns_text() {
1939 assert_eq!(default_extract(), "text");
1940 }
1941
1942 #[test]
1943 fn scrape_instruction_debug() {
1944 let json = r#"{"url":"https://example.com","select":"h1"}"#;
1945 let instr: ScrapeInstruction = serde_json::from_str(json).unwrap();
1946 let dbg = format!("{instr:?}");
1947 assert!(dbg.contains("ScrapeInstruction"));
1948 }
1949
1950 #[test]
1951 fn extract_mode_debug() {
1952 let mode = ExtractMode::Text;
1953 let dbg = format!("{mode:?}");
1954 assert!(dbg.contains("Text"));
1955 }
1956
1957 #[test]
1962 fn max_redirects_constant_is_three() {
1963 const MAX_REDIRECTS: usize = 3;
1967 assert_eq!(MAX_REDIRECTS, 3, "fetch_html allows exactly 3 redirects");
1968 }
1969
1970 #[test]
1973 fn redirect_no_location_error_message() {
1974 let err = std::io::Error::other("redirect with no Location");
1975 assert!(err.to_string().contains("redirect with no Location"));
1976 }
1977
1978 #[test]
1980 fn too_many_redirects_error_message() {
1981 let err = std::io::Error::other("too many redirects");
1982 assert!(err.to_string().contains("too many redirects"));
1983 }
1984
1985 #[test]
1987 fn non_2xx_status_error_format() {
1988 let status = reqwest::StatusCode::FORBIDDEN;
1989 let msg = format!("HTTP {status}");
1990 assert!(msg.contains("403"));
1991 }
1992
1993 #[test]
1995 fn not_found_status_error_format() {
1996 let status = reqwest::StatusCode::NOT_FOUND;
1997 let msg = format!("HTTP {status}");
1998 assert!(msg.contains("404"));
1999 }
2000
2001 #[test]
2003 fn relative_redirect_same_host_path() {
2004 let base = Url::parse("https://example.com/current").unwrap();
2005 let resolved = base.join("/other").unwrap();
2006 assert_eq!(resolved.as_str(), "https://example.com/other");
2007 }
2008
2009 #[test]
2011 fn relative_redirect_relative_path() {
2012 let base = Url::parse("https://example.com/a/b").unwrap();
2013 let resolved = base.join("c").unwrap();
2014 assert_eq!(resolved.as_str(), "https://example.com/a/c");
2015 }
2016
2017 #[test]
2019 fn absolute_redirect_overrides_base() {
2020 let base = Url::parse("https://example.com/page").unwrap();
2021 let resolved = base.join("https://other.com/target").unwrap();
2022 assert_eq!(resolved.as_str(), "https://other.com/target");
2023 }
2024
2025 #[test]
2027 fn redirect_http_downgrade_rejected() {
2028 let location = "http://example.com/page";
2029 let base = Url::parse("https://example.com/start").unwrap();
2030 let next = base.join(location).unwrap();
2031 let err = validate_url(next.as_str()).unwrap_err();
2032 assert!(matches!(err, ToolError::Blocked { .. }));
2033 }
2034
2035 #[test]
2037 fn redirect_location_private_ip_blocked() {
2038 let location = "https://192.168.100.1/admin";
2039 let base = Url::parse("https://example.com/start").unwrap();
2040 let next = base.join(location).unwrap();
2041 let err = validate_url(next.as_str()).unwrap_err();
2042 assert!(matches!(err, ToolError::Blocked { .. }));
2043 let ToolError::Blocked { command: cmd } = err else {
2044 panic!("expected Blocked");
2045 };
2046 assert!(
2047 cmd.contains("private") || cmd.contains("scheme"),
2048 "error message should describe the block reason: {cmd}"
2049 );
2050 }
2051
2052 #[test]
2054 fn redirect_location_internal_domain_blocked() {
2055 let location = "https://metadata.internal/latest/meta-data/";
2056 let base = Url::parse("https://example.com/start").unwrap();
2057 let next = base.join(location).unwrap();
2058 let err = validate_url(next.as_str()).unwrap_err();
2059 assert!(matches!(err, ToolError::Blocked { .. }));
2060 }
2061
2062 #[test]
2064 fn redirect_chain_three_hops_all_public() {
2065 let hops = [
2066 "https://redirect1.example.com/hop1",
2067 "https://redirect2.example.com/hop2",
2068 "https://destination.example.com/final",
2069 ];
2070 for hop in hops {
2071 assert!(validate_url(hop).is_ok(), "expected ok for {hop}");
2072 }
2073 }
2074
2075 #[test]
2080 fn redirect_to_private_ip_rejected_by_validate_url() {
2081 let private_targets = [
2083 "https://127.0.0.1/secret",
2084 "https://10.0.0.1/internal",
2085 "https://192.168.1.1/admin",
2086 "https://172.16.0.1/data",
2087 "https://[::1]/path",
2088 "https://[fe80::1]/path",
2089 "https://localhost/path",
2090 "https://service.internal/api",
2091 ];
2092 for target in private_targets {
2093 let result = validate_url(target);
2094 assert!(result.is_err(), "expected error for {target}");
2095 assert!(
2096 matches!(result.unwrap_err(), ToolError::Blocked { .. }),
2097 "expected Blocked for {target}"
2098 );
2099 }
2100 }
2101
2102 #[test]
2104 fn redirect_relative_url_resolves_correctly() {
2105 let base = Url::parse("https://example.com/page").unwrap();
2106 let relative = "/other";
2107 let resolved = base.join(relative).unwrap();
2108 assert_eq!(resolved.as_str(), "https://example.com/other");
2109 }
2110
2111 #[test]
2113 fn redirect_to_http_rejected() {
2114 let err = validate_url("http://example.com/page").unwrap_err();
2115 assert!(matches!(err, ToolError::Blocked { .. }));
2116 }
2117
2118 #[test]
2119 fn ipv4_mapped_ipv6_link_local_blocked() {
2120 let err = validate_url("https://[::ffff:169.254.0.1]/path").unwrap_err();
2121 assert!(matches!(err, ToolError::Blocked { .. }));
2122 }
2123
2124 #[test]
2125 fn ipv4_mapped_ipv6_public_allowed() {
2126 assert!(validate_url("https://[::ffff:93.184.216.34]/path").is_ok());
2127 }
2128
2129 #[tokio::test]
2132 async fn fetch_http_scheme_blocked() {
2133 let config = ScrapeConfig::default();
2134 let executor = WebScrapeExecutor::new(&config);
2135 let call = crate::executor::ToolCall {
2136 tool_id: ToolName::new("fetch"),
2137 params: {
2138 let mut m = serde_json::Map::new();
2139 m.insert("url".to_owned(), serde_json::json!("http://example.com"));
2140 m
2141 },
2142 caller_id: None,
2143 };
2144 let result = executor.execute_tool_call(&call).await;
2145 assert!(matches!(result, Err(ToolError::Blocked { .. })));
2146 }
2147
2148 #[tokio::test]
2149 async fn fetch_private_ip_blocked() {
2150 let config = ScrapeConfig::default();
2151 let executor = WebScrapeExecutor::new(&config);
2152 let call = crate::executor::ToolCall {
2153 tool_id: ToolName::new("fetch"),
2154 params: {
2155 let mut m = serde_json::Map::new();
2156 m.insert(
2157 "url".to_owned(),
2158 serde_json::json!("https://192.168.1.1/secret"),
2159 );
2160 m
2161 },
2162 caller_id: None,
2163 };
2164 let result = executor.execute_tool_call(&call).await;
2165 assert!(matches!(result, Err(ToolError::Blocked { .. })));
2166 }
2167
2168 #[tokio::test]
2169 async fn fetch_localhost_blocked() {
2170 let config = ScrapeConfig::default();
2171 let executor = WebScrapeExecutor::new(&config);
2172 let call = crate::executor::ToolCall {
2173 tool_id: ToolName::new("fetch"),
2174 params: {
2175 let mut m = serde_json::Map::new();
2176 m.insert(
2177 "url".to_owned(),
2178 serde_json::json!("https://localhost/page"),
2179 );
2180 m
2181 },
2182 caller_id: None,
2183 };
2184 let result = executor.execute_tool_call(&call).await;
2185 assert!(matches!(result, Err(ToolError::Blocked { .. })));
2186 }
2187
2188 #[tokio::test]
2189 async fn fetch_unknown_tool_returns_none() {
2190 let config = ScrapeConfig::default();
2191 let executor = WebScrapeExecutor::new(&config);
2192 let call = crate::executor::ToolCall {
2193 tool_id: ToolName::new("unknown_tool"),
2194 params: serde_json::Map::new(),
2195 caller_id: None,
2196 };
2197 let result = executor.execute_tool_call(&call).await;
2198 assert!(result.unwrap().is_none());
2199 }
2200
2201 #[tokio::test]
2202 async fn fetch_returns_body_via_mock() {
2203 use wiremock::matchers::{method, path};
2204 use wiremock::{Mock, ResponseTemplate};
2205
2206 let (executor, server) = mock_server_executor().await;
2207 Mock::given(method("GET"))
2208 .and(path("/content"))
2209 .respond_with(ResponseTemplate::new(200).set_body_string("plain text content"))
2210 .mount(&server)
2211 .await;
2212
2213 let (host, addrs) = server_host_and_addr(&server);
2214 let url = format!("{}/content", server.uri());
2215 let result = executor
2216 .fetch_html(&url, &host, &addrs, "fetch", "test-cid", None)
2217 .await;
2218 assert!(result.is_ok());
2219 assert_eq!(result.unwrap(), "plain text content");
2220 }
2221
2222 #[test]
2223 fn tool_definitions_returns_web_scrape_and_fetch() {
2224 let config = ScrapeConfig::default();
2225 let executor = WebScrapeExecutor::new(&config);
2226 let defs = executor.tool_definitions();
2227 assert_eq!(defs.len(), 2);
2228 assert_eq!(defs[0].id, "web_scrape");
2229 assert_eq!(
2230 defs[0].invocation,
2231 crate::registry::InvocationHint::FencedBlock("scrape")
2232 );
2233 assert_eq!(defs[1].id, "fetch");
2234 assert_eq!(
2235 defs[1].invocation,
2236 crate::registry::InvocationHint::ToolCall
2237 );
2238 }
2239
2240 #[test]
2241 fn tool_definitions_schema_has_all_params() {
2242 let config = ScrapeConfig::default();
2243 let executor = WebScrapeExecutor::new(&config);
2244 let defs = executor.tool_definitions();
2245 let obj = defs[0].schema.as_object().unwrap();
2246 let props = obj["properties"].as_object().unwrap();
2247 assert!(props.contains_key("url"));
2248 assert!(props.contains_key("select"));
2249 assert!(props.contains_key("extract"));
2250 assert!(props.contains_key("limit"));
2251 let req = obj["required"].as_array().unwrap();
2252 assert!(req.iter().any(|v| v.as_str() == Some("url")));
2253 assert!(req.iter().any(|v| v.as_str() == Some("select")));
2254 assert!(!req.iter().any(|v| v.as_str() == Some("extract")));
2255 }
2256
2257 #[test]
2260 fn subdomain_localhost_blocked() {
2261 let host: url::Host<&str> = url::Host::Domain("foo.localhost");
2262 assert!(is_private_host(&host));
2263 }
2264
2265 #[test]
2266 fn internal_tld_blocked() {
2267 let host: url::Host<&str> = url::Host::Domain("service.internal");
2268 assert!(is_private_host(&host));
2269 }
2270
2271 #[test]
2272 fn local_tld_blocked() {
2273 let host: url::Host<&str> = url::Host::Domain("printer.local");
2274 assert!(is_private_host(&host));
2275 }
2276
2277 #[test]
2278 fn public_domain_not_blocked() {
2279 let host: url::Host<&str> = url::Host::Domain("example.com");
2280 assert!(!is_private_host(&host));
2281 }
2282
2283 #[tokio::test]
2286 async fn resolve_loopback_rejected() {
2287 let url = url::Url::parse("https://127.0.0.1/path").unwrap();
2289 let result = resolve_and_validate(&url).await;
2291 assert!(
2292 result.is_err(),
2293 "loopback IP must be rejected by resolve_and_validate"
2294 );
2295 let err = result.unwrap_err();
2296 assert!(matches!(err, crate::executor::ToolError::Blocked { .. }));
2297 }
2298
2299 #[tokio::test]
2300 async fn resolve_private_10_rejected() {
2301 let url = url::Url::parse("https://10.0.0.1/path").unwrap();
2302 let result = resolve_and_validate(&url).await;
2303 assert!(result.is_err());
2304 assert!(matches!(
2305 result.unwrap_err(),
2306 crate::executor::ToolError::Blocked { .. }
2307 ));
2308 }
2309
2310 #[tokio::test]
2311 async fn resolve_private_192_rejected() {
2312 let url = url::Url::parse("https://192.168.1.1/path").unwrap();
2313 let result = resolve_and_validate(&url).await;
2314 assert!(result.is_err());
2315 assert!(matches!(
2316 result.unwrap_err(),
2317 crate::executor::ToolError::Blocked { .. }
2318 ));
2319 }
2320
2321 #[tokio::test]
2322 async fn resolve_ipv6_loopback_rejected() {
2323 let url = url::Url::parse("https://[::1]/path").unwrap();
2324 let result = resolve_and_validate(&url).await;
2325 assert!(result.is_err());
2326 assert!(matches!(
2327 result.unwrap_err(),
2328 crate::executor::ToolError::Blocked { .. }
2329 ));
2330 }
2331
2332 #[tokio::test]
2333 async fn resolve_no_host_returns_ok() {
2334 let url = url::Url::parse("https://example.com/path").unwrap();
2336 let url_no_host = url::Url::parse("data:text/plain,hello").unwrap();
2338 let result = resolve_and_validate(&url_no_host).await;
2340 assert!(result.is_ok());
2341 let (host, addrs) = result.unwrap();
2342 assert!(host.is_empty());
2343 assert!(addrs.is_empty());
2344 drop(url);
2345 drop(url_no_host);
2346 }
2347
2348 async fn make_file_audit_logger(
2352 dir: &tempfile::TempDir,
2353 ) -> (
2354 std::sync::Arc<crate::audit::AuditLogger>,
2355 std::path::PathBuf,
2356 ) {
2357 use crate::audit::AuditLogger;
2358 use crate::config::AuditConfig;
2359 let path = dir.path().join("audit.log");
2360 let config = AuditConfig {
2361 enabled: true,
2362 destination: path.display().to_string(),
2363 ..Default::default()
2364 };
2365 let logger = std::sync::Arc::new(AuditLogger::from_config(&config, false).await.unwrap());
2366 (logger, path)
2367 }
2368
2369 #[tokio::test]
2370 async fn with_audit_sets_logger() {
2371 let config = ScrapeConfig::default();
2372 let executor = WebScrapeExecutor::new(&config);
2373 assert!(executor.audit_logger.is_none());
2374
2375 let dir = tempfile::tempdir().unwrap();
2376 let (logger, _path) = make_file_audit_logger(&dir).await;
2377 let executor = executor.with_audit(logger);
2378 assert!(executor.audit_logger.is_some());
2379 }
2380
2381 #[test]
2382 fn tool_error_to_audit_result_blocked_maps_correctly() {
2383 let err = ToolError::Blocked {
2384 command: "scheme not allowed: http".into(),
2385 };
2386 let result = tool_error_to_audit_result(&err);
2387 assert!(
2388 matches!(result, AuditResult::Blocked { reason } if reason == "scheme not allowed: http")
2389 );
2390 }
2391
2392 #[test]
2393 fn tool_error_to_audit_result_timeout_maps_correctly() {
2394 let err = ToolError::Timeout { timeout_secs: 15 };
2395 let result = tool_error_to_audit_result(&err);
2396 assert!(matches!(result, AuditResult::Timeout));
2397 }
2398
2399 #[test]
2400 fn tool_error_to_audit_result_execution_error_maps_correctly() {
2401 let err = ToolError::Execution(std::io::Error::other("connection refused"));
2402 let result = tool_error_to_audit_result(&err);
2403 assert!(
2404 matches!(result, AuditResult::Error { message } if message.contains("connection refused"))
2405 );
2406 }
2407
2408 #[tokio::test]
2409 async fn fetch_audit_blocked_url_logged() {
2410 let dir = tempfile::tempdir().unwrap();
2411 let (logger, log_path) = make_file_audit_logger(&dir).await;
2412
2413 let config = ScrapeConfig::default();
2414 let executor = WebScrapeExecutor::new(&config).with_audit(logger);
2415
2416 let call = crate::executor::ToolCall {
2417 tool_id: ToolName::new("fetch"),
2418 params: {
2419 let mut m = serde_json::Map::new();
2420 m.insert("url".to_owned(), serde_json::json!("http://example.com"));
2421 m
2422 },
2423 caller_id: None,
2424 };
2425 let result = executor.execute_tool_call(&call).await;
2426 assert!(matches!(result, Err(ToolError::Blocked { .. })));
2427
2428 let content = tokio::fs::read_to_string(&log_path).await.unwrap();
2429 assert!(
2430 content.contains("\"tool\":\"fetch\""),
2431 "expected tool=fetch in audit: {content}"
2432 );
2433 assert!(
2434 content.contains("\"type\":\"blocked\""),
2435 "expected type=blocked in audit: {content}"
2436 );
2437 assert!(
2438 content.contains("http://example.com"),
2439 "expected URL in audit command field: {content}"
2440 );
2441 }
2442
2443 #[tokio::test]
2444 async fn log_audit_success_writes_to_file() {
2445 let dir = tempfile::tempdir().unwrap();
2446 let (logger, log_path) = make_file_audit_logger(&dir).await;
2447
2448 let config = ScrapeConfig::default();
2449 let executor = WebScrapeExecutor::new(&config).with_audit(logger);
2450
2451 executor
2452 .log_audit(
2453 "fetch",
2454 "https://example.com/page",
2455 AuditResult::Success,
2456 42,
2457 None,
2458 None,
2459 None,
2460 )
2461 .await;
2462
2463 let content = tokio::fs::read_to_string(&log_path).await.unwrap();
2464 assert!(
2465 content.contains("\"tool\":\"fetch\""),
2466 "expected tool=fetch in audit: {content}"
2467 );
2468 assert!(
2469 content.contains("\"type\":\"success\""),
2470 "expected type=success in audit: {content}"
2471 );
2472 assert!(
2473 content.contains("\"command\":\"https://example.com/page\""),
2474 "expected command URL in audit: {content}"
2475 );
2476 assert!(
2477 content.contains("\"duration_ms\":42"),
2478 "expected duration_ms in audit: {content}"
2479 );
2480 }
2481
2482 #[tokio::test]
2483 async fn log_audit_blocked_writes_to_file() {
2484 let dir = tempfile::tempdir().unwrap();
2485 let (logger, log_path) = make_file_audit_logger(&dir).await;
2486
2487 let config = ScrapeConfig::default();
2488 let executor = WebScrapeExecutor::new(&config).with_audit(logger);
2489
2490 executor
2491 .log_audit(
2492 "web_scrape",
2493 "http://evil.com/page",
2494 AuditResult::Blocked {
2495 reason: "scheme not allowed: http".into(),
2496 },
2497 0,
2498 None,
2499 None,
2500 None,
2501 )
2502 .await;
2503
2504 let content = tokio::fs::read_to_string(&log_path).await.unwrap();
2505 assert!(
2506 content.contains("\"tool\":\"web_scrape\""),
2507 "expected tool=web_scrape in audit: {content}"
2508 );
2509 assert!(
2510 content.contains("\"type\":\"blocked\""),
2511 "expected type=blocked in audit: {content}"
2512 );
2513 assert!(
2514 content.contains("scheme not allowed"),
2515 "expected block reason in audit: {content}"
2516 );
2517 }
2518
2519 #[tokio::test]
2520 async fn web_scrape_audit_blocked_url_logged() {
2521 let dir = tempfile::tempdir().unwrap();
2522 let (logger, log_path) = make_file_audit_logger(&dir).await;
2523
2524 let config = ScrapeConfig::default();
2525 let executor = WebScrapeExecutor::new(&config).with_audit(logger);
2526
2527 let call = crate::executor::ToolCall {
2528 tool_id: ToolName::new("web_scrape"),
2529 params: {
2530 let mut m = serde_json::Map::new();
2531 m.insert("url".to_owned(), serde_json::json!("http://example.com"));
2532 m.insert("select".to_owned(), serde_json::json!("h1"));
2533 m
2534 },
2535 caller_id: None,
2536 };
2537 let result = executor.execute_tool_call(&call).await;
2538 assert!(matches!(result, Err(ToolError::Blocked { .. })));
2539
2540 let content = tokio::fs::read_to_string(&log_path).await.unwrap();
2541 assert!(
2542 content.contains("\"tool\":\"web_scrape\""),
2543 "expected tool=web_scrape in audit: {content}"
2544 );
2545 assert!(
2546 content.contains("\"type\":\"blocked\""),
2547 "expected type=blocked in audit: {content}"
2548 );
2549 }
2550
2551 #[tokio::test]
2552 async fn no_audit_logger_does_not_panic_on_blocked_fetch() {
2553 let config = ScrapeConfig::default();
2554 let executor = WebScrapeExecutor::new(&config);
2555 assert!(executor.audit_logger.is_none());
2556
2557 let call = crate::executor::ToolCall {
2558 tool_id: ToolName::new("fetch"),
2559 params: {
2560 let mut m = serde_json::Map::new();
2561 m.insert("url".to_owned(), serde_json::json!("http://example.com"));
2562 m
2563 },
2564 caller_id: None,
2565 };
2566 let result = executor.execute_tool_call(&call).await;
2568 assert!(matches!(result, Err(ToolError::Blocked { .. })));
2569 }
2570
2571 #[tokio::test]
2573 async fn fetch_execute_tool_call_end_to_end() {
2574 use wiremock::matchers::{method, path};
2575 use wiremock::{Mock, ResponseTemplate};
2576
2577 let (executor, server) = mock_server_executor().await;
2578 Mock::given(method("GET"))
2579 .and(path("/e2e"))
2580 .respond_with(ResponseTemplate::new(200).set_body_string("<h1>end-to-end</h1>"))
2581 .mount(&server)
2582 .await;
2583
2584 let (host, addrs) = server_host_and_addr(&server);
2585 let result = executor
2587 .fetch_html(
2588 &format!("{}/e2e", server.uri()),
2589 &host,
2590 &addrs,
2591 "fetch",
2592 "test-cid",
2593 None,
2594 )
2595 .await;
2596 assert!(result.is_ok());
2597 assert!(result.unwrap().contains("end-to-end"));
2598 }
2599
2600 #[test]
2603 fn domain_matches_exact() {
2604 assert!(domain_matches("example.com", "example.com"));
2605 assert!(!domain_matches("example.com", "other.com"));
2606 assert!(!domain_matches("example.com", "sub.example.com"));
2607 }
2608
2609 #[test]
2610 fn domain_matches_wildcard_single_subdomain() {
2611 assert!(domain_matches("*.example.com", "sub.example.com"));
2612 assert!(!domain_matches("*.example.com", "example.com"));
2613 assert!(!domain_matches("*.example.com", "sub.sub.example.com"));
2614 }
2615
2616 #[test]
2617 fn domain_matches_wildcard_does_not_match_empty_label() {
2618 assert!(!domain_matches("*.example.com", ".example.com"));
2620 }
2621
2622 #[test]
2623 fn domain_matches_multi_wildcard_treated_as_exact() {
2624 assert!(!domain_matches("*.*.example.com", "a.b.example.com"));
2626 }
2627
2628 #[test]
2631 fn check_domain_policy_empty_lists_allow_all() {
2632 assert!(check_domain_policy("example.com", &[], &[]).is_ok());
2633 assert!(check_domain_policy("evil.com", &[], &[]).is_ok());
2634 }
2635
2636 #[test]
2637 fn check_domain_policy_denylist_blocks() {
2638 let denied = vec!["evil.com".to_string()];
2639 let err = check_domain_policy("evil.com", &[], &denied).unwrap_err();
2640 assert!(matches!(err, ToolError::Blocked { .. }));
2641 }
2642
2643 #[test]
2644 fn check_domain_policy_denylist_does_not_block_other_domains() {
2645 let denied = vec!["evil.com".to_string()];
2646 assert!(check_domain_policy("good.com", &[], &denied).is_ok());
2647 }
2648
2649 #[test]
2650 fn check_domain_policy_allowlist_permits_matching() {
2651 let allowed = vec!["docs.rs".to_string(), "*.rust-lang.org".to_string()];
2652 assert!(check_domain_policy("docs.rs", &allowed, &[]).is_ok());
2653 assert!(check_domain_policy("blog.rust-lang.org", &allowed, &[]).is_ok());
2654 }
2655
2656 #[test]
2657 fn check_domain_policy_allowlist_blocks_unknown() {
2658 let allowed = vec!["docs.rs".to_string()];
2659 let err = check_domain_policy("other.com", &allowed, &[]).unwrap_err();
2660 assert!(matches!(err, ToolError::Blocked { .. }));
2661 }
2662
2663 #[test]
2664 fn check_domain_policy_deny_overrides_allow() {
2665 let allowed = vec!["example.com".to_string()];
2666 let denied = vec!["example.com".to_string()];
2667 let err = check_domain_policy("example.com", &allowed, &denied).unwrap_err();
2668 assert!(matches!(err, ToolError::Blocked { .. }));
2669 }
2670
2671 #[test]
2672 fn check_domain_policy_wildcard_in_denylist() {
2673 let denied = vec!["*.evil.com".to_string()];
2674 let err = check_domain_policy("sub.evil.com", &[], &denied).unwrap_err();
2675 assert!(matches!(err, ToolError::Blocked { .. }));
2676 assert!(check_domain_policy("evil.com", &[], &denied).is_ok());
2678 }
2679}