1use crate::providers::ProviderName;
5use serde::{Deserialize, Serialize};
6
7use crate::defaults::default_true;
8
9fn default_max_content_size() -> usize {
14 65_536
15}
16
17#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
20pub struct EmbeddingGuardConfig {
21 #[serde(default)]
23 pub enabled: bool,
24 #[serde(
26 default = "default_embedding_threshold",
27 deserialize_with = "validate_embedding_threshold"
28 )]
29 pub threshold: f64,
30 #[serde(
33 default = "default_embedding_min_samples",
34 deserialize_with = "validate_min_samples"
35 )]
36 pub min_samples: usize,
37 #[serde(default = "default_ema_floor")]
44 pub ema_floor: f32,
45}
46
47fn validate_embedding_threshold<'de, D>(deserializer: D) -> Result<f64, D::Error>
48where
49 D: serde::Deserializer<'de>,
50{
51 let value = <f64 as serde::Deserialize>::deserialize(deserializer)?;
52 if value.is_nan() || value.is_infinite() {
53 return Err(serde::de::Error::custom(
54 "embedding_guard.threshold must be a finite number",
55 ));
56 }
57 if !(value > 0.0 && value <= 1.0) {
58 return Err(serde::de::Error::custom(
59 "embedding_guard.threshold must be in (0.0, 1.0]",
60 ));
61 }
62 Ok(value)
63}
64
65fn validate_min_samples<'de, D>(deserializer: D) -> Result<usize, D::Error>
66where
67 D: serde::Deserializer<'de>,
68{
69 let value = <usize as serde::Deserialize>::deserialize(deserializer)?;
70 if value == 0 {
71 return Err(serde::de::Error::custom(
72 "embedding_guard.min_samples must be >= 1",
73 ));
74 }
75 Ok(value)
76}
77
78fn default_embedding_threshold() -> f64 {
79 0.35
80}
81
82fn default_embedding_min_samples() -> usize {
83 10
84}
85
86fn default_ema_floor() -> f32 {
87 0.01
88}
89
90impl Default for EmbeddingGuardConfig {
91 fn default() -> Self {
92 Self {
93 enabled: false,
94 threshold: default_embedding_threshold(),
95 min_samples: default_embedding_min_samples(),
96 ema_floor: default_ema_floor(),
97 }
98 }
99}
100
101#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
104#[allow(clippy::struct_excessive_bools)]
105pub struct ContentIsolationConfig {
106 #[serde(default = "default_true")]
108 pub enabled: bool,
109
110 #[serde(default = "default_max_content_size")]
112 pub max_content_size: usize,
113
114 #[serde(default = "default_true")]
117 pub flag_injection_patterns: bool,
118
119 #[serde(default = "default_true")]
122 pub spotlight_untrusted: bool,
123
124 #[serde(default)]
126 pub quarantine: QuarantineConfig,
127
128 #[serde(default)]
130 pub embedding_guard: EmbeddingGuardConfig,
131
132 #[serde(default = "default_true")]
137 pub mcp_to_acp_boundary: bool,
138}
139
140impl Default for ContentIsolationConfig {
141 fn default() -> Self {
142 Self {
143 enabled: true,
144 max_content_size: default_max_content_size(),
145 flag_injection_patterns: true,
146 spotlight_untrusted: true,
147 quarantine: QuarantineConfig::default(),
148 embedding_guard: EmbeddingGuardConfig::default(),
149 mcp_to_acp_boundary: true,
150 }
151 }
152}
153
154#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
157pub struct QuarantineConfig {
158 #[serde(default)]
160 pub enabled: bool,
161
162 #[serde(default = "default_quarantine_sources")]
164 pub sources: Vec<String>,
165
166 #[serde(default = "default_quarantine_model")]
168 pub model: String,
169}
170
171fn default_quarantine_sources() -> Vec<String> {
172 vec!["web_scrape".to_owned(), "a2a_message".to_owned()]
173}
174
175fn default_quarantine_model() -> String {
176 "claude".to_owned()
177}
178
179impl Default for QuarantineConfig {
180 fn default() -> Self {
181 Self {
182 enabled: false,
183 sources: default_quarantine_sources(),
184 model: default_quarantine_model(),
185 }
186 }
187}
188
189#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
196pub struct ExfiltrationGuardConfig {
197 #[serde(default = "default_true")]
199 pub block_markdown_images: bool,
200
201 #[serde(default = "default_true")]
203 pub validate_tool_urls: bool,
204
205 #[serde(default = "default_true")]
207 pub guard_memory_writes: bool,
208}
209
210impl Default for ExfiltrationGuardConfig {
211 fn default() -> Self {
212 Self {
213 block_markdown_images: true,
214 validate_tool_urls: true,
215 guard_memory_writes: true,
216 }
217 }
218}
219
220fn default_max_content_bytes() -> usize {
225 4096
226}
227
228fn default_max_entity_name_bytes() -> usize {
229 256
230}
231
232fn default_min_entity_name_bytes() -> usize {
233 3
234}
235
236fn default_max_fact_bytes() -> usize {
237 1024
238}
239
240fn default_max_entities() -> usize {
241 50
242}
243
244fn default_max_edges() -> usize {
245 100
246}
247
248#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
252pub struct MemoryWriteValidationConfig {
253 #[serde(default = "default_true")]
255 pub enabled: bool,
256 #[serde(default = "default_max_content_bytes")]
258 pub max_content_bytes: usize,
259 #[serde(default = "default_min_entity_name_bytes")]
261 pub min_entity_name_bytes: usize,
262 #[serde(default = "default_max_entity_name_bytes")]
264 pub max_entity_name_bytes: usize,
265 #[serde(default = "default_max_fact_bytes")]
267 pub max_fact_bytes: usize,
268 #[serde(default = "default_max_entities")]
270 pub max_entities_per_extraction: usize,
271 #[serde(default = "default_max_edges")]
273 pub max_edges_per_extraction: usize,
274 #[serde(default)]
276 pub forbidden_content_patterns: Vec<String>,
277}
278
279impl Default for MemoryWriteValidationConfig {
280 fn default() -> Self {
281 Self {
282 enabled: true,
283 max_content_bytes: default_max_content_bytes(),
284 min_entity_name_bytes: default_min_entity_name_bytes(),
285 max_entity_name_bytes: default_max_entity_name_bytes(),
286 max_fact_bytes: default_max_fact_bytes(),
287 max_entities_per_extraction: default_max_entities(),
288 max_edges_per_extraction: default_max_edges(),
289 forbidden_content_patterns: Vec::new(),
290 }
291 }
292}
293
294#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
300pub struct CustomPiiPattern {
301 pub name: String,
303 pub pattern: String,
305 #[serde(default = "default_custom_replacement")]
307 pub replacement: String,
308}
309
310fn default_custom_replacement() -> String {
311 "[PII:custom]".to_owned()
312}
313
314#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
318#[allow(clippy::struct_excessive_bools)]
319pub struct PiiFilterConfig {
320 #[serde(default)]
322 pub enabled: bool,
323 #[serde(default = "default_true")]
325 pub filter_email: bool,
326 #[serde(default = "default_true")]
328 pub filter_phone: bool,
329 #[serde(default = "default_true")]
331 pub filter_ssn: bool,
332 #[serde(default = "default_true")]
334 pub filter_credit_card: bool,
335 #[serde(default)]
337 pub custom_patterns: Vec<CustomPiiPattern>,
338}
339
340impl Default for PiiFilterConfig {
341 fn default() -> Self {
342 Self {
343 enabled: false,
344 filter_email: true,
345 filter_phone: true,
346 filter_ssn: true,
347 filter_credit_card: true,
348 custom_patterns: Vec::new(),
349 }
350 }
351}
352
353#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Deserialize, Serialize)]
359#[serde(rename_all = "lowercase")]
360pub enum GuardrailAction {
361 #[default]
363 Block,
364 Warn,
366}
367
368#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Deserialize, Serialize)]
370#[serde(rename_all = "lowercase")]
371pub enum GuardrailFailStrategy {
372 #[default]
374 Closed,
375 Open,
377}
378
379#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
381pub struct GuardrailConfig {
382 #[serde(default)]
384 pub enabled: bool,
385 #[serde(default)]
387 pub provider: Option<String>,
388 #[serde(default)]
390 pub model: Option<String>,
391 #[serde(default = "default_guardrail_timeout_ms")]
393 pub timeout_ms: u64,
394 #[serde(default)]
396 pub action: GuardrailAction,
397 #[serde(default = "default_fail_strategy")]
399 pub fail_strategy: GuardrailFailStrategy,
400 #[serde(default)]
402 pub scan_tool_output: bool,
403 #[serde(default = "default_max_input_chars")]
405 pub max_input_chars: usize,
406}
407fn default_guardrail_timeout_ms() -> u64 {
408 500
409}
410fn default_max_input_chars() -> usize {
411 4096
412}
413fn default_fail_strategy() -> GuardrailFailStrategy {
414 GuardrailFailStrategy::Closed
415}
416impl Default for GuardrailConfig {
417 fn default() -> Self {
418 Self {
419 enabled: false,
420 provider: None,
421 model: None,
422 timeout_ms: default_guardrail_timeout_ms(),
423 action: GuardrailAction::default(),
424 fail_strategy: default_fail_strategy(),
425 scan_tool_output: false,
426 max_input_chars: default_max_input_chars(),
427 }
428 }
429}
430
431#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
441pub struct ResponseVerificationConfig {
442 #[serde(default = "default_true")]
444 pub enabled: bool,
445 #[serde(default)]
451 pub block_on_detection: bool,
452 #[serde(default)]
459 pub verifier_provider: ProviderName,
460}
461
462impl Default for ResponseVerificationConfig {
463 fn default() -> Self {
464 Self {
465 enabled: true,
466 block_on_detection: false,
467 verifier_provider: ProviderName::default(),
468 }
469 }
470}
471
472#[cfg(test)]
473mod tests {
474 use super::*;
475
476 #[test]
477 fn content_isolation_default_mcp_to_acp_boundary_true() {
478 let cfg = ContentIsolationConfig::default();
479 assert!(cfg.mcp_to_acp_boundary);
480 }
481
482 #[test]
483 fn content_isolation_deserialize_mcp_to_acp_boundary_false() {
484 let toml = r"
485 mcp_to_acp_boundary = false
486 ";
487 let cfg: ContentIsolationConfig = toml::from_str(toml).unwrap();
488 assert!(!cfg.mcp_to_acp_boundary);
489 }
490
491 #[test]
492 fn content_isolation_deserialize_absent_defaults_true() {
493 let cfg: ContentIsolationConfig = toml::from_str("").unwrap();
494 assert!(cfg.mcp_to_acp_boundary);
495 }
496
497 fn de_guard(toml: &str) -> Result<EmbeddingGuardConfig, toml::de::Error> {
498 toml::from_str(toml)
499 }
500
501 #[test]
502 fn threshold_valid() {
503 let cfg = de_guard("threshold = 0.35\nmin_samples = 5").unwrap();
504 assert!((cfg.threshold - 0.35).abs() < f64::EPSILON);
505 }
506
507 #[test]
508 fn threshold_one_valid() {
509 let cfg = de_guard("threshold = 1.0\nmin_samples = 1").unwrap();
510 assert!((cfg.threshold - 1.0).abs() < f64::EPSILON);
511 }
512
513 #[test]
514 fn threshold_zero_rejected() {
515 assert!(de_guard("threshold = 0.0\nmin_samples = 1").is_err());
516 }
517
518 #[test]
519 fn threshold_above_one_rejected() {
520 assert!(de_guard("threshold = 1.5\nmin_samples = 1").is_err());
521 }
522
523 #[test]
524 fn threshold_negative_rejected() {
525 assert!(de_guard("threshold = -0.1\nmin_samples = 1").is_err());
526 }
527
528 #[test]
529 fn min_samples_zero_rejected() {
530 assert!(de_guard("threshold = 0.35\nmin_samples = 0").is_err());
531 }
532
533 #[test]
534 fn min_samples_one_valid() {
535 let cfg = de_guard("threshold = 0.35\nmin_samples = 1").unwrap();
536 assert_eq!(cfg.min_samples, 1);
537 }
538}
539
540fn default_causal_threshold() -> f32 {
545 0.7
546}
547
548fn validate_causal_threshold<'de, D>(deserializer: D) -> Result<f32, D::Error>
549where
550 D: serde::Deserializer<'de>,
551{
552 let value = <f32 as serde::Deserialize>::deserialize(deserializer)?;
553 if value.is_nan() || value.is_infinite() {
554 return Err(serde::de::Error::custom(
555 "causal_ipi.threshold must be a finite number",
556 ));
557 }
558 if !(value > 0.0 && value <= 1.0) {
559 return Err(serde::de::Error::custom(
560 "causal_ipi.threshold must be in (0.0, 1.0]",
561 ));
562 }
563 Ok(value)
564}
565
566fn default_probe_max_tokens() -> u32 {
567 100
568}
569
570fn default_probe_timeout_ms() -> u64 {
571 3000
572}
573
574#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
582pub struct CausalIpiConfig {
583 #[serde(default)]
585 pub enabled: bool,
586
587 #[serde(
592 default = "default_causal_threshold",
593 deserialize_with = "validate_causal_threshold"
594 )]
595 pub threshold: f32,
596
597 #[serde(default)]
602 pub provider: Option<String>,
603
604 #[serde(default = "default_probe_max_tokens")]
608 pub probe_max_tokens: u32,
609
610 #[serde(default = "default_probe_timeout_ms")]
614 pub probe_timeout_ms: u64,
615}
616
617impl Default for CausalIpiConfig {
618 fn default() -> Self {
619 Self {
620 enabled: false,
621 threshold: default_causal_threshold(),
622 provider: None,
623 probe_max_tokens: default_probe_max_tokens(),
624 probe_timeout_ms: default_probe_timeout_ms(),
625 }
626 }
627}
628
629#[cfg(test)]
630mod causal_ipi_tests {
631 use super::*;
632
633 #[test]
634 fn causal_ipi_defaults() {
635 let cfg = CausalIpiConfig::default();
636 assert!(!cfg.enabled);
637 assert!((cfg.threshold - 0.7).abs() < 1e-6);
638 assert!(cfg.provider.is_none());
639 assert_eq!(cfg.probe_max_tokens, 100);
640 assert_eq!(cfg.probe_timeout_ms, 3000);
641 }
642
643 #[test]
644 fn causal_ipi_deserialize_enabled() {
645 let toml = r#"
646 enabled = true
647 threshold = 0.8
648 provider = "fast"
649 probe_max_tokens = 150
650 probe_timeout_ms = 5000
651 "#;
652 let cfg: CausalIpiConfig = toml::from_str(toml).unwrap();
653 assert!(cfg.enabled);
654 assert!((cfg.threshold - 0.8).abs() < 1e-6);
655 assert_eq!(cfg.provider.as_deref(), Some("fast"));
656 assert_eq!(cfg.probe_max_tokens, 150);
657 assert_eq!(cfg.probe_timeout_ms, 5000);
658 }
659
660 #[test]
661 fn causal_ipi_threshold_zero_rejected() {
662 let result: Result<CausalIpiConfig, _> = toml::from_str("threshold = 0.0");
663 assert!(result.is_err());
664 }
665
666 #[test]
667 fn causal_ipi_threshold_above_one_rejected() {
668 let result: Result<CausalIpiConfig, _> = toml::from_str("threshold = 1.1");
669 assert!(result.is_err());
670 }
671
672 #[test]
673 fn causal_ipi_threshold_exactly_one_accepted() {
674 let cfg: CausalIpiConfig = toml::from_str("threshold = 1.0").unwrap();
675 assert!((cfg.threshold - 1.0).abs() < 1e-6);
676 }
677}