1use serde::{Deserialize, Serialize};
5
6use crate::defaults::default_true;
7
8fn default_max_content_size() -> usize {
13 65_536
14}
15
16#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
19pub struct EmbeddingGuardConfig {
20 #[serde(default)]
22 pub enabled: bool,
23 #[serde(
25 default = "default_embedding_threshold",
26 deserialize_with = "validate_embedding_threshold"
27 )]
28 pub threshold: f64,
29 #[serde(
32 default = "default_embedding_min_samples",
33 deserialize_with = "validate_min_samples"
34 )]
35 pub min_samples: usize,
36 #[serde(default = "default_ema_floor")]
43 pub ema_floor: f32,
44}
45
46fn validate_embedding_threshold<'de, D>(deserializer: D) -> Result<f64, D::Error>
47where
48 D: serde::Deserializer<'de>,
49{
50 let value = <f64 as serde::Deserialize>::deserialize(deserializer)?;
51 if value.is_nan() || value.is_infinite() {
52 return Err(serde::de::Error::custom(
53 "embedding_guard.threshold must be a finite number",
54 ));
55 }
56 if !(value > 0.0 && value <= 1.0) {
57 return Err(serde::de::Error::custom(
58 "embedding_guard.threshold must be in (0.0, 1.0]",
59 ));
60 }
61 Ok(value)
62}
63
64fn validate_min_samples<'de, D>(deserializer: D) -> Result<usize, D::Error>
65where
66 D: serde::Deserializer<'de>,
67{
68 let value = <usize as serde::Deserialize>::deserialize(deserializer)?;
69 if value == 0 {
70 return Err(serde::de::Error::custom(
71 "embedding_guard.min_samples must be >= 1",
72 ));
73 }
74 Ok(value)
75}
76
77fn default_embedding_threshold() -> f64 {
78 0.35
79}
80
81fn default_embedding_min_samples() -> usize {
82 10
83}
84
85fn default_ema_floor() -> f32 {
86 0.01
87}
88
89impl Default for EmbeddingGuardConfig {
90 fn default() -> Self {
91 Self {
92 enabled: false,
93 threshold: default_embedding_threshold(),
94 min_samples: default_embedding_min_samples(),
95 ema_floor: default_ema_floor(),
96 }
97 }
98}
99
100#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
103#[allow(clippy::struct_excessive_bools)]
104pub struct ContentIsolationConfig {
105 #[serde(default = "default_true")]
107 pub enabled: bool,
108
109 #[serde(default = "default_max_content_size")]
111 pub max_content_size: usize,
112
113 #[serde(default = "default_true")]
116 pub flag_injection_patterns: bool,
117
118 #[serde(default = "default_true")]
121 pub spotlight_untrusted: bool,
122
123 #[serde(default)]
125 pub quarantine: QuarantineConfig,
126
127 #[serde(default)]
129 pub embedding_guard: EmbeddingGuardConfig,
130
131 #[serde(default = "default_true")]
136 pub mcp_to_acp_boundary: bool,
137}
138
139impl Default for ContentIsolationConfig {
140 fn default() -> Self {
141 Self {
142 enabled: true,
143 max_content_size: default_max_content_size(),
144 flag_injection_patterns: true,
145 spotlight_untrusted: true,
146 quarantine: QuarantineConfig::default(),
147 embedding_guard: EmbeddingGuardConfig::default(),
148 mcp_to_acp_boundary: true,
149 }
150 }
151}
152
153#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
156pub struct QuarantineConfig {
157 #[serde(default)]
159 pub enabled: bool,
160
161 #[serde(default = "default_quarantine_sources")]
163 pub sources: Vec<String>,
164
165 #[serde(default = "default_quarantine_model")]
167 pub model: String,
168}
169
170fn default_quarantine_sources() -> Vec<String> {
171 vec!["web_scrape".to_owned(), "a2a_message".to_owned()]
172}
173
174fn default_quarantine_model() -> String {
175 "claude".to_owned()
176}
177
178impl Default for QuarantineConfig {
179 fn default() -> Self {
180 Self {
181 enabled: false,
182 sources: default_quarantine_sources(),
183 model: default_quarantine_model(),
184 }
185 }
186}
187
188#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
195pub struct ExfiltrationGuardConfig {
196 #[serde(default = "default_true")]
198 pub block_markdown_images: bool,
199
200 #[serde(default = "default_true")]
202 pub validate_tool_urls: bool,
203
204 #[serde(default = "default_true")]
206 pub guard_memory_writes: bool,
207}
208
209impl Default for ExfiltrationGuardConfig {
210 fn default() -> Self {
211 Self {
212 block_markdown_images: true,
213 validate_tool_urls: true,
214 guard_memory_writes: true,
215 }
216 }
217}
218
219fn default_max_content_bytes() -> usize {
224 4096
225}
226
227fn default_max_entity_name_bytes() -> usize {
228 256
229}
230
231fn default_min_entity_name_bytes() -> usize {
232 3
233}
234
235fn default_max_fact_bytes() -> usize {
236 1024
237}
238
239fn default_max_entities() -> usize {
240 50
241}
242
243fn default_max_edges() -> usize {
244 100
245}
246
247#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
251pub struct MemoryWriteValidationConfig {
252 #[serde(default = "default_true")]
254 pub enabled: bool,
255 #[serde(default = "default_max_content_bytes")]
257 pub max_content_bytes: usize,
258 #[serde(default = "default_min_entity_name_bytes")]
260 pub min_entity_name_bytes: usize,
261 #[serde(default = "default_max_entity_name_bytes")]
263 pub max_entity_name_bytes: usize,
264 #[serde(default = "default_max_fact_bytes")]
266 pub max_fact_bytes: usize,
267 #[serde(default = "default_max_entities")]
269 pub max_entities_per_extraction: usize,
270 #[serde(default = "default_max_edges")]
272 pub max_edges_per_extraction: usize,
273 #[serde(default)]
275 pub forbidden_content_patterns: Vec<String>,
276}
277
278impl Default for MemoryWriteValidationConfig {
279 fn default() -> Self {
280 Self {
281 enabled: true,
282 max_content_bytes: default_max_content_bytes(),
283 min_entity_name_bytes: default_min_entity_name_bytes(),
284 max_entity_name_bytes: default_max_entity_name_bytes(),
285 max_fact_bytes: default_max_fact_bytes(),
286 max_entities_per_extraction: default_max_entities(),
287 max_edges_per_extraction: default_max_edges(),
288 forbidden_content_patterns: Vec::new(),
289 }
290 }
291}
292
293#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
299pub struct CustomPiiPattern {
300 pub name: String,
302 pub pattern: String,
304 #[serde(default = "default_custom_replacement")]
306 pub replacement: String,
307}
308
309fn default_custom_replacement() -> String {
310 "[PII:custom]".to_owned()
311}
312
313#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
317#[allow(clippy::struct_excessive_bools)]
318pub struct PiiFilterConfig {
319 #[serde(default)]
321 pub enabled: bool,
322 #[serde(default = "default_true")]
324 pub filter_email: bool,
325 #[serde(default = "default_true")]
327 pub filter_phone: bool,
328 #[serde(default = "default_true")]
330 pub filter_ssn: bool,
331 #[serde(default = "default_true")]
333 pub filter_credit_card: bool,
334 #[serde(default)]
336 pub custom_patterns: Vec<CustomPiiPattern>,
337}
338
339impl Default for PiiFilterConfig {
340 fn default() -> Self {
341 Self {
342 enabled: false,
343 filter_email: true,
344 filter_phone: true,
345 filter_ssn: true,
346 filter_credit_card: true,
347 custom_patterns: Vec::new(),
348 }
349 }
350}
351
352#[cfg(feature = "guardrail")]
358#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Deserialize, Serialize)]
359#[serde(rename_all = "lowercase")]
360pub enum GuardrailAction {
361 #[default]
363 Block,
364 Warn,
366}
367
368#[cfg(feature = "guardrail")]
370#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Deserialize, Serialize)]
371#[serde(rename_all = "lowercase")]
372pub enum GuardrailFailStrategy {
373 #[default]
375 Closed,
376 Open,
378}
379
380#[cfg(feature = "guardrail")]
382#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
383pub struct GuardrailConfig {
384 #[serde(default)]
386 pub enabled: bool,
387 #[serde(default)]
389 pub provider: Option<String>,
390 #[serde(default)]
392 pub model: Option<String>,
393 #[serde(default = "default_guardrail_timeout_ms")]
395 pub timeout_ms: u64,
396 #[serde(default)]
398 pub action: GuardrailAction,
399 #[serde(default = "default_fail_strategy")]
401 pub fail_strategy: GuardrailFailStrategy,
402 #[serde(default)]
404 pub scan_tool_output: bool,
405 #[serde(default = "default_max_input_chars")]
407 pub max_input_chars: usize,
408}
409
410#[cfg(feature = "guardrail")]
411fn default_guardrail_timeout_ms() -> u64 {
412 500
413}
414
415#[cfg(feature = "guardrail")]
416fn default_max_input_chars() -> usize {
417 4096
418}
419
420#[cfg(feature = "guardrail")]
421fn default_fail_strategy() -> GuardrailFailStrategy {
422 GuardrailFailStrategy::Closed
423}
424
425#[cfg(feature = "guardrail")]
426impl Default for GuardrailConfig {
427 fn default() -> Self {
428 Self {
429 enabled: false,
430 provider: None,
431 model: None,
432 timeout_ms: default_guardrail_timeout_ms(),
433 action: GuardrailAction::default(),
434 fail_strategy: default_fail_strategy(),
435 scan_tool_output: false,
436 max_input_chars: default_max_input_chars(),
437 }
438 }
439}
440
441#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
451pub struct ResponseVerificationConfig {
452 #[serde(default = "default_true")]
454 pub enabled: bool,
455 #[serde(default)]
461 pub block_on_detection: bool,
462 #[serde(default)]
469 pub verifier_provider: String,
470}
471
472impl Default for ResponseVerificationConfig {
473 fn default() -> Self {
474 Self {
475 enabled: true,
476 block_on_detection: false,
477 verifier_provider: String::new(),
478 }
479 }
480}
481
482#[cfg(test)]
483mod tests {
484 use super::*;
485
486 #[test]
487 fn content_isolation_default_mcp_to_acp_boundary_true() {
488 let cfg = ContentIsolationConfig::default();
489 assert!(cfg.mcp_to_acp_boundary);
490 }
491
492 #[test]
493 fn content_isolation_deserialize_mcp_to_acp_boundary_false() {
494 let toml = r"
495 mcp_to_acp_boundary = false
496 ";
497 let cfg: ContentIsolationConfig = toml::from_str(toml).unwrap();
498 assert!(!cfg.mcp_to_acp_boundary);
499 }
500
501 #[test]
502 fn content_isolation_deserialize_absent_defaults_true() {
503 let cfg: ContentIsolationConfig = toml::from_str("").unwrap();
504 assert!(cfg.mcp_to_acp_boundary);
505 }
506
507 fn de_guard(toml: &str) -> Result<EmbeddingGuardConfig, toml::de::Error> {
508 toml::from_str(toml)
509 }
510
511 #[test]
512 fn threshold_valid() {
513 let cfg = de_guard("threshold = 0.35\nmin_samples = 5").unwrap();
514 assert!((cfg.threshold - 0.35).abs() < f64::EPSILON);
515 }
516
517 #[test]
518 fn threshold_one_valid() {
519 let cfg = de_guard("threshold = 1.0\nmin_samples = 1").unwrap();
520 assert!((cfg.threshold - 1.0).abs() < f64::EPSILON);
521 }
522
523 #[test]
524 fn threshold_zero_rejected() {
525 assert!(de_guard("threshold = 0.0\nmin_samples = 1").is_err());
526 }
527
528 #[test]
529 fn threshold_above_one_rejected() {
530 assert!(de_guard("threshold = 1.5\nmin_samples = 1").is_err());
531 }
532
533 #[test]
534 fn threshold_negative_rejected() {
535 assert!(de_guard("threshold = -0.1\nmin_samples = 1").is_err());
536 }
537
538 #[test]
539 fn min_samples_zero_rejected() {
540 assert!(de_guard("threshold = 0.35\nmin_samples = 0").is_err());
541 }
542
543 #[test]
544 fn min_samples_one_valid() {
545 let cfg = de_guard("threshold = 0.35\nmin_samples = 1").unwrap();
546 assert_eq!(cfg.min_samples, 1);
547 }
548}
549
550fn default_causal_threshold() -> f32 {
555 0.7
556}
557
558fn validate_causal_threshold<'de, D>(deserializer: D) -> Result<f32, D::Error>
559where
560 D: serde::Deserializer<'de>,
561{
562 let value = <f32 as serde::Deserialize>::deserialize(deserializer)?;
563 if value.is_nan() || value.is_infinite() {
564 return Err(serde::de::Error::custom(
565 "causal_ipi.threshold must be a finite number",
566 ));
567 }
568 if !(value > 0.0 && value <= 1.0) {
569 return Err(serde::de::Error::custom(
570 "causal_ipi.threshold must be in (0.0, 1.0]",
571 ));
572 }
573 Ok(value)
574}
575
576fn default_probe_max_tokens() -> u32 {
577 100
578}
579
580fn default_probe_timeout_ms() -> u64 {
581 3000
582}
583
584#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
592pub struct CausalIpiConfig {
593 #[serde(default)]
595 pub enabled: bool,
596
597 #[serde(
602 default = "default_causal_threshold",
603 deserialize_with = "validate_causal_threshold"
604 )]
605 pub threshold: f32,
606
607 #[serde(default)]
612 pub provider: Option<String>,
613
614 #[serde(default = "default_probe_max_tokens")]
618 pub probe_max_tokens: u32,
619
620 #[serde(default = "default_probe_timeout_ms")]
624 pub probe_timeout_ms: u64,
625}
626
627impl Default for CausalIpiConfig {
628 fn default() -> Self {
629 Self {
630 enabled: false,
631 threshold: default_causal_threshold(),
632 provider: None,
633 probe_max_tokens: default_probe_max_tokens(),
634 probe_timeout_ms: default_probe_timeout_ms(),
635 }
636 }
637}
638
639#[cfg(test)]
640mod causal_ipi_tests {
641 use super::*;
642
643 #[test]
644 fn causal_ipi_defaults() {
645 let cfg = CausalIpiConfig::default();
646 assert!(!cfg.enabled);
647 assert!((cfg.threshold - 0.7).abs() < 1e-6);
648 assert!(cfg.provider.is_none());
649 assert_eq!(cfg.probe_max_tokens, 100);
650 assert_eq!(cfg.probe_timeout_ms, 3000);
651 }
652
653 #[test]
654 fn causal_ipi_deserialize_enabled() {
655 let toml = r#"
656 enabled = true
657 threshold = 0.8
658 provider = "fast"
659 probe_max_tokens = 150
660 probe_timeout_ms = 5000
661 "#;
662 let cfg: CausalIpiConfig = toml::from_str(toml).unwrap();
663 assert!(cfg.enabled);
664 assert!((cfg.threshold - 0.8).abs() < 1e-6);
665 assert_eq!(cfg.provider.as_deref(), Some("fast"));
666 assert_eq!(cfg.probe_max_tokens, 150);
667 assert_eq!(cfg.probe_timeout_ms, 5000);
668 }
669
670 #[test]
671 fn causal_ipi_threshold_zero_rejected() {
672 let result: Result<CausalIpiConfig, _> = toml::from_str("threshold = 0.0");
673 assert!(result.is_err());
674 }
675
676 #[test]
677 fn causal_ipi_threshold_above_one_rejected() {
678 let result: Result<CausalIpiConfig, _> = toml::from_str("threshold = 1.1");
679 assert!(result.is_err());
680 }
681
682 #[test]
683 fn causal_ipi_threshold_exactly_one_accepted() {
684 let cfg: CausalIpiConfig = toml::from_str("threshold = 1.0").unwrap();
685 assert!((cfg.threshold - 1.0).abs() < 1e-6);
686 }
687}