Skip to main content

zeph_config/
classifiers.rs

1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4use serde::{Deserialize, Serialize};
5
6fn default_classifier_timeout_ms() -> u64 {
7    5000
8}
9
10fn default_injection_model() -> String {
11    "protectai/deberta-v3-small-prompt-injection-v2".into()
12}
13
14fn default_injection_threshold() -> f32 {
15    0.8
16}
17
18fn default_injection_threshold_soft() -> f32 {
19    0.5
20}
21
22fn default_enforcement_mode() -> InjectionEnforcementMode {
23    InjectionEnforcementMode::Warn
24}
25
26fn default_pii_model() -> String {
27    "iiiorg/piiranha-v1-detect-personal-information".into()
28}
29
30fn default_pii_threshold() -> f32 {
31    0.75
32}
33
34fn default_pii_ner_max_chars() -> usize {
35    8192
36}
37
38fn default_pii_ner_allowlist() -> Vec<String> {
39    vec![
40        "Zeph".into(),
41        "Rust".into(),
42        "OpenAI".into(),
43        "Ollama".into(),
44        "Claude".into(),
45    ]
46}
47
48fn default_three_class_threshold() -> f32 {
49    0.7
50}
51
52fn validate_unit_threshold<'de, D>(deserializer: D) -> Result<f32, D::Error>
53where
54    D: serde::Deserializer<'de>,
55{
56    let value = <f32 as serde::Deserialize>::deserialize(deserializer)?;
57    if value.is_nan() || value.is_infinite() {
58        return Err(serde::de::Error::custom(
59            "threshold must be a finite number",
60        ));
61    }
62    if !(value > 0.0 && value <= 1.0) {
63        return Err(serde::de::Error::custom("threshold must be in (0.0, 1.0]"));
64    }
65    Ok(value)
66}
67
68/// Enforcement mode for the injection classifier.
69///
70/// `warn` (default): scores above `injection_threshold` emit WARN and increment metrics
71/// but do NOT block content. Use this when deploying `DeBERTa` classifiers on tool outputs —
72/// FPR of 12-37% on benign content makes hard-blocking unsafe.
73///
74/// `block`: scores above `injection_threshold` block content (behavior before v0.17).
75/// Only safe for well-calibrated models or when FPR is verified on your workload.
76#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize, Serialize)]
77#[serde(rename_all = "snake_case")]
78pub enum InjectionEnforcementMode {
79    /// Log + metric only, never block.
80    Warn,
81    /// Block content above hard threshold.
82    Block,
83}
84
85/// Configuration for the ML-backed classifier subsystem.
86///
87/// Placed under `[classifiers]` in `config.toml`. All fields are optional with safe defaults
88/// so existing configs continue to work when this section is absent.
89///
90/// When `enabled = false` (the default), all classifier code is bypassed and the existing
91/// regex-based detection runs unchanged.
92#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
93pub struct ClassifiersConfig {
94    /// Master switch. When `false`, classifiers are never loaded or invoked.
95    #[serde(default)]
96    pub enabled: bool,
97
98    /// Per-inference timeout in milliseconds.
99    ///
100    /// On timeout the call site falls back to regex. Separate from model download time.
101    #[serde(default = "default_classifier_timeout_ms")]
102    pub timeout_ms: u64,
103
104    /// Resolved `HuggingFace` Hub API token.
105    ///
106    /// Must be the **token value** (not a vault key name) — resolved by the caller before
107    /// constructing `ClassifiersConfig`. When `None`, model downloads are unauthenticated,
108    /// which fails for gated or private repos.
109    #[serde(default)]
110    pub hf_token: Option<String>,
111
112    /// When `true`, the ML injection classifier runs on direct user chat messages.
113    ///
114    /// Default `false`: the `DeBERTa` model is intended for external/untrusted content
115    /// (tool output, web scrapes) — not for direct user input. Enabling this may cause
116    /// false positives on benign conversational messages.
117    #[serde(default)]
118    pub scan_user_input: bool,
119
120    /// `HuggingFace` repo ID for the injection detection model.
121    #[serde(default = "default_injection_model")]
122    pub injection_model: String,
123
124    /// Enforcement mode for the injection classifier.
125    ///
126    /// `warn` (default): scores above `injection_threshold` emit WARN and increment metrics
127    /// but do NOT block content. Use this when deploying classifiers on tool outputs —
128    /// FPR of 12-37% on benign content makes hard-blocking unsafe.
129    ///
130    /// `block`: scores above `injection_threshold` block content. Only safe for well-calibrated
131    /// models or when FPR is verified on your workload.
132    #[serde(default = "default_enforcement_mode")]
133    pub enforcement_mode: InjectionEnforcementMode,
134
135    /// Soft threshold: classifier score at or above this emits a WARN log and increments
136    /// the suspicious-injection metric, but content is allowed through.
137    ///
138    /// Range: `(0.0, 1.0]`. Default `0.5`. Must be ≤ `injection_threshold`.
139    #[serde(
140        default = "default_injection_threshold_soft",
141        deserialize_with = "validate_unit_threshold"
142    )]
143    pub injection_threshold_soft: f32,
144
145    /// Hard threshold: classifier score at or above this blocks the content (in `block` mode)
146    /// or emits WARN (in `warn` mode).
147    ///
148    /// Range: `(0.0, 1.0]`. Conservative default of `0.8` minimises false positives.
149    /// Real-world ML injection classifiers have 12–37% recall gaps at high thresholds —
150    /// defense-in-depth via regex fallback and spotlighting is mandatory.
151    #[serde(
152        default = "default_injection_threshold",
153        deserialize_with = "validate_unit_threshold"
154    )]
155    pub injection_threshold: f32,
156
157    /// Optional SHA-256 hex digest of the injection model safetensors file.
158    ///
159    /// When set, the file is verified before loading. Mismatch aborts startup with an error.
160    /// Useful for security-sensitive deployments to detect corruption or tampering.
161    #[serde(default)]
162    pub injection_model_sha256: Option<String>,
163
164    /// Optional `HuggingFace` repo ID or local path for the three-class `AlignSentinel` model.
165    ///
166    /// When set, content flagged as Suspicious or Blocked by the binary `DeBERTa` classifier
167    /// is passed to this model for refinement. If the three-class model classifies the content
168    /// as `aligned-instruction` or `no-instruction`, the verdict is downgraded to `Clean`.
169    /// This directly reduces false positives from legitimate instruction-style content.
170    #[serde(default)]
171    pub three_class_model: Option<String>,
172
173    /// Confidence threshold for the three-class model's `misaligned-instruction` label.
174    ///
175    /// Content is only kept as Suspicious/Blocked when the misaligned score meets this threshold.
176    /// Range: `(0.0, 1.0]`. Default `0.7`.
177    #[serde(
178        default = "default_three_class_threshold",
179        deserialize_with = "validate_unit_threshold"
180    )]
181    pub three_class_threshold: f32,
182
183    /// Optional SHA-256 hex digest of the three-class model safetensors file.
184    #[serde(default)]
185    pub three_class_model_sha256: Option<String>,
186
187    /// Enable PII detection via the NER model (`pii_model`).
188    ///
189    /// When `true`, `CandlePiiClassifier` runs on user messages in addition to the
190    /// regex-based `PiiFilter`. Both results are merged (union with deduplication).
191    #[serde(default)]
192    pub pii_enabled: bool,
193
194    /// `HuggingFace` repo ID for the PII NER model.
195    #[serde(default = "default_pii_model")]
196    pub pii_model: String,
197
198    /// Minimum per-token confidence to accept a PII label.
199    ///
200    /// Tokens below this threshold are treated as O (no entity).
201    /// Default `0.75` balances recall on rarer entity types (DRIVERLICENSE, PASSPORT, IBAN)
202    /// with precision. Raise to `0.85` to prefer precision over recall.
203    #[serde(default = "default_pii_threshold")]
204    pub pii_threshold: f32,
205
206    /// Optional SHA-256 hex digest of the PII model safetensors file.
207    #[serde(default)]
208    pub pii_model_sha256: Option<String>,
209
210    /// Maximum number of bytes passed to the NER PII classifier per call.
211    ///
212    /// Input is truncated at a valid UTF-8 boundary before classification to prevent
213    /// timeout on large tool outputs (e.g. `search_code`). Default `8192`.
214    #[serde(default = "default_pii_ner_max_chars")]
215    pub pii_ner_max_chars: usize,
216
217    /// Allowlist of tokens that are never redacted by the NER PII classifier, regardless
218    /// of model confidence.
219    ///
220    /// Matching is case-insensitive and exact (whole span text must equal an allowlist entry).
221    /// This suppresses common false positives from the piiranha model — for example,
222    /// "Zeph" is misclassified as a city (PII:CITY) by the base model.
223    ///
224    /// Default entries: `["Zeph", "Rust", "OpenAI", "Ollama", "Claude"]`.
225    /// Set to `[]` to disable the allowlist entirely.
226    #[serde(default = "default_pii_ner_allowlist")]
227    pub pii_ner_allowlist: Vec<String>,
228}
229
230impl Default for ClassifiersConfig {
231    fn default() -> Self {
232        Self {
233            enabled: false,
234            timeout_ms: default_classifier_timeout_ms(),
235            hf_token: None,
236            scan_user_input: false,
237            injection_model: default_injection_model(),
238            enforcement_mode: default_enforcement_mode(),
239            injection_threshold_soft: default_injection_threshold_soft(),
240            injection_threshold: default_injection_threshold(),
241            injection_model_sha256: None,
242            three_class_model: None,
243            three_class_threshold: default_three_class_threshold(),
244            three_class_model_sha256: None,
245            pii_enabled: false,
246            pii_model: default_pii_model(),
247            pii_threshold: default_pii_threshold(),
248            pii_model_sha256: None,
249            pii_ner_max_chars: default_pii_ner_max_chars(),
250            pii_ner_allowlist: default_pii_ner_allowlist(),
251        }
252    }
253}
254
255#[cfg(test)]
256mod tests {
257    use super::*;
258
259    #[test]
260    fn default_values() {
261        let cfg = ClassifiersConfig::default();
262        assert!(!cfg.enabled);
263        assert_eq!(cfg.timeout_ms, 5000);
264        assert!(cfg.hf_token.is_none());
265        assert!(!cfg.scan_user_input);
266        assert_eq!(
267            cfg.injection_model,
268            "protectai/deberta-v3-small-prompt-injection-v2"
269        );
270        assert_eq!(cfg.enforcement_mode, InjectionEnforcementMode::Warn);
271        assert!((cfg.injection_threshold_soft - 0.5).abs() < 1e-6);
272        assert!((cfg.injection_threshold - 0.8).abs() < 1e-6);
273        assert!(cfg.injection_model_sha256.is_none());
274        assert!(cfg.three_class_model.is_none());
275        assert!((cfg.three_class_threshold - 0.7).abs() < 1e-6);
276        assert!(cfg.three_class_model_sha256.is_none());
277        assert!(!cfg.pii_enabled);
278        assert_eq!(
279            cfg.pii_model,
280            "iiiorg/piiranha-v1-detect-personal-information"
281        );
282        assert!((cfg.pii_threshold - 0.75).abs() < 1e-6);
283        assert!(cfg.pii_model_sha256.is_none());
284        assert_eq!(
285            cfg.pii_ner_allowlist,
286            vec!["Zeph", "Rust", "OpenAI", "Ollama", "Claude"]
287        );
288    }
289
290    #[test]
291    fn hf_token_and_scan_user_input_round_trip() {
292        let toml = r#"
293            hf_token = "hf_secret"
294            scan_user_input = true
295        "#;
296        let cfg: ClassifiersConfig = toml::from_str(toml).unwrap();
297        assert_eq!(cfg.hf_token.as_deref(), Some("hf_secret"));
298        assert!(cfg.scan_user_input);
299    }
300
301    #[test]
302    fn deserialize_empty_section_uses_defaults() {
303        let cfg: ClassifiersConfig = toml::from_str("").unwrap();
304        assert!(!cfg.enabled);
305        assert_eq!(cfg.timeout_ms, 5000);
306        assert_eq!(
307            cfg.injection_model,
308            "protectai/deberta-v3-small-prompt-injection-v2"
309        );
310        assert!((cfg.injection_threshold_soft - 0.5).abs() < 1e-6);
311        assert!((cfg.injection_threshold - 0.8).abs() < 1e-6);
312        assert!(!cfg.pii_enabled);
313        assert!((cfg.pii_threshold - 0.75).abs() < 1e-6);
314    }
315
316    #[test]
317    fn deserialize_custom_values() {
318        let toml = r#"
319            enabled = true
320            timeout_ms = 2000
321            injection_model = "custom/model-v1"
322            injection_threshold = 0.9
323            pii_enabled = true
324            pii_threshold = 0.85
325        "#;
326        let cfg: ClassifiersConfig = toml::from_str(toml).unwrap();
327        assert!(cfg.enabled);
328        assert_eq!(cfg.timeout_ms, 2000);
329        assert_eq!(cfg.injection_model, "custom/model-v1");
330        assert!((cfg.injection_threshold_soft - 0.5).abs() < 1e-6);
331        assert!((cfg.injection_threshold - 0.9).abs() < 1e-6);
332        assert!(cfg.pii_enabled);
333        assert!((cfg.pii_threshold - 0.85).abs() < 1e-6);
334    }
335
336    #[test]
337    fn deserialize_sha256_fields() {
338        let toml = r#"
339            injection_model_sha256 = "abc123"
340            pii_model_sha256 = "def456"
341        "#;
342        let cfg: ClassifiersConfig = toml::from_str(toml).unwrap();
343        assert_eq!(cfg.injection_model_sha256.as_deref(), Some("abc123"));
344        assert_eq!(cfg.pii_model_sha256.as_deref(), Some("def456"));
345    }
346
347    #[test]
348    fn serialize_roundtrip() {
349        let original = ClassifiersConfig {
350            enabled: true,
351            timeout_ms: 3000,
352            hf_token: Some("hf_test_token".into()),
353            scan_user_input: true,
354            injection_model: "org/model".into(),
355            enforcement_mode: InjectionEnforcementMode::Block,
356            injection_threshold_soft: 0.45,
357            injection_threshold: 0.75,
358            injection_model_sha256: Some("deadbeef".into()),
359            three_class_model: Some("org/three-class".into()),
360            three_class_threshold: 0.65,
361            three_class_model_sha256: Some("abc456".into()),
362            pii_enabled: true,
363            pii_model: "org/pii-model".into(),
364            pii_threshold: 0.80,
365            pii_model_sha256: None,
366            pii_ner_max_chars: 4096,
367            pii_ner_allowlist: vec!["MyProject".into(), "Rust".into()],
368        };
369        let serialized = toml::to_string(&original).unwrap();
370        let deserialized: ClassifiersConfig = toml::from_str(&serialized).unwrap();
371        assert_eq!(original, deserialized);
372    }
373
374    #[test]
375    fn dual_threshold_deserialization() {
376        let toml = r"
377            injection_threshold_soft = 0.4
378            injection_threshold = 0.85
379        ";
380        let cfg: ClassifiersConfig = toml::from_str(toml).unwrap();
381        assert!((cfg.injection_threshold_soft - 0.4).abs() < 1e-6);
382        assert!((cfg.injection_threshold - 0.85).abs() < 1e-6);
383    }
384
385    #[test]
386    fn soft_threshold_defaults_when_only_hard_provided() {
387        let toml = "injection_threshold = 0.9";
388        let cfg: ClassifiersConfig = toml::from_str(toml).unwrap();
389        assert!((cfg.injection_threshold_soft - 0.5).abs() < 1e-6);
390        assert!((cfg.injection_threshold - 0.9).abs() < 1e-6);
391    }
392
393    #[test]
394    fn partial_override_timeout_only() {
395        let toml = "timeout_ms = 1000";
396        let cfg: ClassifiersConfig = toml::from_str(toml).unwrap();
397        assert!(!cfg.enabled);
398        assert_eq!(cfg.timeout_ms, 1000);
399        assert_eq!(
400            cfg.injection_model,
401            "protectai/deberta-v3-small-prompt-injection-v2"
402        );
403        assert!((cfg.injection_threshold_soft - 0.5).abs() < 1e-6);
404        assert!((cfg.injection_threshold - 0.8).abs() < 1e-6);
405    }
406
407    #[test]
408    fn enforcement_mode_warn_is_default() {
409        let cfg: ClassifiersConfig = toml::from_str("").unwrap();
410        assert_eq!(cfg.enforcement_mode, InjectionEnforcementMode::Warn);
411    }
412
413    #[test]
414    fn enforcement_mode_block_roundtrip() {
415        let toml = r#"enforcement_mode = "block""#;
416        let cfg: ClassifiersConfig = toml::from_str(toml).unwrap();
417        assert_eq!(cfg.enforcement_mode, InjectionEnforcementMode::Block);
418        let back = toml::to_string(&cfg).unwrap();
419        let cfg2: ClassifiersConfig = toml::from_str(&back).unwrap();
420        assert_eq!(cfg2.enforcement_mode, InjectionEnforcementMode::Block);
421    }
422
423    #[test]
424    fn threshold_validation_rejects_zero() {
425        let result: Result<ClassifiersConfig, _> = toml::from_str("injection_threshold = 0.0");
426        assert!(result.is_err());
427    }
428
429    #[test]
430    fn threshold_validation_rejects_above_one() {
431        let result: Result<ClassifiersConfig, _> = toml::from_str("injection_threshold = 1.1");
432        assert!(result.is_err());
433    }
434
435    #[test]
436    fn threshold_validation_accepts_exactly_one() {
437        let cfg: ClassifiersConfig = toml::from_str("injection_threshold = 1.0").unwrap();
438        assert!((cfg.injection_threshold - 1.0).abs() < 1e-6);
439    }
440
441    #[test]
442    fn threshold_validation_soft_rejects_zero() {
443        let result: Result<ClassifiersConfig, _> = toml::from_str("injection_threshold_soft = 0.0");
444        assert!(result.is_err());
445    }
446
447    #[test]
448    fn three_class_model_roundtrip() {
449        let toml = r#"
450            three_class_model = "org/align-sentinel"
451            three_class_threshold = 0.65
452            three_class_model_sha256 = "aabbcc"
453        "#;
454        let cfg: ClassifiersConfig = toml::from_str(toml).unwrap();
455        assert_eq!(cfg.three_class_model.as_deref(), Some("org/align-sentinel"));
456        assert!((cfg.three_class_threshold - 0.65).abs() < 1e-6);
457        assert_eq!(cfg.three_class_model_sha256.as_deref(), Some("aabbcc"));
458    }
459
460    #[test]
461    fn pii_ner_allowlist_default_entries() {
462        let cfg = ClassifiersConfig::default();
463        assert!(cfg.pii_ner_allowlist.contains(&"Zeph".to_owned()));
464        assert!(cfg.pii_ner_allowlist.contains(&"Rust".to_owned()));
465        assert!(cfg.pii_ner_allowlist.contains(&"OpenAI".to_owned()));
466        assert!(cfg.pii_ner_allowlist.contains(&"Ollama".to_owned()));
467        assert!(cfg.pii_ner_allowlist.contains(&"Claude".to_owned()));
468    }
469
470    #[test]
471    fn pii_ner_allowlist_configurable() {
472        let toml = r#"pii_ner_allowlist = ["MyProject", "AcmeCorp"]"#;
473        let cfg: ClassifiersConfig = toml::from_str(toml).unwrap();
474        assert_eq!(cfg.pii_ner_allowlist, vec!["MyProject", "AcmeCorp"]);
475    }
476
477    #[test]
478    fn pii_ner_allowlist_empty_disables() {
479        let toml = "pii_ner_allowlist = []";
480        let cfg: ClassifiersConfig = toml::from_str(toml).unwrap();
481        assert!(cfg.pii_ner_allowlist.is_empty());
482    }
483
484    #[test]
485    fn three_class_threshold_validation_rejects_zero() {
486        let result: Result<ClassifiersConfig, _> = toml::from_str("three_class_threshold = 0.0");
487        assert!(result.is_err());
488    }
489}