Skip to main content

zeph_config/
classifiers.rs

1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4use serde::{Deserialize, Serialize};
5
6fn default_classifier_timeout_ms() -> u64 {
7    5000
8}
9
10fn default_injection_model() -> String {
11    "protectai/deberta-v3-small-prompt-injection-v2".into()
12}
13
14fn default_injection_threshold() -> f32 {
15    0.95
16}
17
18fn default_injection_threshold_soft() -> f32 {
19    0.5
20}
21
22fn default_enforcement_mode() -> InjectionEnforcementMode {
23    InjectionEnforcementMode::Warn
24}
25
26fn default_pii_model() -> String {
27    "iiiorg/piiranha-v1-detect-personal-information".into()
28}
29
30fn default_pii_threshold() -> f32 {
31    0.75
32}
33
34fn default_pii_ner_max_chars() -> usize {
35    8192
36}
37
38fn default_pii_ner_circuit_breaker() -> u32 {
39    2
40}
41
42fn default_pii_ner_allowlist() -> Vec<String> {
43    vec![
44        "Zeph".into(),
45        "Rust".into(),
46        "OpenAI".into(),
47        "Ollama".into(),
48        "Claude".into(),
49    ]
50}
51
52fn default_three_class_threshold() -> f32 {
53    0.7
54}
55
56fn validate_unit_threshold<'de, D>(deserializer: D) -> Result<f32, D::Error>
57where
58    D: serde::Deserializer<'de>,
59{
60    let value = <f32 as serde::Deserialize>::deserialize(deserializer)?;
61    if value.is_nan() || value.is_infinite() {
62        return Err(serde::de::Error::custom(
63            "threshold must be a finite number",
64        ));
65    }
66    if !(value > 0.0 && value <= 1.0) {
67        return Err(serde::de::Error::custom("threshold must be in (0.0, 1.0]"));
68    }
69    Ok(value)
70}
71
72/// Enforcement mode for the injection classifier.
73///
74/// `warn` (default): scores above `injection_threshold` emit WARN and increment metrics
75/// but do NOT block content. Use this when deploying `DeBERTa` classifiers on tool outputs —
76/// FPR of 12-37% on benign content makes hard-blocking unsafe.
77///
78/// `block`: scores above `injection_threshold` block content (behavior before v0.17).
79/// Only safe for well-calibrated models or when FPR is verified on your workload.
80#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize, Serialize)]
81#[serde(rename_all = "snake_case")]
82#[non_exhaustive]
83pub enum InjectionEnforcementMode {
84    /// Log + metric only, never block.
85    Warn,
86    /// Block content above hard threshold.
87    Block,
88}
89
90/// Configuration for the ML-backed classifier subsystem.
91///
92/// Placed under `[classifiers]` in `config.toml`. All fields are optional with safe defaults
93/// so existing configs continue to work when this section is absent.
94///
95/// When `enabled = false` (the default), all classifier code is bypassed and the existing
96/// regex-based detection runs unchanged.
97#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
98pub struct ClassifiersConfig {
99    /// Master switch. When `false`, classifiers are never loaded or invoked.
100    #[serde(default)]
101    pub enabled: bool,
102
103    /// Per-inference timeout in milliseconds.
104    ///
105    /// On timeout the call site falls back to regex. Separate from model download time.
106    #[serde(default = "default_classifier_timeout_ms")]
107    pub timeout_ms: u64,
108
109    /// Resolved `HuggingFace` Hub API token.
110    ///
111    /// Must be the **token value** (not a vault key name) — resolved by the caller before
112    /// constructing `ClassifiersConfig`. When `None`, model downloads are unauthenticated,
113    /// which fails for gated or private repos.
114    #[serde(default)]
115    pub hf_token: Option<String>,
116
117    /// When `true`, the ML injection classifier runs on direct user chat messages.
118    ///
119    /// Default `false`: the `DeBERTa` model is intended for external/untrusted content
120    /// (tool output, web scrapes) — not for direct user input. Enabling this may cause
121    /// false positives on benign conversational messages.
122    #[serde(default)]
123    pub scan_user_input: bool,
124
125    /// `HuggingFace` repo ID for the injection detection model.
126    #[serde(default = "default_injection_model")]
127    pub injection_model: String,
128
129    /// Enforcement mode for the injection classifier.
130    ///
131    /// `warn` (default): scores above `injection_threshold` emit WARN and increment metrics
132    /// but do NOT block content. Use this when deploying classifiers on tool outputs —
133    /// FPR of 12-37% on benign content makes hard-blocking unsafe.
134    ///
135    /// `block`: scores above `injection_threshold` block content. Only safe for well-calibrated
136    /// models or when FPR is verified on your workload.
137    #[serde(default = "default_enforcement_mode")]
138    pub enforcement_mode: InjectionEnforcementMode,
139
140    /// Soft threshold: classifier score at or above this emits a WARN log and increments
141    /// the suspicious-injection metric, but content is allowed through.
142    ///
143    /// Range: `(0.0, 1.0]`. Default `0.5`. Must be ≤ `injection_threshold`.
144    #[serde(
145        default = "default_injection_threshold_soft",
146        deserialize_with = "validate_unit_threshold"
147    )]
148    pub injection_threshold_soft: f32,
149
150    /// Hard threshold: classifier score at or above this blocks the content (in `block` mode)
151    /// or emits WARN (in `warn` mode).
152    ///
153    /// Range: `(0.0, 1.0]`. Conservative default of `0.95` minimises false positives.
154    /// Real-world ML injection classifiers have 12–37% recall gaps at high thresholds —
155    /// defense-in-depth via regex fallback and spotlighting is mandatory.
156    #[serde(
157        default = "default_injection_threshold",
158        deserialize_with = "validate_unit_threshold"
159    )]
160    pub injection_threshold: f32,
161
162    /// Optional SHA-256 hex digest of the injection model safetensors file.
163    ///
164    /// When set, the file is verified before loading. Mismatch aborts startup with an error.
165    /// Useful for security-sensitive deployments to detect corruption or tampering.
166    #[serde(default)]
167    pub injection_model_sha256: Option<String>,
168
169    /// Optional `HuggingFace` repo ID or local path for the three-class `AlignSentinel` model.
170    ///
171    /// When set, content flagged as Suspicious or Blocked by the binary `DeBERTa` classifier
172    /// is passed to this model for refinement. If the three-class model classifies the content
173    /// as `aligned-instruction` or `no-instruction`, the verdict is downgraded to `Clean`.
174    /// This directly reduces false positives from legitimate instruction-style content.
175    #[serde(default)]
176    pub three_class_model: Option<String>,
177
178    /// Confidence threshold for the three-class model's `misaligned-instruction` label.
179    ///
180    /// Content is only kept as Suspicious/Blocked when the misaligned score meets this threshold.
181    /// Range: `(0.0, 1.0]`. Default `0.7`.
182    #[serde(
183        default = "default_three_class_threshold",
184        deserialize_with = "validate_unit_threshold"
185    )]
186    pub three_class_threshold: f32,
187
188    /// Optional SHA-256 hex digest of the three-class model safetensors file.
189    #[serde(default)]
190    pub three_class_model_sha256: Option<String>,
191
192    /// Enable PII detection via the NER model (`pii_model`).
193    ///
194    /// When `true`, `CandlePiiClassifier` runs on user messages in addition to the
195    /// regex-based `PiiFilter`. Both results are merged (union with deduplication).
196    #[serde(default)]
197    pub pii_enabled: bool,
198
199    /// `HuggingFace` repo ID for the PII NER model.
200    #[serde(default = "default_pii_model")]
201    pub pii_model: String,
202
203    /// Minimum per-token confidence to accept a PII label.
204    ///
205    /// Tokens below this threshold are treated as O (no entity).
206    /// Default `0.75` balances recall on rarer entity types (DRIVERLICENSE, PASSPORT, IBAN)
207    /// with precision. Raise to `0.85` to prefer precision over recall.
208    #[serde(default = "default_pii_threshold")]
209    pub pii_threshold: f32,
210
211    /// Optional SHA-256 hex digest of the PII model safetensors file.
212    #[serde(default)]
213    pub pii_model_sha256: Option<String>,
214
215    /// Maximum number of bytes passed to the NER PII classifier per call.
216    ///
217    /// Input is truncated at a valid UTF-8 boundary before classification to prevent
218    /// timeout on large tool outputs (e.g. `search_code`). Default `8192`.
219    #[serde(default = "default_pii_ner_max_chars")]
220    pub pii_ner_max_chars: usize,
221
222    /// Allowlist of tokens that are never redacted by the NER PII classifier, regardless
223    /// of model confidence.
224    ///
225    /// Matching is case-insensitive and exact (whole span text must equal an allowlist entry).
226    /// This suppresses common false positives from the piiranha model — for example,
227    /// "Zeph" is misclassified as a city (PII:CITY) by the base model.
228    ///
229    /// Default entries: `["Zeph", "Rust", "OpenAI", "Ollama", "Claude"]`.
230    /// Set to `[]` to disable the allowlist entirely.
231    #[serde(default = "default_pii_ner_allowlist")]
232    pub pii_ner_allowlist: Vec<String>,
233
234    /// Number of consecutive NER timeouts before the circuit breaker trips and disables NER
235    /// for the remainder of the session.
236    ///
237    /// When the breaker trips, all subsequent chunks fall back to regex-only PII detection,
238    /// preventing repeated timeout stalls on paginated reads (e.g. 12 chunks × 30 s = 6 min).
239    /// Set to `0` to disable the circuit breaker (NER is always attempted).
240    ///
241    /// Default: `2`. Takes effect on the next session start if changed mid-session.
242    #[serde(default = "default_pii_ner_circuit_breaker")]
243    pub pii_ner_circuit_breaker: u32,
244}
245
246impl Default for ClassifiersConfig {
247    fn default() -> Self {
248        Self {
249            enabled: false,
250            timeout_ms: default_classifier_timeout_ms(),
251            hf_token: None,
252            scan_user_input: false,
253            injection_model: default_injection_model(),
254            enforcement_mode: default_enforcement_mode(),
255            injection_threshold_soft: default_injection_threshold_soft(),
256            injection_threshold: default_injection_threshold(),
257            injection_model_sha256: None,
258            three_class_model: None,
259            three_class_threshold: default_three_class_threshold(),
260            three_class_model_sha256: None,
261            pii_enabled: false,
262            pii_model: default_pii_model(),
263            pii_threshold: default_pii_threshold(),
264            pii_model_sha256: None,
265            pii_ner_max_chars: default_pii_ner_max_chars(),
266            pii_ner_allowlist: default_pii_ner_allowlist(),
267            pii_ner_circuit_breaker: default_pii_ner_circuit_breaker(),
268        }
269    }
270}
271
272#[cfg(test)]
273mod tests {
274    use super::*;
275
276    #[test]
277    fn default_values() {
278        let cfg = ClassifiersConfig::default();
279        assert!(!cfg.enabled);
280        assert_eq!(cfg.timeout_ms, 5000);
281        assert!(cfg.hf_token.is_none());
282        assert!(!cfg.scan_user_input);
283        assert_eq!(
284            cfg.injection_model,
285            "protectai/deberta-v3-small-prompt-injection-v2"
286        );
287        assert_eq!(cfg.enforcement_mode, InjectionEnforcementMode::Warn);
288        assert!((cfg.injection_threshold_soft - 0.5).abs() < 1e-6);
289        assert!((cfg.injection_threshold - 0.95).abs() < 1e-6);
290        assert!(cfg.injection_model_sha256.is_none());
291        assert!(cfg.three_class_model.is_none());
292        assert!((cfg.three_class_threshold - 0.7).abs() < 1e-6);
293        assert!(cfg.three_class_model_sha256.is_none());
294        assert!(!cfg.pii_enabled);
295        assert_eq!(
296            cfg.pii_model,
297            "iiiorg/piiranha-v1-detect-personal-information"
298        );
299        assert!((cfg.pii_threshold - 0.75).abs() < 1e-6);
300        assert!(cfg.pii_model_sha256.is_none());
301        assert_eq!(
302            cfg.pii_ner_allowlist,
303            vec!["Zeph", "Rust", "OpenAI", "Ollama", "Claude"]
304        );
305    }
306
307    #[test]
308    fn hf_token_and_scan_user_input_round_trip() {
309        let toml = r#"
310            hf_token = "hf_secret"
311            scan_user_input = true
312        "#;
313        let cfg: ClassifiersConfig = toml::from_str(toml).unwrap();
314        assert_eq!(cfg.hf_token.as_deref(), Some("hf_secret"));
315        assert!(cfg.scan_user_input);
316    }
317
318    #[test]
319    fn deserialize_empty_section_uses_defaults() {
320        let cfg: ClassifiersConfig = toml::from_str("").unwrap();
321        assert!(!cfg.enabled);
322        assert_eq!(cfg.timeout_ms, 5000);
323        assert_eq!(
324            cfg.injection_model,
325            "protectai/deberta-v3-small-prompt-injection-v2"
326        );
327        assert!((cfg.injection_threshold_soft - 0.5).abs() < 1e-6);
328        assert!((cfg.injection_threshold - 0.95).abs() < 1e-6);
329        assert!(!cfg.pii_enabled);
330        assert!((cfg.pii_threshold - 0.75).abs() < 1e-6);
331    }
332
333    #[test]
334    fn deserialize_custom_values() {
335        let toml = r#"
336            enabled = true
337            timeout_ms = 2000
338            injection_model = "custom/model-v1"
339            injection_threshold = 0.9
340            pii_enabled = true
341            pii_threshold = 0.85
342        "#;
343        let cfg: ClassifiersConfig = toml::from_str(toml).unwrap();
344        assert!(cfg.enabled);
345        assert_eq!(cfg.timeout_ms, 2000);
346        assert_eq!(cfg.injection_model, "custom/model-v1");
347        assert!((cfg.injection_threshold_soft - 0.5).abs() < 1e-6);
348        assert!((cfg.injection_threshold - 0.9).abs() < 1e-6);
349        assert!(cfg.pii_enabled);
350        assert!((cfg.pii_threshold - 0.85).abs() < 1e-6);
351    }
352
353    #[test]
354    fn deserialize_sha256_fields() {
355        let toml = r#"
356            injection_model_sha256 = "abc123"
357            pii_model_sha256 = "def456"
358        "#;
359        let cfg: ClassifiersConfig = toml::from_str(toml).unwrap();
360        assert_eq!(cfg.injection_model_sha256.as_deref(), Some("abc123"));
361        assert_eq!(cfg.pii_model_sha256.as_deref(), Some("def456"));
362    }
363
364    #[test]
365    fn serialize_roundtrip() {
366        let original = ClassifiersConfig {
367            enabled: true,
368            timeout_ms: 3000,
369            hf_token: Some("hf_test_token".into()),
370            scan_user_input: true,
371            injection_model: "org/model".into(),
372            enforcement_mode: InjectionEnforcementMode::Block,
373            injection_threshold_soft: 0.45,
374            injection_threshold: 0.75,
375            injection_model_sha256: Some("deadbeef".into()),
376            three_class_model: Some("org/three-class".into()),
377            three_class_threshold: 0.65,
378            three_class_model_sha256: Some("abc456".into()),
379            pii_enabled: true,
380            pii_model: "org/pii-model".into(),
381            pii_threshold: 0.80,
382            pii_model_sha256: None,
383            pii_ner_max_chars: 4096,
384            pii_ner_allowlist: vec!["MyProject".into(), "Rust".into()],
385            pii_ner_circuit_breaker: 3,
386        };
387        let serialized = toml::to_string(&original).unwrap();
388        let deserialized: ClassifiersConfig = toml::from_str(&serialized).unwrap();
389        assert_eq!(original, deserialized);
390    }
391
392    #[test]
393    fn dual_threshold_deserialization() {
394        let toml = r"
395            injection_threshold_soft = 0.4
396            injection_threshold = 0.85
397        ";
398        let cfg: ClassifiersConfig = toml::from_str(toml).unwrap();
399        assert!((cfg.injection_threshold_soft - 0.4).abs() < 1e-6);
400        assert!((cfg.injection_threshold - 0.85).abs() < 1e-6);
401    }
402
403    #[test]
404    fn soft_threshold_defaults_when_only_hard_provided() {
405        let toml = "injection_threshold = 0.9";
406        let cfg: ClassifiersConfig = toml::from_str(toml).unwrap();
407        assert!((cfg.injection_threshold_soft - 0.5).abs() < 1e-6);
408        assert!((cfg.injection_threshold - 0.9).abs() < 1e-6);
409    }
410
411    #[test]
412    fn partial_override_timeout_only() {
413        let toml = "timeout_ms = 1000";
414        let cfg: ClassifiersConfig = toml::from_str(toml).unwrap();
415        assert!(!cfg.enabled);
416        assert_eq!(cfg.timeout_ms, 1000);
417        assert_eq!(
418            cfg.injection_model,
419            "protectai/deberta-v3-small-prompt-injection-v2"
420        );
421        assert!((cfg.injection_threshold_soft - 0.5).abs() < 1e-6);
422        assert!((cfg.injection_threshold - 0.95).abs() < 1e-6);
423    }
424
425    #[test]
426    fn enforcement_mode_warn_is_default() {
427        let cfg: ClassifiersConfig = toml::from_str("").unwrap();
428        assert_eq!(cfg.enforcement_mode, InjectionEnforcementMode::Warn);
429    }
430
431    #[test]
432    fn enforcement_mode_block_roundtrip() {
433        let toml = r#"enforcement_mode = "block""#;
434        let cfg: ClassifiersConfig = toml::from_str(toml).unwrap();
435        assert_eq!(cfg.enforcement_mode, InjectionEnforcementMode::Block);
436        let back = toml::to_string(&cfg).unwrap();
437        let cfg2: ClassifiersConfig = toml::from_str(&back).unwrap();
438        assert_eq!(cfg2.enforcement_mode, InjectionEnforcementMode::Block);
439    }
440
441    #[test]
442    fn threshold_validation_rejects_zero() {
443        let result: Result<ClassifiersConfig, _> = toml::from_str("injection_threshold = 0.0");
444        assert!(result.is_err());
445    }
446
447    #[test]
448    fn threshold_validation_rejects_above_one() {
449        let result: Result<ClassifiersConfig, _> = toml::from_str("injection_threshold = 1.1");
450        assert!(result.is_err());
451    }
452
453    #[test]
454    fn threshold_validation_accepts_exactly_one() {
455        let cfg: ClassifiersConfig = toml::from_str("injection_threshold = 1.0").unwrap();
456        assert!((cfg.injection_threshold - 1.0).abs() < 1e-6);
457    }
458
459    #[test]
460    fn threshold_validation_soft_rejects_zero() {
461        let result: Result<ClassifiersConfig, _> = toml::from_str("injection_threshold_soft = 0.0");
462        assert!(result.is_err());
463    }
464
465    #[test]
466    fn three_class_model_roundtrip() {
467        let toml = r#"
468            three_class_model = "org/align-sentinel"
469            three_class_threshold = 0.65
470            three_class_model_sha256 = "aabbcc"
471        "#;
472        let cfg: ClassifiersConfig = toml::from_str(toml).unwrap();
473        assert_eq!(cfg.three_class_model.as_deref(), Some("org/align-sentinel"));
474        assert!((cfg.three_class_threshold - 0.65).abs() < 1e-6);
475        assert_eq!(cfg.three_class_model_sha256.as_deref(), Some("aabbcc"));
476    }
477
478    #[test]
479    fn pii_ner_allowlist_default_entries() {
480        let cfg = ClassifiersConfig::default();
481        assert!(cfg.pii_ner_allowlist.contains(&"Zeph".to_owned()));
482        assert!(cfg.pii_ner_allowlist.contains(&"Rust".to_owned()));
483        assert!(cfg.pii_ner_allowlist.contains(&"OpenAI".to_owned()));
484        assert!(cfg.pii_ner_allowlist.contains(&"Ollama".to_owned()));
485        assert!(cfg.pii_ner_allowlist.contains(&"Claude".to_owned()));
486    }
487
488    #[test]
489    fn pii_ner_allowlist_configurable() {
490        let toml = r#"pii_ner_allowlist = ["MyProject", "AcmeCorp"]"#;
491        let cfg: ClassifiersConfig = toml::from_str(toml).unwrap();
492        assert_eq!(cfg.pii_ner_allowlist, vec!["MyProject", "AcmeCorp"]);
493    }
494
495    #[test]
496    fn pii_ner_allowlist_empty_disables() {
497        let toml = "pii_ner_allowlist = []";
498        let cfg: ClassifiersConfig = toml::from_str(toml).unwrap();
499        assert!(cfg.pii_ner_allowlist.is_empty());
500    }
501
502    #[test]
503    fn three_class_threshold_validation_rejects_zero() {
504        let result: Result<ClassifiersConfig, _> = toml::from_str("three_class_threshold = 0.0");
505        assert!(result.is_err());
506    }
507
508    #[test]
509    fn pii_ner_circuit_breaker_default() {
510        let cfg = ClassifiersConfig::default();
511        assert_eq!(cfg.pii_ner_circuit_breaker, 2);
512    }
513
514    #[test]
515    fn pii_ner_circuit_breaker_configurable() {
516        let cfg: ClassifiersConfig = toml::from_str("pii_ner_circuit_breaker = 5").unwrap();
517        assert_eq!(cfg.pii_ner_circuit_breaker, 5);
518    }
519
520    #[test]
521    fn pii_ner_circuit_breaker_zero_disables() {
522        let cfg: ClassifiersConfig = toml::from_str("pii_ner_circuit_breaker = 0").unwrap();
523        assert_eq!(cfg.pii_ner_circuit_breaker, 0);
524    }
525
526    #[test]
527    fn pii_ner_circuit_breaker_missing_uses_default() {
528        let cfg: ClassifiersConfig = toml::from_str("").unwrap();
529        assert_eq!(cfg.pii_ner_circuit_breaker, 2);
530    }
531}