1use std::collections::HashSet;
13use std::fmt::Write as _;
14use std::sync::LazyLock;
15
16use regex::Regex;
17
18pub use zeph_config::ExfiltrationGuardConfig;
19
20static MARKDOWN_IMAGE_RE: LazyLock<Regex> = LazyLock::new(|| {
30 Regex::new(r"!\[([^\]]*)\]\((https?://[^)]+)\)").expect("valid MARKDOWN_IMAGE_RE")
31});
32
33static REFERENCE_DEF_RE: LazyLock<Regex> = LazyLock::new(|| {
36 Regex::new(r"(?m)^\[([^\]]+)\]:\s*(https?://\S+)").expect("valid REFERENCE_DEF_RE")
37});
38
39static REFERENCE_USAGE_RE: LazyLock<Regex> =
41 LazyLock::new(|| Regex::new(r"!\[([^\]]*)\]\[([^\]]+)\]").expect("valid REFERENCE_USAGE_RE"));
42
43static URL_EXTRACT_RE: LazyLock<Regex> =
45 LazyLock::new(|| Regex::new(r#"https?://[^\s"'<>]+"#).expect("valid URL_EXTRACT_RE"));
46
47#[derive(Debug, Clone, PartialEq)]
53pub enum ExfiltrationEvent {
54 MarkdownImageBlocked { url: String },
56 SuspiciousToolUrl { url: String, tool_name: String },
58 MemoryWriteGuarded { reason: String },
60}
61
62#[derive(Debug, Clone)]
70pub struct ExfiltrationGuard {
71 config: ExfiltrationGuardConfig,
72}
73
74impl ExfiltrationGuard {
75 #[must_use]
77 pub fn new(config: ExfiltrationGuardConfig) -> Self {
78 Self { config }
79 }
80
81 #[must_use]
106 pub fn scan_output(&self, text: &str) -> (String, Vec<ExfiltrationEvent>) {
107 if !self.config.block_markdown_images {
108 return (text.to_owned(), vec![]);
109 }
110
111 let mut events = Vec::new();
112 let mut result = text.to_owned();
113
114 let mut replacement = String::new();
116 let mut last_end = 0usize;
117 for cap in MARKDOWN_IMAGE_RE.captures_iter(text) {
118 let m = cap.get(0).expect("full match");
119 let raw_url = cap.get(2).expect("url group").as_str();
120 let url = percent_decode_url(raw_url);
121
122 if is_external_url(&url) {
123 replacement.push_str(&text[last_end..m.start()]);
124 let _ = write!(replacement, "[image removed: {url}]");
125 last_end = m.end();
126 events.push(ExfiltrationEvent::MarkdownImageBlocked { url });
127 }
128 }
129 if !events.is_empty() || last_end > 0 {
130 replacement.push_str(&text[last_end..]);
131 result = replacement;
132 }
133
134 let mut ref_defs: std::collections::HashMap<String, String> =
137 std::collections::HashMap::new();
138 for cap in REFERENCE_DEF_RE.captures_iter(&result) {
139 let label = cap.get(1).expect("label").as_str().to_lowercase();
140 let raw_url = cap.get(2).expect("url").as_str();
141 let url = percent_decode_url(raw_url);
142 if is_external_url(&url) {
143 ref_defs.insert(label, url);
144 }
145 }
146
147 if !ref_defs.is_empty() {
148 let mut cleaned = String::with_capacity(result.len());
150 let mut last_end = 0usize;
151 for cap in REFERENCE_USAGE_RE.captures_iter(&result) {
152 let m = cap.get(0).expect("full match");
153 let label = cap.get(2).expect("label").as_str().to_lowercase();
154 if let Some(url) = ref_defs.get(&label) {
155 cleaned.push_str(&result[last_end..m.start()]);
156 let _ = write!(cleaned, "[image removed: {url}]");
157 last_end = m.end();
158 events.push(ExfiltrationEvent::MarkdownImageBlocked { url: url.clone() });
159 }
160 }
161 cleaned.push_str(&result[last_end..]);
162 result = cleaned;
163
164 let mut def_cleaned = String::with_capacity(result.len());
169 for line in result.split('\n') {
170 let mut keep = true;
171 for cap in REFERENCE_DEF_RE.captures_iter(line) {
172 let label = cap.get(1).expect("label").as_str().to_lowercase();
173 if ref_defs.contains_key(&label) {
174 keep = false;
175 break;
176 }
177 }
178 if keep {
179 def_cleaned.push_str(line);
180 def_cleaned.push('\n');
181 }
182 }
183 if !text.ends_with('\n') && def_cleaned.ends_with('\n') {
185 def_cleaned.pop();
186 }
187 result = def_cleaned;
188 }
189
190 (result, events)
191 }
192
193 #[must_use]
206 pub fn validate_tool_call(
207 &self,
208 tool_name: &str,
209 args_json: &str,
210 flagged_urls: &HashSet<String>,
211 ) -> Vec<ExfiltrationEvent> {
212 if !self.config.validate_tool_urls || flagged_urls.is_empty() {
213 return vec![];
214 }
215
216 let parsed: serde_json::Value = match serde_json::from_str(args_json) {
217 Ok(v) => v,
218 Err(_) => {
219 return Self::scan_raw_args(tool_name, args_json, flagged_urls);
221 }
222 };
223
224 let mut events = Vec::new();
225 let mut strings = Vec::new();
226 collect_strings(&parsed, &mut strings);
227
228 for s in &strings {
229 for url_match in URL_EXTRACT_RE.find_iter(s) {
230 let url = url_match.as_str();
231 if flagged_urls.contains(url) {
232 events.push(ExfiltrationEvent::SuspiciousToolUrl {
233 url: url.to_owned(),
234 tool_name: tool_name.to_owned(),
235 });
236 }
237 }
238 }
239
240 events
241 }
242
243 #[must_use]
252 pub fn should_guard_memory_write(
253 &self,
254 has_injection_flags: bool,
255 ) -> Option<ExfiltrationEvent> {
256 if !self.config.guard_memory_writes || !has_injection_flags {
257 return None;
258 }
259 Some(ExfiltrationEvent::MemoryWriteGuarded {
260 reason: "content contained injection patterns flagged by ContentSanitizer".to_owned(),
261 })
262 }
263
264 fn scan_raw_args(
267 tool_name: &str,
268 args: &str,
269 flagged_urls: &HashSet<String>,
270 ) -> Vec<ExfiltrationEvent> {
271 URL_EXTRACT_RE
272 .find_iter(args)
273 .filter(|m| flagged_urls.contains(m.as_str()))
274 .map(|m| ExfiltrationEvent::SuspiciousToolUrl {
275 url: m.as_str().to_owned(),
276 tool_name: tool_name.to_owned(),
277 })
278 .collect()
279 }
280}
281
282#[must_use]
288pub fn extract_flagged_urls(content: &str) -> HashSet<String> {
289 URL_EXTRACT_RE
290 .find_iter(content)
291 .map(|m| m.as_str().to_owned())
292 .collect()
293}
294
295fn percent_decode_url(raw: &str) -> String {
304 let mut out = String::with_capacity(raw.len());
305 let bytes = raw.as_bytes();
306 let mut i = 0;
307 while i < bytes.len() {
308 if bytes[i] == b'%'
309 && i + 2 < bytes.len()
310 && let (Some(hi), Some(lo)) = (
311 (bytes[i + 1] as char).to_digit(16),
312 (bytes[i + 2] as char).to_digit(16),
313 )
314 {
315 #[allow(clippy::cast_possible_truncation)]
317 let byte = ((hi << 4) | lo) as u8;
318 out.push(byte as char);
319 i += 3;
320 continue;
321 }
322 out.push(bytes[i] as char);
323 i += 1;
324 }
325 out
326}
327
328fn is_external_url(url: &str) -> bool {
329 url.starts_with("http://") || url.starts_with("https://")
330}
331
332fn collect_strings<'a>(value: &'a serde_json::Value, out: &mut Vec<&'a str>) {
334 match value {
335 serde_json::Value::String(s) => out.push(s.as_str()),
336 serde_json::Value::Array(arr) => {
337 for v in arr {
338 collect_strings(v, out);
339 }
340 }
341 serde_json::Value::Object(map) => {
342 for v in map.values() {
343 collect_strings(v, out);
344 }
345 }
346 _ => {}
347 }
348}
349
350#[cfg(test)]
355mod tests {
356 use super::*;
357
358 fn guard() -> ExfiltrationGuard {
359 ExfiltrationGuard::new(ExfiltrationGuardConfig::default())
360 }
361
362 fn guard_disabled() -> ExfiltrationGuard {
363 ExfiltrationGuard::new(ExfiltrationGuardConfig {
364 block_markdown_images: false,
365 validate_tool_urls: false,
366 guard_memory_writes: false,
367 })
368 }
369
370 #[test]
373 fn strips_external_inline_image() {
374 let (cleaned, events) =
375 guard().scan_output("Before  after");
376 assert_eq!(
377 cleaned,
378 "Before [image removed: https://evil.com/p.gif] after"
379 );
380 assert_eq!(events.len(), 1);
381 assert!(
382 matches!(&events[0], ExfiltrationEvent::MarkdownImageBlocked { url } if url == "https://evil.com/p.gif")
383 );
384 }
385
386 #[test]
387 fn preserves_local_image() {
388 let text = "Look:  — local";
389 let (cleaned, events) = guard().scan_output(text);
390 assert_eq!(cleaned, text);
391 assert!(events.is_empty());
392 }
393
394 #[test]
395 fn preserves_data_uri() {
396 let text = "Inline: ";
397 let (cleaned, events) = guard().scan_output(text);
398 assert_eq!(cleaned, text);
399 assert!(events.is_empty());
400 }
401
402 #[test]
403 fn strips_multiple_external_images() {
404 let text = " text ";
405 let (cleaned, events) = guard().scan_output(text);
406 assert!(
408 !cleaned.contains(",
409 "first image syntax must be removed: {cleaned}"
410 );
411 assert!(
412 !cleaned.contains(",
413 "second image syntax must be removed: {cleaned}"
414 );
415 assert_eq!(events.len(), 2);
416 }
417
418 #[test]
419 fn scan_output_noop_when_disabled() {
420 let text = "";
421 let (cleaned, events) = guard_disabled().scan_output(text);
422 assert_eq!(cleaned, text);
423 assert!(events.is_empty());
424 }
425
426 #[test]
427 fn strips_reference_style_image() {
428 let text = "Here is the image: ![alt][ref]\n[ref]: https://evil.com/track.gif\nend";
429 let (cleaned, events) = guard().scan_output(text);
430 assert!(
432 !cleaned.contains("![alt][ref]"),
433 "image usage syntax must be removed: {cleaned}"
434 );
435 assert!(
436 !cleaned.contains("[ref]:"),
437 "reference definition must be removed: {cleaned}"
438 );
439 assert!(
440 cleaned.contains("[image removed:"),
441 "replacement label must be present: {cleaned}"
442 );
443 assert!(!events.is_empty(), "must generate event");
444 }
445
446 #[test]
447 fn preserves_local_reference_image() {
448 let text = "![alt][ref]\n[ref]: ./local.png\n";
450 let (cleaned, events) = guard().scan_output(text);
451 assert_eq!(cleaned, text);
452 assert!(events.is_empty());
453 }
454
455 #[test]
456 fn decodes_percent_encoded_url_in_inline_image() {
457 let text = "";
467 let (cleaned, _events) = guard().scan_output(text);
468 assert_eq!(
470 cleaned, text,
471 "percent-encoded scheme not detected by inline regex"
472 );
473
474 let normal = "";
476 let (normal_cleaned, normal_events) = guard().scan_output(normal);
477 assert!(
478 !normal_cleaned.contains(",
479 "normal URL must be removed"
480 );
481 assert_eq!(normal_events.len(), 1);
482 }
483
484 #[test]
485 fn empty_alt_text_still_blocked() {
486 let text = "";
487 let (cleaned, events) = guard().scan_output(text);
488 assert!(
490 !cleaned.contains(",
491 "markdown image syntax must be removed: {cleaned}"
492 );
493 assert!(
494 cleaned.contains("[image removed:"),
495 "replacement label must be present: {cleaned}"
496 );
497 assert_eq!(events.len(), 1);
498 }
499
500 #[test]
503 fn detects_flagged_url_in_json_string() {
504 let mut flagged = HashSet::new();
505 flagged.insert("https://evil.com/payload".to_owned());
506 let args = r#"{"url": "https://evil.com/payload"}"#;
507 let events = guard().validate_tool_call("fetch", args, &flagged);
508 assert_eq!(events.len(), 1);
509 assert!(
510 matches!(&events[0], ExfiltrationEvent::SuspiciousToolUrl { url, tool_name }
511 if url == "https://evil.com/payload" && tool_name == "fetch")
512 );
513 }
514
515 #[test]
516 fn no_event_when_url_not_flagged() {
517 let mut flagged = HashSet::new();
518 flagged.insert("https://other.com/benign".to_owned());
519 let args = r#"{"url": "https://legitimate.com/page"}"#;
520 let events = guard().validate_tool_call("fetch", args, &flagged);
521 assert!(events.is_empty());
522 }
523
524 #[test]
525 fn validate_tool_call_noop_when_disabled() {
526 let mut flagged = HashSet::new();
527 flagged.insert("https://evil.com/x".to_owned());
528 let args = r#"{"url": "https://evil.com/x"}"#;
529 let events = guard_disabled().validate_tool_call("fetch", args, &flagged);
530 assert!(events.is_empty());
531 }
532
533 #[test]
534 fn validate_tool_call_noop_with_empty_flagged() {
535 let args = r#"{"url": "https://evil.com/x"}"#;
536 let events = guard().validate_tool_call("fetch", args, &HashSet::new());
537 assert!(events.is_empty());
538 }
539
540 #[test]
541 fn extracts_urls_from_nested_json() {
542 let mut flagged = HashSet::new();
543 flagged.insert("https://evil.com/deep".to_owned());
544 let args = r#"{"nested": {"inner": ["https://evil.com/deep"]}}"#;
545 let events = guard().validate_tool_call("tool", args, &flagged);
546 assert_eq!(events.len(), 1);
547 }
548
549 #[test]
550 fn handles_escaped_slashes_in_json() {
551 let mut flagged = HashSet::new();
554 flagged.insert("https://evil.com/path".to_owned());
555 let args = r#"{"url": "https:\/\/evil.com\/path"}"#;
557 let parsed: serde_json::Value = serde_json::from_str(args).unwrap();
558 assert_eq!(parsed["url"], "https://evil.com/path");
560 let events = guard().validate_tool_call("fetch", args, &flagged);
561 assert_eq!(events.len(), 1, "JSON-escaped URL must be caught");
562 }
563
564 #[test]
567 fn guards_when_injection_flags_set() {
568 let event = guard().should_guard_memory_write(true);
569 assert!(event.is_some());
570 assert!(matches!(
571 event.unwrap(),
572 ExfiltrationEvent::MemoryWriteGuarded { .. }
573 ));
574 }
575
576 #[test]
577 fn passes_when_no_injection_flags() {
578 let event = guard().should_guard_memory_write(false);
579 assert!(event.is_none());
580 }
581
582 #[test]
583 fn guard_memory_write_noop_when_disabled() {
584 let event = guard_disabled().should_guard_memory_write(true);
585 assert!(event.is_none());
586 }
587
588 #[test]
591 fn percent_decode_roundtrip() {
592 assert_eq!(
593 percent_decode_url("https://example.com"),
594 "https://example.com"
595 );
596 assert_eq!(
597 percent_decode_url("%68ttps://example.com"),
598 "https://example.com"
599 );
600 assert_eq!(percent_decode_url("hello%20world"), "hello world");
601 }
602
603 #[test]
606 fn extracts_urls_from_plain_text() {
607 let content = "check https://evil.com/x and https://other.com/y for details";
608 let urls = extract_flagged_urls(content);
609 assert!(urls.contains("https://evil.com/x"));
610 assert!(urls.contains("https://other.com/y"));
611 }
612}