1use std::sync::{
33 Arc,
34 atomic::{AtomicU32, Ordering},
35};
36
37use serde_json::Value as JsonValue;
38use tracing::{Instrument as _, info_span};
39use zeph_db::DbPool;
40use zeph_llm::LlmProvider;
41use zeph_llm::any::AnyProvider;
42use zeph_llm::provider::{Message, Role};
43
44use crate::agent::error::AgentError;
45
46#[derive(Debug, Clone, Copy, PartialEq, Eq)]
53pub enum ToolRiskCategory {
54 Shell,
56 FileWrite,
58 ExfilCapable,
60 Low,
62}
63
64#[derive(Debug, Clone, PartialEq, Eq)]
68pub enum ProbeVerdict {
69 Allow,
71 Deny {
74 reason: String,
76 },
77 Skip,
80}
81
82#[derive(Debug, Clone)]
88pub struct ShadowEvent {
89 pub id: i64,
91 pub session_id: String,
93 pub turn_number: u64,
95 pub event_type: String,
97 pub tool_id: Option<String>,
99 pub risk_signal: Option<String>,
101 pub risk_level: String,
103 pub probe_verdict: Option<String>,
105 pub context_summary: Option<String>,
107 pub created_at: i64,
109}
110
111pub trait SafetyProbe: Send + Sync {
127 fn evaluate<'a>(
135 &'a self,
136 tool_id: &'a str,
137 tool_args: &'a JsonValue,
138 trajectory: &'a [ShadowEvent],
139 ) -> std::pin::Pin<Box<dyn std::future::Future<Output = ProbeVerdict> + Send + 'a>>;
140}
141
142pub struct LlmSafetyProbe {
150 provider: Arc<AnyProvider>,
151 timeout_ms: u64,
152 deny_on_timeout: bool,
153}
154
155impl LlmSafetyProbe {
156 #[must_use]
164 pub fn new(provider: Arc<AnyProvider>, timeout_ms: u64, deny_on_timeout: bool) -> Self {
165 Self {
166 provider,
167 timeout_ms,
168 deny_on_timeout,
169 }
170 }
171
172 fn build_prompt(
173 tool_id: &str,
174 tool_args: &JsonValue,
175 trajectory: &[ShadowEvent],
176 ) -> Vec<Message> {
177 let context = if trajectory.is_empty() {
178 "No prior events in this session.".to_owned()
179 } else {
180 trajectory
181 .iter()
182 .map(|e| {
183 format!(
184 "[turn {}] {} | tool={} | {}",
185 e.turn_number,
186 e.event_type,
187 e.tool_id.as_deref().unwrap_or("-"),
188 e.context_summary.as_deref().unwrap_or(""),
189 )
190 })
191 .collect::<Vec<_>>()
192 .join("\n")
193 };
194
195 let args_preview = serde_json::to_string(tool_args)
196 .unwrap_or_default()
197 .chars()
198 .take(500)
199 .collect::<String>();
200
201 let system = "You are a safety evaluator for an AI agent. \
202 Evaluate whether the proposed tool call is safe to execute given the session context. \
203 Reply ONLY with a JSON object: {\"verdict\": \"allow\"} or {\"verdict\": \"deny\", \"reason\": \"...\"}. \
204 Deny only if there is clear evidence of harmful intent or a dangerous pattern. \
205 When uncertain, allow.";
206
207 let user =
208 format!("Tool: {tool_id}\nArgs: {args_preview}\n\nRecent session events:\n{context}");
209
210 vec![
211 Message::from_legacy(Role::System, system),
212 Message::from_legacy(Role::User, user),
213 ]
214 }
215
216 fn parse_verdict(response: &str) -> ProbeVerdict {
217 let start = response.find('{');
219 let end = response.rfind('}');
220 if let (Some(s), Some(e)) = (start, end)
221 && let Ok(v) = serde_json::from_str::<serde_json::Value>(&response[s..=e])
222 {
223 match v.get("verdict").and_then(|x| x.as_str()) {
224 Some("allow") => return ProbeVerdict::Allow,
225 Some("deny") => {
226 let reason = v
227 .get("reason")
228 .and_then(|r| r.as_str())
229 .unwrap_or("safety probe denied this tool call")
230 .to_owned();
231 return ProbeVerdict::Deny { reason };
232 }
233 _ => {}
234 }
235 }
236 tracing::warn!(
238 raw = %response,
239 "ShadowSentinel: probe response could not be parsed, defaulting to Allow"
240 );
241 ProbeVerdict::Allow
242 }
243}
244
245impl SafetyProbe for LlmSafetyProbe {
246 fn evaluate<'a>(
247 &'a self,
248 tool_id: &'a str,
249 tool_args: &'a JsonValue,
250 trajectory: &'a [ShadowEvent],
251 ) -> std::pin::Pin<Box<dyn std::future::Future<Output = ProbeVerdict> + Send + 'a>> {
252 let span = info_span!("security.shadow.probe", tool_id = %tool_id);
253 Box::pin(
254 async move {
255 let messages = Self::build_prompt(tool_id, tool_args, trajectory);
256 let timeout = std::time::Duration::from_millis(self.timeout_ms);
257
258 match tokio::time::timeout(timeout, self.provider.chat(&messages)).await {
259 Ok(Ok(response)) => Self::parse_verdict(&response),
260 Ok(Err(e)) => {
261 tracing::warn!(error = %e, "ShadowSentinel: probe LLM error");
262 if self.deny_on_timeout {
263 ProbeVerdict::Deny {
264 reason: format!("probe LLM error: {e}"),
265 }
266 } else {
267 ProbeVerdict::Allow
268 }
269 }
270 Err(_) => {
271 tracing::warn!(
272 timeout_ms = self.timeout_ms,
273 "ShadowSentinel: probe timed out"
274 );
275 if self.deny_on_timeout {
276 ProbeVerdict::Deny {
277 reason: "safety probe timed out".to_owned(),
278 }
279 } else {
280 ProbeVerdict::Allow
281 }
282 }
283 }
284 }
285 .instrument(span),
286 )
287 }
288}
289
290#[derive(Clone)]
297pub struct ShadowEventStore {
298 pool: DbPool,
299}
300
301impl ShadowEventStore {
302 #[must_use]
304 pub fn new(pool: DbPool) -> Self {
305 Self { pool }
306 }
307
308 #[tracing::instrument(name = "security.shadow.record", skip_all, fields(event_type = %event.event_type))]
316 pub async fn record(&self, event: &ShadowEvent) -> Result<(), AgentError> {
317 sqlx::query(
318 "INSERT INTO safety_shadow_events \
319 (session_id, turn_number, event_type, tool_id, risk_signal, risk_level, \
320 probe_verdict, context_summary, created_at) \
321 VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)",
322 )
323 .bind(&event.session_id)
324 .bind(i64::try_from(event.turn_number).unwrap_or(i64::MAX))
325 .bind(&event.event_type)
326 .bind(&event.tool_id)
327 .bind(&event.risk_signal)
328 .bind(&event.risk_level)
329 .bind(&event.probe_verdict)
330 .bind(&event.context_summary)
331 .bind(event.created_at)
332 .execute(&self.pool)
333 .await
334 .map_err(|e| AgentError::Db(e.to_string()))?;
335
336 Ok(())
337 }
338
339 #[tracing::instrument(name = "security.shadow.get_trajectory", skip(self), fields(session_id = %session_id))]
347 pub async fn get_trajectory(
348 &self,
349 session_id: &str,
350 limit: usize,
351 ) -> Result<Vec<ShadowEvent>, AgentError> {
352 let rows = sqlx::query_as::<_, ShadowEventRow>(
353 "SELECT id, session_id, turn_number, event_type, tool_id, risk_signal, \
354 risk_level, probe_verdict, context_summary, created_at \
355 FROM safety_shadow_events \
356 WHERE session_id = ? \
357 ORDER BY created_at DESC \
358 LIMIT ?",
359 )
360 .bind(session_id)
361 .bind(i64::try_from(limit).unwrap_or(i64::MAX))
362 .fetch_all(&self.pool)
363 .await
364 .map_err(|e| AgentError::Db(e.to_string()))?;
365
366 let mut events: Vec<ShadowEvent> = rows.into_iter().map(ShadowEvent::from).collect();
368 events.reverse();
369 Ok(events)
370 }
371
372 #[tracing::instrument(name = "security.shadow.get_tool_history", skip(self), fields(tool_id = %tool_id))]
380 pub async fn get_tool_history(
381 &self,
382 tool_id: &str,
383 limit: usize,
384 ) -> Result<Vec<ShadowEvent>, AgentError> {
385 let rows = sqlx::query_as::<_, ShadowEventRow>(
386 "SELECT id, session_id, turn_number, event_type, tool_id, risk_signal, \
387 risk_level, probe_verdict, context_summary, created_at \
388 FROM safety_shadow_events \
389 WHERE tool_id = ? \
390 ORDER BY created_at DESC \
391 LIMIT ?",
392 )
393 .bind(tool_id)
394 .bind(i64::try_from(limit).unwrap_or(i64::MAX))
395 .fetch_all(&self.pool)
396 .await
397 .map_err(|e| AgentError::Db(e.to_string()))?;
398
399 Ok(rows.into_iter().map(ShadowEvent::from).collect())
400 }
401}
402
403#[derive(sqlx::FromRow)]
405struct ShadowEventRow {
406 id: i64,
407 session_id: String,
408 turn_number: i64,
409 event_type: String,
410 tool_id: Option<String>,
411 risk_signal: Option<String>,
412 risk_level: String,
413 probe_verdict: Option<String>,
414 context_summary: Option<String>,
415 created_at: i64,
416}
417
418impl From<ShadowEventRow> for ShadowEvent {
419 fn from(r: ShadowEventRow) -> Self {
420 Self {
421 id: r.id,
422 session_id: r.session_id,
423 turn_number: u64::try_from(r.turn_number).unwrap_or(0),
424 event_type: r.event_type,
425 tool_id: r.tool_id,
426 risk_signal: r.risk_signal,
427 risk_level: r.risk_level,
428 probe_verdict: r.probe_verdict,
429 context_summary: r.context_summary,
430 created_at: r.created_at,
431 }
432 }
433}
434
435pub struct ShadowSentinel {
454 store: ShadowEventStore,
455 probe: Box<dyn SafetyProbe>,
456 config: zeph_config::ShadowSentinelConfig,
457 probes_this_turn: AtomicU32,
460 session_id: String,
461}
462
463impl ShadowSentinel {
464 #[must_use]
473 pub fn new(
474 store: ShadowEventStore,
475 probe: Box<dyn SafetyProbe>,
476 config: zeph_config::ShadowSentinelConfig,
477 session_id: impl Into<String>,
478 ) -> Self {
479 Self {
480 store,
481 probe,
482 config,
483 probes_this_turn: AtomicU32::new(0),
484 session_id: session_id.into(),
485 }
486 }
487
488 #[must_use]
494 pub fn classify_tool(&self, qualified_tool_id: &str) -> ToolRiskCategory {
495 if qualified_tool_id == "builtin:shell"
497 || qualified_tool_id == "builtin:bash"
498 || qualified_tool_id.starts_with("builtin:shell")
499 || qualified_tool_id == "bash"
500 || qualified_tool_id == "shell"
501 || qualified_tool_id == "sh"
502 {
503 return ToolRiskCategory::Shell;
504 }
505 if qualified_tool_id == "builtin:write"
506 || qualified_tool_id == "builtin:edit"
507 || qualified_tool_id == "builtin:delete"
508 || qualified_tool_id == "write"
509 || qualified_tool_id == "edit"
510 || qualified_tool_id == "delete"
511 {
512 return ToolRiskCategory::FileWrite;
513 }
514
515 for pattern in &self.config.probe_patterns {
517 if glob_matches(pattern, qualified_tool_id) {
518 if pattern.contains("shell") || pattern.contains("exec") {
520 return ToolRiskCategory::Shell;
521 }
522 if pattern.contains("write") || pattern.contains("edit") || pattern.contains("file")
523 {
524 if qualified_tool_id.starts_with("mcp:") {
525 return ToolRiskCategory::ExfilCapable;
526 }
527 return ToolRiskCategory::FileWrite;
528 }
529 return ToolRiskCategory::ExfilCapable;
530 }
531 }
532
533 ToolRiskCategory::Low
534 }
535
536 #[tracing::instrument(name = "security.shadow.check", skip(self, tool_args), fields(tool_id = %qualified_tool_id))]
550 pub async fn check_tool_call(
551 &self,
552 qualified_tool_id: &str,
553 tool_args: &JsonValue,
554 turn_number: u64,
555 current_risk_level: &str,
556 ) -> ProbeVerdict {
557 if !self.config.enabled {
558 return ProbeVerdict::Skip;
559 }
560
561 let category = self.classify_tool(qualified_tool_id);
562 if category == ToolRiskCategory::Low {
563 return ProbeVerdict::Skip;
564 }
565
566 let count = self.probes_this_turn.fetch_add(1, Ordering::Relaxed);
568 let max_probes = u32::try_from(self.config.max_probes_per_turn).unwrap_or(u32::MAX);
569 if count >= max_probes {
570 self.probes_this_turn.fetch_sub(1, Ordering::Relaxed);
572 tracing::debug!(
573 max = self.config.max_probes_per_turn,
574 "ShadowSentinel: probe budget exhausted for this turn, skipping"
575 );
576 return ProbeVerdict::Skip;
577 }
578
579 let trajectory = match self
583 .store
584 .get_trajectory(&self.session_id, self.config.max_context_events)
585 .await
586 {
587 Ok(t) => t
588 .into_iter()
589 .filter(|e| e.event_type != "probe_result")
590 .collect(),
591 Err(e) => {
592 tracing::warn!(error = %e, "ShadowSentinel: failed to load trajectory, proceeding without context");
593 vec![]
594 }
595 };
596
597 let verdict = self
598 .probe
599 .evaluate(qualified_tool_id, tool_args, &trajectory)
600 .await;
601
602 let probe_verdict_str = match &verdict {
604 ProbeVerdict::Allow => "allow",
605 ProbeVerdict::Deny { .. } => "deny",
606 ProbeVerdict::Skip => "skip",
607 };
608 let summary = match &verdict {
609 ProbeVerdict::Deny { reason } => {
610 format!("probe denied: {}", &reason[..reason.len().min(120)])
611 }
612 ProbeVerdict::Allow => format!("probe allowed {qualified_tool_id}"),
613 ProbeVerdict::Skip => format!("probe skipped {qualified_tool_id}"),
614 };
615 let event = ShadowEvent {
616 id: 0,
617 session_id: self.session_id.clone(),
618 turn_number,
619 event_type: "probe_result".to_owned(),
620 tool_id: Some(qualified_tool_id.to_owned()),
621 risk_signal: None,
622 risk_level: current_risk_level.to_owned(),
623 probe_verdict: Some(probe_verdict_str.to_owned()),
624 context_summary: Some(summary),
625 created_at: unix_now(),
626 };
627 let store = self.store.clone();
628 tokio::spawn(async move {
629 if let Err(e) = store.record(&event).await {
630 tracing::warn!(error = %e, "ShadowSentinel: failed to persist probe result");
631 }
632 });
633
634 verdict
635 }
636
637 pub fn record_tool_event(
641 &self,
642 qualified_tool_id: &str,
643 turn_number: u64,
644 risk_level: &str,
645 context_summary: &str,
646 ) {
647 if !self.config.enabled {
648 return;
649 }
650 let event = ShadowEvent {
651 id: 0,
652 session_id: self.session_id.clone(),
653 turn_number,
654 event_type: "tool_call".to_owned(),
655 tool_id: Some(qualified_tool_id.to_owned()),
656 risk_signal: None,
657 risk_level: risk_level.to_owned(),
658 probe_verdict: None,
659 context_summary: Some(context_summary.chars().take(250).collect()),
660 created_at: unix_now(),
661 };
662 let store = self.store.clone();
663 tokio::spawn(async move {
664 if let Err(e) = store.record(&event).await {
665 tracing::warn!(error = %e, "ShadowSentinel: failed to persist tool event");
666 }
667 });
668 }
669
670 pub fn advance_turn(&self) {
675 self.probes_this_turn.store(0, Ordering::Release);
676 }
677}
678
679fn unix_now() -> i64 {
683 std::time::SystemTime::now()
684 .duration_since(std::time::UNIX_EPOCH)
685 .ok()
686 .and_then(|d| i64::try_from(d.as_secs()).ok())
687 .unwrap_or(0)
688}
689
690fn glob_matches(pattern: &str, value: &str) -> bool {
693 if pattern == "*" {
694 return true;
695 }
696 let parts: Vec<&str> = pattern.split('*').collect();
698 if parts.len() == 1 {
699 return pattern == value;
700 }
701 let mut remaining = value;
702 for (i, part) in parts.iter().enumerate() {
703 if part.is_empty() {
704 continue;
705 }
706 if i == 0 {
707 if !remaining.starts_with(part) {
708 return false;
709 }
710 remaining = &remaining[part.len()..];
711 } else if i == parts.len() - 1 {
712 return remaining.ends_with(part);
713 } else if let Some(pos) = remaining.find(part) {
714 remaining = &remaining[pos + part.len()..];
715 } else {
716 return false;
717 }
718 }
719 true
720}
721
722#[cfg(test)]
727mod tests {
728 use super::*;
729
730 #[tokio::test]
731 async fn classify_builtin_shell_is_shell_risk() {
732 let config = zeph_config::ShadowSentinelConfig::default();
733 let sentinel = make_test_sentinel(config).await;
734 assert_eq!(
735 sentinel.classify_tool("builtin:shell"),
736 ToolRiskCategory::Shell
737 );
738 assert_eq!(
739 sentinel.classify_tool("builtin:bash"),
740 ToolRiskCategory::Shell
741 );
742 }
743
744 #[tokio::test]
745 async fn classify_builtin_write_is_file_write_risk() {
746 let config = zeph_config::ShadowSentinelConfig::default();
747 let sentinel = make_test_sentinel(config).await;
748 assert_eq!(
749 sentinel.classify_tool("builtin:write"),
750 ToolRiskCategory::FileWrite
751 );
752 assert_eq!(
753 sentinel.classify_tool("builtin:edit"),
754 ToolRiskCategory::FileWrite
755 );
756 }
757
758 #[tokio::test]
759 async fn classify_low_risk_returns_low() {
760 let config = zeph_config::ShadowSentinelConfig::default();
761 let sentinel = make_test_sentinel(config).await;
762 assert_eq!(
763 sentinel.classify_tool("builtin:read"),
764 ToolRiskCategory::Low
765 );
766 assert_eq!(
767 sentinel.classify_tool("builtin:search"),
768 ToolRiskCategory::Low
769 );
770 }
771
772 #[tokio::test]
773 async fn classify_bare_shell_names_are_shell_risk() {
774 let config = zeph_config::ShadowSentinelConfig::default();
775 let sentinel = make_test_sentinel(config).await;
776 assert_eq!(sentinel.classify_tool("bash"), ToolRiskCategory::Shell);
777 assert_eq!(sentinel.classify_tool("shell"), ToolRiskCategory::Shell);
778 assert_eq!(sentinel.classify_tool("sh"), ToolRiskCategory::Shell);
779 }
780
781 #[tokio::test]
782 async fn classify_bare_file_write_names_are_file_write_risk() {
783 let config = zeph_config::ShadowSentinelConfig::default();
784 let sentinel = make_test_sentinel(config).await;
785 assert_eq!(sentinel.classify_tool("write"), ToolRiskCategory::FileWrite);
786 assert_eq!(sentinel.classify_tool("edit"), ToolRiskCategory::FileWrite);
787 assert_eq!(
788 sentinel.classify_tool("delete"),
789 ToolRiskCategory::FileWrite
790 );
791 }
792
793 #[tokio::test]
794 async fn advance_turn_resets_counter() {
795 let config = zeph_config::ShadowSentinelConfig::default();
796 let sentinel = make_test_sentinel(config).await;
797 sentinel.probes_this_turn.store(3, Ordering::Relaxed);
798 sentinel.advance_turn();
799 assert_eq!(sentinel.probes_this_turn.load(Ordering::Relaxed), 0);
800 }
801
802 #[test]
803 fn glob_matches_star_wildcard() {
804 assert!(glob_matches("mcp:*/file_*", "mcp:myserver/file_read"));
805 assert!(glob_matches("mcp:*/file_*", "mcp:other/file_write"));
806 assert!(!glob_matches("mcp:*/file_*", "builtin:shell"));
807 }
808
809 #[test]
810 fn glob_matches_exact() {
811 assert!(glob_matches("builtin:shell", "builtin:shell"));
812 assert!(!glob_matches("builtin:shell", "builtin:write"));
813 }
814
815 #[test]
816 fn parse_verdict_allow() {
817 let v = LlmSafetyProbe::parse_verdict(r#"{"verdict": "allow"}"#);
818 assert_eq!(v, ProbeVerdict::Allow);
819 }
820
821 #[test]
822 fn parse_verdict_deny_with_reason() {
823 let v =
824 LlmSafetyProbe::parse_verdict(r#"{"verdict": "deny", "reason": "suspicious pattern"}"#);
825 assert_eq!(
826 v,
827 ProbeVerdict::Deny {
828 reason: "suspicious pattern".to_owned()
829 }
830 );
831 }
832
833 #[test]
834 fn parse_verdict_unparseable_allows() {
835 let v = LlmSafetyProbe::parse_verdict("I think this is fine");
836 assert_eq!(v, ProbeVerdict::Allow);
837 }
838
839 #[tokio::test]
840 async fn check_tool_call_skips_after_budget_exhausted() {
841 let config = zeph_config::ShadowSentinelConfig {
842 enabled: true,
843 max_probes_per_turn: 2,
844 ..zeph_config::ShadowSentinelConfig::default()
845 };
846 let sentinel = make_test_sentinel(config).await;
847
848 let args = serde_json::Value::Object(serde_json::Map::new());
850 let v1 = sentinel
851 .check_tool_call("builtin:shell", &args, 1, "calm")
852 .await;
853 let v2 = sentinel
854 .check_tool_call("builtin:shell", &args, 1, "calm")
855 .await;
856 assert_ne!(v1, ProbeVerdict::Skip, "first call within budget");
857 assert_ne!(v2, ProbeVerdict::Skip, "second call within budget");
858
859 let v3 = sentinel
861 .check_tool_call("builtin:shell", &args, 1, "calm")
862 .await;
863 assert_eq!(
864 v3,
865 ProbeVerdict::Skip,
866 "third call must be skipped (budget exhausted)"
867 );
868 }
869
870 #[tokio::test]
871 async fn check_tool_call_returns_skip_when_disabled() {
872 let config = zeph_config::ShadowSentinelConfig {
873 enabled: false,
874 ..zeph_config::ShadowSentinelConfig::default()
875 };
876 let sentinel = make_test_sentinel(config).await;
877 let args = serde_json::Value::Object(serde_json::Map::new());
878 let verdict = sentinel
879 .check_tool_call("builtin:shell", &args, 1, "calm")
880 .await;
881 assert_eq!(
882 verdict,
883 ProbeVerdict::Skip,
884 "disabled sentinel must always return Skip without calling the probe"
885 );
886 }
887
888 async fn make_test_sentinel(config: zeph_config::ShadowSentinelConfig) -> ShadowSentinel {
893 struct NoopProbe;
894 impl SafetyProbe for NoopProbe {
895 fn evaluate<'a>(
896 &'a self,
897 _: &'a str,
898 _: &'a JsonValue,
899 _: &'a [ShadowEvent],
900 ) -> std::pin::Pin<Box<dyn std::future::Future<Output = ProbeVerdict> + Send + 'a>>
901 {
902 Box::pin(async { ProbeVerdict::Allow })
903 }
904 }
905 let pool = sqlx::sqlite::SqlitePoolOptions::new()
906 .connect("sqlite::memory:")
907 .await
908 .expect("in-memory SQLite pool");
909 let store = ShadowEventStore::new(pool);
910 ShadowSentinel::new(store, Box::new(NoopProbe), config, "test-session")
911 }
912}