Skip to main content

zeph_config/
sanitizer.rs

1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4use crate::providers::ProviderName;
5use serde::{Deserialize, Serialize};
6
7use crate::defaults::default_true;
8
9// ---------------------------------------------------------------------------
10// ContentIsolationConfig
11// ---------------------------------------------------------------------------
12
13fn default_max_content_size() -> usize {
14    65_536
15}
16
17/// Configuration for the embedding anomaly guard, nested under
18/// `[security.content_isolation.embedding_guard]`.
19#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
20pub struct EmbeddingGuardConfig {
21    /// Enable embedding-based anomaly detection (default: false — opt-in).
22    #[serde(default)]
23    pub enabled: bool,
24    /// Cosine distance threshold above which outputs are flagged as anomalous.
25    #[serde(
26        default = "default_embedding_threshold",
27        deserialize_with = "validate_embedding_threshold"
28    )]
29    pub threshold: f64,
30    /// Minimum clean samples before centroid-based detection activates.
31    /// Before this count, regex fallback is used instead.
32    #[serde(
33        default = "default_embedding_min_samples",
34        deserialize_with = "validate_min_samples"
35    )]
36    pub min_samples: usize,
37    /// EMA alpha floor for centroid updates after stabilization (n >= `min_samples`).
38    ///
39    /// Once the centroid has accumulated `min_samples` clean outputs, each new sample
40    /// can shift it by at most this fraction. Lower values make the centroid more
41    /// resistant to slow drift attacks but slower to adapt to legitimate distribution
42    /// changes. Default: 0.01 (1% per sample).
43    #[serde(default = "default_ema_floor")]
44    pub ema_floor: f32,
45}
46
47fn validate_embedding_threshold<'de, D>(deserializer: D) -> Result<f64, D::Error>
48where
49    D: serde::Deserializer<'de>,
50{
51    let value = <f64 as serde::Deserialize>::deserialize(deserializer)?;
52    if value.is_nan() || value.is_infinite() {
53        return Err(serde::de::Error::custom(
54            "embedding_guard.threshold must be a finite number",
55        ));
56    }
57    if !(value > 0.0 && value <= 1.0) {
58        return Err(serde::de::Error::custom(
59            "embedding_guard.threshold must be in (0.0, 1.0]",
60        ));
61    }
62    Ok(value)
63}
64
65fn validate_min_samples<'de, D>(deserializer: D) -> Result<usize, D::Error>
66where
67    D: serde::Deserializer<'de>,
68{
69    let value = <usize as serde::Deserialize>::deserialize(deserializer)?;
70    if value == 0 {
71        return Err(serde::de::Error::custom(
72            "embedding_guard.min_samples must be >= 1",
73        ));
74    }
75    Ok(value)
76}
77
78fn default_embedding_threshold() -> f64 {
79    0.35
80}
81
82fn default_embedding_min_samples() -> usize {
83    10
84}
85
86fn default_ema_floor() -> f32 {
87    0.01
88}
89
90impl Default for EmbeddingGuardConfig {
91    fn default() -> Self {
92        Self {
93            enabled: false,
94            threshold: default_embedding_threshold(),
95            min_samples: default_embedding_min_samples(),
96            ema_floor: default_ema_floor(),
97        }
98    }
99}
100
101/// Configuration for the content isolation pipeline, nested under
102/// `[security.content_isolation]` in the agent config file.
103#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
104#[allow(clippy::struct_excessive_bools)]
105pub struct ContentIsolationConfig {
106    /// When `false`, the sanitizer is a no-op: content passes through unchanged.
107    #[serde(default = "default_true")]
108    pub enabled: bool,
109
110    /// Maximum byte length of untrusted content before truncation.
111    #[serde(default = "default_max_content_size")]
112    pub max_content_size: usize,
113
114    /// When `true`, injection patterns detected in content are recorded as
115    /// flags and a warning is prepended to the spotlighting wrapper.
116    #[serde(default = "default_true")]
117    pub flag_injection_patterns: bool,
118
119    /// When `true`, untrusted content is wrapped in spotlighting XML delimiters
120    /// that instruct the LLM to treat the enclosed text as data, not instructions.
121    #[serde(default = "default_true")]
122    pub spotlight_untrusted: bool,
123
124    /// Quarantine summarizer configuration.
125    #[serde(default)]
126    pub quarantine: QuarantineConfig,
127
128    /// Embedding anomaly guard configuration.
129    #[serde(default)]
130    pub embedding_guard: EmbeddingGuardConfig,
131
132    /// When `true`, MCP tool results flowing through ACP-serving sessions receive
133    /// unconditional quarantine summarization and cross-boundary audit log entries.
134    /// This prevents confused-deputy attacks where untrusted MCP output influences
135    /// responses served to ACP clients (e.g. IDE integrations).
136    #[serde(default = "default_true")]
137    pub mcp_to_acp_boundary: bool,
138}
139
140impl Default for ContentIsolationConfig {
141    fn default() -> Self {
142        Self {
143            enabled: true,
144            max_content_size: default_max_content_size(),
145            flag_injection_patterns: true,
146            spotlight_untrusted: true,
147            quarantine: QuarantineConfig::default(),
148            embedding_guard: EmbeddingGuardConfig::default(),
149            mcp_to_acp_boundary: true,
150        }
151    }
152}
153
154/// Configuration for the quarantine summarizer, nested under
155/// `[security.content_isolation.quarantine]` in the agent config file.
156#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
157pub struct QuarantineConfig {
158    /// When `false`, quarantine summarization is disabled entirely.
159    #[serde(default)]
160    pub enabled: bool,
161
162    /// Source kinds to route through the quarantine LLM.
163    #[serde(default = "default_quarantine_sources")]
164    pub sources: Vec<String>,
165
166    /// Provider name passed to `create_named_provider`.
167    #[serde(default = "default_quarantine_model")]
168    pub model: String,
169}
170
171fn default_quarantine_sources() -> Vec<String> {
172    vec!["web_scrape".to_owned(), "a2a_message".to_owned()]
173}
174
175fn default_quarantine_model() -> String {
176    "claude".to_owned()
177}
178
179impl Default for QuarantineConfig {
180    fn default() -> Self {
181        Self {
182            enabled: false,
183            sources: default_quarantine_sources(),
184            model: default_quarantine_model(),
185        }
186    }
187}
188
189// ---------------------------------------------------------------------------
190// ExfiltrationGuardConfig
191// ---------------------------------------------------------------------------
192
193/// Configuration for exfiltration guards, nested under
194/// `[security.exfiltration_guard]` in the agent config file.
195#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
196pub struct ExfiltrationGuardConfig {
197    /// Strip external markdown images from LLM output to prevent pixel-tracking exfiltration.
198    #[serde(default = "default_true")]
199    pub block_markdown_images: bool,
200
201    /// Cross-reference tool call arguments against URLs seen in flagged untrusted content.
202    #[serde(default = "default_true")]
203    pub validate_tool_urls: bool,
204
205    /// Skip Qdrant embedding for messages that contained injection-flagged content.
206    #[serde(default = "default_true")]
207    pub guard_memory_writes: bool,
208}
209
210impl Default for ExfiltrationGuardConfig {
211    fn default() -> Self {
212        Self {
213            block_markdown_images: true,
214            validate_tool_urls: true,
215            guard_memory_writes: true,
216        }
217    }
218}
219
220// ---------------------------------------------------------------------------
221// MemoryWriteValidationConfig
222// ---------------------------------------------------------------------------
223
224fn default_max_content_bytes() -> usize {
225    4096
226}
227
228fn default_max_entity_name_bytes() -> usize {
229    256
230}
231
232fn default_min_entity_name_bytes() -> usize {
233    3
234}
235
236fn default_max_fact_bytes() -> usize {
237    1024
238}
239
240fn default_max_entities() -> usize {
241    50
242}
243
244fn default_max_edges() -> usize {
245    100
246}
247
248/// Configuration for memory write validation, nested under `[security.memory_validation]`.
249///
250/// Enabled by default with conservative limits.
251#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
252pub struct MemoryWriteValidationConfig {
253    /// Master switch. When `false`, validation is a no-op.
254    #[serde(default = "default_true")]
255    pub enabled: bool,
256    /// Maximum byte length of content passed to `memory_save`.
257    #[serde(default = "default_max_content_bytes")]
258    pub max_content_bytes: usize,
259    /// Minimum byte length of an entity name in graph extraction.
260    #[serde(default = "default_min_entity_name_bytes")]
261    pub min_entity_name_bytes: usize,
262    /// Maximum byte length of a single entity name in graph extraction.
263    #[serde(default = "default_max_entity_name_bytes")]
264    pub max_entity_name_bytes: usize,
265    /// Maximum byte length of an edge fact string in graph extraction.
266    #[serde(default = "default_max_fact_bytes")]
267    pub max_fact_bytes: usize,
268    /// Maximum number of entities allowed per graph extraction result.
269    #[serde(default = "default_max_entities")]
270    pub max_entities_per_extraction: usize,
271    /// Maximum number of edges allowed per graph extraction result.
272    #[serde(default = "default_max_edges")]
273    pub max_edges_per_extraction: usize,
274    /// Forbidden substring patterns.
275    #[serde(default)]
276    pub forbidden_content_patterns: Vec<String>,
277}
278
279impl Default for MemoryWriteValidationConfig {
280    fn default() -> Self {
281        Self {
282            enabled: true,
283            max_content_bytes: default_max_content_bytes(),
284            min_entity_name_bytes: default_min_entity_name_bytes(),
285            max_entity_name_bytes: default_max_entity_name_bytes(),
286            max_fact_bytes: default_max_fact_bytes(),
287            max_entities_per_extraction: default_max_entities(),
288            max_edges_per_extraction: default_max_edges(),
289            forbidden_content_patterns: Vec::new(),
290        }
291    }
292}
293
294// ---------------------------------------------------------------------------
295// PiiFilterConfig
296// ---------------------------------------------------------------------------
297
298/// A single user-defined PII pattern.
299#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
300pub struct CustomPiiPattern {
301    /// Human-readable name used in the replacement label.
302    pub name: String,
303    /// Regular expression pattern.
304    pub pattern: String,
305    /// Replacement text. Defaults to `[PII:custom]`.
306    #[serde(default = "default_custom_replacement")]
307    pub replacement: String,
308}
309
310fn default_custom_replacement() -> String {
311    "[PII:custom]".to_owned()
312}
313
314/// Configuration for the PII filter, nested under `[security.pii_filter]` in the config file.
315///
316/// Disabled by default — opt-in to avoid unexpected data loss.
317#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
318#[allow(clippy::struct_excessive_bools)]
319pub struct PiiFilterConfig {
320    /// Master switch. When `false`, the filter is a no-op.
321    #[serde(default)]
322    pub enabled: bool,
323    /// Scrub email addresses.
324    #[serde(default = "default_true")]
325    pub filter_email: bool,
326    /// Scrub US phone numbers.
327    #[serde(default = "default_true")]
328    pub filter_phone: bool,
329    /// Scrub US Social Security Numbers.
330    #[serde(default = "default_true")]
331    pub filter_ssn: bool,
332    /// Scrub credit card numbers (16-digit patterns).
333    #[serde(default = "default_true")]
334    pub filter_credit_card: bool,
335    /// Custom regex patterns to add on top of the built-ins.
336    #[serde(default)]
337    pub custom_patterns: Vec<CustomPiiPattern>,
338}
339
340impl Default for PiiFilterConfig {
341    fn default() -> Self {
342        Self {
343            enabled: false,
344            filter_email: true,
345            filter_phone: true,
346            filter_ssn: true,
347            filter_credit_card: true,
348            custom_patterns: Vec::new(),
349        }
350    }
351}
352
353// ---------------------------------------------------------------------------
354// GuardrailConfig
355// ---------------------------------------------------------------------------
356
357/// What happens when the guardrail flags input.
358#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Deserialize, Serialize)]
359#[serde(rename_all = "lowercase")]
360pub enum GuardrailAction {
361    /// Block the input and return an error message to the user.
362    #[default]
363    Block,
364    /// Allow the input but emit a warning message.
365    Warn,
366}
367
368/// Behavior on timeout or LLM error.
369#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Deserialize, Serialize)]
370#[serde(rename_all = "lowercase")]
371pub enum GuardrailFailStrategy {
372    /// Block input on timeout/error (safe default for security-sensitive deployments).
373    #[default]
374    Closed,
375    /// Allow input on timeout/error (for availability-sensitive deployments).
376    Open,
377}
378
379/// Configuration for the LLM-based guardrail, nested under `[security.guardrail]`.
380#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
381pub struct GuardrailConfig {
382    /// Enable the guardrail (default: false).
383    #[serde(default)]
384    pub enabled: bool,
385    /// Provider to use for guardrail classification (e.g. `"ollama"`, `"claude"`).
386    #[serde(default)]
387    pub provider: Option<String>,
388    /// Model to use (e.g. `"llama-guard-3:1b"`).
389    #[serde(default)]
390    pub model: Option<String>,
391    /// Timeout for each guardrail LLM call in milliseconds (default: 500).
392    #[serde(default = "default_guardrail_timeout_ms")]
393    pub timeout_ms: u64,
394    /// Action to take when a message is flagged (default: block).
395    #[serde(default)]
396    pub action: GuardrailAction,
397    /// What to do on timeout or LLM error (default: closed — block).
398    #[serde(default = "default_fail_strategy")]
399    pub fail_strategy: GuardrailFailStrategy,
400    /// When `true`, also scan tool outputs before they enter message history (default: false).
401    #[serde(default)]
402    pub scan_tool_output: bool,
403    /// Maximum number of characters to send to the guard model (default: 4096).
404    #[serde(default = "default_max_input_chars")]
405    pub max_input_chars: usize,
406}
407fn default_guardrail_timeout_ms() -> u64 {
408    500
409}
410fn default_max_input_chars() -> usize {
411    4096
412}
413fn default_fail_strategy() -> GuardrailFailStrategy {
414    GuardrailFailStrategy::Closed
415}
416impl Default for GuardrailConfig {
417    fn default() -> Self {
418        Self {
419            enabled: false,
420            provider: None,
421            model: None,
422            timeout_ms: default_guardrail_timeout_ms(),
423            action: GuardrailAction::default(),
424            fail_strategy: default_fail_strategy(),
425            scan_tool_output: false,
426            max_input_chars: default_max_input_chars(),
427        }
428    }
429}
430
431// ---------------------------------------------------------------------------
432// ResponseVerificationConfig
433// ---------------------------------------------------------------------------
434
435/// Configuration for post-LLM response verification, nested under
436/// `[security.response_verification]` in the agent config file.
437///
438/// Scans LLM responses for injected instruction patterns before tool dispatch.
439/// This is defense-in-depth layer 3 (after input sanitization and pre-execution verification).
440#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
441pub struct ResponseVerificationConfig {
442    /// Enable post-LLM response verification (default: true).
443    #[serde(default = "default_true")]
444    pub enabled: bool,
445    /// Block tool dispatch when injection patterns are detected (default: false).
446    ///
447    /// When `false`, flagged responses are logged and shown in the TUI SEC panel
448    /// but still delivered. When `true`, the response is suppressed and the user
449    /// is notified.
450    #[serde(default)]
451    pub block_on_detection: bool,
452    /// Optional LLM provider for async deep verification of flagged responses.
453    ///
454    /// When set: suspicious responses are delivered immediately with a `[FLAGGED]`
455    /// annotation, and background LLM verification runs asynchronously. The verifier
456    /// receives a sanitized summary (via `QuarantinedSummarizer`) to prevent recursive
457    /// injection. Empty string = disabled (regex-only verification).
458    #[serde(default)]
459    pub verifier_provider: ProviderName,
460}
461
462impl Default for ResponseVerificationConfig {
463    fn default() -> Self {
464        Self {
465            enabled: true,
466            block_on_detection: false,
467            verifier_provider: ProviderName::default(),
468        }
469    }
470}
471
472#[cfg(test)]
473mod tests {
474    use super::*;
475
476    #[test]
477    fn content_isolation_default_mcp_to_acp_boundary_true() {
478        let cfg = ContentIsolationConfig::default();
479        assert!(cfg.mcp_to_acp_boundary);
480    }
481
482    #[test]
483    fn content_isolation_deserialize_mcp_to_acp_boundary_false() {
484        let toml = r"
485            mcp_to_acp_boundary = false
486        ";
487        let cfg: ContentIsolationConfig = toml::from_str(toml).unwrap();
488        assert!(!cfg.mcp_to_acp_boundary);
489    }
490
491    #[test]
492    fn content_isolation_deserialize_absent_defaults_true() {
493        let cfg: ContentIsolationConfig = toml::from_str("").unwrap();
494        assert!(cfg.mcp_to_acp_boundary);
495    }
496
497    fn de_guard(toml: &str) -> Result<EmbeddingGuardConfig, toml::de::Error> {
498        toml::from_str(toml)
499    }
500
501    #[test]
502    fn threshold_valid() {
503        let cfg = de_guard("threshold = 0.35\nmin_samples = 5").unwrap();
504        assert!((cfg.threshold - 0.35).abs() < f64::EPSILON);
505    }
506
507    #[test]
508    fn threshold_one_valid() {
509        let cfg = de_guard("threshold = 1.0\nmin_samples = 1").unwrap();
510        assert!((cfg.threshold - 1.0).abs() < f64::EPSILON);
511    }
512
513    #[test]
514    fn threshold_zero_rejected() {
515        assert!(de_guard("threshold = 0.0\nmin_samples = 1").is_err());
516    }
517
518    #[test]
519    fn threshold_above_one_rejected() {
520        assert!(de_guard("threshold = 1.5\nmin_samples = 1").is_err());
521    }
522
523    #[test]
524    fn threshold_negative_rejected() {
525        assert!(de_guard("threshold = -0.1\nmin_samples = 1").is_err());
526    }
527
528    #[test]
529    fn min_samples_zero_rejected() {
530        assert!(de_guard("threshold = 0.35\nmin_samples = 0").is_err());
531    }
532
533    #[test]
534    fn min_samples_one_valid() {
535        let cfg = de_guard("threshold = 0.35\nmin_samples = 1").unwrap();
536        assert_eq!(cfg.min_samples, 1);
537    }
538}
539
540// ---------------------------------------------------------------------------
541// CausalIpiConfig
542// ---------------------------------------------------------------------------
543
544fn default_causal_threshold() -> f32 {
545    0.7
546}
547
548fn validate_causal_threshold<'de, D>(deserializer: D) -> Result<f32, D::Error>
549where
550    D: serde::Deserializer<'de>,
551{
552    let value = <f32 as serde::Deserialize>::deserialize(deserializer)?;
553    if value.is_nan() || value.is_infinite() {
554        return Err(serde::de::Error::custom(
555            "causal_ipi.threshold must be a finite number",
556        ));
557    }
558    if !(value > 0.0 && value <= 1.0) {
559        return Err(serde::de::Error::custom(
560            "causal_ipi.threshold must be in (0.0, 1.0]",
561        ));
562    }
563    Ok(value)
564}
565
566fn default_probe_max_tokens() -> u32 {
567    100
568}
569
570fn default_probe_timeout_ms() -> u64 {
571    3000
572}
573
574/// Temporal causal IPI analysis at tool-return boundaries.
575///
576/// When enabled, the agent generates behavioral probes before and after tool batch dispatch
577/// and compares them to detect behavioral deviation caused by injected instructions in
578/// tool outputs. Probes are per-batch (2 LLM calls total), not per individual tool.
579///
580/// Config section: `[security.causal_ipi]`
581#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
582pub struct CausalIpiConfig {
583    /// Master switch. Default: false (opt-in).
584    #[serde(default)]
585    pub enabled: bool,
586
587    /// Causal attribution score threshold for flagging. Range: (0.0, 1.0]. Default 0.7.
588    ///
589    /// Scores above this value trigger a WARN log, metric increment, and `SecurityEvent`.
590    /// Content is never blocked — this is an observation layer only.
591    #[serde(
592        default = "default_causal_threshold",
593        deserialize_with = "validate_causal_threshold"
594    )]
595    pub threshold: f32,
596
597    /// LLM provider name from `[[llm.providers]]` for probe calls.
598    ///
599    /// Should reference a fast/cheap provider — probes run on every tool batch return.
600    /// When `None`, falls back to the agent's default provider.
601    #[serde(default)]
602    pub provider: Option<String>,
603
604    /// Maximum tokens for each probe response. Limits cost per probe call. Default: 100.
605    ///
606    /// Two probes per batch = max `2 * probe_max_tokens` output tokens per tool batch.
607    #[serde(default = "default_probe_max_tokens")]
608    pub probe_max_tokens: u32,
609
610    /// Timeout in milliseconds for each individual probe LLM call. Default: 3000.
611    ///
612    /// On timeout: WARN log, skip causal analysis for the batch (never block).
613    #[serde(default = "default_probe_timeout_ms")]
614    pub probe_timeout_ms: u64,
615}
616
617impl Default for CausalIpiConfig {
618    fn default() -> Self {
619        Self {
620            enabled: false,
621            threshold: default_causal_threshold(),
622            provider: None,
623            probe_max_tokens: default_probe_max_tokens(),
624            probe_timeout_ms: default_probe_timeout_ms(),
625        }
626    }
627}
628
629#[cfg(test)]
630mod causal_ipi_tests {
631    use super::*;
632
633    #[test]
634    fn causal_ipi_defaults() {
635        let cfg = CausalIpiConfig::default();
636        assert!(!cfg.enabled);
637        assert!((cfg.threshold - 0.7).abs() < 1e-6);
638        assert!(cfg.provider.is_none());
639        assert_eq!(cfg.probe_max_tokens, 100);
640        assert_eq!(cfg.probe_timeout_ms, 3000);
641    }
642
643    #[test]
644    fn causal_ipi_deserialize_enabled() {
645        let toml = r#"
646            enabled = true
647            threshold = 0.8
648            provider = "fast"
649            probe_max_tokens = 150
650            probe_timeout_ms = 5000
651        "#;
652        let cfg: CausalIpiConfig = toml::from_str(toml).unwrap();
653        assert!(cfg.enabled);
654        assert!((cfg.threshold - 0.8).abs() < 1e-6);
655        assert_eq!(cfg.provider.as_deref(), Some("fast"));
656        assert_eq!(cfg.probe_max_tokens, 150);
657        assert_eq!(cfg.probe_timeout_ms, 5000);
658    }
659
660    #[test]
661    fn causal_ipi_threshold_zero_rejected() {
662        let result: Result<CausalIpiConfig, _> = toml::from_str("threshold = 0.0");
663        assert!(result.is_err());
664    }
665
666    #[test]
667    fn causal_ipi_threshold_above_one_rejected() {
668        let result: Result<CausalIpiConfig, _> = toml::from_str("threshold = 1.1");
669        assert!(result.is_err());
670    }
671
672    #[test]
673    fn causal_ipi_threshold_exactly_one_accepted() {
674        let cfg: CausalIpiConfig = toml::from_str("threshold = 1.0").unwrap();
675        assert!((cfg.threshold - 1.0).abs() < 1e-6);
676    }
677}