zeph_memory/shadow/
mod.rs1use std::collections::VecDeque;
25
26use tracing::info_span;
27use zeph_config::TrajectoryRiskAccumulatorConfig;
28
29pub use zeph_common::audit::{AuditSignalType, Severity};
30
31fn signal_type_label(t: AuditSignalType) -> &'static str {
32 match t {
33 AuditSignalType::PolicyViolation => "policy_violation",
34 AuditSignalType::PromptInjectionPattern => "prompt_injection",
35 AuditSignalType::ToolChainAnomaly => "tool_chain_anomaly",
36 AuditSignalType::ConfidenceDrop => "confidence_drop",
37 _ => "unknown",
38 }
39}
40
41fn severity_label(s: Severity) -> &'static str {
42 match s {
43 Severity::Low => "low",
44 Severity::Medium => "medium",
45 Severity::High => "high",
46 _ => "unknown",
47 }
48}
49
50#[derive(Debug, Clone)]
52pub struct SignalEvent {
53 pub turn_id: u32,
55 pub signal_type: AuditSignalType,
57 pub severity: Severity,
59 pub raw_score: f64,
61}
62
63pub struct TrajectoryRiskAccumulator {
73 config: Option<TrajectoryRiskAccumulatorConfig>,
75 trajectory_risk: f64,
77 turn_count: u32,
79 signal_history: VecDeque<SignalEvent>,
81}
82
83impl TrajectoryRiskAccumulator {
84 #[must_use]
89 pub fn new_noop() -> Self {
90 Self {
91 config: None,
92 trajectory_risk: 0.0,
93 turn_count: 0,
94 signal_history: VecDeque::new(),
95 }
96 }
97
98 #[must_use]
103 pub fn new(config: TrajectoryRiskAccumulatorConfig) -> Self {
104 if !config.enabled {
105 return Self::new_noop();
106 }
107 let cap = config.signal_history_cap;
108 Self {
109 config: Some(config),
110 trajectory_risk: 0.0,
111 turn_count: 0,
112 signal_history: VecDeque::with_capacity(cap.min(1024)),
113 }
114 }
115
116 pub fn advance_turn(&mut self) {
123 let _span = info_span!("memory.shadow.advance_turn").entered();
124 let Some(config) = &self.config else { return };
125 self.turn_count = self.turn_count.saturating_add(1);
126 let halflife = if config.risk_halflife_turns == 0 {
127 tracing::warn!("risk_halflife_turns = 0 is invalid, clamping to 1");
128 1u32
129 } else {
130 config.risk_halflife_turns
131 };
132 let decay = (-std::f64::consts::LN_2 / f64::from(halflife)).exp();
133 self.trajectory_risk *= decay;
134 }
135
136 pub fn ingest(&mut self, signal_type: AuditSignalType, severity: Severity) {
147 let _span = info_span!("memory.shadow.ingest").entered();
148 let Some(config) = &self.config else { return };
149
150 let base_weight = match signal_type {
151 AuditSignalType::PolicyViolation => config.signal_weights.policy_violation,
152 AuditSignalType::PromptInjectionPattern => config.signal_weights.prompt_injection,
153 AuditSignalType::ToolChainAnomaly => config.signal_weights.tool_chain_anomaly,
154 AuditSignalType::ConfidenceDrop => config.signal_weights.confidence_drop,
155 _ => 0.0,
156 };
157 let severity_mult = match severity {
158 Severity::Low => config.severity_multipliers.low,
159 Severity::Medium => config.severity_multipliers.medium,
160 Severity::High => config.severity_multipliers.high,
161 _ => 1.0,
162 };
163 let raw_score = base_weight * severity_mult;
164
165 self.trajectory_risk = (self.trajectory_risk + raw_score).min(1.0);
166
167 let cap = config.signal_history_cap;
168 if self.signal_history.len() >= cap {
169 self.signal_history.pop_front();
170 }
171 self.signal_history.push_back(SignalEvent {
172 turn_id: self.turn_count,
173 signal_type,
174 severity,
175 raw_score,
176 });
177
178 metrics::counter!(
179 "shadow_memory_signals_total",
180 "type" => signal_type_label(signal_type),
181 "severity" => severity_label(severity),
182 )
183 .increment(1);
184 }
185
186 pub fn record_block(&self) {
191 metrics::counter!("shadow_memory_blocks_total").increment(1);
192 }
193
194 pub fn record_escalation(&self) {
198 metrics::counter!("shadow_memory_escalations_total").increment(1);
199 }
200
201 #[must_use]
205 pub fn current_risk(&self) -> f64 {
206 if self.config.is_none() {
207 return 0.0;
208 }
209 self.trajectory_risk
210 }
211
212 #[must_use]
216 pub fn is_blocked(&self) -> bool {
217 let Some(config) = &self.config else {
218 return false;
219 };
220 self.trajectory_risk >= config.risk_threshold
221 }
222
223 #[must_use]
227 pub fn should_escalate(&self) -> bool {
228 let Some(config) = &self.config else {
229 return false;
230 };
231 self.trajectory_risk >= config.escalation_threshold
232 && self.trajectory_risk < config.risk_threshold
233 }
234
235 #[must_use]
237 pub fn top_signals(&self, n: usize) -> Vec<&SignalEvent> {
238 let mut signals: Vec<&SignalEvent> = self.signal_history.iter().collect();
239 signals.sort_by(|a, b| {
240 b.raw_score
241 .partial_cmp(&a.raw_score)
242 .unwrap_or(std::cmp::Ordering::Equal)
243 });
244 signals.truncate(n);
245 signals
246 }
247
248 pub fn reset(&mut self) {
252 if self.config.is_none() {
253 return;
254 }
255 self.trajectory_risk = 0.0;
256 self.signal_history.clear();
257 }
258
259 #[must_use]
261 pub fn is_enabled(&self) -> bool {
262 self.config.is_some()
263 }
264
265 #[must_use]
267 pub fn turn_count(&self) -> u32 {
268 self.turn_count
269 }
270}
271
272#[cfg(test)]
273mod tests {
274 use super::*;
275 use zeph_config::{
276 TrajectoryRiskAccumulatorConfig, TrajectorySeverityMultipliers, TrajectorySignalWeights,
277 };
278
279 fn enabled_config() -> TrajectoryRiskAccumulatorConfig {
280 TrajectoryRiskAccumulatorConfig {
281 enabled: true,
282 risk_threshold: 0.75,
283 escalation_threshold: 0.50,
284 risk_halflife_turns: 10,
285 signal_history_cap: 200,
286 tui_show_risk_gauge: true,
287 reset_on_compaction: false,
288 signal_weights: TrajectorySignalWeights::default(),
289 severity_multipliers: TrajectorySeverityMultipliers::default(),
290 }
291 }
292
293 #[test]
294 fn new_noop_returns_zero_risk() {
295 let acc = TrajectoryRiskAccumulator::new_noop();
296 assert!(acc.current_risk() < f64::EPSILON);
297 assert!(!acc.is_blocked());
298 assert!(!acc.is_enabled());
299 }
300
301 #[test]
302 fn single_signal_below_threshold_not_blocked() {
303 let mut acc = TrajectoryRiskAccumulator::new(enabled_config());
304 acc.advance_turn();
305 acc.ingest(AuditSignalType::PolicyViolation, Severity::Medium);
307 assert!(acc.current_risk() > 0.0);
308 assert!(acc.current_risk() < 0.75);
309 assert!(!acc.is_blocked());
310 }
311
312 #[test]
313 fn multi_turn_chain_accumulates_and_blocks() {
314 let mut acc = TrajectoryRiskAccumulator::new(enabled_config());
315 for _ in 0..5 {
318 acc.advance_turn();
319 acc.ingest(AuditSignalType::PromptInjectionPattern, Severity::High);
320 }
321 assert!(acc.is_blocked(), "risk={}", acc.current_risk());
322 }
323
324 #[test]
325 fn temporal_decay_reduces_score() {
326 let mut acc = TrajectoryRiskAccumulator::new(enabled_config());
327 acc.advance_turn();
328 acc.ingest(AuditSignalType::PromptInjectionPattern, Severity::High);
329 let risk_after_signal = acc.current_risk();
330 assert!(risk_after_signal > 0.0);
331
332 for _ in 0..100 {
334 acc.advance_turn();
335 }
336 assert!(
337 acc.current_risk() < risk_after_signal / 2.0,
338 "expected significant decay, got {}",
339 acc.current_risk()
340 );
341 }
342
343 #[test]
344 fn risk_clamped_at_one() {
345 let mut acc = TrajectoryRiskAccumulator::new(enabled_config());
346 for _ in 0..20 {
347 acc.advance_turn();
348 acc.ingest(AuditSignalType::PromptInjectionPattern, Severity::High);
349 }
350 assert!(
351 acc.current_risk() <= 1.0,
352 "trajectory_risk exceeded 1.0: {}",
353 acc.current_risk()
354 );
355 }
356
357 #[test]
358 fn advance_turn_before_ingest_applies_decay() {
359 let mut acc = TrajectoryRiskAccumulator::new(enabled_config());
360 acc.advance_turn();
362 acc.ingest(AuditSignalType::PolicyViolation, Severity::High);
363 let risk_t1 = acc.current_risk();
364
365 acc.advance_turn();
367 let risk_after_decay = acc.current_risk();
368
369 assert!(
371 risk_after_decay < risk_t1,
372 "decay should reduce risk before new ingest: {risk_after_decay} vs {risk_t1}"
373 );
374
375 acc.ingest(AuditSignalType::PolicyViolation, Severity::High);
376 assert!(
378 acc.current_risk() > risk_after_decay,
379 "ingest should increase risk: {} vs {}",
380 acc.current_risk(),
381 risk_after_decay
382 );
383 }
384
385 #[test]
386 fn decay_formula_matches_spec() {
387 let mut acc = TrajectoryRiskAccumulator::new(enabled_config());
394 for _ in 0..5 {
395 acc.advance_turn();
396 acc.ingest(AuditSignalType::ConfidenceDrop, Severity::Medium);
397 }
398 let decay = (-std::f64::consts::LN_2 / 10.0_f64).exp();
399 let expected: f64 = (0..5).map(|k| 0.15_f64 * decay.powi(k)).sum();
401 assert!(
402 expected < 1.0,
403 "test precondition: expected sum {expected} must be < 1.0 (no clamping)"
404 );
405 assert!(
406 (acc.current_risk() - expected).abs() < 1e-9,
407 "expected {expected:.12}, got {:.12}",
408 acc.current_risk()
409 );
410 }
411
412 #[test]
413 fn fifty_clean_turns_zero_risk() {
414 let mut acc = TrajectoryRiskAccumulator::new(enabled_config());
415 for _ in 0..50 {
416 acc.advance_turn();
417 }
418 assert!(
419 acc.current_risk() < f64::EPSILON,
420 "no signals → risk must stay 0.0"
421 );
422 assert!(!acc.is_blocked());
423 }
424}