1use serde::{Deserialize, Serialize};
5
6fn default_classifier_timeout_ms() -> u64 {
7 5000
8}
9
10fn default_injection_model() -> String {
11 "protectai/deberta-v3-small-prompt-injection-v2".into()
12}
13
14fn default_injection_threshold() -> f32 {
15 0.95
16}
17
18fn default_injection_threshold_soft() -> f32 {
19 0.5
20}
21
22fn default_enforcement_mode() -> InjectionEnforcementMode {
23 InjectionEnforcementMode::Warn
24}
25
26fn default_pii_model() -> String {
27 "iiiorg/piiranha-v1-detect-personal-information".into()
28}
29
30fn default_pii_threshold() -> f32 {
31 0.75
32}
33
34fn default_pii_ner_max_chars() -> usize {
35 8192
36}
37
38fn default_pii_ner_circuit_breaker() -> u32 {
39 2
40}
41
42fn default_pii_ner_allowlist() -> Vec<String> {
43 vec![
44 "Zeph".into(),
45 "Rust".into(),
46 "OpenAI".into(),
47 "Ollama".into(),
48 "Claude".into(),
49 ]
50}
51
52fn default_three_class_threshold() -> f32 {
53 0.7
54}
55
56fn validate_unit_threshold<'de, D>(deserializer: D) -> Result<f32, D::Error>
57where
58 D: serde::Deserializer<'de>,
59{
60 let value = <f32 as serde::Deserialize>::deserialize(deserializer)?;
61 if value.is_nan() || value.is_infinite() {
62 return Err(serde::de::Error::custom(
63 "threshold must be a finite number",
64 ));
65 }
66 if !(value > 0.0 && value <= 1.0) {
67 return Err(serde::de::Error::custom("threshold must be in (0.0, 1.0]"));
68 }
69 Ok(value)
70}
71
72#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize, Serialize)]
81#[serde(rename_all = "snake_case")]
82pub enum InjectionEnforcementMode {
83 Warn,
85 Block,
87}
88
89#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
97pub struct ClassifiersConfig {
98 #[serde(default)]
100 pub enabled: bool,
101
102 #[serde(default = "default_classifier_timeout_ms")]
106 pub timeout_ms: u64,
107
108 #[serde(default)]
114 pub hf_token: Option<String>,
115
116 #[serde(default)]
122 pub scan_user_input: bool,
123
124 #[serde(default = "default_injection_model")]
126 pub injection_model: String,
127
128 #[serde(default = "default_enforcement_mode")]
137 pub enforcement_mode: InjectionEnforcementMode,
138
139 #[serde(
144 default = "default_injection_threshold_soft",
145 deserialize_with = "validate_unit_threshold"
146 )]
147 pub injection_threshold_soft: f32,
148
149 #[serde(
156 default = "default_injection_threshold",
157 deserialize_with = "validate_unit_threshold"
158 )]
159 pub injection_threshold: f32,
160
161 #[serde(default)]
166 pub injection_model_sha256: Option<String>,
167
168 #[serde(default)]
175 pub three_class_model: Option<String>,
176
177 #[serde(
182 default = "default_three_class_threshold",
183 deserialize_with = "validate_unit_threshold"
184 )]
185 pub three_class_threshold: f32,
186
187 #[serde(default)]
189 pub three_class_model_sha256: Option<String>,
190
191 #[serde(default)]
196 pub pii_enabled: bool,
197
198 #[serde(default = "default_pii_model")]
200 pub pii_model: String,
201
202 #[serde(default = "default_pii_threshold")]
208 pub pii_threshold: f32,
209
210 #[serde(default)]
212 pub pii_model_sha256: Option<String>,
213
214 #[serde(default = "default_pii_ner_max_chars")]
219 pub pii_ner_max_chars: usize,
220
221 #[serde(default = "default_pii_ner_allowlist")]
231 pub pii_ner_allowlist: Vec<String>,
232
233 #[serde(default = "default_pii_ner_circuit_breaker")]
242 pub pii_ner_circuit_breaker: u32,
243}
244
245impl Default for ClassifiersConfig {
246 fn default() -> Self {
247 Self {
248 enabled: false,
249 timeout_ms: default_classifier_timeout_ms(),
250 hf_token: None,
251 scan_user_input: false,
252 injection_model: default_injection_model(),
253 enforcement_mode: default_enforcement_mode(),
254 injection_threshold_soft: default_injection_threshold_soft(),
255 injection_threshold: default_injection_threshold(),
256 injection_model_sha256: None,
257 three_class_model: None,
258 three_class_threshold: default_three_class_threshold(),
259 three_class_model_sha256: None,
260 pii_enabled: false,
261 pii_model: default_pii_model(),
262 pii_threshold: default_pii_threshold(),
263 pii_model_sha256: None,
264 pii_ner_max_chars: default_pii_ner_max_chars(),
265 pii_ner_allowlist: default_pii_ner_allowlist(),
266 pii_ner_circuit_breaker: default_pii_ner_circuit_breaker(),
267 }
268 }
269}
270
271#[cfg(test)]
272mod tests {
273 use super::*;
274
275 #[test]
276 fn default_values() {
277 let cfg = ClassifiersConfig::default();
278 assert!(!cfg.enabled);
279 assert_eq!(cfg.timeout_ms, 5000);
280 assert!(cfg.hf_token.is_none());
281 assert!(!cfg.scan_user_input);
282 assert_eq!(
283 cfg.injection_model,
284 "protectai/deberta-v3-small-prompt-injection-v2"
285 );
286 assert_eq!(cfg.enforcement_mode, InjectionEnforcementMode::Warn);
287 assert!((cfg.injection_threshold_soft - 0.5).abs() < 1e-6);
288 assert!((cfg.injection_threshold - 0.95).abs() < 1e-6);
289 assert!(cfg.injection_model_sha256.is_none());
290 assert!(cfg.three_class_model.is_none());
291 assert!((cfg.three_class_threshold - 0.7).abs() < 1e-6);
292 assert!(cfg.three_class_model_sha256.is_none());
293 assert!(!cfg.pii_enabled);
294 assert_eq!(
295 cfg.pii_model,
296 "iiiorg/piiranha-v1-detect-personal-information"
297 );
298 assert!((cfg.pii_threshold - 0.75).abs() < 1e-6);
299 assert!(cfg.pii_model_sha256.is_none());
300 assert_eq!(
301 cfg.pii_ner_allowlist,
302 vec!["Zeph", "Rust", "OpenAI", "Ollama", "Claude"]
303 );
304 }
305
306 #[test]
307 fn hf_token_and_scan_user_input_round_trip() {
308 let toml = r#"
309 hf_token = "hf_secret"
310 scan_user_input = true
311 "#;
312 let cfg: ClassifiersConfig = toml::from_str(toml).unwrap();
313 assert_eq!(cfg.hf_token.as_deref(), Some("hf_secret"));
314 assert!(cfg.scan_user_input);
315 }
316
317 #[test]
318 fn deserialize_empty_section_uses_defaults() {
319 let cfg: ClassifiersConfig = toml::from_str("").unwrap();
320 assert!(!cfg.enabled);
321 assert_eq!(cfg.timeout_ms, 5000);
322 assert_eq!(
323 cfg.injection_model,
324 "protectai/deberta-v3-small-prompt-injection-v2"
325 );
326 assert!((cfg.injection_threshold_soft - 0.5).abs() < 1e-6);
327 assert!((cfg.injection_threshold - 0.95).abs() < 1e-6);
328 assert!(!cfg.pii_enabled);
329 assert!((cfg.pii_threshold - 0.75).abs() < 1e-6);
330 }
331
332 #[test]
333 fn deserialize_custom_values() {
334 let toml = r#"
335 enabled = true
336 timeout_ms = 2000
337 injection_model = "custom/model-v1"
338 injection_threshold = 0.9
339 pii_enabled = true
340 pii_threshold = 0.85
341 "#;
342 let cfg: ClassifiersConfig = toml::from_str(toml).unwrap();
343 assert!(cfg.enabled);
344 assert_eq!(cfg.timeout_ms, 2000);
345 assert_eq!(cfg.injection_model, "custom/model-v1");
346 assert!((cfg.injection_threshold_soft - 0.5).abs() < 1e-6);
347 assert!((cfg.injection_threshold - 0.9).abs() < 1e-6);
348 assert!(cfg.pii_enabled);
349 assert!((cfg.pii_threshold - 0.85).abs() < 1e-6);
350 }
351
352 #[test]
353 fn deserialize_sha256_fields() {
354 let toml = r#"
355 injection_model_sha256 = "abc123"
356 pii_model_sha256 = "def456"
357 "#;
358 let cfg: ClassifiersConfig = toml::from_str(toml).unwrap();
359 assert_eq!(cfg.injection_model_sha256.as_deref(), Some("abc123"));
360 assert_eq!(cfg.pii_model_sha256.as_deref(), Some("def456"));
361 }
362
363 #[test]
364 fn serialize_roundtrip() {
365 let original = ClassifiersConfig {
366 enabled: true,
367 timeout_ms: 3000,
368 hf_token: Some("hf_test_token".into()),
369 scan_user_input: true,
370 injection_model: "org/model".into(),
371 enforcement_mode: InjectionEnforcementMode::Block,
372 injection_threshold_soft: 0.45,
373 injection_threshold: 0.75,
374 injection_model_sha256: Some("deadbeef".into()),
375 three_class_model: Some("org/three-class".into()),
376 three_class_threshold: 0.65,
377 three_class_model_sha256: Some("abc456".into()),
378 pii_enabled: true,
379 pii_model: "org/pii-model".into(),
380 pii_threshold: 0.80,
381 pii_model_sha256: None,
382 pii_ner_max_chars: 4096,
383 pii_ner_allowlist: vec!["MyProject".into(), "Rust".into()],
384 pii_ner_circuit_breaker: 3,
385 };
386 let serialized = toml::to_string(&original).unwrap();
387 let deserialized: ClassifiersConfig = toml::from_str(&serialized).unwrap();
388 assert_eq!(original, deserialized);
389 }
390
391 #[test]
392 fn dual_threshold_deserialization() {
393 let toml = r"
394 injection_threshold_soft = 0.4
395 injection_threshold = 0.85
396 ";
397 let cfg: ClassifiersConfig = toml::from_str(toml).unwrap();
398 assert!((cfg.injection_threshold_soft - 0.4).abs() < 1e-6);
399 assert!((cfg.injection_threshold - 0.85).abs() < 1e-6);
400 }
401
402 #[test]
403 fn soft_threshold_defaults_when_only_hard_provided() {
404 let toml = "injection_threshold = 0.9";
405 let cfg: ClassifiersConfig = toml::from_str(toml).unwrap();
406 assert!((cfg.injection_threshold_soft - 0.5).abs() < 1e-6);
407 assert!((cfg.injection_threshold - 0.9).abs() < 1e-6);
408 }
409
410 #[test]
411 fn partial_override_timeout_only() {
412 let toml = "timeout_ms = 1000";
413 let cfg: ClassifiersConfig = toml::from_str(toml).unwrap();
414 assert!(!cfg.enabled);
415 assert_eq!(cfg.timeout_ms, 1000);
416 assert_eq!(
417 cfg.injection_model,
418 "protectai/deberta-v3-small-prompt-injection-v2"
419 );
420 assert!((cfg.injection_threshold_soft - 0.5).abs() < 1e-6);
421 assert!((cfg.injection_threshold - 0.95).abs() < 1e-6);
422 }
423
424 #[test]
425 fn enforcement_mode_warn_is_default() {
426 let cfg: ClassifiersConfig = toml::from_str("").unwrap();
427 assert_eq!(cfg.enforcement_mode, InjectionEnforcementMode::Warn);
428 }
429
430 #[test]
431 fn enforcement_mode_block_roundtrip() {
432 let toml = r#"enforcement_mode = "block""#;
433 let cfg: ClassifiersConfig = toml::from_str(toml).unwrap();
434 assert_eq!(cfg.enforcement_mode, InjectionEnforcementMode::Block);
435 let back = toml::to_string(&cfg).unwrap();
436 let cfg2: ClassifiersConfig = toml::from_str(&back).unwrap();
437 assert_eq!(cfg2.enforcement_mode, InjectionEnforcementMode::Block);
438 }
439
440 #[test]
441 fn threshold_validation_rejects_zero() {
442 let result: Result<ClassifiersConfig, _> = toml::from_str("injection_threshold = 0.0");
443 assert!(result.is_err());
444 }
445
446 #[test]
447 fn threshold_validation_rejects_above_one() {
448 let result: Result<ClassifiersConfig, _> = toml::from_str("injection_threshold = 1.1");
449 assert!(result.is_err());
450 }
451
452 #[test]
453 fn threshold_validation_accepts_exactly_one() {
454 let cfg: ClassifiersConfig = toml::from_str("injection_threshold = 1.0").unwrap();
455 assert!((cfg.injection_threshold - 1.0).abs() < 1e-6);
456 }
457
458 #[test]
459 fn threshold_validation_soft_rejects_zero() {
460 let result: Result<ClassifiersConfig, _> = toml::from_str("injection_threshold_soft = 0.0");
461 assert!(result.is_err());
462 }
463
464 #[test]
465 fn three_class_model_roundtrip() {
466 let toml = r#"
467 three_class_model = "org/align-sentinel"
468 three_class_threshold = 0.65
469 three_class_model_sha256 = "aabbcc"
470 "#;
471 let cfg: ClassifiersConfig = toml::from_str(toml).unwrap();
472 assert_eq!(cfg.three_class_model.as_deref(), Some("org/align-sentinel"));
473 assert!((cfg.three_class_threshold - 0.65).abs() < 1e-6);
474 assert_eq!(cfg.three_class_model_sha256.as_deref(), Some("aabbcc"));
475 }
476
477 #[test]
478 fn pii_ner_allowlist_default_entries() {
479 let cfg = ClassifiersConfig::default();
480 assert!(cfg.pii_ner_allowlist.contains(&"Zeph".to_owned()));
481 assert!(cfg.pii_ner_allowlist.contains(&"Rust".to_owned()));
482 assert!(cfg.pii_ner_allowlist.contains(&"OpenAI".to_owned()));
483 assert!(cfg.pii_ner_allowlist.contains(&"Ollama".to_owned()));
484 assert!(cfg.pii_ner_allowlist.contains(&"Claude".to_owned()));
485 }
486
487 #[test]
488 fn pii_ner_allowlist_configurable() {
489 let toml = r#"pii_ner_allowlist = ["MyProject", "AcmeCorp"]"#;
490 let cfg: ClassifiersConfig = toml::from_str(toml).unwrap();
491 assert_eq!(cfg.pii_ner_allowlist, vec!["MyProject", "AcmeCorp"]);
492 }
493
494 #[test]
495 fn pii_ner_allowlist_empty_disables() {
496 let toml = "pii_ner_allowlist = []";
497 let cfg: ClassifiersConfig = toml::from_str(toml).unwrap();
498 assert!(cfg.pii_ner_allowlist.is_empty());
499 }
500
501 #[test]
502 fn three_class_threshold_validation_rejects_zero() {
503 let result: Result<ClassifiersConfig, _> = toml::from_str("three_class_threshold = 0.0");
504 assert!(result.is_err());
505 }
506
507 #[test]
508 fn pii_ner_circuit_breaker_default() {
509 let cfg = ClassifiersConfig::default();
510 assert_eq!(cfg.pii_ner_circuit_breaker, 2);
511 }
512
513 #[test]
514 fn pii_ner_circuit_breaker_configurable() {
515 let cfg: ClassifiersConfig = toml::from_str("pii_ner_circuit_breaker = 5").unwrap();
516 assert_eq!(cfg.pii_ner_circuit_breaker, 5);
517 }
518
519 #[test]
520 fn pii_ner_circuit_breaker_zero_disables() {
521 let cfg: ClassifiersConfig = toml::from_str("pii_ner_circuit_breaker = 0").unwrap();
522 assert_eq!(cfg.pii_ner_circuit_breaker, 0);
523 }
524
525 #[test]
526 fn pii_ner_circuit_breaker_missing_uses_default() {
527 let cfg: ClassifiersConfig = toml::from_str("").unwrap();
528 assert_eq!(cfg.pii_ner_circuit_breaker, 2);
529 }
530}