Skip to main content

zeph_memory/shadow/
mod.rs

1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4//! MAGE shadow memory stream — trajectory-level risk accumulation (spec 004-16).
5//!
6//! [`TrajectoryRiskAccumulator`] maintains a per-session rolling risk score by ingesting
7//! [`AuditSignalType`] events from `zeph-sanitizer`. The score decays exponentially between
8//! turns and is used to gate tool execution when it exceeds a configured threshold.
9//!
10//! When `enabled = false` (default), every method is a zero-cost no-op — no allocations,
11//! no computation.
12//!
13//! # Example
14//!
15//! ```rust
16//! use zeph_memory::shadow::{TrajectoryRiskAccumulator, AuditSignalType, Severity};
17//! use zeph_config::TrajectoryRiskAccumulatorConfig;
18//!
19//! let mut acc = TrajectoryRiskAccumulator::new_noop();
20//! assert_eq!(acc.current_risk(), 0.0);
21//! assert!(!acc.is_blocked());
22//! ```
23
24use std::collections::VecDeque;
25
26use tracing::info_span;
27use zeph_config::TrajectoryRiskAccumulatorConfig;
28
29pub use zeph_common::audit::{AuditSignalType, Severity};
30
31fn signal_type_label(t: AuditSignalType) -> &'static str {
32    match t {
33        AuditSignalType::PolicyViolation => "policy_violation",
34        AuditSignalType::PromptInjectionPattern => "prompt_injection",
35        AuditSignalType::ToolChainAnomaly => "tool_chain_anomaly",
36        AuditSignalType::ConfidenceDrop => "confidence_drop",
37        _ => "unknown",
38    }
39}
40
41fn severity_label(s: Severity) -> &'static str {
42    match s {
43        Severity::Low => "low",
44        Severity::Medium => "medium",
45        Severity::High => "high",
46        _ => "unknown",
47    }
48}
49
50/// A recorded safety signal ingested during a specific turn.
51#[derive(Debug, Clone)]
52pub struct SignalEvent {
53    /// Turn index at which the signal was ingested.
54    pub turn_id: u32,
55    /// Category of the detected signal.
56    pub signal_type: AuditSignalType,
57    /// Severity of the detected signal.
58    pub severity: Severity,
59    /// Computed contribution: `base_weight × severity_multiplier`.
60    pub raw_score: f64,
61}
62
63/// Per-session trajectory risk accumulator (MAGE spec 004-16).
64///
65/// Maintains a rolling `trajectory_risk` score in `[0.0, 1.0]` that accumulates safety
66/// signals with exponential temporal decay. Designed to detect multi-turn attacks that
67/// evade per-turn controls.
68///
69/// When constructed via [`new_noop`][`TrajectoryRiskAccumulator::new_noop`] or when
70/// `config.enabled = false`, **all methods are zero-cost no-ops** — no allocations and
71/// `current_risk()` always returns `0.0`.
72pub struct TrajectoryRiskAccumulator {
73    /// `None` means noop mode — all operations are skipped.
74    config: Option<TrajectoryRiskAccumulatorConfig>,
75    /// Current accumulated risk score, clamped to `[0.0, 1.0]`.
76    trajectory_risk: f64,
77    /// Number of `advance_turn` calls since creation.
78    turn_count: u32,
79    /// Capped ring buffer of the most recent ingested signals.
80    signal_history: VecDeque<SignalEvent>,
81}
82
83impl TrajectoryRiskAccumulator {
84    /// Construct an accumulator that operates as a zero-cost noop.
85    ///
86    /// Use when shadow memory is disabled or during testing scenarios that do not need
87    /// risk tracking. No heap allocation is performed.
88    #[must_use]
89    pub fn new_noop() -> Self {
90        Self {
91            config: None,
92            trajectory_risk: 0.0,
93            turn_count: 0,
94            signal_history: VecDeque::new(),
95        }
96    }
97
98    /// Construct an accumulator from configuration.
99    ///
100    /// When `config.enabled = false`, delegates to [`new_noop`][Self::new_noop] — no
101    /// allocation. When enabled, pre-allocates the signal history ring buffer.
102    #[must_use]
103    pub fn new(config: TrajectoryRiskAccumulatorConfig) -> Self {
104        if !config.enabled {
105            return Self::new_noop();
106        }
107        let cap = config.signal_history_cap;
108        Self {
109            config: Some(config),
110            trajectory_risk: 0.0,
111            turn_count: 0,
112            signal_history: VecDeque::with_capacity(cap.min(1024)),
113        }
114    }
115
116    /// Advance the turn counter and apply exponential decay to the accumulated risk.
117    ///
118    /// Must be called **once per turn, before** [`ingest`][Self::ingest] is called for
119    /// that turn. Decay formula: `risk *= exp(-ln(2) / halflife_turns)`.
120    ///
121    /// No-op when disabled.
122    pub fn advance_turn(&mut self) {
123        let _span = info_span!("memory.shadow.advance_turn").entered();
124        let Some(config) = &self.config else { return };
125        self.turn_count = self.turn_count.saturating_add(1);
126        let halflife = if config.risk_halflife_turns == 0 {
127            tracing::warn!("risk_halflife_turns = 0 is invalid, clamping to 1");
128            1u32
129        } else {
130            config.risk_halflife_turns
131        };
132        let decay = (-std::f64::consts::LN_2 / f64::from(halflife)).exp();
133        self.trajectory_risk *= decay;
134    }
135
136    /// Ingest a safety signal and add its weighted contribution to `trajectory_risk`.
137    ///
138    /// The raw score is `base_weight(signal_type) × severity_multiplier(severity)`.
139    /// After addition, `trajectory_risk` is clamped to `[0.0, 1.0]`. The event is
140    /// appended to the signal history ring buffer; the oldest entry is evicted when
141    /// the buffer is full.
142    ///
143    /// Emits `shadow_memory_signals_total{type, severity}` counter (NFR-007).
144    ///
145    /// No-op when disabled.
146    pub fn ingest(&mut self, signal_type: AuditSignalType, severity: Severity) {
147        let _span = info_span!("memory.shadow.ingest").entered();
148        let Some(config) = &self.config else { return };
149
150        let base_weight = match signal_type {
151            AuditSignalType::PolicyViolation => config.signal_weights.policy_violation,
152            AuditSignalType::PromptInjectionPattern => config.signal_weights.prompt_injection,
153            AuditSignalType::ToolChainAnomaly => config.signal_weights.tool_chain_anomaly,
154            AuditSignalType::ConfidenceDrop => config.signal_weights.confidence_drop,
155            _ => 0.0,
156        };
157        let severity_mult = match severity {
158            Severity::Low => config.severity_multipliers.low,
159            Severity::Medium => config.severity_multipliers.medium,
160            Severity::High => config.severity_multipliers.high,
161            _ => 1.0,
162        };
163        let raw_score = base_weight * severity_mult;
164
165        self.trajectory_risk = (self.trajectory_risk + raw_score).min(1.0);
166
167        let cap = config.signal_history_cap;
168        if self.signal_history.len() >= cap {
169            self.signal_history.pop_front();
170        }
171        self.signal_history.push_back(SignalEvent {
172            turn_id: self.turn_count,
173            signal_type,
174            severity,
175            raw_score,
176        });
177
178        metrics::counter!(
179            "shadow_memory_signals_total",
180            "type" => signal_type_label(signal_type),
181            "severity" => severity_label(severity),
182        )
183        .increment(1);
184    }
185
186    /// Increment `shadow_memory_blocks_total` counter (NFR-007).
187    ///
188    /// Call this once when a tool execution is actually blocked due to trajectory risk.
189    /// Do **not** call on every `is_blocked()` query — only when a block action fires.
190    pub fn record_block(&self) {
191        metrics::counter!("shadow_memory_blocks_total").increment(1);
192    }
193
194    /// Increment `shadow_memory_escalations_total` counter (NFR-007).
195    ///
196    /// Call this once when an escalation-to-human-confirmation is triggered.
197    pub fn record_escalation(&self) {
198        metrics::counter!("shadow_memory_escalations_total").increment(1);
199    }
200
201    /// Returns the current accumulated risk score in `[0.0, 1.0]`.
202    ///
203    /// Always returns `0.0` when disabled.
204    #[must_use]
205    pub fn current_risk(&self) -> f64 {
206        if self.config.is_none() {
207            return 0.0;
208        }
209        self.trajectory_risk
210    }
211
212    /// Returns `true` when `trajectory_risk >= risk_threshold` and shadow memory is enabled.
213    ///
214    /// Always returns `false` when disabled.
215    #[must_use]
216    pub fn is_blocked(&self) -> bool {
217        let Some(config) = &self.config else {
218            return false;
219        };
220        self.trajectory_risk >= config.risk_threshold
221    }
222
223    /// Returns `true` when risk is in `[escalation_threshold, risk_threshold)`.
224    ///
225    /// Always returns `false` when disabled.
226    #[must_use]
227    pub fn should_escalate(&self) -> bool {
228        let Some(config) = &self.config else {
229            return false;
230        };
231        self.trajectory_risk >= config.escalation_threshold
232            && self.trajectory_risk < config.risk_threshold
233    }
234
235    /// Returns the top `n` signals by `raw_score` descending from recent history.
236    #[must_use]
237    pub fn top_signals(&self, n: usize) -> Vec<&SignalEvent> {
238        let mut signals: Vec<&SignalEvent> = self.signal_history.iter().collect();
239        signals.sort_by(|a, b| {
240            b.raw_score
241                .partial_cmp(&a.raw_score)
242                .unwrap_or(std::cmp::Ordering::Equal)
243        });
244        signals.truncate(n);
245        signals
246    }
247
248    /// Resets `trajectory_risk` to zero and clears signal history.
249    ///
250    /// Called on context compaction when `reset_on_compaction = true`. No-op when disabled.
251    pub fn reset(&mut self) {
252        if self.config.is_none() {
253            return;
254        }
255        self.trajectory_risk = 0.0;
256        self.signal_history.clear();
257    }
258
259    /// Returns `true` when shadow memory is enabled (i.e., not in noop mode).
260    #[must_use]
261    pub fn is_enabled(&self) -> bool {
262        self.config.is_some()
263    }
264
265    /// Returns the current turn count.
266    #[must_use]
267    pub fn turn_count(&self) -> u32 {
268        self.turn_count
269    }
270}
271
272#[cfg(test)]
273mod tests {
274    use super::*;
275    use zeph_config::{
276        TrajectoryRiskAccumulatorConfig, TrajectorySeverityMultipliers, TrajectorySignalWeights,
277    };
278
279    fn enabled_config() -> TrajectoryRiskAccumulatorConfig {
280        TrajectoryRiskAccumulatorConfig {
281            enabled: true,
282            risk_threshold: 0.75,
283            escalation_threshold: 0.50,
284            risk_halflife_turns: 10,
285            signal_history_cap: 200,
286            tui_show_risk_gauge: true,
287            reset_on_compaction: false,
288            signal_weights: TrajectorySignalWeights::default(),
289            severity_multipliers: TrajectorySeverityMultipliers::default(),
290        }
291    }
292
293    #[test]
294    fn new_noop_returns_zero_risk() {
295        let acc = TrajectoryRiskAccumulator::new_noop();
296        assert!(acc.current_risk() < f64::EPSILON);
297        assert!(!acc.is_blocked());
298        assert!(!acc.is_enabled());
299    }
300
301    #[test]
302    fn single_signal_below_threshold_not_blocked() {
303        let mut acc = TrajectoryRiskAccumulator::new(enabled_config());
304        acc.advance_turn();
305        // PolicyViolation medium = 0.30 * 1.0 = 0.30 < 0.75
306        acc.ingest(AuditSignalType::PolicyViolation, Severity::Medium);
307        assert!(acc.current_risk() > 0.0);
308        assert!(acc.current_risk() < 0.75);
309        assert!(!acc.is_blocked());
310    }
311
312    #[test]
313    fn multi_turn_chain_accumulates_and_blocks() {
314        let mut acc = TrajectoryRiskAccumulator::new(enabled_config());
315        // PromptInjectionPattern high = 0.50 * 2.0 = 1.0 per signal
316        // After 2 signals (clamped to 1.0), should be blocked
317        for _ in 0..5 {
318            acc.advance_turn();
319            acc.ingest(AuditSignalType::PromptInjectionPattern, Severity::High);
320        }
321        assert!(acc.is_blocked(), "risk={}", acc.current_risk());
322    }
323
324    #[test]
325    fn temporal_decay_reduces_score() {
326        let mut acc = TrajectoryRiskAccumulator::new(enabled_config());
327        acc.advance_turn();
328        acc.ingest(AuditSignalType::PromptInjectionPattern, Severity::High);
329        let risk_after_signal = acc.current_risk();
330        assert!(risk_after_signal > 0.0);
331
332        // Advance 100 turns without new signals — risk should decay significantly
333        for _ in 0..100 {
334            acc.advance_turn();
335        }
336        assert!(
337            acc.current_risk() < risk_after_signal / 2.0,
338            "expected significant decay, got {}",
339            acc.current_risk()
340        );
341    }
342
343    #[test]
344    fn risk_clamped_at_one() {
345        let mut acc = TrajectoryRiskAccumulator::new(enabled_config());
346        for _ in 0..20 {
347            acc.advance_turn();
348            acc.ingest(AuditSignalType::PromptInjectionPattern, Severity::High);
349        }
350        assert!(
351            acc.current_risk() <= 1.0,
352            "trajectory_risk exceeded 1.0: {}",
353            acc.current_risk()
354        );
355    }
356
357    #[test]
358    fn advance_turn_before_ingest_applies_decay() {
359        let mut acc = TrajectoryRiskAccumulator::new(enabled_config());
360        // Seed some risk first
361        acc.advance_turn();
362        acc.ingest(AuditSignalType::PolicyViolation, Severity::High);
363        let risk_t1 = acc.current_risk();
364
365        // Advance a turn (decay applied) before next ingest
366        acc.advance_turn();
367        let risk_after_decay = acc.current_risk();
368
369        // After decay, risk should be strictly less than risk_t1 (no new signals yet)
370        assert!(
371            risk_after_decay < risk_t1,
372            "decay should reduce risk before new ingest: {risk_after_decay} vs {risk_t1}"
373        );
374
375        acc.ingest(AuditSignalType::PolicyViolation, Severity::High);
376        // After ingest, risk should be higher than the decayed value
377        assert!(
378            acc.current_risk() > risk_after_decay,
379            "ingest should increase risk: {} vs {}",
380            acc.current_risk(),
381            risk_after_decay
382        );
383    }
384
385    #[test]
386    fn decay_formula_matches_spec() {
387        // halflife=10, confidence_drop base_weight=0.15, medium severity=1.0
388        // 5 turns: each turn calls advance_turn() then ingest(ConfidenceDrop, Medium)
389        // per-signal contribution = 0.15 * 1.0 = 0.15; sum over 5 turns < 1.0 so no clamping.
390        // After turn 5, the accumulator holds:
391        //   risk = 0.15*d^0 + 0.15*d^1 + 0.15*d^2 + 0.15*d^3 + 0.15*d^4
392        // where d = exp(-ln(2)/10), most recent signal (turn 5) has least decay (d^0).
393        let mut acc = TrajectoryRiskAccumulator::new(enabled_config());
394        for _ in 0..5 {
395            acc.advance_turn();
396            acc.ingest(AuditSignalType::ConfidenceDrop, Severity::Medium);
397        }
398        let decay = (-std::f64::consts::LN_2 / 10.0_f64).exp();
399        // sum_{k=0}^{4} 0.15 * decay^k (most recent turn = k=0, least decay applied)
400        let expected: f64 = (0..5).map(|k| 0.15_f64 * decay.powi(k)).sum();
401        assert!(
402            expected < 1.0,
403            "test precondition: expected sum {expected} must be < 1.0 (no clamping)"
404        );
405        assert!(
406            (acc.current_risk() - expected).abs() < 1e-9,
407            "expected {expected:.12}, got {:.12}",
408            acc.current_risk()
409        );
410    }
411
412    #[test]
413    fn fifty_clean_turns_zero_risk() {
414        let mut acc = TrajectoryRiskAccumulator::new(enabled_config());
415        for _ in 0..50 {
416            acc.advance_turn();
417        }
418        assert!(
419            acc.current_risk() < f64::EPSILON,
420            "no signals → risk must stay 0.0"
421        );
422        assert!(!acc.is_blocked());
423    }
424}