1use serde::{Deserialize, Serialize};
5
6fn default_classifier_timeout_ms() -> u64 {
7 5000
8}
9
10fn default_injection_model() -> String {
11 "protectai/deberta-v3-small-prompt-injection-v2".into()
12}
13
14fn default_injection_threshold() -> f32 {
15 0.95
16}
17
18fn default_injection_threshold_soft() -> f32 {
19 0.5
20}
21
22fn default_enforcement_mode() -> InjectionEnforcementMode {
23 InjectionEnforcementMode::Warn
24}
25
26fn default_pii_model() -> String {
27 "iiiorg/piiranha-v1-detect-personal-information".into()
28}
29
30fn default_pii_threshold() -> f32 {
31 0.75
32}
33
34fn default_pii_ner_max_chars() -> usize {
35 8192
36}
37
38fn default_pii_ner_circuit_breaker() -> u32 {
39 2
40}
41
42fn default_pii_ner_allowlist() -> Vec<String> {
43 vec![
44 "Zeph".into(),
45 "Rust".into(),
46 "OpenAI".into(),
47 "Ollama".into(),
48 "Claude".into(),
49 ]
50}
51
52fn default_three_class_threshold() -> f32 {
53 0.7
54}
55
56fn validate_unit_threshold<'de, D>(deserializer: D) -> Result<f32, D::Error>
57where
58 D: serde::Deserializer<'de>,
59{
60 let value = <f32 as serde::Deserialize>::deserialize(deserializer)?;
61 if value.is_nan() || value.is_infinite() {
62 return Err(serde::de::Error::custom(
63 "threshold must be a finite number",
64 ));
65 }
66 if !(value > 0.0 && value <= 1.0) {
67 return Err(serde::de::Error::custom("threshold must be in (0.0, 1.0]"));
68 }
69 Ok(value)
70}
71
72#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize, Serialize)]
81#[serde(rename_all = "snake_case")]
82#[non_exhaustive]
83pub enum InjectionEnforcementMode {
84 Warn,
86 Block,
88}
89
90#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
98pub struct ClassifiersConfig {
99 #[serde(default)]
101 pub enabled: bool,
102
103 #[serde(default = "default_classifier_timeout_ms")]
107 pub timeout_ms: u64,
108
109 #[serde(default)]
115 pub hf_token: Option<String>,
116
117 #[serde(default)]
123 pub scan_user_input: bool,
124
125 #[serde(default = "default_injection_model")]
127 pub injection_model: String,
128
129 #[serde(default = "default_enforcement_mode")]
138 pub enforcement_mode: InjectionEnforcementMode,
139
140 #[serde(
145 default = "default_injection_threshold_soft",
146 deserialize_with = "validate_unit_threshold"
147 )]
148 pub injection_threshold_soft: f32,
149
150 #[serde(
157 default = "default_injection_threshold",
158 deserialize_with = "validate_unit_threshold"
159 )]
160 pub injection_threshold: f32,
161
162 #[serde(default)]
167 pub injection_model_sha256: Option<String>,
168
169 #[serde(default)]
176 pub three_class_model: Option<String>,
177
178 #[serde(
183 default = "default_three_class_threshold",
184 deserialize_with = "validate_unit_threshold"
185 )]
186 pub three_class_threshold: f32,
187
188 #[serde(default)]
190 pub three_class_model_sha256: Option<String>,
191
192 #[serde(default)]
197 pub pii_enabled: bool,
198
199 #[serde(default = "default_pii_model")]
201 pub pii_model: String,
202
203 #[serde(default = "default_pii_threshold")]
209 pub pii_threshold: f32,
210
211 #[serde(default)]
213 pub pii_model_sha256: Option<String>,
214
215 #[serde(default = "default_pii_ner_max_chars")]
220 pub pii_ner_max_chars: usize,
221
222 #[serde(default = "default_pii_ner_allowlist")]
232 pub pii_ner_allowlist: Vec<String>,
233
234 #[serde(default = "default_pii_ner_circuit_breaker")]
243 pub pii_ner_circuit_breaker: u32,
244}
245
246impl Default for ClassifiersConfig {
247 fn default() -> Self {
248 Self {
249 enabled: false,
250 timeout_ms: default_classifier_timeout_ms(),
251 hf_token: None,
252 scan_user_input: false,
253 injection_model: default_injection_model(),
254 enforcement_mode: default_enforcement_mode(),
255 injection_threshold_soft: default_injection_threshold_soft(),
256 injection_threshold: default_injection_threshold(),
257 injection_model_sha256: None,
258 three_class_model: None,
259 three_class_threshold: default_three_class_threshold(),
260 three_class_model_sha256: None,
261 pii_enabled: false,
262 pii_model: default_pii_model(),
263 pii_threshold: default_pii_threshold(),
264 pii_model_sha256: None,
265 pii_ner_max_chars: default_pii_ner_max_chars(),
266 pii_ner_allowlist: default_pii_ner_allowlist(),
267 pii_ner_circuit_breaker: default_pii_ner_circuit_breaker(),
268 }
269 }
270}
271
272#[cfg(test)]
273mod tests {
274 use super::*;
275
276 #[test]
277 fn default_values() {
278 let cfg = ClassifiersConfig::default();
279 assert!(!cfg.enabled);
280 assert_eq!(cfg.timeout_ms, 5000);
281 assert!(cfg.hf_token.is_none());
282 assert!(!cfg.scan_user_input);
283 assert_eq!(
284 cfg.injection_model,
285 "protectai/deberta-v3-small-prompt-injection-v2"
286 );
287 assert_eq!(cfg.enforcement_mode, InjectionEnforcementMode::Warn);
288 assert!((cfg.injection_threshold_soft - 0.5).abs() < 1e-6);
289 assert!((cfg.injection_threshold - 0.95).abs() < 1e-6);
290 assert!(cfg.injection_model_sha256.is_none());
291 assert!(cfg.three_class_model.is_none());
292 assert!((cfg.three_class_threshold - 0.7).abs() < 1e-6);
293 assert!(cfg.three_class_model_sha256.is_none());
294 assert!(!cfg.pii_enabled);
295 assert_eq!(
296 cfg.pii_model,
297 "iiiorg/piiranha-v1-detect-personal-information"
298 );
299 assert!((cfg.pii_threshold - 0.75).abs() < 1e-6);
300 assert!(cfg.pii_model_sha256.is_none());
301 assert_eq!(
302 cfg.pii_ner_allowlist,
303 vec!["Zeph", "Rust", "OpenAI", "Ollama", "Claude"]
304 );
305 }
306
307 #[test]
308 fn hf_token_and_scan_user_input_round_trip() {
309 let toml = r#"
310 hf_token = "hf_secret"
311 scan_user_input = true
312 "#;
313 let cfg: ClassifiersConfig = toml::from_str(toml).unwrap();
314 assert_eq!(cfg.hf_token.as_deref(), Some("hf_secret"));
315 assert!(cfg.scan_user_input);
316 }
317
318 #[test]
319 fn deserialize_empty_section_uses_defaults() {
320 let cfg: ClassifiersConfig = toml::from_str("").unwrap();
321 assert!(!cfg.enabled);
322 assert_eq!(cfg.timeout_ms, 5000);
323 assert_eq!(
324 cfg.injection_model,
325 "protectai/deberta-v3-small-prompt-injection-v2"
326 );
327 assert!((cfg.injection_threshold_soft - 0.5).abs() < 1e-6);
328 assert!((cfg.injection_threshold - 0.95).abs() < 1e-6);
329 assert!(!cfg.pii_enabled);
330 assert!((cfg.pii_threshold - 0.75).abs() < 1e-6);
331 }
332
333 #[test]
334 fn deserialize_custom_values() {
335 let toml = r#"
336 enabled = true
337 timeout_ms = 2000
338 injection_model = "custom/model-v1"
339 injection_threshold = 0.9
340 pii_enabled = true
341 pii_threshold = 0.85
342 "#;
343 let cfg: ClassifiersConfig = toml::from_str(toml).unwrap();
344 assert!(cfg.enabled);
345 assert_eq!(cfg.timeout_ms, 2000);
346 assert_eq!(cfg.injection_model, "custom/model-v1");
347 assert!((cfg.injection_threshold_soft - 0.5).abs() < 1e-6);
348 assert!((cfg.injection_threshold - 0.9).abs() < 1e-6);
349 assert!(cfg.pii_enabled);
350 assert!((cfg.pii_threshold - 0.85).abs() < 1e-6);
351 }
352
353 #[test]
354 fn deserialize_sha256_fields() {
355 let toml = r#"
356 injection_model_sha256 = "abc123"
357 pii_model_sha256 = "def456"
358 "#;
359 let cfg: ClassifiersConfig = toml::from_str(toml).unwrap();
360 assert_eq!(cfg.injection_model_sha256.as_deref(), Some("abc123"));
361 assert_eq!(cfg.pii_model_sha256.as_deref(), Some("def456"));
362 }
363
364 #[test]
365 fn serialize_roundtrip() {
366 let original = ClassifiersConfig {
367 enabled: true,
368 timeout_ms: 3000,
369 hf_token: Some("hf_test_token".into()),
370 scan_user_input: true,
371 injection_model: "org/model".into(),
372 enforcement_mode: InjectionEnforcementMode::Block,
373 injection_threshold_soft: 0.45,
374 injection_threshold: 0.75,
375 injection_model_sha256: Some("deadbeef".into()),
376 three_class_model: Some("org/three-class".into()),
377 three_class_threshold: 0.65,
378 three_class_model_sha256: Some("abc456".into()),
379 pii_enabled: true,
380 pii_model: "org/pii-model".into(),
381 pii_threshold: 0.80,
382 pii_model_sha256: None,
383 pii_ner_max_chars: 4096,
384 pii_ner_allowlist: vec!["MyProject".into(), "Rust".into()],
385 pii_ner_circuit_breaker: 3,
386 };
387 let serialized = toml::to_string(&original).unwrap();
388 let deserialized: ClassifiersConfig = toml::from_str(&serialized).unwrap();
389 assert_eq!(original, deserialized);
390 }
391
392 #[test]
393 fn dual_threshold_deserialization() {
394 let toml = r"
395 injection_threshold_soft = 0.4
396 injection_threshold = 0.85
397 ";
398 let cfg: ClassifiersConfig = toml::from_str(toml).unwrap();
399 assert!((cfg.injection_threshold_soft - 0.4).abs() < 1e-6);
400 assert!((cfg.injection_threshold - 0.85).abs() < 1e-6);
401 }
402
403 #[test]
404 fn soft_threshold_defaults_when_only_hard_provided() {
405 let toml = "injection_threshold = 0.9";
406 let cfg: ClassifiersConfig = toml::from_str(toml).unwrap();
407 assert!((cfg.injection_threshold_soft - 0.5).abs() < 1e-6);
408 assert!((cfg.injection_threshold - 0.9).abs() < 1e-6);
409 }
410
411 #[test]
412 fn partial_override_timeout_only() {
413 let toml = "timeout_ms = 1000";
414 let cfg: ClassifiersConfig = toml::from_str(toml).unwrap();
415 assert!(!cfg.enabled);
416 assert_eq!(cfg.timeout_ms, 1000);
417 assert_eq!(
418 cfg.injection_model,
419 "protectai/deberta-v3-small-prompt-injection-v2"
420 );
421 assert!((cfg.injection_threshold_soft - 0.5).abs() < 1e-6);
422 assert!((cfg.injection_threshold - 0.95).abs() < 1e-6);
423 }
424
425 #[test]
426 fn enforcement_mode_warn_is_default() {
427 let cfg: ClassifiersConfig = toml::from_str("").unwrap();
428 assert_eq!(cfg.enforcement_mode, InjectionEnforcementMode::Warn);
429 }
430
431 #[test]
432 fn enforcement_mode_block_roundtrip() {
433 let toml = r#"enforcement_mode = "block""#;
434 let cfg: ClassifiersConfig = toml::from_str(toml).unwrap();
435 assert_eq!(cfg.enforcement_mode, InjectionEnforcementMode::Block);
436 let back = toml::to_string(&cfg).unwrap();
437 let cfg2: ClassifiersConfig = toml::from_str(&back).unwrap();
438 assert_eq!(cfg2.enforcement_mode, InjectionEnforcementMode::Block);
439 }
440
441 #[test]
442 fn threshold_validation_rejects_zero() {
443 let result: Result<ClassifiersConfig, _> = toml::from_str("injection_threshold = 0.0");
444 assert!(result.is_err());
445 }
446
447 #[test]
448 fn threshold_validation_rejects_above_one() {
449 let result: Result<ClassifiersConfig, _> = toml::from_str("injection_threshold = 1.1");
450 assert!(result.is_err());
451 }
452
453 #[test]
454 fn threshold_validation_accepts_exactly_one() {
455 let cfg: ClassifiersConfig = toml::from_str("injection_threshold = 1.0").unwrap();
456 assert!((cfg.injection_threshold - 1.0).abs() < 1e-6);
457 }
458
459 #[test]
460 fn threshold_validation_soft_rejects_zero() {
461 let result: Result<ClassifiersConfig, _> = toml::from_str("injection_threshold_soft = 0.0");
462 assert!(result.is_err());
463 }
464
465 #[test]
466 fn three_class_model_roundtrip() {
467 let toml = r#"
468 three_class_model = "org/align-sentinel"
469 three_class_threshold = 0.65
470 three_class_model_sha256 = "aabbcc"
471 "#;
472 let cfg: ClassifiersConfig = toml::from_str(toml).unwrap();
473 assert_eq!(cfg.three_class_model.as_deref(), Some("org/align-sentinel"));
474 assert!((cfg.three_class_threshold - 0.65).abs() < 1e-6);
475 assert_eq!(cfg.three_class_model_sha256.as_deref(), Some("aabbcc"));
476 }
477
478 #[test]
479 fn pii_ner_allowlist_default_entries() {
480 let cfg = ClassifiersConfig::default();
481 assert!(cfg.pii_ner_allowlist.contains(&"Zeph".to_owned()));
482 assert!(cfg.pii_ner_allowlist.contains(&"Rust".to_owned()));
483 assert!(cfg.pii_ner_allowlist.contains(&"OpenAI".to_owned()));
484 assert!(cfg.pii_ner_allowlist.contains(&"Ollama".to_owned()));
485 assert!(cfg.pii_ner_allowlist.contains(&"Claude".to_owned()));
486 }
487
488 #[test]
489 fn pii_ner_allowlist_configurable() {
490 let toml = r#"pii_ner_allowlist = ["MyProject", "AcmeCorp"]"#;
491 let cfg: ClassifiersConfig = toml::from_str(toml).unwrap();
492 assert_eq!(cfg.pii_ner_allowlist, vec!["MyProject", "AcmeCorp"]);
493 }
494
495 #[test]
496 fn pii_ner_allowlist_empty_disables() {
497 let toml = "pii_ner_allowlist = []";
498 let cfg: ClassifiersConfig = toml::from_str(toml).unwrap();
499 assert!(cfg.pii_ner_allowlist.is_empty());
500 }
501
502 #[test]
503 fn three_class_threshold_validation_rejects_zero() {
504 let result: Result<ClassifiersConfig, _> = toml::from_str("three_class_threshold = 0.0");
505 assert!(result.is_err());
506 }
507
508 #[test]
509 fn pii_ner_circuit_breaker_default() {
510 let cfg = ClassifiersConfig::default();
511 assert_eq!(cfg.pii_ner_circuit_breaker, 2);
512 }
513
514 #[test]
515 fn pii_ner_circuit_breaker_configurable() {
516 let cfg: ClassifiersConfig = toml::from_str("pii_ner_circuit_breaker = 5").unwrap();
517 assert_eq!(cfg.pii_ner_circuit_breaker, 5);
518 }
519
520 #[test]
521 fn pii_ner_circuit_breaker_zero_disables() {
522 let cfg: ClassifiersConfig = toml::from_str("pii_ner_circuit_breaker = 0").unwrap();
523 assert_eq!(cfg.pii_ner_circuit_breaker, 0);
524 }
525
526 #[test]
527 fn pii_ner_circuit_breaker_missing_uses_default() {
528 let cfg: ClassifiersConfig = toml::from_str("").unwrap();
529 assert_eq!(cfg.pii_ner_circuit_breaker, 2);
530 }
531}