Skip to main content

zeph_config/
sanitizer.rs

1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4use crate::providers::ProviderName;
5use serde::{Deserialize, Serialize};
6
7use crate::defaults::default_true;
8
9// ---------------------------------------------------------------------------
10// ContentIsolationConfig
11// ---------------------------------------------------------------------------
12
13fn default_max_content_size() -> usize {
14    65_536
15}
16
17/// Configuration for the embedding anomaly guard, nested under
18/// `[security.content_isolation.embedding_guard]`.
19#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
20pub struct EmbeddingGuardConfig {
21    /// Enable embedding-based anomaly detection (default: false — opt-in).
22    #[serde(default)]
23    pub enabled: bool,
24    /// Cosine distance threshold above which outputs are flagged as anomalous.
25    #[serde(
26        default = "default_embedding_threshold",
27        deserialize_with = "validate_embedding_threshold"
28    )]
29    pub threshold: f64,
30    /// Minimum clean samples before centroid-based detection activates.
31    /// Before this count, regex fallback is used instead.
32    #[serde(
33        default = "default_embedding_min_samples",
34        deserialize_with = "validate_min_samples"
35    )]
36    pub min_samples: usize,
37    /// EMA alpha floor for centroid updates after stabilization (n >= `min_samples`).
38    ///
39    /// Once the centroid has accumulated `min_samples` clean outputs, each new sample
40    /// can shift it by at most this fraction. Lower values make the centroid more
41    /// resistant to slow drift attacks but slower to adapt to legitimate distribution
42    /// changes. Default: 0.01 (1% per sample).
43    #[serde(default = "default_ema_floor")]
44    pub ema_floor: f32,
45}
46
47fn validate_embedding_threshold<'de, D>(deserializer: D) -> Result<f64, D::Error>
48where
49    D: serde::Deserializer<'de>,
50{
51    let value = <f64 as serde::Deserialize>::deserialize(deserializer)?;
52    if value.is_nan() || value.is_infinite() {
53        return Err(serde::de::Error::custom(
54            "embedding_guard.threshold must be a finite number",
55        ));
56    }
57    if !(value > 0.0 && value <= 1.0) {
58        return Err(serde::de::Error::custom(
59            "embedding_guard.threshold must be in (0.0, 1.0]",
60        ));
61    }
62    Ok(value)
63}
64
65fn validate_min_samples<'de, D>(deserializer: D) -> Result<usize, D::Error>
66where
67    D: serde::Deserializer<'de>,
68{
69    let value = <usize as serde::Deserialize>::deserialize(deserializer)?;
70    if value == 0 {
71        return Err(serde::de::Error::custom(
72            "embedding_guard.min_samples must be >= 1",
73        ));
74    }
75    Ok(value)
76}
77
78fn default_embedding_threshold() -> f64 {
79    0.35
80}
81
82fn default_embedding_min_samples() -> usize {
83    10
84}
85
86fn default_ema_floor() -> f32 {
87    0.01
88}
89
90impl Default for EmbeddingGuardConfig {
91    fn default() -> Self {
92        Self {
93            enabled: false,
94            threshold: default_embedding_threshold(),
95            min_samples: default_embedding_min_samples(),
96            ema_floor: default_ema_floor(),
97        }
98    }
99}
100
101/// Configuration for the content isolation pipeline, nested under
102/// `[security.content_isolation]` in the agent config file.
103#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
104#[allow(clippy::struct_excessive_bools)] // config struct — boolean flags are idiomatic for TOML-deserialized configuration
105pub struct ContentIsolationConfig {
106    /// When `false`, the sanitizer is a no-op: content passes through unchanged.
107    #[serde(default = "default_true")]
108    pub enabled: bool,
109
110    /// Maximum byte length of untrusted content before truncation.
111    #[serde(default = "default_max_content_size")]
112    pub max_content_size: usize,
113
114    /// When `true`, injection patterns detected in content are recorded as
115    /// flags and a warning is prepended to the spotlighting wrapper.
116    #[serde(default = "default_true")]
117    pub flag_injection_patterns: bool,
118
119    /// When `true`, untrusted content is wrapped in spotlighting XML delimiters
120    /// that instruct the LLM to treat the enclosed text as data, not instructions.
121    #[serde(default = "default_true")]
122    pub spotlight_untrusted: bool,
123
124    /// Quarantine summarizer configuration.
125    #[serde(default)]
126    pub quarantine: QuarantineConfig,
127
128    /// Embedding anomaly guard configuration.
129    #[serde(default)]
130    pub embedding_guard: EmbeddingGuardConfig,
131
132    /// When `true`, MCP tool results flowing through ACP-serving sessions receive
133    /// unconditional quarantine summarization and cross-boundary audit log entries.
134    /// This prevents confused-deputy attacks where untrusted MCP output influences
135    /// responses served to ACP clients (e.g. IDE integrations).
136    #[serde(default = "default_true")]
137    pub mcp_to_acp_boundary: bool,
138
139    /// NLI entailment check stage configuration.
140    #[serde(default)]
141    pub nli: NliConfig,
142
143    /// PAAC secret placeholder masking configuration.
144    #[serde(default)]
145    pub secret_masking: SecretMaskingConfig,
146}
147
148impl Default for ContentIsolationConfig {
149    fn default() -> Self {
150        Self {
151            enabled: true,
152            max_content_size: default_max_content_size(),
153            flag_injection_patterns: true,
154            spotlight_untrusted: true,
155            quarantine: QuarantineConfig::default(),
156            embedding_guard: EmbeddingGuardConfig::default(),
157            mcp_to_acp_boundary: true,
158            nli: NliConfig::default(),
159            secret_masking: SecretMaskingConfig::default(),
160        }
161    }
162}
163
164/// Configuration for the SONAR NLI entailment check stage, nested under
165/// `[security.content_isolation.nli]` in the agent config file.
166///
167/// When `enabled = false` (the default), the NLI stage is skipped entirely.
168#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
169pub struct NliConfig {
170    /// Enable NLI entailment-based injection detection (default: false — opt-in).
171    #[serde(default)]
172    pub enabled: bool,
173
174    /// Provider name from `[[llm.providers]]` to use for NLI inference.
175    ///
176    /// An empty [`ProviderName`] falls back to the default provider. Prefer a fast, cheap model.
177    #[serde(default)]
178    pub provider: ProviderName,
179
180    /// Entailment score threshold above which content is flagged (default: 0.75).
181    #[serde(default = "default_nli_threshold")]
182    pub threshold: f32,
183
184    /// Maximum milliseconds to wait for the NLI provider response (default: 5000).
185    #[serde(default = "default_nli_timeout_ms")]
186    pub timeout_ms: u64,
187
188    /// Maximum characters of content sent to the NLI provider (default: 2048).
189    #[serde(default = "default_nli_max_content_len")]
190    pub max_content_len: usize,
191}
192
193fn default_nli_threshold() -> f32 {
194    0.75
195}
196
197fn default_nli_timeout_ms() -> u64 {
198    5000
199}
200
201fn default_nli_max_content_len() -> usize {
202    2048
203}
204
205impl Default for NliConfig {
206    fn default() -> Self {
207        Self {
208            enabled: false,
209            provider: ProviderName::default(),
210            threshold: default_nli_threshold(),
211            timeout_ms: default_nli_timeout_ms(),
212            max_content_len: default_nli_max_content_len(),
213        }
214    }
215}
216
217/// Configuration for PAAC secret placeholder masking, nested under
218/// `[security.secret_masking]` in the agent config file.
219///
220/// When `enabled = false` (the default), vault secrets are not masked.
221#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
222pub struct SecretMaskingConfig {
223    /// Enable secret placeholder masking (default: false — opt-in).
224    #[serde(default)]
225    pub enabled: bool,
226
227    /// Minimum secret byte length to be eligible for masking (default: 8).
228    ///
229    /// Secrets shorter than this value are not substituted to avoid false matches
230    /// on common short strings.
231    #[serde(default = "default_min_secret_len")]
232    pub min_secret_len: usize,
233}
234
235fn default_min_secret_len() -> usize {
236    8
237}
238
239impl Default for SecretMaskingConfig {
240    fn default() -> Self {
241        Self {
242            enabled: false,
243            min_secret_len: default_min_secret_len(),
244        }
245    }
246}
247
248/// Configuration for the quarantine summarizer, nested under
249/// `[security.content_isolation.quarantine]` in the agent config file.
250#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
251pub struct QuarantineConfig {
252    /// When `false`, quarantine summarization is disabled entirely.
253    #[serde(default)]
254    pub enabled: bool,
255
256    /// Source kinds to route through the quarantine LLM.
257    #[serde(default = "default_quarantine_sources")]
258    pub sources: Vec<String>,
259
260    /// Provider name passed to `create_named_provider`.
261    #[serde(default = "default_quarantine_model")]
262    pub model: String,
263
264    /// Maximum time in milliseconds to wait for the quarantine LLM to respond.
265    ///
266    /// When the LLM does not respond within this window, `extract_facts` returns a timeout
267    /// error so the agent can recover rather than stalling indefinitely.
268    /// Defaults to 30 000 ms (30 s).
269    #[serde(default = "default_quarantine_timeout_ms")]
270    pub timeout_ms: u64,
271}
272
273fn default_quarantine_sources() -> Vec<String> {
274    vec!["web_scrape".to_owned(), "a2a_message".to_owned()]
275}
276
277fn default_quarantine_model() -> String {
278    "claude".to_owned()
279}
280
281fn default_quarantine_timeout_ms() -> u64 {
282    30_000
283}
284
285impl Default for QuarantineConfig {
286    fn default() -> Self {
287        Self {
288            enabled: false,
289            sources: default_quarantine_sources(),
290            model: default_quarantine_model(),
291            timeout_ms: default_quarantine_timeout_ms(),
292        }
293    }
294}
295
296// ---------------------------------------------------------------------------
297// ExfiltrationGuardConfig
298// ---------------------------------------------------------------------------
299
300/// Configuration for exfiltration guards, nested under
301/// `[security.exfiltration_guard]` in the agent config file.
302#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
303pub struct ExfiltrationGuardConfig {
304    /// Strip external markdown images from LLM output to prevent pixel-tracking exfiltration.
305    #[serde(default = "default_true")]
306    pub block_markdown_images: bool,
307
308    /// Cross-reference tool call arguments against URLs seen in flagged untrusted content.
309    #[serde(default = "default_true")]
310    pub validate_tool_urls: bool,
311
312    /// Skip Qdrant embedding for messages that contained injection-flagged content.
313    #[serde(default = "default_true")]
314    pub guard_memory_writes: bool,
315}
316
317impl Default for ExfiltrationGuardConfig {
318    fn default() -> Self {
319        Self {
320            block_markdown_images: true,
321            validate_tool_urls: true,
322            guard_memory_writes: true,
323        }
324    }
325}
326
327// ---------------------------------------------------------------------------
328// MemoryWriteValidationConfig
329// ---------------------------------------------------------------------------
330
331fn default_max_content_bytes() -> usize {
332    4096
333}
334
335fn default_max_entity_name_bytes() -> usize {
336    256
337}
338
339fn default_min_entity_name_bytes() -> usize {
340    3
341}
342
343fn default_max_fact_bytes() -> usize {
344    1024
345}
346
347fn default_max_entities() -> usize {
348    50
349}
350
351fn default_max_edges() -> usize {
352    100
353}
354
355/// Configuration for memory write validation, nested under `[security.memory_validation]`.
356///
357/// Enabled by default with conservative limits.
358#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
359pub struct MemoryWriteValidationConfig {
360    /// Master switch. When `false`, validation is a no-op.
361    #[serde(default = "default_true")]
362    pub enabled: bool,
363    /// Maximum byte length of content passed to `memory_save`.
364    #[serde(default = "default_max_content_bytes")]
365    pub max_content_bytes: usize,
366    /// Minimum byte length of an entity name in graph extraction.
367    #[serde(default = "default_min_entity_name_bytes")]
368    pub min_entity_name_bytes: usize,
369    /// Maximum byte length of a single entity name in graph extraction.
370    #[serde(default = "default_max_entity_name_bytes")]
371    pub max_entity_name_bytes: usize,
372    /// Maximum byte length of an edge fact string in graph extraction.
373    #[serde(default = "default_max_fact_bytes")]
374    pub max_fact_bytes: usize,
375    /// Maximum number of entities allowed per graph extraction result.
376    #[serde(default = "default_max_entities")]
377    pub max_entities_per_extraction: usize,
378    /// Maximum number of edges allowed per graph extraction result.
379    #[serde(default = "default_max_edges")]
380    pub max_edges_per_extraction: usize,
381    /// Forbidden substring patterns.
382    #[serde(default)]
383    pub forbidden_content_patterns: Vec<String>,
384}
385
386impl Default for MemoryWriteValidationConfig {
387    fn default() -> Self {
388        Self {
389            enabled: true,
390            max_content_bytes: default_max_content_bytes(),
391            min_entity_name_bytes: default_min_entity_name_bytes(),
392            max_entity_name_bytes: default_max_entity_name_bytes(),
393            max_fact_bytes: default_max_fact_bytes(),
394            max_entities_per_extraction: default_max_entities(),
395            max_edges_per_extraction: default_max_edges(),
396            forbidden_content_patterns: Vec::new(),
397        }
398    }
399}
400
401// ---------------------------------------------------------------------------
402// PiiFilterConfig
403// ---------------------------------------------------------------------------
404
405/// A single user-defined PII pattern.
406#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
407pub struct CustomPiiPattern {
408    /// Human-readable name used in the replacement label.
409    pub name: String,
410    /// Regular expression pattern.
411    pub pattern: String,
412    /// Replacement text. Defaults to `[PII:custom]`.
413    #[serde(default = "default_custom_replacement")]
414    pub replacement: String,
415}
416
417fn default_custom_replacement() -> String {
418    "[PII:custom]".to_owned()
419}
420
421/// Configuration for the PII filter, nested under `[security.pii_filter]` in the config file.
422///
423/// Disabled by default — opt-in to avoid unexpected data loss.
424#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
425#[allow(clippy::struct_excessive_bools)] // config struct — boolean flags are idiomatic for TOML-deserialized configuration
426pub struct PiiFilterConfig {
427    /// Master switch. When `false`, the filter is a no-op.
428    #[serde(default)]
429    pub enabled: bool,
430    /// Scrub email addresses.
431    #[serde(default = "default_true")]
432    pub filter_email: bool,
433    /// Scrub US phone numbers.
434    #[serde(default = "default_true")]
435    pub filter_phone: bool,
436    /// Scrub US Social Security Numbers.
437    #[serde(default = "default_true")]
438    pub filter_ssn: bool,
439    /// Scrub credit card numbers (16-digit patterns).
440    #[serde(default = "default_true")]
441    pub filter_credit_card: bool,
442    /// Custom regex patterns to add on top of the built-ins.
443    #[serde(default)]
444    pub custom_patterns: Vec<CustomPiiPattern>,
445}
446
447impl Default for PiiFilterConfig {
448    fn default() -> Self {
449        Self {
450            enabled: false,
451            filter_email: true,
452            filter_phone: true,
453            filter_ssn: true,
454            filter_credit_card: true,
455            custom_patterns: Vec::new(),
456        }
457    }
458}
459
460// ---------------------------------------------------------------------------
461// GuardrailConfig
462// ---------------------------------------------------------------------------
463
464/// What happens when the guardrail flags input.
465#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Deserialize, Serialize)]
466#[serde(rename_all = "lowercase")]
467#[non_exhaustive]
468pub enum GuardrailAction {
469    /// Block the input and return an error message to the user.
470    #[default]
471    Block,
472    /// Allow the input but emit a warning message.
473    Warn,
474}
475
476/// Behavior on timeout or LLM error.
477#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Deserialize, Serialize)]
478#[serde(rename_all = "lowercase")]
479#[non_exhaustive]
480pub enum GuardrailFailStrategy {
481    /// Block input on timeout/error (safe default for security-sensitive deployments).
482    #[default]
483    Closed,
484    /// Allow input on timeout/error (for availability-sensitive deployments).
485    Open,
486}
487
488/// Configuration for the LLM-based guardrail, nested under `[security.guardrail]`.
489#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
490pub struct GuardrailConfig {
491    /// Enable the guardrail (default: false).
492    #[serde(default)]
493    pub enabled: bool,
494    /// Provider to use for guardrail classification (e.g. `"ollama"`, `"claude"`).
495    #[serde(default)]
496    pub provider: Option<String>,
497    /// Model to use (e.g. `"llama-guard-3:1b"`).
498    #[serde(default)]
499    pub model: Option<String>,
500    /// Timeout for each guardrail LLM call in milliseconds (default: 500).
501    #[serde(default = "default_guardrail_timeout_ms")]
502    pub timeout_ms: u64,
503    /// Action to take when a message is flagged (default: block).
504    #[serde(default)]
505    pub action: GuardrailAction,
506    /// What to do on timeout or LLM error (default: closed — block).
507    #[serde(default = "default_fail_strategy")]
508    pub fail_strategy: GuardrailFailStrategy,
509    /// When `true`, also scan tool outputs before they enter message history (default: false).
510    #[serde(default)]
511    pub scan_tool_output: bool,
512    /// Maximum number of characters to send to the guard model (default: 4096).
513    #[serde(default = "default_max_input_chars")]
514    pub max_input_chars: usize,
515}
516fn default_guardrail_timeout_ms() -> u64 {
517    500
518}
519fn default_max_input_chars() -> usize {
520    4096
521}
522fn default_fail_strategy() -> GuardrailFailStrategy {
523    GuardrailFailStrategy::Closed
524}
525impl Default for GuardrailConfig {
526    fn default() -> Self {
527        Self {
528            enabled: false,
529            provider: None,
530            model: None,
531            timeout_ms: default_guardrail_timeout_ms(),
532            action: GuardrailAction::default(),
533            fail_strategy: default_fail_strategy(),
534            scan_tool_output: false,
535            max_input_chars: default_max_input_chars(),
536        }
537    }
538}
539
540// ---------------------------------------------------------------------------
541// ResponseVerificationConfig
542// ---------------------------------------------------------------------------
543
544/// Configuration for post-LLM response verification, nested under
545/// `[security.response_verification]` in the agent config file.
546///
547/// Scans LLM responses for injected instruction patterns before tool dispatch.
548/// This is defense-in-depth layer 3 (after input sanitization and pre-execution verification).
549#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
550pub struct ResponseVerificationConfig {
551    /// Enable post-LLM response verification (default: true).
552    #[serde(default = "default_true")]
553    pub enabled: bool,
554    /// Block tool dispatch when injection patterns are detected (default: false).
555    ///
556    /// When `false`, flagged responses are logged and shown in the TUI SEC panel
557    /// but still delivered. When `true`, the response is suppressed and the user
558    /// is notified.
559    #[serde(default)]
560    pub block_on_detection: bool,
561    /// Optional LLM provider for async deep verification of flagged responses.
562    ///
563    /// When set: suspicious responses are delivered immediately with a `[FLAGGED]`
564    /// annotation, and background LLM verification runs asynchronously. The verifier
565    /// receives a sanitized summary (via `QuarantinedSummarizer`) to prevent recursive
566    /// injection. Empty string = disabled (regex-only verification).
567    #[serde(default)]
568    pub verifier_provider: ProviderName,
569}
570
571impl Default for ResponseVerificationConfig {
572    fn default() -> Self {
573        Self {
574            enabled: true,
575            block_on_detection: false,
576            verifier_provider: ProviderName::default(),
577        }
578    }
579}
580
581#[cfg(test)]
582mod tests {
583    use super::*;
584
585    #[test]
586    fn content_isolation_default_mcp_to_acp_boundary_true() {
587        let cfg = ContentIsolationConfig::default();
588        assert!(cfg.mcp_to_acp_boundary);
589    }
590
591    #[test]
592    fn content_isolation_deserialize_mcp_to_acp_boundary_false() {
593        let toml = r"
594            mcp_to_acp_boundary = false
595        ";
596        let cfg: ContentIsolationConfig = toml::from_str(toml).unwrap();
597        assert!(!cfg.mcp_to_acp_boundary);
598    }
599
600    #[test]
601    fn content_isolation_deserialize_absent_defaults_true() {
602        let cfg: ContentIsolationConfig = toml::from_str("").unwrap();
603        assert!(cfg.mcp_to_acp_boundary);
604    }
605
606    fn de_guard(toml: &str) -> Result<EmbeddingGuardConfig, toml::de::Error> {
607        toml::from_str(toml)
608    }
609
610    #[test]
611    fn threshold_valid() {
612        let cfg = de_guard("threshold = 0.35\nmin_samples = 5").unwrap();
613        assert!((cfg.threshold - 0.35).abs() < f64::EPSILON);
614    }
615
616    #[test]
617    fn threshold_one_valid() {
618        let cfg = de_guard("threshold = 1.0\nmin_samples = 1").unwrap();
619        assert!((cfg.threshold - 1.0).abs() < f64::EPSILON);
620    }
621
622    #[test]
623    fn threshold_zero_rejected() {
624        assert!(de_guard("threshold = 0.0\nmin_samples = 1").is_err());
625    }
626
627    #[test]
628    fn threshold_above_one_rejected() {
629        assert!(de_guard("threshold = 1.5\nmin_samples = 1").is_err());
630    }
631
632    #[test]
633    fn threshold_negative_rejected() {
634        assert!(de_guard("threshold = -0.1\nmin_samples = 1").is_err());
635    }
636
637    #[test]
638    fn min_samples_zero_rejected() {
639        assert!(de_guard("threshold = 0.35\nmin_samples = 0").is_err());
640    }
641
642    #[test]
643    fn min_samples_one_valid() {
644        let cfg = de_guard("threshold = 0.35\nmin_samples = 1").unwrap();
645        assert_eq!(cfg.min_samples, 1);
646    }
647}
648
649// ---------------------------------------------------------------------------
650// CausalIpiConfig
651// ---------------------------------------------------------------------------
652
653fn default_causal_threshold() -> f32 {
654    0.7
655}
656
657fn validate_causal_threshold<'de, D>(deserializer: D) -> Result<f32, D::Error>
658where
659    D: serde::Deserializer<'de>,
660{
661    let value = <f32 as serde::Deserialize>::deserialize(deserializer)?;
662    if value.is_nan() || value.is_infinite() {
663        return Err(serde::de::Error::custom(
664            "causal_ipi.threshold must be a finite number",
665        ));
666    }
667    if !(value > 0.0 && value <= 1.0) {
668        return Err(serde::de::Error::custom(
669            "causal_ipi.threshold must be in (0.0, 1.0]",
670        ));
671    }
672    Ok(value)
673}
674
675fn default_probe_max_tokens() -> u32 {
676    100
677}
678
679fn default_probe_timeout_ms() -> u64 {
680    3000
681}
682
683/// Temporal causal IPI analysis at tool-return boundaries.
684///
685/// When enabled, the agent generates behavioral probes before and after tool batch dispatch
686/// and compares them to detect behavioral deviation caused by injected instructions in
687/// tool outputs. Probes are per-batch (2 LLM calls total), not per individual tool.
688///
689/// Config section: `[security.causal_ipi]`
690#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
691pub struct CausalIpiConfig {
692    /// Master switch. Default: false (opt-in).
693    #[serde(default)]
694    pub enabled: bool,
695
696    /// Causal attribution score threshold for flagging. Range: (0.0, 1.0]. Default 0.7.
697    ///
698    /// Scores above this value trigger a WARN log, metric increment, and `SecurityEvent`.
699    /// Content is never blocked — this is an observation layer only.
700    #[serde(
701        default = "default_causal_threshold",
702        deserialize_with = "validate_causal_threshold"
703    )]
704    pub threshold: f32,
705
706    /// LLM provider name from `[[llm.providers]]` for probe calls.
707    ///
708    /// Should reference a fast/cheap provider — probes run on every tool batch return.
709    /// When `None`, falls back to the agent's default provider.
710    #[serde(default)]
711    pub provider: Option<String>,
712
713    /// Maximum tokens for each probe response. Limits cost per probe call. Default: 100.
714    ///
715    /// Two probes per batch = max `2 * probe_max_tokens` output tokens per tool batch.
716    #[serde(default = "default_probe_max_tokens")]
717    pub probe_max_tokens: u32,
718
719    /// Timeout in milliseconds for each individual probe LLM call. Default: 3000.
720    ///
721    /// On timeout: WARN log, skip causal analysis for the batch (never block).
722    #[serde(default = "default_probe_timeout_ms")]
723    pub probe_timeout_ms: u64,
724
725    /// Shadow memory configuration for cross-turn trajectory analysis.
726    #[serde(default)]
727    pub shadow_memory: ShadowMemoryConfig,
728}
729
730impl Default for CausalIpiConfig {
731    fn default() -> Self {
732        Self {
733            enabled: false,
734            threshold: default_causal_threshold(),
735            provider: None,
736            probe_max_tokens: default_probe_max_tokens(),
737            probe_timeout_ms: default_probe_timeout_ms(),
738            shadow_memory: ShadowMemoryConfig::default(),
739        }
740    }
741}
742
743// ---------------------------------------------------------------------------
744// ShadowMemoryConfig
745// ---------------------------------------------------------------------------
746
747fn default_shadow_window() -> usize {
748    8
749}
750
751fn default_shadow_max_events() -> usize {
752    64
753}
754
755fn default_shadow_drift_threshold() -> f32 {
756    0.6
757}
758
759fn validate_shadow_window<'de, D>(deserializer: D) -> Result<usize, D::Error>
760where
761    D: serde::Deserializer<'de>,
762{
763    let value = <usize as serde::Deserialize>::deserialize(deserializer)?;
764    if value == 0 {
765        return Err(serde::de::Error::custom(
766            "shadow_memory.window_size must be >= 1",
767        ));
768    }
769    Ok(value)
770}
771
772fn validate_shadow_max_events<'de, D>(deserializer: D) -> Result<usize, D::Error>
773where
774    D: serde::Deserializer<'de>,
775{
776    let value = <usize as serde::Deserialize>::deserialize(deserializer)?;
777    if value == 0 {
778        return Err(serde::de::Error::custom(
779            "shadow_memory.max_events must be >= 1",
780        ));
781    }
782    Ok(value)
783}
784
785fn validate_shadow_drift_threshold<'de, D>(deserializer: D) -> Result<f32, D::Error>
786where
787    D: serde::Deserializer<'de>,
788{
789    let value = <f32 as serde::Deserialize>::deserialize(deserializer)?;
790    if value.is_nan() || value.is_infinite() {
791        return Err(serde::de::Error::custom(
792            "shadow_memory.drift_threshold must be a finite number",
793        ));
794    }
795    if !(value > 0.0 && value <= 1.0) {
796        return Err(serde::de::Error::custom(
797            "shadow_memory.drift_threshold must be in (0.0, 1.0]",
798        ));
799    }
800    Ok(value)
801}
802
803/// Per-session append-only event store for cross-turn trajectory analysis.
804///
805/// Detects multi-turn attacks that distribute payload across several turns —
806/// invisible to the stateless [`CausalIpiConfig`] single-batch analysis.
807///
808/// Config section: `[security.causal_ipi.shadow_memory]`
809///
810/// # Examples
811///
812/// ```toml
813/// [security.causal_ipi.shadow_memory]
814/// enabled = true
815/// window_size = 8
816/// max_events = 64
817/// drift_threshold = 0.6
818/// ```
819#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
820pub struct ShadowMemoryConfig {
821    /// Enable shadow memory trajectory tracking. Default: false.
822    #[serde(default)]
823    pub enabled: bool,
824
825    /// Sliding window size for drift computation. Must be >= 1. Default: 8.
826    #[serde(
827        default = "default_shadow_window",
828        deserialize_with = "validate_shadow_window"
829    )]
830    pub window_size: usize,
831
832    /// Maximum events retained before oldest are evicted. Must be >= 1. Default: 64.
833    #[serde(
834        default = "default_shadow_max_events",
835        deserialize_with = "validate_shadow_max_events"
836    )]
837    pub max_events: usize,
838
839    /// Goal drift score threshold for flagging. Range: (0.0, 1.0]. Default: 0.6.
840    #[serde(
841        default = "default_shadow_drift_threshold",
842        deserialize_with = "validate_shadow_drift_threshold"
843    )]
844    pub drift_threshold: f32,
845}
846
847impl Default for ShadowMemoryConfig {
848    fn default() -> Self {
849        Self {
850            enabled: false,
851            window_size: default_shadow_window(),
852            max_events: default_shadow_max_events(),
853            drift_threshold: default_shadow_drift_threshold(),
854        }
855    }
856}
857
858#[cfg(test)]
859mod causal_ipi_tests {
860    use super::*;
861
862    #[test]
863    fn causal_ipi_defaults() {
864        let cfg = CausalIpiConfig::default();
865        assert!(!cfg.enabled);
866        assert!((cfg.threshold - 0.7).abs() < 1e-6);
867        assert!(cfg.provider.is_none());
868        assert_eq!(cfg.probe_max_tokens, 100);
869        assert_eq!(cfg.probe_timeout_ms, 3000);
870    }
871
872    #[test]
873    fn causal_ipi_deserialize_enabled() {
874        let toml = r#"
875            enabled = true
876            threshold = 0.8
877            provider = "fast"
878            probe_max_tokens = 150
879            probe_timeout_ms = 5000
880        "#;
881        let cfg: CausalIpiConfig = toml::from_str(toml).unwrap();
882        assert!(cfg.enabled);
883        assert!((cfg.threshold - 0.8).abs() < 1e-6);
884        assert_eq!(cfg.provider.as_deref(), Some("fast"));
885        assert_eq!(cfg.probe_max_tokens, 150);
886        assert_eq!(cfg.probe_timeout_ms, 5000);
887    }
888
889    #[test]
890    fn causal_ipi_threshold_zero_rejected() {
891        let result: Result<CausalIpiConfig, _> = toml::from_str("threshold = 0.0");
892        assert!(result.is_err());
893    }
894
895    #[test]
896    fn causal_ipi_threshold_above_one_rejected() {
897        let result: Result<CausalIpiConfig, _> = toml::from_str("threshold = 1.1");
898        assert!(result.is_err());
899    }
900
901    #[test]
902    fn causal_ipi_threshold_exactly_one_accepted() {
903        let cfg: CausalIpiConfig = toml::from_str("threshold = 1.0").unwrap();
904        assert!((cfg.threshold - 1.0).abs() < 1e-6);
905    }
906}
907
908#[cfg(test)]
909mod shadow_memory_config_tests {
910    use super::*;
911
912    #[test]
913    fn shadow_memory_defaults() {
914        let cfg = ShadowMemoryConfig::default();
915        assert!(!cfg.enabled);
916        assert_eq!(cfg.window_size, 8);
917        assert_eq!(cfg.max_events, 64);
918        assert!((cfg.drift_threshold - 0.6).abs() < 1e-6);
919    }
920
921    #[test]
922    fn shadow_memory_window_zero_rejected() {
923        let result: Result<ShadowMemoryConfig, _> = toml::from_str("window_size = 0");
924        assert!(result.is_err());
925    }
926
927    #[test]
928    fn shadow_memory_max_events_zero_rejected() {
929        let result: Result<ShadowMemoryConfig, _> = toml::from_str("max_events = 0");
930        assert!(result.is_err());
931    }
932
933    #[test]
934    fn shadow_memory_drift_threshold_zero_rejected() {
935        let result: Result<ShadowMemoryConfig, _> = toml::from_str("drift_threshold = 0.0");
936        assert!(result.is_err());
937    }
938
939    #[test]
940    fn shadow_memory_drift_threshold_above_one_rejected() {
941        let result: Result<ShadowMemoryConfig, _> = toml::from_str("drift_threshold = 1.1");
942        assert!(result.is_err());
943    }
944
945    #[test]
946    fn shadow_memory_drift_threshold_exactly_one_accepted() {
947        let cfg: ShadowMemoryConfig = toml::from_str("drift_threshold = 1.0").unwrap();
948        assert!((cfg.drift_threshold - 1.0).abs() < 1e-6);
949    }
950
951    #[test]
952    fn shadow_memory_full_deserialization() {
953        let toml = r"
954            enabled = true
955            window_size = 4
956            max_events = 32
957            drift_threshold = 0.8
958        ";
959        let cfg: ShadowMemoryConfig = toml::from_str(toml).unwrap();
960        assert!(cfg.enabled);
961        assert_eq!(cfg.window_size, 4);
962        assert_eq!(cfg.max_events, 32);
963        assert!((cfg.drift_threshold - 0.8).abs() < 1e-6);
964    }
965}