Skip to main content

zeph_core/agent/
shadow_sentinel.rs

1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4//! `ShadowSentinel`: persistent safety memory stream + LLM-based pre-execution probe.
5//!
6//! Extends [`TrajectorySentinel`](crate::agent::trajectory) (Phase 1, spec 050) with:
7//!
8//! 1. **Persistent event stream**: `safety_shadow_events` table stores ALL safety-relevant
9//!    events across sessions (not limited to the last 8 turns like the in-memory sentinel).
10//! 2. **[`SafetyProbe`] trait**: before high-risk tool categories (shell, file write, exfil-
11//!    capable MCP tools), an LLM evaluates the full trajectory context and approves/denies.
12//!
13//! `ShadowSentinel` is **defence-in-depth only** — it is NOT the primary security gate.
14//! `PolicyGateExecutor` and `TrajectorySentinel` remain the primary enforcement mechanisms
15//! and continue to run regardless of probe results or timeouts.
16//!
17//! # Fail-open default
18//!
19//! `deny_on_timeout = false` (default) means a probe timeout or LLM error results in
20//! [`ProbeVerdict::Allow`]. This is correct because:
21//!
22//! - `ShadowSentinel` is defence-in-depth; policy gate still runs after it.
23//! - Failing closed on timeout would allow a `DoS`: slow context → every high-risk tool blocked.
24//! - Operators who want fail-closed can set `deny_on_timeout = true` in config.
25//!
26//! # LLM isolation invariant
27//!
28//! The probe prompt MUST NEVER include the `TrajectorySentinel` score or risk level.
29//! Exposing internal risk scores to the LLM would allow prompt injection attacks that
30//! manipulate probe verdicts by crafting tool outputs to lower the perceived risk level.
31
32use std::sync::{
33    Arc,
34    atomic::{AtomicU32, Ordering},
35};
36use tokio::sync::Mutex;
37use tokio::task::JoinSet;
38
39use serde_json::Value as JsonValue;
40use tracing::{Instrument as _, info_span};
41use zeph_db::DbPool;
42use zeph_llm::LlmProvider;
43use zeph_llm::any::AnyProvider;
44use zeph_llm::provider::{Message, Role};
45
46use zeph_common::SessionId;
47
48use crate::agent::error::AgentError;
49
50// ── Risk category ────────────────────────────────────────────────────────────
51
52/// Classifies a tool into a risk tier for probe gating.
53///
54/// Only `Shell`, `FileWrite`, and `ExfilCapable` tools trigger a safety probe.
55/// `Low` tools bypass the probe entirely, adding zero latency.
56#[derive(Debug, Clone, Copy, PartialEq, Eq)]
57#[non_exhaustive]
58pub enum ToolRiskCategory {
59    /// Shell execution — arbitrary commands, highest risk.
60    Shell,
61    /// File write or delete operations — persistent side effects.
62    FileWrite,
63    /// Network-capable MCP tools that could exfiltrate data.
64    ExfilCapable,
65    /// All other tools — probe is skipped.
66    Low,
67}
68
69// ── Probe verdict ─────────────────────────────────────────────────────────────
70
71/// Result of a `SafetyProbe` evaluation.
72#[derive(Debug, Clone, PartialEq, Eq)]
73#[non_exhaustive]
74pub enum ProbeVerdict {
75    /// Tool execution is safe to proceed.
76    Allow,
77    /// Tool execution is denied. The `reason` is LLM-generated and returned to the
78    /// agent loop as the tool result so the model can adapt its strategy.
79    Deny {
80        /// Human-readable explanation from the safety probe.
81        reason: String,
82    },
83    /// Probe was skipped — tool is not in a high-risk category, feature is disabled,
84    /// or the per-turn probe budget was exhausted.
85    Skip,
86}
87
88// ── Sentinel event ───────────────────────────────────────────────────────────
89
90/// A single probe trajectory record in the persistent safety sentinel stream.
91///
92/// Stored in `safety_shadow_events` and retrieved for cross-session probe context.
93#[derive(Debug, Clone)]
94pub struct SentinelEvent {
95    /// Database row id (0 for unsaved records).
96    pub id: i64,
97    /// Agent session identifier.
98    pub session_id: SessionId,
99    /// Turn number within the session.
100    pub turn_number: u64,
101    /// Event category: `"tool_call"`, `"tool_result"`, `"risk_signal"`, `"probe_result"`.
102    pub event_type: String,
103    /// Fully-qualified tool id for tool events, `None` for non-tool events.
104    pub tool_id: Option<String>,
105    /// Serialised risk signal variant (from `TrajectorySentinel`), if applicable.
106    pub risk_signal: Option<String>,
107    /// Risk level at the time of the event: `"calm"`, `"elevated"`, `"high"`, `"critical"`.
108    pub risk_level: String,
109    /// Probe verdict for `probe_result` events: `"allow"`, `"deny"`, `"skip"`.
110    pub probe_verdict: Option<String>,
111    /// Short human-readable summary included in the LLM probe context.
112    pub context_summary: Option<String>,
113    /// Unix timestamp (seconds) when the event was recorded.
114    pub created_at: i64,
115}
116
117// ── SafetyProbe trait ─────────────────────────────────────────────────────────
118
119/// LLM-based pre-execution safety evaluator.
120///
121/// Implementors receive the full trajectory context and the proposed tool call
122/// and return a [`ProbeVerdict`]. The probe runs BEFORE [`zeph_tools::PolicyGateExecutor`].
123///
124/// # Contract
125///
126/// - Probe timeout is mandatory (configured via `probe_timeout_ms`).
127/// - Probe failure (LLM error, timeout when `deny_on_timeout = false`) results in `Allow`.
128/// - Probe results are persisted to `safety_shadow_events` for cross-session learning.
129/// - The probe prompt MUST NOT include the sentinel score or risk level (LLM isolation).
130///
131/// Uses `Pin<Box<dyn Future>>` returns for dyn-compatibility (stored as `Box<dyn SafetyProbe>`).
132pub trait SafetyProbe: Send + Sync {
133    /// Evaluate whether the proposed tool call is safe given the trajectory context.
134    ///
135    /// # Arguments
136    ///
137    /// * `tool_id` — fully-qualified tool identifier (e.g. `"builtin:shell"`).
138    /// * `tool_args` — JSON arguments for the tool call.
139    /// * `trajectory` — recent shadow events for context (last N events from the store).
140    fn evaluate<'a>(
141        &'a self,
142        tool_id: &'a str,
143        tool_args: &'a JsonValue,
144        trajectory: &'a [SentinelEvent],
145    ) -> std::pin::Pin<Box<dyn std::future::Future<Output = ProbeVerdict> + Send + 'a>>;
146}
147
148// ── LlmSafetyProbe ───────────────────────────────────────────────────────────
149
150/// LLM-backed implementation of `SafetyProbe`.
151///
152/// Sends a structured prompt to an LLM asking it to evaluate whether the proposed
153/// tool call is safe given the trajectory. The prompt is intentionally minimal and
154/// never includes internal risk scores (LLM isolation invariant).
155pub struct LlmSafetyProbe {
156    provider: Arc<AnyProvider>,
157    timeout_ms: u64,
158    deny_on_timeout: bool,
159}
160
161impl LlmSafetyProbe {
162    /// Create a new `LlmSafetyProbe`.
163    ///
164    /// # Arguments
165    ///
166    /// * `provider` — LLM provider used for probe evaluation.
167    /// * `timeout_ms` — maximum milliseconds to wait for the LLM response.
168    /// * `deny_on_timeout` — when `true`, a timeout results in `Deny`; when `false`, `Allow`.
169    #[must_use]
170    pub fn new(provider: Arc<AnyProvider>, timeout_ms: u64, deny_on_timeout: bool) -> Self {
171        Self {
172            provider,
173            timeout_ms,
174            deny_on_timeout,
175        }
176    }
177
178    fn build_prompt(
179        tool_id: &str,
180        tool_args: &JsonValue,
181        trajectory: &[SentinelEvent],
182    ) -> Vec<Message> {
183        let context = if trajectory.is_empty() {
184            "No prior events in this session.".to_owned()
185        } else {
186            trajectory
187                .iter()
188                .map(|e| {
189                    format!(
190                        "[turn {}] {} | tool={} | {}",
191                        e.turn_number,
192                        e.event_type,
193                        e.tool_id.as_deref().unwrap_or("-"),
194                        e.context_summary.as_deref().unwrap_or(""),
195                    )
196                })
197                .collect::<Vec<_>>()
198                .join("\n")
199        };
200
201        let args_preview = serde_json::to_string(tool_args)
202            .unwrap_or_default()
203            .chars()
204            .take(500)
205            .collect::<String>();
206
207        let system = "You are a safety evaluator for an AI agent. \
208            Evaluate whether the proposed tool call is safe to execute given the session context. \
209            Reply ONLY with a JSON object: {\"verdict\": \"allow\"} or {\"verdict\": \"deny\", \"reason\": \"...\"}. \
210            Deny only if there is clear evidence of harmful intent or a dangerous pattern. \
211            When uncertain, allow.";
212
213        let user =
214            format!("Tool: {tool_id}\nArgs: {args_preview}\n\nRecent session events:\n{context}");
215
216        vec![
217            Message::from_legacy(Role::System, system),
218            Message::from_legacy(Role::User, user),
219        ]
220    }
221
222    fn parse_verdict(response: &str) -> ProbeVerdict {
223        // Try to extract JSON from the response.
224        let start = response.find('{');
225        let end = response.rfind('}');
226        if let (Some(s), Some(e)) = (start, end)
227            && let Ok(v) = serde_json::from_str::<serde_json::Value>(&response[s..=e])
228        {
229            match v.get("verdict").and_then(|x| x.as_str()) {
230                Some("allow") => return ProbeVerdict::Allow,
231                Some("deny") => {
232                    let reason = v
233                        .get("reason")
234                        .and_then(|r| r.as_str())
235                        .unwrap_or("safety probe denied this tool call")
236                        .to_owned();
237                    return ProbeVerdict::Deny { reason };
238                }
239                _ => {}
240            }
241        }
242        // Unparseable response → allow (fail-open)
243        tracing::warn!(
244            raw = %response,
245            "ShadowSentinel: probe response could not be parsed, defaulting to Allow"
246        );
247        ProbeVerdict::Allow
248    }
249}
250
251impl SafetyProbe for LlmSafetyProbe {
252    fn evaluate<'a>(
253        &'a self,
254        tool_id: &'a str,
255        tool_args: &'a JsonValue,
256        trajectory: &'a [SentinelEvent],
257    ) -> std::pin::Pin<Box<dyn std::future::Future<Output = ProbeVerdict> + Send + 'a>> {
258        let span = info_span!("security.shadow.probe", tool_id = %tool_id);
259        Box::pin(
260            async move {
261                let messages = Self::build_prompt(tool_id, tool_args, trajectory);
262                let timeout = std::time::Duration::from_millis(self.timeout_ms);
263
264                match tokio::time::timeout(timeout, self.provider.chat(&messages)).await {
265                    Ok(Ok(response)) => Self::parse_verdict(&response),
266                    Ok(Err(e)) => {
267                        tracing::warn!(error = %e, "ShadowSentinel: probe LLM error");
268                        if self.deny_on_timeout {
269                            ProbeVerdict::Deny {
270                                reason: format!("probe LLM error: {e}"),
271                            }
272                        } else {
273                            ProbeVerdict::Allow
274                        }
275                    }
276                    Err(_) => {
277                        tracing::warn!(
278                            timeout_ms = self.timeout_ms,
279                            "ShadowSentinel: probe timed out"
280                        );
281                        if self.deny_on_timeout {
282                            ProbeVerdict::Deny {
283                                reason: "safety probe timed out".to_owned(),
284                            }
285                        } else {
286                            ProbeVerdict::Allow
287                        }
288                    }
289                }
290            }
291            .instrument(span),
292        )
293    }
294}
295
296// ── ShadowEventStore ─────────────────────────────────────────────────────────
297
298/// Persistent storage for the safety shadow event stream.
299///
300/// Thin wrapper around [`DbPool`] for the `safety_shadow_events` table.
301/// Methods are `async` and return typed errors.
302#[derive(Clone)]
303pub struct ShadowEventStore {
304    pool: DbPool,
305}
306
307impl ShadowEventStore {
308    /// Create a `ShadowEventStore` backed by the given pool.
309    #[must_use]
310    pub fn new(pool: DbPool) -> Self {
311        Self { pool }
312    }
313
314    /// Persist a shadow event to the database.
315    ///
316    /// The `id` field of the event is ignored; the database assigns a new row id.
317    ///
318    /// # Errors
319    ///
320    /// Returns `AgentError` on database failure.
321    #[tracing::instrument(name = "security.shadow.record", skip_all, fields(event_type = %event.event_type))]
322    pub async fn record(&self, event: &SentinelEvent) -> Result<(), AgentError> {
323        sqlx::query(
324            "INSERT INTO safety_shadow_events \
325             (session_id, turn_number, event_type, tool_id, risk_signal, risk_level, \
326              probe_verdict, context_summary, created_at) \
327             VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)",
328        )
329        .bind(event.session_id.as_str())
330        .bind(i64::try_from(event.turn_number).unwrap_or(i64::MAX))
331        .bind(&event.event_type)
332        .bind(&event.tool_id)
333        .bind(&event.risk_signal)
334        .bind(&event.risk_level)
335        .bind(&event.probe_verdict)
336        .bind(&event.context_summary)
337        .bind(event.created_at)
338        .execute(&self.pool)
339        .await
340        .map_err(|e| AgentError::Db(e.into()))?;
341
342        Ok(())
343    }
344
345    /// Retrieve the last `limit` events for a session in ascending time order.
346    ///
347    /// Used to build the trajectory context for probe evaluation.
348    ///
349    /// # Errors
350    ///
351    /// Returns `AgentError` on database failure.
352    #[tracing::instrument(name = "security.shadow.get_trajectory", skip(self), fields(session_id = %session_id))]
353    pub async fn get_trajectory(
354        &self,
355        session_id: &str,
356        limit: usize,
357    ) -> Result<Vec<SentinelEvent>, AgentError> {
358        let rows = sqlx::query_as::<_, ShadowEventRow>(
359            "SELECT id, session_id, turn_number, event_type, tool_id, risk_signal, \
360             risk_level, probe_verdict, context_summary, created_at \
361             FROM safety_shadow_events \
362             WHERE session_id = ? \
363             ORDER BY created_at DESC \
364             LIMIT ?",
365        )
366        .bind(session_id)
367        .bind(i64::try_from(limit).unwrap_or(i64::MAX))
368        .fetch_all(&self.pool)
369        .await
370        .map_err(|e| AgentError::Db(e.into()))?;
371
372        // DB returns DESC (newest first); reverse once to get ASC (oldest first) for LLM context.
373        let mut events: Vec<SentinelEvent> = rows.into_iter().map(SentinelEvent::from).collect();
374        events.reverse();
375        Ok(events)
376    }
377
378    /// Retrieve the last `limit` events for a specific tool across all sessions.
379    ///
380    /// Used for cross-session pattern detection.
381    ///
382    /// # Errors
383    ///
384    /// Returns `AgentError` on database failure.
385    #[tracing::instrument(name = "security.shadow.get_tool_history", skip(self), fields(tool_id = %tool_id))]
386    pub async fn get_tool_history(
387        &self,
388        tool_id: &str,
389        limit: usize,
390    ) -> Result<Vec<SentinelEvent>, AgentError> {
391        let rows = sqlx::query_as::<_, ShadowEventRow>(
392            "SELECT id, session_id, turn_number, event_type, tool_id, risk_signal, \
393             risk_level, probe_verdict, context_summary, created_at \
394             FROM safety_shadow_events \
395             WHERE tool_id = ? \
396             ORDER BY created_at DESC \
397             LIMIT ?",
398        )
399        .bind(tool_id)
400        .bind(i64::try_from(limit).unwrap_or(i64::MAX))
401        .fetch_all(&self.pool)
402        .await
403        .map_err(|e| AgentError::Db(e.into()))?;
404
405        Ok(rows.into_iter().map(SentinelEvent::from).collect())
406    }
407}
408
409// Internal sqlx row type for `safety_shadow_events`.
410#[derive(sqlx::FromRow)]
411struct ShadowEventRow {
412    id: i64,
413    session_id: String,
414    turn_number: i64,
415    event_type: String,
416    tool_id: Option<String>,
417    risk_signal: Option<String>,
418    risk_level: String,
419    probe_verdict: Option<String>,
420    context_summary: Option<String>,
421    created_at: i64,
422}
423
424impl From<ShadowEventRow> for SentinelEvent {
425    fn from(r: ShadowEventRow) -> Self {
426        Self {
427            id: r.id,
428            session_id: SessionId::new(r.session_id),
429            turn_number: u64::try_from(r.turn_number).unwrap_or(0),
430            event_type: r.event_type,
431            tool_id: r.tool_id,
432            risk_signal: r.risk_signal,
433            risk_level: r.risk_level,
434            probe_verdict: r.probe_verdict,
435            context_summary: r.context_summary,
436            created_at: r.created_at,
437        }
438    }
439}
440
441// ── ShadowSentinel ────────────────────────────────────────────────────────────
442
443/// Maximum number of concurrent fire-and-forget persist tasks tracked in `pending_writes`.
444///
445/// When the set is at capacity the oldest completed tasks are reaped before spawning a new one.
446/// If the set is still full after reaping (all tasks are still running), the new spawn is skipped
447/// with a debug log — persistence is best-effort and the sentinel must never block tool dispatch.
448const MAX_PENDING_WRITES: usize = 32;
449
450/// Orchestrates the persistent safety stream and LLM pre-execution probe.
451///
452/// `ShadowSentinel` is wrapped in `Arc` and shared between `ShadowProbeExecutor` instances
453/// when tools run in parallel. All mutable state uses `AtomicU32` to allow `&self` access
454/// from concurrent tool dispatch without a `Mutex`.
455///
456/// # Turn lifecycle
457///
458/// - `advance_turn()` — call once per turn before tool execution; resets the per-turn
459///   probe counter.
460/// - `check_tool_call()` — call before each tool execution to probe high-risk calls.
461/// - `record_tool_event()` — call after tool execution to persist the event.
462/// - `drain_pending()` — call at session shutdown to await all queued persist writes.
463///
464/// # NEVER
465///
466/// Never expose the `ShadowSentinel` state or probe verdicts to LLM-visible context.
467pub struct ShadowSentinel {
468    store: ShadowEventStore,
469    probe: Box<dyn SafetyProbe>,
470    config: zeph_config::ShadowSentinelConfig,
471    /// Counter of probe calls made in the current turn. Uses `AtomicU32` so all
472    /// probe-checking methods can take `&self` even under parallel tool execution.
473    probes_this_turn: AtomicU32,
474    session_id: SessionId,
475    /// Bounded set of fire-and-forget DB persist tasks. Prevents unbounded task accumulation
476    /// and ensures panics surface at `drain_pending()` instead of being silently swallowed.
477    pending_writes: Mutex<JoinSet<()>>,
478}
479
480impl ShadowSentinel {
481    /// Create a new `ShadowSentinel`.
482    ///
483    /// # Arguments
484    ///
485    /// * `store` — persistent shadow event store.
486    /// * `probe` — safety probe implementation.
487    /// * `config` — subsystem configuration.
488    /// * `session_id` — current agent session identifier.
489    #[must_use]
490    pub fn new(
491        store: ShadowEventStore,
492        probe: Box<dyn SafetyProbe>,
493        config: zeph_config::ShadowSentinelConfig,
494        session_id: impl Into<SessionId>,
495    ) -> Self {
496        Self {
497            store,
498            probe,
499            config,
500            probes_this_turn: AtomicU32::new(0),
501            session_id: session_id.into(),
502            pending_writes: Mutex::new(JoinSet::new()),
503        }
504    }
505
506    /// Classify a fully-qualified tool id into a risk tier.
507    ///
508    /// Pattern matching is prefix/glob-based against the configured `probe_patterns`.
509    /// For efficiency, we check common built-in names first before falling back to
510    /// glob matching against the configured patterns.
511    #[must_use]
512    pub fn classify_tool(&self, qualified_tool_id: &str) -> ToolRiskCategory {
513        // Fast-path for well-known high-risk builtins.
514        if qualified_tool_id == "builtin:shell"
515            || qualified_tool_id == "builtin:bash"
516            || qualified_tool_id.starts_with("builtin:shell")
517            || qualified_tool_id == "bash"
518            || qualified_tool_id == "shell"
519            || qualified_tool_id == "sh"
520        {
521            return ToolRiskCategory::Shell;
522        }
523        if qualified_tool_id == "builtin:write"
524            || qualified_tool_id == "builtin:edit"
525            || qualified_tool_id == "builtin:delete"
526            || qualified_tool_id == "write"
527            || qualified_tool_id == "edit"
528            || qualified_tool_id == "delete"
529        {
530            return ToolRiskCategory::FileWrite;
531        }
532
533        // Glob matching against configured patterns.
534        for pattern in &self.config.probe_patterns {
535            if glob_matches(pattern, qualified_tool_id) {
536                // Classify based on the pattern name.
537                if pattern.contains("shell") || pattern.contains("exec") {
538                    return ToolRiskCategory::Shell;
539                }
540                if pattern.contains("write") || pattern.contains("edit") || pattern.contains("file")
541                {
542                    if qualified_tool_id.starts_with("mcp:") {
543                        return ToolRiskCategory::ExfilCapable;
544                    }
545                    return ToolRiskCategory::FileWrite;
546                }
547                return ToolRiskCategory::ExfilCapable;
548            }
549        }
550
551        ToolRiskCategory::Low
552    }
553
554    /// Evaluate a proposed tool call and return a probe verdict.
555    ///
556    /// Returns `ProbeVerdict::Skip` when:
557    /// - The tool is not in a high-risk category.
558    /// - The feature is disabled.
559    /// - The per-turn probe budget (`max_probes_per_turn`) is exhausted.
560    ///
561    /// This method takes `&self` so it can be called from parallel tool dispatch.
562    ///
563    /// # Errors
564    ///
565    /// Does not return errors; probe failures are handled internally (fail-open or
566    /// fail-closed depending on `deny_on_timeout`).
567    #[tracing::instrument(name = "security.shadow.check", skip(self, tool_args), fields(tool_id = %qualified_tool_id))]
568    pub async fn check_tool_call(
569        &self,
570        qualified_tool_id: &str,
571        tool_args: &JsonValue,
572        turn_number: u64,
573        current_risk_level: &str,
574    ) -> ProbeVerdict {
575        if !self.config.enabled {
576            return ProbeVerdict::Skip;
577        }
578
579        let category = self.classify_tool(qualified_tool_id);
580        if category == ToolRiskCategory::Low {
581            return ProbeVerdict::Skip;
582        }
583
584        // Check per-turn probe budget using relaxed atomics (false sharing is acceptable here).
585        let count = self.probes_this_turn.fetch_add(1, Ordering::Relaxed);
586        let max_probes = u32::try_from(self.config.max_probes_per_turn).unwrap_or(u32::MAX);
587        if count >= max_probes {
588            // Undo the increment so future fast-path checks are accurate.
589            self.probes_this_turn.fetch_sub(1, Ordering::Relaxed);
590            tracing::debug!(
591                max = self.config.max_probes_per_turn,
592                "ShadowSentinel: probe budget exhausted for this turn, skipping"
593            );
594            return ProbeVerdict::Skip;
595        }
596
597        // Load recent trajectory for probe context.
598        // Filter out probe_result events — exposing probe verdicts to the LLM would allow
599        // prompt injection attacks that craft tool outputs to manipulate perceived safety.
600        let trajectory = match self
601            .store
602            .get_trajectory(&self.session_id, self.config.max_context_events)
603            .await
604        {
605            Ok(t) => t
606                .into_iter()
607                .filter(|e| e.event_type != "probe_result")
608                .collect(),
609            Err(e) => {
610                tracing::warn!(error = %e, "ShadowSentinel: failed to load trajectory, proceeding without context");
611                vec![]
612            }
613        };
614
615        let verdict = self
616            .probe
617            .evaluate(qualified_tool_id, tool_args, &trajectory)
618            .await;
619
620        // Persist the probe result asynchronously (best-effort — never blocks tool path).
621        let probe_verdict_str = match &verdict {
622            ProbeVerdict::Allow => "allow",
623            ProbeVerdict::Deny { .. } => "deny",
624            ProbeVerdict::Skip => "skip",
625        };
626        let summary = match &verdict {
627            ProbeVerdict::Deny { reason } => {
628                format!("probe denied: {}", &reason[..reason.len().min(120)])
629            }
630            ProbeVerdict::Allow => format!("probe allowed {qualified_tool_id}"),
631            ProbeVerdict::Skip => format!("probe skipped {qualified_tool_id}"),
632        };
633        let event = SentinelEvent {
634            id: 0,
635            session_id: self.session_id.clone(),
636            turn_number,
637            event_type: "probe_result".to_owned(),
638            tool_id: Some(qualified_tool_id.to_owned()),
639            risk_signal: None,
640            risk_level: current_risk_level.to_owned(),
641            probe_verdict: Some(probe_verdict_str.to_owned()),
642            context_summary: Some(summary),
643            created_at: unix_now(),
644        };
645        let store = self.store.clone();
646        self.spawn_persist(async move {
647            if let Err(e) = store.record(&event).await {
648                tracing::warn!(error = %e, "ShadowSentinel: failed to persist probe result");
649            }
650        })
651        .await;
652
653        verdict
654    }
655
656    /// Persist a tool execution event in the shadow stream (fire-and-forget).
657    ///
658    /// Called after a tool finishes execution to maintain the trajectory for future probes.
659    pub async fn record_tool_event(
660        &self,
661        qualified_tool_id: &str,
662        turn_number: u64,
663        risk_level: &str,
664        context_summary: &str,
665    ) {
666        if !self.config.enabled {
667            return;
668        }
669        let event = SentinelEvent {
670            id: 0,
671            session_id: self.session_id.clone(),
672            turn_number,
673            event_type: "tool_call".to_owned(),
674            tool_id: Some(qualified_tool_id.to_owned()),
675            risk_signal: None,
676            risk_level: risk_level.to_owned(),
677            probe_verdict: None,
678            context_summary: Some(context_summary.chars().take(250).collect()),
679            created_at: unix_now(),
680        };
681        let store = self.store.clone();
682        self.spawn_persist(async move {
683            if let Err(e) = store.record(&event).await {
684                tracing::warn!(error = %e, "ShadowSentinel: failed to persist tool event");
685            }
686        })
687        .await;
688    }
689
690    /// Await all queued fire-and-forget persist tasks.
691    ///
692    /// Call once at session shutdown to ensure no DB writes are silently dropped.
693    /// All errors have already been logged inside each task; this method only joins the handles.
694    pub async fn drain_pending(&self) {
695        let mut set = self.pending_writes.lock().await;
696        while set.join_next().await.is_some() {}
697    }
698
699    /// Spawn a background persist task into the bounded `JoinSet`.
700    ///
701    /// Reaps completed handles before spawning to stay within `MAX_PENDING_WRITES`. If the set
702    /// is still at capacity after reaping (all tasks still running), the new task is dropped and
703    /// a debug message is emitted — persistence is best-effort and must never block the tool path.
704    async fn spawn_persist<F>(&self, fut: F)
705    where
706        F: std::future::Future<Output = ()> + Send + 'static,
707    {
708        let mut set = self.pending_writes.lock().await;
709        // Reap only already-finished handles — never block waiting for a running task.
710        // try_join_next() returns immediately if no task has completed yet.
711        while set.try_join_next().is_some() {}
712        if set.len() < MAX_PENDING_WRITES {
713            set.spawn(fut);
714        } else {
715            tracing::debug!(
716                max = MAX_PENDING_WRITES,
717                "ShadowSentinel: pending_writes at capacity, skipping persist"
718            );
719        }
720    }
721
722    /// Reset the per-turn probe counter.
723    ///
724    /// Must be called once per turn BEFORE any tool calls, alongside
725    /// `TrajectorySentinel::advance_turn()`.
726    pub fn advance_turn(&self) {
727        self.probes_this_turn.store(0, Ordering::Release);
728    }
729}
730
731// ── Helpers ───────────────────────────────────────────────────────────────────
732
733/// Returns the current Unix timestamp in seconds.
734fn unix_now() -> i64 {
735    std::time::SystemTime::now()
736        .duration_since(std::time::UNIX_EPOCH)
737        .ok()
738        .and_then(|d| i64::try_from(d.as_secs()).ok())
739        .unwrap_or(0)
740}
741
742/// Simple glob matching: `*` matches any sequence of characters except `/`.
743/// `*/` in the pattern matches any single path segment.
744fn glob_matches(pattern: &str, value: &str) -> bool {
745    if pattern == "*" {
746        return true;
747    }
748    // Split on `*` and check each segment is present in order.
749    let parts: Vec<&str> = pattern.split('*').collect();
750    if parts.len() == 1 {
751        return pattern == value;
752    }
753    let mut remaining = value;
754    for (i, part) in parts.iter().enumerate() {
755        if part.is_empty() {
756            continue;
757        }
758        if i == 0 {
759            if !remaining.starts_with(part) {
760                return false;
761            }
762            remaining = &remaining[part.len()..];
763        } else if i == parts.len() - 1 {
764            return remaining.ends_with(part);
765        } else if let Some(pos) = remaining.find(part) {
766            remaining = &remaining[pos + part.len()..];
767        } else {
768            return false;
769        }
770    }
771    true
772}
773
774// ── AgentError extension ──────────────────────────────────────────────────────
775// ShadowEventStore uses AgentError::Db — add that variant if missing.
776// (The actual variant is declared in agent/error.rs; we only reference it here.)
777
778#[cfg(test)]
779mod tests {
780    use super::*;
781
782    #[tokio::test]
783    async fn classify_builtin_shell_is_shell_risk() {
784        let config = zeph_config::ShadowSentinelConfig::default();
785        let sentinel = make_test_sentinel(config).await;
786        assert_eq!(
787            sentinel.classify_tool("builtin:shell"),
788            ToolRiskCategory::Shell
789        );
790        assert_eq!(
791            sentinel.classify_tool("builtin:bash"),
792            ToolRiskCategory::Shell
793        );
794    }
795
796    #[tokio::test]
797    async fn classify_builtin_write_is_file_write_risk() {
798        let config = zeph_config::ShadowSentinelConfig::default();
799        let sentinel = make_test_sentinel(config).await;
800        assert_eq!(
801            sentinel.classify_tool("builtin:write"),
802            ToolRiskCategory::FileWrite
803        );
804        assert_eq!(
805            sentinel.classify_tool("builtin:edit"),
806            ToolRiskCategory::FileWrite
807        );
808    }
809
810    #[tokio::test]
811    async fn classify_low_risk_returns_low() {
812        let config = zeph_config::ShadowSentinelConfig::default();
813        let sentinel = make_test_sentinel(config).await;
814        assert_eq!(
815            sentinel.classify_tool("builtin:read"),
816            ToolRiskCategory::Low
817        );
818        assert_eq!(
819            sentinel.classify_tool("builtin:search"),
820            ToolRiskCategory::Low
821        );
822    }
823
824    #[tokio::test]
825    async fn classify_bare_shell_names_are_shell_risk() {
826        let config = zeph_config::ShadowSentinelConfig::default();
827        let sentinel = make_test_sentinel(config).await;
828        assert_eq!(sentinel.classify_tool("bash"), ToolRiskCategory::Shell);
829        assert_eq!(sentinel.classify_tool("shell"), ToolRiskCategory::Shell);
830        assert_eq!(sentinel.classify_tool("sh"), ToolRiskCategory::Shell);
831    }
832
833    #[tokio::test]
834    async fn classify_bare_file_write_names_are_file_write_risk() {
835        let config = zeph_config::ShadowSentinelConfig::default();
836        let sentinel = make_test_sentinel(config).await;
837        assert_eq!(sentinel.classify_tool("write"), ToolRiskCategory::FileWrite);
838        assert_eq!(sentinel.classify_tool("edit"), ToolRiskCategory::FileWrite);
839        assert_eq!(
840            sentinel.classify_tool("delete"),
841            ToolRiskCategory::FileWrite
842        );
843    }
844
845    #[tokio::test]
846    async fn advance_turn_resets_counter() {
847        let config = zeph_config::ShadowSentinelConfig::default();
848        let sentinel = make_test_sentinel(config).await;
849        sentinel.probes_this_turn.store(3, Ordering::Relaxed);
850        sentinel.advance_turn();
851        assert_eq!(sentinel.probes_this_turn.load(Ordering::Relaxed), 0);
852    }
853
854    #[test]
855    fn glob_matches_star_wildcard() {
856        assert!(glob_matches("mcp:*/file_*", "mcp:myserver/file_read"));
857        assert!(glob_matches("mcp:*/file_*", "mcp:other/file_write"));
858        assert!(!glob_matches("mcp:*/file_*", "builtin:shell"));
859    }
860
861    #[test]
862    fn glob_matches_exact() {
863        assert!(glob_matches("builtin:shell", "builtin:shell"));
864        assert!(!glob_matches("builtin:shell", "builtin:write"));
865    }
866
867    #[test]
868    fn parse_verdict_allow() {
869        let v = LlmSafetyProbe::parse_verdict(r#"{"verdict": "allow"}"#);
870        assert_eq!(v, ProbeVerdict::Allow);
871    }
872
873    #[test]
874    fn parse_verdict_deny_with_reason() {
875        let v =
876            LlmSafetyProbe::parse_verdict(r#"{"verdict": "deny", "reason": "suspicious pattern"}"#);
877        assert_eq!(
878            v,
879            ProbeVerdict::Deny {
880                reason: "suspicious pattern".to_owned()
881            }
882        );
883    }
884
885    #[test]
886    fn parse_verdict_unparseable_allows() {
887        let v = LlmSafetyProbe::parse_verdict("I think this is fine");
888        assert_eq!(v, ProbeVerdict::Allow);
889    }
890
891    #[tokio::test]
892    async fn check_tool_call_skips_after_budget_exhausted() {
893        let config = zeph_config::ShadowSentinelConfig {
894            enabled: true,
895            max_probes_per_turn: 2,
896            ..zeph_config::ShadowSentinelConfig::default()
897        };
898        let sentinel = make_test_sentinel(config).await;
899
900        // First two calls should not be skipped (noop probe returns Allow).
901        let args = serde_json::Value::Object(serde_json::Map::new());
902        let v1 = sentinel
903            .check_tool_call("builtin:shell", &args, 1, "calm")
904            .await;
905        let v2 = sentinel
906            .check_tool_call("builtin:shell", &args, 1, "calm")
907            .await;
908        assert_ne!(v1, ProbeVerdict::Skip, "first call within budget");
909        assert_ne!(v2, ProbeVerdict::Skip, "second call within budget");
910
911        // Third call exceeds max_probes_per_turn = 2 → must skip.
912        let v3 = sentinel
913            .check_tool_call("builtin:shell", &args, 1, "calm")
914            .await;
915        assert_eq!(
916            v3,
917            ProbeVerdict::Skip,
918            "third call must be skipped (budget exhausted)"
919        );
920    }
921
922    #[tokio::test]
923    async fn check_tool_call_returns_skip_when_disabled() {
924        let config = zeph_config::ShadowSentinelConfig {
925            enabled: false,
926            ..zeph_config::ShadowSentinelConfig::default()
927        };
928        let sentinel = make_test_sentinel(config).await;
929        let args = serde_json::Value::Object(serde_json::Map::new());
930        let verdict = sentinel
931            .check_tool_call("builtin:shell", &args, 1, "calm")
932            .await;
933        assert_eq!(
934            verdict,
935            ProbeVerdict::Skip,
936            "disabled sentinel must always return Skip without calling the probe"
937        );
938    }
939
940    // ── JoinSet regression tests (#4570) ─────────────────────────────────────
941
942    /// `drain_pending` awaits all spawned persist tasks and returns when the set is empty.
943    #[tokio::test]
944    async fn drain_pending_awaits_all_tasks() {
945        use std::sync::atomic::{AtomicU32, Ordering};
946
947        let config = zeph_config::ShadowSentinelConfig::default();
948        let sentinel = make_test_sentinel(config).await;
949
950        let counter = Arc::new(AtomicU32::new(0));
951        for _ in 0..5 {
952            let c = Arc::clone(&counter);
953            sentinel
954                .spawn_persist(async move {
955                    tokio::time::sleep(std::time::Duration::from_millis(10)).await;
956                    c.fetch_add(1, Ordering::Relaxed);
957                })
958                .await;
959        }
960
961        sentinel.drain_pending().await;
962
963        assert_eq!(
964            counter.load(Ordering::Relaxed),
965            5,
966            "drain_pending must join all 5 tasks before returning"
967        );
968    }
969
970    /// When the pending set is at capacity and all running tasks complete before the next
971    /// `spawn_persist`, the new task IS accepted (the set has room after reaping).
972    /// Conversely, if we fill the set, drain it, then overfill past capacity while tasks are
973    /// still running — the implementation drops extras.  We verify the simpler property:
974    /// `spawn_persist` never panics when called repeatedly beyond `MAX_PENDING_WRITES`.
975    #[tokio::test]
976    async fn spawn_persist_beyond_capacity_does_not_panic() {
977        use std::sync::atomic::{AtomicU32, Ordering};
978
979        let config = zeph_config::ShadowSentinelConfig::default();
980        let sentinel = make_test_sentinel(config).await;
981        let counter = Arc::new(AtomicU32::new(0));
982
983        // Spawn twice the capacity; each task completes instantly.
984        // spawn_persist will reap completed tasks between spawns, so most will be accepted.
985        for _ in 0..(MAX_PENDING_WRITES * 2) {
986            let c = Arc::clone(&counter);
987            sentinel
988                .spawn_persist(async move {
989                    c.fetch_add(1, Ordering::Relaxed);
990                })
991                .await;
992        }
993
994        sentinel.drain_pending().await;
995
996        // All tasks (or at least MAX_PENDING_WRITES of them) must have run; none panicked.
997        let ran = counter.load(Ordering::Relaxed);
998        assert!(
999            ran >= u32::try_from(MAX_PENDING_WRITES).unwrap(),
1000            "at least MAX_PENDING_WRITES tasks must complete; ran={ran}"
1001        );
1002    }
1003
1004    // Build a minimal ShadowSentinel with a no-op probe for unit tests.
1005    //
1006    // Opens an in-memory SQLite pool. Store methods are never called in these unit
1007    // tests — they test only classification and counter logic.
1008    async fn make_test_sentinel(config: zeph_config::ShadowSentinelConfig) -> ShadowSentinel {
1009        struct NoopProbe;
1010        impl SafetyProbe for NoopProbe {
1011            fn evaluate<'a>(
1012                &'a self,
1013                _: &'a str,
1014                _: &'a JsonValue,
1015                _: &'a [SentinelEvent],
1016            ) -> std::pin::Pin<Box<dyn std::future::Future<Output = ProbeVerdict> + Send + 'a>>
1017            {
1018                Box::pin(async { ProbeVerdict::Allow })
1019            }
1020        }
1021        let pool = sqlx::sqlite::SqlitePoolOptions::new()
1022            .connect("sqlite::memory:")
1023            .await
1024            .expect("in-memory SQLite pool");
1025        let store = ShadowEventStore::new(pool);
1026        ShadowSentinel::new(store, Box::new(NoopProbe), config, "test-session")
1027    }
1028}