zeph_config/
classifiers.rs1use serde::{Deserialize, Serialize};
5
6fn default_classifier_timeout_ms() -> u64 {
7 5000
8}
9
10fn default_injection_model() -> String {
11 "protectai/deberta-v3-small-prompt-injection-v2".into()
12}
13
14fn default_injection_threshold() -> f32 {
15 0.8
16}
17
18fn default_pii_model() -> String {
19 "iiiorg/piiranha-v1-detect-personal-information".into()
20}
21
22fn default_pii_threshold() -> f32 {
23 0.75
24}
25
26#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
34pub struct ClassifiersConfig {
35 #[serde(default)]
37 pub enabled: bool,
38
39 #[serde(default = "default_classifier_timeout_ms")]
43 pub timeout_ms: u64,
44
45 #[serde(default = "default_injection_model")]
47 pub injection_model: String,
48
49 #[serde(default = "default_injection_threshold")]
53 pub injection_threshold: f32,
54
55 #[serde(default)]
60 pub injection_model_sha256: Option<String>,
61
62 #[serde(default)]
67 pub pii_enabled: bool,
68
69 #[serde(default = "default_pii_model")]
71 pub pii_model: String,
72
73 #[serde(default = "default_pii_threshold")]
79 pub pii_threshold: f32,
80
81 #[serde(default)]
83 pub pii_model_sha256: Option<String>,
84}
85
86impl Default for ClassifiersConfig {
87 fn default() -> Self {
88 Self {
89 enabled: false,
90 timeout_ms: default_classifier_timeout_ms(),
91 injection_model: default_injection_model(),
92 injection_threshold: default_injection_threshold(),
93 injection_model_sha256: None,
94 pii_enabled: false,
95 pii_model: default_pii_model(),
96 pii_threshold: default_pii_threshold(),
97 pii_model_sha256: None,
98 }
99 }
100}
101
102#[cfg(test)]
103mod tests {
104 use super::*;
105
106 #[test]
107 fn default_values() {
108 let cfg = ClassifiersConfig::default();
109 assert!(!cfg.enabled);
110 assert_eq!(cfg.timeout_ms, 5000);
111 assert_eq!(
112 cfg.injection_model,
113 "protectai/deberta-v3-small-prompt-injection-v2"
114 );
115 assert!((cfg.injection_threshold - 0.8).abs() < 1e-6);
116 assert!(cfg.injection_model_sha256.is_none());
117 assert!(!cfg.pii_enabled);
118 assert_eq!(
119 cfg.pii_model,
120 "iiiorg/piiranha-v1-detect-personal-information"
121 );
122 assert!((cfg.pii_threshold - 0.75).abs() < 1e-6);
123 assert!(cfg.pii_model_sha256.is_none());
124 }
125
126 #[test]
127 fn deserialize_empty_section_uses_defaults() {
128 let cfg: ClassifiersConfig = toml::from_str("").unwrap();
129 assert!(!cfg.enabled);
130 assert_eq!(cfg.timeout_ms, 5000);
131 assert_eq!(
132 cfg.injection_model,
133 "protectai/deberta-v3-small-prompt-injection-v2"
134 );
135 assert!((cfg.injection_threshold - 0.8).abs() < 1e-6);
136 assert!(!cfg.pii_enabled);
137 assert!((cfg.pii_threshold - 0.75).abs() < 1e-6);
138 }
139
140 #[test]
141 fn deserialize_custom_values() {
142 let toml = r#"
143 enabled = true
144 timeout_ms = 2000
145 injection_model = "custom/model-v1"
146 injection_threshold = 0.9
147 pii_enabled = true
148 pii_threshold = 0.85
149 "#;
150 let cfg: ClassifiersConfig = toml::from_str(toml).unwrap();
151 assert!(cfg.enabled);
152 assert_eq!(cfg.timeout_ms, 2000);
153 assert_eq!(cfg.injection_model, "custom/model-v1");
154 assert!((cfg.injection_threshold - 0.9).abs() < 1e-6);
155 assert!(cfg.pii_enabled);
156 assert!((cfg.pii_threshold - 0.85).abs() < 1e-6);
157 }
158
159 #[test]
160 fn deserialize_sha256_fields() {
161 let toml = r#"
162 injection_model_sha256 = "abc123"
163 pii_model_sha256 = "def456"
164 "#;
165 let cfg: ClassifiersConfig = toml::from_str(toml).unwrap();
166 assert_eq!(cfg.injection_model_sha256.as_deref(), Some("abc123"));
167 assert_eq!(cfg.pii_model_sha256.as_deref(), Some("def456"));
168 }
169
170 #[test]
171 fn serialize_roundtrip() {
172 let original = ClassifiersConfig {
173 enabled: true,
174 timeout_ms: 3000,
175 injection_model: "org/model".into(),
176 injection_threshold: 0.75,
177 injection_model_sha256: Some("deadbeef".into()),
178 pii_enabled: true,
179 pii_model: "org/pii-model".into(),
180 pii_threshold: 0.80,
181 pii_model_sha256: None,
182 };
183 let serialized = toml::to_string(&original).unwrap();
184 let deserialized: ClassifiersConfig = toml::from_str(&serialized).unwrap();
185 assert_eq!(original, deserialized);
186 }
187
188 #[test]
189 fn partial_override_timeout_only() {
190 let toml = "timeout_ms = 1000";
191 let cfg: ClassifiersConfig = toml::from_str(toml).unwrap();
192 assert!(!cfg.enabled);
193 assert_eq!(cfg.timeout_ms, 1000);
194 assert_eq!(
195 cfg.injection_model,
196 "protectai/deberta-v3-small-prompt-injection-v2"
197 );
198 assert!((cfg.injection_threshold - 0.8).abs() < 1e-6);
199 }
200}