Skip to main content

zeph_config/
classifiers.rs

1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4use serde::{Deserialize, Serialize};
5
6fn default_classifier_timeout_ms() -> u64 {
7    5000
8}
9
10fn default_injection_model() -> String {
11    "protectai/deberta-v3-small-prompt-injection-v2".into()
12}
13
14fn default_injection_threshold() -> f32 {
15    0.8
16}
17
18fn default_pii_model() -> String {
19    "iiiorg/piiranha-v1-detect-personal-information".into()
20}
21
22fn default_pii_threshold() -> f32 {
23    0.75
24}
25
26/// Configuration for the ML-backed classifier subsystem.
27///
28/// Placed under `[classifiers]` in `config.toml`. All fields are optional with safe defaults
29/// so existing configs continue to work when this section is absent.
30///
31/// When `enabled = false` (the default), all classifier code is bypassed and the existing
32/// regex-based detection runs unchanged.
33#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
34pub struct ClassifiersConfig {
35    /// Master switch. When `false`, classifiers are never loaded or invoked.
36    #[serde(default)]
37    pub enabled: bool,
38
39    /// Per-inference timeout in milliseconds.
40    ///
41    /// On timeout the call site falls back to regex. Separate from model download time.
42    #[serde(default = "default_classifier_timeout_ms")]
43    pub timeout_ms: u64,
44
45    /// `HuggingFace` repo ID for the injection detection model.
46    #[serde(default = "default_injection_model")]
47    pub injection_model: String,
48
49    /// Minimum classifier score to treat a result as an injection.
50    ///
51    /// Range: `(0.0, 1.0]`. Conservative default of `0.8` minimises false positives.
52    #[serde(default = "default_injection_threshold")]
53    pub injection_threshold: f32,
54
55    /// Optional SHA-256 hex digest of the injection model safetensors file.
56    ///
57    /// When set, the file is verified before loading. Mismatch aborts startup with an error.
58    /// Useful for security-sensitive deployments to detect corruption or tampering.
59    #[serde(default)]
60    pub injection_model_sha256: Option<String>,
61
62    /// Enable PII detection via the NER model (`pii_model`).
63    ///
64    /// When `true`, `CandlePiiClassifier` runs on user messages in addition to the
65    /// regex-based `PiiFilter`. Both results are merged (union with deduplication).
66    #[serde(default)]
67    pub pii_enabled: bool,
68
69    /// `HuggingFace` repo ID for the PII NER model.
70    #[serde(default = "default_pii_model")]
71    pub pii_model: String,
72
73    /// Minimum per-token confidence to accept a PII label.
74    ///
75    /// Tokens below this threshold are treated as O (no entity).
76    /// Default `0.75` balances recall on rarer entity types (DRIVERLICENSE, PASSPORT, IBAN)
77    /// with precision. Raise to `0.85` to prefer precision over recall.
78    #[serde(default = "default_pii_threshold")]
79    pub pii_threshold: f32,
80
81    /// Optional SHA-256 hex digest of the PII model safetensors file.
82    #[serde(default)]
83    pub pii_model_sha256: Option<String>,
84}
85
86impl Default for ClassifiersConfig {
87    fn default() -> Self {
88        Self {
89            enabled: false,
90            timeout_ms: default_classifier_timeout_ms(),
91            injection_model: default_injection_model(),
92            injection_threshold: default_injection_threshold(),
93            injection_model_sha256: None,
94            pii_enabled: false,
95            pii_model: default_pii_model(),
96            pii_threshold: default_pii_threshold(),
97            pii_model_sha256: None,
98        }
99    }
100}
101
102#[cfg(test)]
103mod tests {
104    use super::*;
105
106    #[test]
107    fn default_values() {
108        let cfg = ClassifiersConfig::default();
109        assert!(!cfg.enabled);
110        assert_eq!(cfg.timeout_ms, 5000);
111        assert_eq!(
112            cfg.injection_model,
113            "protectai/deberta-v3-small-prompt-injection-v2"
114        );
115        assert!((cfg.injection_threshold - 0.8).abs() < 1e-6);
116        assert!(cfg.injection_model_sha256.is_none());
117        assert!(!cfg.pii_enabled);
118        assert_eq!(
119            cfg.pii_model,
120            "iiiorg/piiranha-v1-detect-personal-information"
121        );
122        assert!((cfg.pii_threshold - 0.75).abs() < 1e-6);
123        assert!(cfg.pii_model_sha256.is_none());
124    }
125
126    #[test]
127    fn deserialize_empty_section_uses_defaults() {
128        let cfg: ClassifiersConfig = toml::from_str("").unwrap();
129        assert!(!cfg.enabled);
130        assert_eq!(cfg.timeout_ms, 5000);
131        assert_eq!(
132            cfg.injection_model,
133            "protectai/deberta-v3-small-prompt-injection-v2"
134        );
135        assert!((cfg.injection_threshold - 0.8).abs() < 1e-6);
136        assert!(!cfg.pii_enabled);
137        assert!((cfg.pii_threshold - 0.75).abs() < 1e-6);
138    }
139
140    #[test]
141    fn deserialize_custom_values() {
142        let toml = r#"
143            enabled = true
144            timeout_ms = 2000
145            injection_model = "custom/model-v1"
146            injection_threshold = 0.9
147            pii_enabled = true
148            pii_threshold = 0.85
149        "#;
150        let cfg: ClassifiersConfig = toml::from_str(toml).unwrap();
151        assert!(cfg.enabled);
152        assert_eq!(cfg.timeout_ms, 2000);
153        assert_eq!(cfg.injection_model, "custom/model-v1");
154        assert!((cfg.injection_threshold - 0.9).abs() < 1e-6);
155        assert!(cfg.pii_enabled);
156        assert!((cfg.pii_threshold - 0.85).abs() < 1e-6);
157    }
158
159    #[test]
160    fn deserialize_sha256_fields() {
161        let toml = r#"
162            injection_model_sha256 = "abc123"
163            pii_model_sha256 = "def456"
164        "#;
165        let cfg: ClassifiersConfig = toml::from_str(toml).unwrap();
166        assert_eq!(cfg.injection_model_sha256.as_deref(), Some("abc123"));
167        assert_eq!(cfg.pii_model_sha256.as_deref(), Some("def456"));
168    }
169
170    #[test]
171    fn serialize_roundtrip() {
172        let original = ClassifiersConfig {
173            enabled: true,
174            timeout_ms: 3000,
175            injection_model: "org/model".into(),
176            injection_threshold: 0.75,
177            injection_model_sha256: Some("deadbeef".into()),
178            pii_enabled: true,
179            pii_model: "org/pii-model".into(),
180            pii_threshold: 0.80,
181            pii_model_sha256: None,
182        };
183        let serialized = toml::to_string(&original).unwrap();
184        let deserialized: ClassifiersConfig = toml::from_str(&serialized).unwrap();
185        assert_eq!(original, deserialized);
186    }
187
188    #[test]
189    fn partial_override_timeout_only() {
190        let toml = "timeout_ms = 1000";
191        let cfg: ClassifiersConfig = toml::from_str(toml).unwrap();
192        assert!(!cfg.enabled);
193        assert_eq!(cfg.timeout_ms, 1000);
194        assert_eq!(
195            cfg.injection_model,
196            "protectai/deberta-v3-small-prompt-injection-v2"
197        );
198        assert!((cfg.injection_threshold - 0.8).abs() < 1e-6);
199    }
200}