1use std::sync::{
33 Arc,
34 atomic::{AtomicU32, Ordering},
35};
36use tokio::sync::Mutex;
37use tokio::task::JoinSet;
38
39use serde_json::Value as JsonValue;
40use tracing::{Instrument as _, info_span};
41use zeph_db::DbPool;
42use zeph_llm::LlmProvider;
43use zeph_llm::any::AnyProvider;
44use zeph_llm::provider::{Message, Role};
45
46use zeph_common::SessionId;
47
48use crate::agent::error::AgentError;
49
50#[derive(Debug, Clone, Copy, PartialEq, Eq)]
57#[non_exhaustive]
58pub enum ToolRiskCategory {
59 Shell,
61 FileWrite,
63 ExfilCapable,
65 Low,
67}
68
69#[derive(Debug, Clone, PartialEq, Eq)]
73#[non_exhaustive]
74pub enum ProbeVerdict {
75 Allow,
77 Deny {
80 reason: String,
82 },
83 Skip,
86}
87
88#[derive(Debug, Clone)]
94pub struct SentinelEvent {
95 pub id: i64,
97 pub session_id: SessionId,
99 pub turn_number: u64,
101 pub event_type: String,
103 pub tool_id: Option<String>,
105 pub risk_signal: Option<String>,
107 pub risk_level: String,
109 pub probe_verdict: Option<String>,
111 pub context_summary: Option<String>,
113 pub created_at: i64,
115}
116
117pub trait SafetyProbe: Send + Sync {
133 fn evaluate<'a>(
141 &'a self,
142 tool_id: &'a str,
143 tool_args: &'a JsonValue,
144 trajectory: &'a [SentinelEvent],
145 ) -> std::pin::Pin<Box<dyn std::future::Future<Output = ProbeVerdict> + Send + 'a>>;
146}
147
148pub struct LlmSafetyProbe {
156 provider: Arc<AnyProvider>,
157 timeout_ms: u64,
158 deny_on_timeout: bool,
159}
160
161impl LlmSafetyProbe {
162 #[must_use]
170 pub fn new(provider: Arc<AnyProvider>, timeout_ms: u64, deny_on_timeout: bool) -> Self {
171 Self {
172 provider,
173 timeout_ms,
174 deny_on_timeout,
175 }
176 }
177
178 fn build_prompt(
179 tool_id: &str,
180 tool_args: &JsonValue,
181 trajectory: &[SentinelEvent],
182 ) -> Vec<Message> {
183 let context = if trajectory.is_empty() {
184 "No prior events in this session.".to_owned()
185 } else {
186 trajectory
187 .iter()
188 .map(|e| {
189 format!(
190 "[turn {}] {} | tool={} | {}",
191 e.turn_number,
192 e.event_type,
193 e.tool_id.as_deref().unwrap_or("-"),
194 e.context_summary.as_deref().unwrap_or(""),
195 )
196 })
197 .collect::<Vec<_>>()
198 .join("\n")
199 };
200
201 let args_preview = serde_json::to_string(tool_args)
202 .unwrap_or_default()
203 .chars()
204 .take(500)
205 .collect::<String>();
206
207 let system = "You are a safety evaluator for an AI agent. \
208 Evaluate whether the proposed tool call is safe to execute given the session context. \
209 Reply ONLY with a JSON object: {\"verdict\": \"allow\"} or {\"verdict\": \"deny\", \"reason\": \"...\"}. \
210 Deny only if there is clear evidence of harmful intent or a dangerous pattern. \
211 When uncertain, allow.";
212
213 let user =
214 format!("Tool: {tool_id}\nArgs: {args_preview}\n\nRecent session events:\n{context}");
215
216 vec![
217 Message::from_legacy(Role::System, system),
218 Message::from_legacy(Role::User, user),
219 ]
220 }
221
222 fn parse_verdict(response: &str) -> ProbeVerdict {
223 let start = response.find('{');
225 let end = response.rfind('}');
226 if let (Some(s), Some(e)) = (start, end)
227 && let Ok(v) = serde_json::from_str::<serde_json::Value>(&response[s..=e])
228 {
229 match v.get("verdict").and_then(|x| x.as_str()) {
230 Some("allow") => return ProbeVerdict::Allow,
231 Some("deny") => {
232 let reason = v
233 .get("reason")
234 .and_then(|r| r.as_str())
235 .unwrap_or("safety probe denied this tool call")
236 .to_owned();
237 return ProbeVerdict::Deny { reason };
238 }
239 _ => {}
240 }
241 }
242 tracing::warn!(
244 raw = %response,
245 "ShadowSentinel: probe response could not be parsed, defaulting to Allow"
246 );
247 ProbeVerdict::Allow
248 }
249}
250
251impl SafetyProbe for LlmSafetyProbe {
252 fn evaluate<'a>(
253 &'a self,
254 tool_id: &'a str,
255 tool_args: &'a JsonValue,
256 trajectory: &'a [SentinelEvent],
257 ) -> std::pin::Pin<Box<dyn std::future::Future<Output = ProbeVerdict> + Send + 'a>> {
258 let span = info_span!("security.shadow.probe", tool_id = %tool_id);
259 Box::pin(
260 async move {
261 let messages = Self::build_prompt(tool_id, tool_args, trajectory);
262 let timeout = std::time::Duration::from_millis(self.timeout_ms);
263
264 match tokio::time::timeout(timeout, self.provider.chat(&messages)).await {
265 Ok(Ok(response)) => Self::parse_verdict(&response),
266 Ok(Err(e)) => {
267 tracing::warn!(error = %e, "ShadowSentinel: probe LLM error");
268 if self.deny_on_timeout {
269 ProbeVerdict::Deny {
270 reason: format!("probe LLM error: {e}"),
271 }
272 } else {
273 ProbeVerdict::Allow
274 }
275 }
276 Err(_) => {
277 tracing::warn!(
278 timeout_ms = self.timeout_ms,
279 "ShadowSentinel: probe timed out"
280 );
281 if self.deny_on_timeout {
282 ProbeVerdict::Deny {
283 reason: "safety probe timed out".to_owned(),
284 }
285 } else {
286 ProbeVerdict::Allow
287 }
288 }
289 }
290 }
291 .instrument(span),
292 )
293 }
294}
295
296#[derive(Clone)]
303pub struct ShadowEventStore {
304 pool: DbPool,
305}
306
307impl ShadowEventStore {
308 #[must_use]
310 pub fn new(pool: DbPool) -> Self {
311 Self { pool }
312 }
313
314 #[tracing::instrument(name = "security.shadow.record", skip_all, fields(event_type = %event.event_type))]
322 pub async fn record(&self, event: &SentinelEvent) -> Result<(), AgentError> {
323 sqlx::query(
324 "INSERT INTO safety_shadow_events \
325 (session_id, turn_number, event_type, tool_id, risk_signal, risk_level, \
326 probe_verdict, context_summary, created_at) \
327 VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)",
328 )
329 .bind(event.session_id.as_str())
330 .bind(i64::try_from(event.turn_number).unwrap_or(i64::MAX))
331 .bind(&event.event_type)
332 .bind(&event.tool_id)
333 .bind(&event.risk_signal)
334 .bind(&event.risk_level)
335 .bind(&event.probe_verdict)
336 .bind(&event.context_summary)
337 .bind(event.created_at)
338 .execute(&self.pool)
339 .await
340 .map_err(|e| AgentError::Db(e.into()))?;
341
342 Ok(())
343 }
344
345 #[tracing::instrument(name = "security.shadow.get_trajectory", skip(self), fields(session_id = %session_id))]
353 pub async fn get_trajectory(
354 &self,
355 session_id: &str,
356 limit: usize,
357 ) -> Result<Vec<SentinelEvent>, AgentError> {
358 let rows = sqlx::query_as::<_, ShadowEventRow>(
359 "SELECT id, session_id, turn_number, event_type, tool_id, risk_signal, \
360 risk_level, probe_verdict, context_summary, created_at \
361 FROM safety_shadow_events \
362 WHERE session_id = ? \
363 ORDER BY created_at DESC \
364 LIMIT ?",
365 )
366 .bind(session_id)
367 .bind(i64::try_from(limit).unwrap_or(i64::MAX))
368 .fetch_all(&self.pool)
369 .await
370 .map_err(|e| AgentError::Db(e.into()))?;
371
372 let mut events: Vec<SentinelEvent> = rows.into_iter().map(SentinelEvent::from).collect();
374 events.reverse();
375 Ok(events)
376 }
377
378 #[tracing::instrument(name = "security.shadow.get_tool_history", skip(self), fields(tool_id = %tool_id))]
386 pub async fn get_tool_history(
387 &self,
388 tool_id: &str,
389 limit: usize,
390 ) -> Result<Vec<SentinelEvent>, AgentError> {
391 let rows = sqlx::query_as::<_, ShadowEventRow>(
392 "SELECT id, session_id, turn_number, event_type, tool_id, risk_signal, \
393 risk_level, probe_verdict, context_summary, created_at \
394 FROM safety_shadow_events \
395 WHERE tool_id = ? \
396 ORDER BY created_at DESC \
397 LIMIT ?",
398 )
399 .bind(tool_id)
400 .bind(i64::try_from(limit).unwrap_or(i64::MAX))
401 .fetch_all(&self.pool)
402 .await
403 .map_err(|e| AgentError::Db(e.into()))?;
404
405 Ok(rows.into_iter().map(SentinelEvent::from).collect())
406 }
407}
408
409#[derive(sqlx::FromRow)]
411struct ShadowEventRow {
412 id: i64,
413 session_id: String,
414 turn_number: i64,
415 event_type: String,
416 tool_id: Option<String>,
417 risk_signal: Option<String>,
418 risk_level: String,
419 probe_verdict: Option<String>,
420 context_summary: Option<String>,
421 created_at: i64,
422}
423
424impl From<ShadowEventRow> for SentinelEvent {
425 fn from(r: ShadowEventRow) -> Self {
426 Self {
427 id: r.id,
428 session_id: SessionId::new(r.session_id),
429 turn_number: u64::try_from(r.turn_number).unwrap_or(0),
430 event_type: r.event_type,
431 tool_id: r.tool_id,
432 risk_signal: r.risk_signal,
433 risk_level: r.risk_level,
434 probe_verdict: r.probe_verdict,
435 context_summary: r.context_summary,
436 created_at: r.created_at,
437 }
438 }
439}
440
441const MAX_PENDING_WRITES: usize = 32;
449
450pub struct ShadowSentinel {
468 store: ShadowEventStore,
469 probe: Box<dyn SafetyProbe>,
470 config: zeph_config::ShadowSentinelConfig,
471 probes_this_turn: AtomicU32,
474 session_id: SessionId,
475 pending_writes: Mutex<JoinSet<()>>,
478}
479
480impl ShadowSentinel {
481 #[must_use]
490 pub fn new(
491 store: ShadowEventStore,
492 probe: Box<dyn SafetyProbe>,
493 config: zeph_config::ShadowSentinelConfig,
494 session_id: impl Into<SessionId>,
495 ) -> Self {
496 Self {
497 store,
498 probe,
499 config,
500 probes_this_turn: AtomicU32::new(0),
501 session_id: session_id.into(),
502 pending_writes: Mutex::new(JoinSet::new()),
503 }
504 }
505
506 #[must_use]
512 pub fn classify_tool(&self, qualified_tool_id: &str) -> ToolRiskCategory {
513 if qualified_tool_id == "builtin:shell"
515 || qualified_tool_id == "builtin:bash"
516 || qualified_tool_id.starts_with("builtin:shell")
517 || qualified_tool_id == "bash"
518 || qualified_tool_id == "shell"
519 || qualified_tool_id == "sh"
520 {
521 return ToolRiskCategory::Shell;
522 }
523 if qualified_tool_id == "builtin:write"
524 || qualified_tool_id == "builtin:edit"
525 || qualified_tool_id == "builtin:delete"
526 || qualified_tool_id == "write"
527 || qualified_tool_id == "edit"
528 || qualified_tool_id == "delete"
529 {
530 return ToolRiskCategory::FileWrite;
531 }
532
533 for pattern in &self.config.probe_patterns {
535 if glob_matches(pattern, qualified_tool_id) {
536 if pattern.contains("shell") || pattern.contains("exec") {
538 return ToolRiskCategory::Shell;
539 }
540 if pattern.contains("write") || pattern.contains("edit") || pattern.contains("file")
541 {
542 if qualified_tool_id.starts_with("mcp:") {
543 return ToolRiskCategory::ExfilCapable;
544 }
545 return ToolRiskCategory::FileWrite;
546 }
547 return ToolRiskCategory::ExfilCapable;
548 }
549 }
550
551 ToolRiskCategory::Low
552 }
553
554 #[tracing::instrument(name = "security.shadow.check", skip(self, tool_args), fields(tool_id = %qualified_tool_id))]
568 pub async fn check_tool_call(
569 &self,
570 qualified_tool_id: &str,
571 tool_args: &JsonValue,
572 turn_number: u64,
573 current_risk_level: &str,
574 ) -> ProbeVerdict {
575 if !self.config.enabled {
576 return ProbeVerdict::Skip;
577 }
578
579 let category = self.classify_tool(qualified_tool_id);
580 if category == ToolRiskCategory::Low {
581 return ProbeVerdict::Skip;
582 }
583
584 let count = self.probes_this_turn.fetch_add(1, Ordering::Relaxed);
586 let max_probes = u32::try_from(self.config.max_probes_per_turn).unwrap_or(u32::MAX);
587 if count >= max_probes {
588 self.probes_this_turn.fetch_sub(1, Ordering::Relaxed);
590 tracing::debug!(
591 max = self.config.max_probes_per_turn,
592 "ShadowSentinel: probe budget exhausted for this turn, skipping"
593 );
594 return ProbeVerdict::Skip;
595 }
596
597 let trajectory = match self
601 .store
602 .get_trajectory(&self.session_id, self.config.max_context_events)
603 .await
604 {
605 Ok(t) => t
606 .into_iter()
607 .filter(|e| e.event_type != "probe_result")
608 .collect(),
609 Err(e) => {
610 tracing::warn!(error = %e, "ShadowSentinel: failed to load trajectory, proceeding without context");
611 vec![]
612 }
613 };
614
615 let verdict = self
616 .probe
617 .evaluate(qualified_tool_id, tool_args, &trajectory)
618 .await;
619
620 let probe_verdict_str = match &verdict {
622 ProbeVerdict::Allow => "allow",
623 ProbeVerdict::Deny { .. } => "deny",
624 ProbeVerdict::Skip => "skip",
625 };
626 let summary = match &verdict {
627 ProbeVerdict::Deny { reason } => {
628 format!("probe denied: {}", &reason[..reason.len().min(120)])
629 }
630 ProbeVerdict::Allow => format!("probe allowed {qualified_tool_id}"),
631 ProbeVerdict::Skip => format!("probe skipped {qualified_tool_id}"),
632 };
633 let event = SentinelEvent {
634 id: 0,
635 session_id: self.session_id.clone(),
636 turn_number,
637 event_type: "probe_result".to_owned(),
638 tool_id: Some(qualified_tool_id.to_owned()),
639 risk_signal: None,
640 risk_level: current_risk_level.to_owned(),
641 probe_verdict: Some(probe_verdict_str.to_owned()),
642 context_summary: Some(summary),
643 created_at: unix_now(),
644 };
645 let store = self.store.clone();
646 self.spawn_persist(async move {
647 if let Err(e) = store.record(&event).await {
648 tracing::warn!(error = %e, "ShadowSentinel: failed to persist probe result");
649 }
650 })
651 .await;
652
653 verdict
654 }
655
656 pub async fn record_tool_event(
660 &self,
661 qualified_tool_id: &str,
662 turn_number: u64,
663 risk_level: &str,
664 context_summary: &str,
665 ) {
666 if !self.config.enabled {
667 return;
668 }
669 let event = SentinelEvent {
670 id: 0,
671 session_id: self.session_id.clone(),
672 turn_number,
673 event_type: "tool_call".to_owned(),
674 tool_id: Some(qualified_tool_id.to_owned()),
675 risk_signal: None,
676 risk_level: risk_level.to_owned(),
677 probe_verdict: None,
678 context_summary: Some(context_summary.chars().take(250).collect()),
679 created_at: unix_now(),
680 };
681 let store = self.store.clone();
682 self.spawn_persist(async move {
683 if let Err(e) = store.record(&event).await {
684 tracing::warn!(error = %e, "ShadowSentinel: failed to persist tool event");
685 }
686 })
687 .await;
688 }
689
690 pub async fn drain_pending(&self) {
695 let mut set = self.pending_writes.lock().await;
696 while set.join_next().await.is_some() {}
697 }
698
699 async fn spawn_persist<F>(&self, fut: F)
705 where
706 F: std::future::Future<Output = ()> + Send + 'static,
707 {
708 let mut set = self.pending_writes.lock().await;
709 while set.try_join_next().is_some() {}
712 if set.len() < MAX_PENDING_WRITES {
713 set.spawn(fut);
714 } else {
715 tracing::debug!(
716 max = MAX_PENDING_WRITES,
717 "ShadowSentinel: pending_writes at capacity, skipping persist"
718 );
719 }
720 }
721
722 pub fn advance_turn(&self) {
727 self.probes_this_turn.store(0, Ordering::Release);
728 }
729}
730
731fn unix_now() -> i64 {
735 std::time::SystemTime::now()
736 .duration_since(std::time::UNIX_EPOCH)
737 .ok()
738 .and_then(|d| i64::try_from(d.as_secs()).ok())
739 .unwrap_or(0)
740}
741
742fn glob_matches(pattern: &str, value: &str) -> bool {
745 if pattern == "*" {
746 return true;
747 }
748 let parts: Vec<&str> = pattern.split('*').collect();
750 if parts.len() == 1 {
751 return pattern == value;
752 }
753 let mut remaining = value;
754 for (i, part) in parts.iter().enumerate() {
755 if part.is_empty() {
756 continue;
757 }
758 if i == 0 {
759 if !remaining.starts_with(part) {
760 return false;
761 }
762 remaining = &remaining[part.len()..];
763 } else if i == parts.len() - 1 {
764 return remaining.ends_with(part);
765 } else if let Some(pos) = remaining.find(part) {
766 remaining = &remaining[pos + part.len()..];
767 } else {
768 return false;
769 }
770 }
771 true
772}
773
774#[cfg(test)]
779mod tests {
780 use super::*;
781
782 #[tokio::test]
783 async fn classify_builtin_shell_is_shell_risk() {
784 let config = zeph_config::ShadowSentinelConfig::default();
785 let sentinel = make_test_sentinel(config).await;
786 assert_eq!(
787 sentinel.classify_tool("builtin:shell"),
788 ToolRiskCategory::Shell
789 );
790 assert_eq!(
791 sentinel.classify_tool("builtin:bash"),
792 ToolRiskCategory::Shell
793 );
794 }
795
796 #[tokio::test]
797 async fn classify_builtin_write_is_file_write_risk() {
798 let config = zeph_config::ShadowSentinelConfig::default();
799 let sentinel = make_test_sentinel(config).await;
800 assert_eq!(
801 sentinel.classify_tool("builtin:write"),
802 ToolRiskCategory::FileWrite
803 );
804 assert_eq!(
805 sentinel.classify_tool("builtin:edit"),
806 ToolRiskCategory::FileWrite
807 );
808 }
809
810 #[tokio::test]
811 async fn classify_low_risk_returns_low() {
812 let config = zeph_config::ShadowSentinelConfig::default();
813 let sentinel = make_test_sentinel(config).await;
814 assert_eq!(
815 sentinel.classify_tool("builtin:read"),
816 ToolRiskCategory::Low
817 );
818 assert_eq!(
819 sentinel.classify_tool("builtin:search"),
820 ToolRiskCategory::Low
821 );
822 }
823
824 #[tokio::test]
825 async fn classify_bare_shell_names_are_shell_risk() {
826 let config = zeph_config::ShadowSentinelConfig::default();
827 let sentinel = make_test_sentinel(config).await;
828 assert_eq!(sentinel.classify_tool("bash"), ToolRiskCategory::Shell);
829 assert_eq!(sentinel.classify_tool("shell"), ToolRiskCategory::Shell);
830 assert_eq!(sentinel.classify_tool("sh"), ToolRiskCategory::Shell);
831 }
832
833 #[tokio::test]
834 async fn classify_bare_file_write_names_are_file_write_risk() {
835 let config = zeph_config::ShadowSentinelConfig::default();
836 let sentinel = make_test_sentinel(config).await;
837 assert_eq!(sentinel.classify_tool("write"), ToolRiskCategory::FileWrite);
838 assert_eq!(sentinel.classify_tool("edit"), ToolRiskCategory::FileWrite);
839 assert_eq!(
840 sentinel.classify_tool("delete"),
841 ToolRiskCategory::FileWrite
842 );
843 }
844
845 #[tokio::test]
846 async fn advance_turn_resets_counter() {
847 let config = zeph_config::ShadowSentinelConfig::default();
848 let sentinel = make_test_sentinel(config).await;
849 sentinel.probes_this_turn.store(3, Ordering::Relaxed);
850 sentinel.advance_turn();
851 assert_eq!(sentinel.probes_this_turn.load(Ordering::Relaxed), 0);
852 }
853
854 #[test]
855 fn glob_matches_star_wildcard() {
856 assert!(glob_matches("mcp:*/file_*", "mcp:myserver/file_read"));
857 assert!(glob_matches("mcp:*/file_*", "mcp:other/file_write"));
858 assert!(!glob_matches("mcp:*/file_*", "builtin:shell"));
859 }
860
861 #[test]
862 fn glob_matches_exact() {
863 assert!(glob_matches("builtin:shell", "builtin:shell"));
864 assert!(!glob_matches("builtin:shell", "builtin:write"));
865 }
866
867 #[test]
868 fn parse_verdict_allow() {
869 let v = LlmSafetyProbe::parse_verdict(r#"{"verdict": "allow"}"#);
870 assert_eq!(v, ProbeVerdict::Allow);
871 }
872
873 #[test]
874 fn parse_verdict_deny_with_reason() {
875 let v =
876 LlmSafetyProbe::parse_verdict(r#"{"verdict": "deny", "reason": "suspicious pattern"}"#);
877 assert_eq!(
878 v,
879 ProbeVerdict::Deny {
880 reason: "suspicious pattern".to_owned()
881 }
882 );
883 }
884
885 #[test]
886 fn parse_verdict_unparseable_allows() {
887 let v = LlmSafetyProbe::parse_verdict("I think this is fine");
888 assert_eq!(v, ProbeVerdict::Allow);
889 }
890
891 #[tokio::test]
892 async fn check_tool_call_skips_after_budget_exhausted() {
893 let config = zeph_config::ShadowSentinelConfig {
894 enabled: true,
895 max_probes_per_turn: 2,
896 ..zeph_config::ShadowSentinelConfig::default()
897 };
898 let sentinel = make_test_sentinel(config).await;
899
900 let args = serde_json::Value::Object(serde_json::Map::new());
902 let v1 = sentinel
903 .check_tool_call("builtin:shell", &args, 1, "calm")
904 .await;
905 let v2 = sentinel
906 .check_tool_call("builtin:shell", &args, 1, "calm")
907 .await;
908 assert_ne!(v1, ProbeVerdict::Skip, "first call within budget");
909 assert_ne!(v2, ProbeVerdict::Skip, "second call within budget");
910
911 let v3 = sentinel
913 .check_tool_call("builtin:shell", &args, 1, "calm")
914 .await;
915 assert_eq!(
916 v3,
917 ProbeVerdict::Skip,
918 "third call must be skipped (budget exhausted)"
919 );
920 }
921
922 #[tokio::test]
923 async fn check_tool_call_returns_skip_when_disabled() {
924 let config = zeph_config::ShadowSentinelConfig {
925 enabled: false,
926 ..zeph_config::ShadowSentinelConfig::default()
927 };
928 let sentinel = make_test_sentinel(config).await;
929 let args = serde_json::Value::Object(serde_json::Map::new());
930 let verdict = sentinel
931 .check_tool_call("builtin:shell", &args, 1, "calm")
932 .await;
933 assert_eq!(
934 verdict,
935 ProbeVerdict::Skip,
936 "disabled sentinel must always return Skip without calling the probe"
937 );
938 }
939
940 #[tokio::test]
944 async fn drain_pending_awaits_all_tasks() {
945 use std::sync::atomic::{AtomicU32, Ordering};
946
947 let config = zeph_config::ShadowSentinelConfig::default();
948 let sentinel = make_test_sentinel(config).await;
949
950 let counter = Arc::new(AtomicU32::new(0));
951 for _ in 0..5 {
952 let c = Arc::clone(&counter);
953 sentinel
954 .spawn_persist(async move {
955 tokio::time::sleep(std::time::Duration::from_millis(10)).await;
956 c.fetch_add(1, Ordering::Relaxed);
957 })
958 .await;
959 }
960
961 sentinel.drain_pending().await;
962
963 assert_eq!(
964 counter.load(Ordering::Relaxed),
965 5,
966 "drain_pending must join all 5 tasks before returning"
967 );
968 }
969
970 #[tokio::test]
976 async fn spawn_persist_beyond_capacity_does_not_panic() {
977 use std::sync::atomic::{AtomicU32, Ordering};
978
979 let config = zeph_config::ShadowSentinelConfig::default();
980 let sentinel = make_test_sentinel(config).await;
981 let counter = Arc::new(AtomicU32::new(0));
982
983 for _ in 0..(MAX_PENDING_WRITES * 2) {
986 let c = Arc::clone(&counter);
987 sentinel
988 .spawn_persist(async move {
989 c.fetch_add(1, Ordering::Relaxed);
990 })
991 .await;
992 }
993
994 sentinel.drain_pending().await;
995
996 let ran = counter.load(Ordering::Relaxed);
998 assert!(
999 ran >= u32::try_from(MAX_PENDING_WRITES).unwrap(),
1000 "at least MAX_PENDING_WRITES tasks must complete; ran={ran}"
1001 );
1002 }
1003
1004 async fn make_test_sentinel(config: zeph_config::ShadowSentinelConfig) -> ShadowSentinel {
1009 struct NoopProbe;
1010 impl SafetyProbe for NoopProbe {
1011 fn evaluate<'a>(
1012 &'a self,
1013 _: &'a str,
1014 _: &'a JsonValue,
1015 _: &'a [SentinelEvent],
1016 ) -> std::pin::Pin<Box<dyn std::future::Future<Output = ProbeVerdict> + Send + 'a>>
1017 {
1018 Box::pin(async { ProbeVerdict::Allow })
1019 }
1020 }
1021 let pool = sqlx::sqlite::SqlitePoolOptions::new()
1022 .connect("sqlite::memory:")
1023 .await
1024 .expect("in-memory SQLite pool");
1025 let store = ShadowEventStore::new(pool);
1026 ShadowSentinel::new(store, Box::new(NoopProbe), config, "test-session")
1027 }
1028}