1use serde::{Deserialize, Serialize};
5
6fn default_classifier_timeout_ms() -> u64 {
7 5000
8}
9
10fn default_injection_model() -> String {
11 "protectai/deberta-v3-small-prompt-injection-v2".into()
12}
13
14fn default_injection_threshold() -> f32 {
15 0.8
16}
17
18fn default_injection_threshold_soft() -> f32 {
19 0.5
20}
21
22fn default_enforcement_mode() -> InjectionEnforcementMode {
23 InjectionEnforcementMode::Warn
24}
25
26fn default_pii_model() -> String {
27 "iiiorg/piiranha-v1-detect-personal-information".into()
28}
29
30fn default_pii_threshold() -> f32 {
31 0.75
32}
33
34fn default_pii_ner_max_chars() -> usize {
35 8192
36}
37
38fn default_pii_ner_allowlist() -> Vec<String> {
39 vec![
40 "Zeph".into(),
41 "Rust".into(),
42 "OpenAI".into(),
43 "Ollama".into(),
44 "Claude".into(),
45 ]
46}
47
48fn default_three_class_threshold() -> f32 {
49 0.7
50}
51
52fn validate_unit_threshold<'de, D>(deserializer: D) -> Result<f32, D::Error>
53where
54 D: serde::Deserializer<'de>,
55{
56 let value = <f32 as serde::Deserialize>::deserialize(deserializer)?;
57 if value.is_nan() || value.is_infinite() {
58 return Err(serde::de::Error::custom(
59 "threshold must be a finite number",
60 ));
61 }
62 if !(value > 0.0 && value <= 1.0) {
63 return Err(serde::de::Error::custom("threshold must be in (0.0, 1.0]"));
64 }
65 Ok(value)
66}
67
68#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize, Serialize)]
77#[serde(rename_all = "snake_case")]
78pub enum InjectionEnforcementMode {
79 Warn,
81 Block,
83}
84
85#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
93pub struct ClassifiersConfig {
94 #[serde(default)]
96 pub enabled: bool,
97
98 #[serde(default = "default_classifier_timeout_ms")]
102 pub timeout_ms: u64,
103
104 #[serde(default)]
110 pub hf_token: Option<String>,
111
112 #[serde(default)]
118 pub scan_user_input: bool,
119
120 #[serde(default = "default_injection_model")]
122 pub injection_model: String,
123
124 #[serde(default = "default_enforcement_mode")]
133 pub enforcement_mode: InjectionEnforcementMode,
134
135 #[serde(
140 default = "default_injection_threshold_soft",
141 deserialize_with = "validate_unit_threshold"
142 )]
143 pub injection_threshold_soft: f32,
144
145 #[serde(
152 default = "default_injection_threshold",
153 deserialize_with = "validate_unit_threshold"
154 )]
155 pub injection_threshold: f32,
156
157 #[serde(default)]
162 pub injection_model_sha256: Option<String>,
163
164 #[serde(default)]
171 pub three_class_model: Option<String>,
172
173 #[serde(
178 default = "default_three_class_threshold",
179 deserialize_with = "validate_unit_threshold"
180 )]
181 pub three_class_threshold: f32,
182
183 #[serde(default)]
185 pub three_class_model_sha256: Option<String>,
186
187 #[serde(default)]
192 pub pii_enabled: bool,
193
194 #[serde(default = "default_pii_model")]
196 pub pii_model: String,
197
198 #[serde(default = "default_pii_threshold")]
204 pub pii_threshold: f32,
205
206 #[serde(default)]
208 pub pii_model_sha256: Option<String>,
209
210 #[serde(default = "default_pii_ner_max_chars")]
215 pub pii_ner_max_chars: usize,
216
217 #[serde(default = "default_pii_ner_allowlist")]
227 pub pii_ner_allowlist: Vec<String>,
228}
229
230impl Default for ClassifiersConfig {
231 fn default() -> Self {
232 Self {
233 enabled: false,
234 timeout_ms: default_classifier_timeout_ms(),
235 hf_token: None,
236 scan_user_input: false,
237 injection_model: default_injection_model(),
238 enforcement_mode: default_enforcement_mode(),
239 injection_threshold_soft: default_injection_threshold_soft(),
240 injection_threshold: default_injection_threshold(),
241 injection_model_sha256: None,
242 three_class_model: None,
243 three_class_threshold: default_three_class_threshold(),
244 three_class_model_sha256: None,
245 pii_enabled: false,
246 pii_model: default_pii_model(),
247 pii_threshold: default_pii_threshold(),
248 pii_model_sha256: None,
249 pii_ner_max_chars: default_pii_ner_max_chars(),
250 pii_ner_allowlist: default_pii_ner_allowlist(),
251 }
252 }
253}
254
255#[cfg(test)]
256mod tests {
257 use super::*;
258
259 #[test]
260 fn default_values() {
261 let cfg = ClassifiersConfig::default();
262 assert!(!cfg.enabled);
263 assert_eq!(cfg.timeout_ms, 5000);
264 assert!(cfg.hf_token.is_none());
265 assert!(!cfg.scan_user_input);
266 assert_eq!(
267 cfg.injection_model,
268 "protectai/deberta-v3-small-prompt-injection-v2"
269 );
270 assert_eq!(cfg.enforcement_mode, InjectionEnforcementMode::Warn);
271 assert!((cfg.injection_threshold_soft - 0.5).abs() < 1e-6);
272 assert!((cfg.injection_threshold - 0.8).abs() < 1e-6);
273 assert!(cfg.injection_model_sha256.is_none());
274 assert!(cfg.three_class_model.is_none());
275 assert!((cfg.three_class_threshold - 0.7).abs() < 1e-6);
276 assert!(cfg.three_class_model_sha256.is_none());
277 assert!(!cfg.pii_enabled);
278 assert_eq!(
279 cfg.pii_model,
280 "iiiorg/piiranha-v1-detect-personal-information"
281 );
282 assert!((cfg.pii_threshold - 0.75).abs() < 1e-6);
283 assert!(cfg.pii_model_sha256.is_none());
284 assert_eq!(
285 cfg.pii_ner_allowlist,
286 vec!["Zeph", "Rust", "OpenAI", "Ollama", "Claude"]
287 );
288 }
289
290 #[test]
291 fn hf_token_and_scan_user_input_round_trip() {
292 let toml = r#"
293 hf_token = "hf_secret"
294 scan_user_input = true
295 "#;
296 let cfg: ClassifiersConfig = toml::from_str(toml).unwrap();
297 assert_eq!(cfg.hf_token.as_deref(), Some("hf_secret"));
298 assert!(cfg.scan_user_input);
299 }
300
301 #[test]
302 fn deserialize_empty_section_uses_defaults() {
303 let cfg: ClassifiersConfig = toml::from_str("").unwrap();
304 assert!(!cfg.enabled);
305 assert_eq!(cfg.timeout_ms, 5000);
306 assert_eq!(
307 cfg.injection_model,
308 "protectai/deberta-v3-small-prompt-injection-v2"
309 );
310 assert!((cfg.injection_threshold_soft - 0.5).abs() < 1e-6);
311 assert!((cfg.injection_threshold - 0.8).abs() < 1e-6);
312 assert!(!cfg.pii_enabled);
313 assert!((cfg.pii_threshold - 0.75).abs() < 1e-6);
314 }
315
316 #[test]
317 fn deserialize_custom_values() {
318 let toml = r#"
319 enabled = true
320 timeout_ms = 2000
321 injection_model = "custom/model-v1"
322 injection_threshold = 0.9
323 pii_enabled = true
324 pii_threshold = 0.85
325 "#;
326 let cfg: ClassifiersConfig = toml::from_str(toml).unwrap();
327 assert!(cfg.enabled);
328 assert_eq!(cfg.timeout_ms, 2000);
329 assert_eq!(cfg.injection_model, "custom/model-v1");
330 assert!((cfg.injection_threshold_soft - 0.5).abs() < 1e-6);
331 assert!((cfg.injection_threshold - 0.9).abs() < 1e-6);
332 assert!(cfg.pii_enabled);
333 assert!((cfg.pii_threshold - 0.85).abs() < 1e-6);
334 }
335
336 #[test]
337 fn deserialize_sha256_fields() {
338 let toml = r#"
339 injection_model_sha256 = "abc123"
340 pii_model_sha256 = "def456"
341 "#;
342 let cfg: ClassifiersConfig = toml::from_str(toml).unwrap();
343 assert_eq!(cfg.injection_model_sha256.as_deref(), Some("abc123"));
344 assert_eq!(cfg.pii_model_sha256.as_deref(), Some("def456"));
345 }
346
347 #[test]
348 fn serialize_roundtrip() {
349 let original = ClassifiersConfig {
350 enabled: true,
351 timeout_ms: 3000,
352 hf_token: Some("hf_test_token".into()),
353 scan_user_input: true,
354 injection_model: "org/model".into(),
355 enforcement_mode: InjectionEnforcementMode::Block,
356 injection_threshold_soft: 0.45,
357 injection_threshold: 0.75,
358 injection_model_sha256: Some("deadbeef".into()),
359 three_class_model: Some("org/three-class".into()),
360 three_class_threshold: 0.65,
361 three_class_model_sha256: Some("abc456".into()),
362 pii_enabled: true,
363 pii_model: "org/pii-model".into(),
364 pii_threshold: 0.80,
365 pii_model_sha256: None,
366 pii_ner_max_chars: 4096,
367 pii_ner_allowlist: vec!["MyProject".into(), "Rust".into()],
368 };
369 let serialized = toml::to_string(&original).unwrap();
370 let deserialized: ClassifiersConfig = toml::from_str(&serialized).unwrap();
371 assert_eq!(original, deserialized);
372 }
373
374 #[test]
375 fn dual_threshold_deserialization() {
376 let toml = r"
377 injection_threshold_soft = 0.4
378 injection_threshold = 0.85
379 ";
380 let cfg: ClassifiersConfig = toml::from_str(toml).unwrap();
381 assert!((cfg.injection_threshold_soft - 0.4).abs() < 1e-6);
382 assert!((cfg.injection_threshold - 0.85).abs() < 1e-6);
383 }
384
385 #[test]
386 fn soft_threshold_defaults_when_only_hard_provided() {
387 let toml = "injection_threshold = 0.9";
388 let cfg: ClassifiersConfig = toml::from_str(toml).unwrap();
389 assert!((cfg.injection_threshold_soft - 0.5).abs() < 1e-6);
390 assert!((cfg.injection_threshold - 0.9).abs() < 1e-6);
391 }
392
393 #[test]
394 fn partial_override_timeout_only() {
395 let toml = "timeout_ms = 1000";
396 let cfg: ClassifiersConfig = toml::from_str(toml).unwrap();
397 assert!(!cfg.enabled);
398 assert_eq!(cfg.timeout_ms, 1000);
399 assert_eq!(
400 cfg.injection_model,
401 "protectai/deberta-v3-small-prompt-injection-v2"
402 );
403 assert!((cfg.injection_threshold_soft - 0.5).abs() < 1e-6);
404 assert!((cfg.injection_threshold - 0.8).abs() < 1e-6);
405 }
406
407 #[test]
408 fn enforcement_mode_warn_is_default() {
409 let cfg: ClassifiersConfig = toml::from_str("").unwrap();
410 assert_eq!(cfg.enforcement_mode, InjectionEnforcementMode::Warn);
411 }
412
413 #[test]
414 fn enforcement_mode_block_roundtrip() {
415 let toml = r#"enforcement_mode = "block""#;
416 let cfg: ClassifiersConfig = toml::from_str(toml).unwrap();
417 assert_eq!(cfg.enforcement_mode, InjectionEnforcementMode::Block);
418 let back = toml::to_string(&cfg).unwrap();
419 let cfg2: ClassifiersConfig = toml::from_str(&back).unwrap();
420 assert_eq!(cfg2.enforcement_mode, InjectionEnforcementMode::Block);
421 }
422
423 #[test]
424 fn threshold_validation_rejects_zero() {
425 let result: Result<ClassifiersConfig, _> = toml::from_str("injection_threshold = 0.0");
426 assert!(result.is_err());
427 }
428
429 #[test]
430 fn threshold_validation_rejects_above_one() {
431 let result: Result<ClassifiersConfig, _> = toml::from_str("injection_threshold = 1.1");
432 assert!(result.is_err());
433 }
434
435 #[test]
436 fn threshold_validation_accepts_exactly_one() {
437 let cfg: ClassifiersConfig = toml::from_str("injection_threshold = 1.0").unwrap();
438 assert!((cfg.injection_threshold - 1.0).abs() < 1e-6);
439 }
440
441 #[test]
442 fn threshold_validation_soft_rejects_zero() {
443 let result: Result<ClassifiersConfig, _> = toml::from_str("injection_threshold_soft = 0.0");
444 assert!(result.is_err());
445 }
446
447 #[test]
448 fn three_class_model_roundtrip() {
449 let toml = r#"
450 three_class_model = "org/align-sentinel"
451 three_class_threshold = 0.65
452 three_class_model_sha256 = "aabbcc"
453 "#;
454 let cfg: ClassifiersConfig = toml::from_str(toml).unwrap();
455 assert_eq!(cfg.three_class_model.as_deref(), Some("org/align-sentinel"));
456 assert!((cfg.three_class_threshold - 0.65).abs() < 1e-6);
457 assert_eq!(cfg.three_class_model_sha256.as_deref(), Some("aabbcc"));
458 }
459
460 #[test]
461 fn pii_ner_allowlist_default_entries() {
462 let cfg = ClassifiersConfig::default();
463 assert!(cfg.pii_ner_allowlist.contains(&"Zeph".to_owned()));
464 assert!(cfg.pii_ner_allowlist.contains(&"Rust".to_owned()));
465 assert!(cfg.pii_ner_allowlist.contains(&"OpenAI".to_owned()));
466 assert!(cfg.pii_ner_allowlist.contains(&"Ollama".to_owned()));
467 assert!(cfg.pii_ner_allowlist.contains(&"Claude".to_owned()));
468 }
469
470 #[test]
471 fn pii_ner_allowlist_configurable() {
472 let toml = r#"pii_ner_allowlist = ["MyProject", "AcmeCorp"]"#;
473 let cfg: ClassifiersConfig = toml::from_str(toml).unwrap();
474 assert_eq!(cfg.pii_ner_allowlist, vec!["MyProject", "AcmeCorp"]);
475 }
476
477 #[test]
478 fn pii_ner_allowlist_empty_disables() {
479 let toml = "pii_ner_allowlist = []";
480 let cfg: ClassifiersConfig = toml::from_str(toml).unwrap();
481 assert!(cfg.pii_ner_allowlist.is_empty());
482 }
483
484 #[test]
485 fn three_class_threshold_validation_rejects_zero() {
486 let result: Result<ClassifiersConfig, _> = toml::from_str("three_class_threshold = 0.0");
487 assert!(result.is_err());
488 }
489}