1use std::collections::HashMap;
2
3use rustrails_support::encryption::{EncryptorError, MessageEncryptor, MessageVerifier};
4
5#[derive(Debug, Clone, PartialEq, Eq)]
7pub struct EncryptedFieldConfig {
8 pub field: String,
10 pub deterministic: bool,
12}
13
14impl EncryptedFieldConfig {
15 #[must_use]
17 pub fn new(field: &str) -> Self {
18 Self {
19 field: field.to_owned(),
20 deterministic: false,
21 }
22 }
23
24 #[must_use]
26 pub fn deterministic(mut self) -> Self {
27 self.deterministic = true;
28 self
29 }
30}
31
32#[must_use]
34pub fn encrypts(field: &str) -> EncryptedFieldConfig {
35 EncryptedFieldConfig::new(field)
36}
37
38#[derive(Debug, Clone, PartialEq, Eq)]
40pub struct StoredEncryptedValue {
41 pub ciphertext: String,
43 pub key_id: String,
45 pub blind_index: Option<String>,
47}
48
49#[derive(Debug, thiserror::Error, Clone, PartialEq, Eq)]
51pub enum EncryptionError {
52 #[error("encryption key not found: {0}")]
54 MissingKey(String),
55 #[error("encryption failed: {0}")]
57 Encrypt(#[from] EncryptorError),
58 #[error("decryption failed: {0}")]
60 Decrypt(EncryptorError),
61 #[error("decrypted plaintext is not valid utf-8")]
63 InvalidUtf8,
64}
65
66#[derive(Debug, Clone)]
68pub struct EncryptionKeyRing {
69 active_key_id: String,
70 keys: HashMap<String, [u8; 32]>,
71}
72
73impl EncryptionKeyRing {
74 #[must_use]
76 pub fn new(active_key_id: &str, keys: HashMap<String, [u8; 32]>) -> Self {
77 Self {
78 active_key_id: active_key_id.to_owned(),
79 keys,
80 }
81 }
82
83 pub fn encrypt_value(
85 &self,
86 config: &EncryptedFieldConfig,
87 plaintext: &str,
88 ) -> Result<StoredEncryptedValue, EncryptionError> {
89 let encryptor = self.encryptor(&self.active_key_id)?;
90 let ciphertext = encryptor.encrypt_and_sign(plaintext.as_bytes())?;
91 let blind_index = if config.deterministic {
92 Some(self.blind_index_for(&self.active_key_id, plaintext)?)
93 } else {
94 None
95 };
96
97 Ok(StoredEncryptedValue {
98 ciphertext,
99 key_id: self.active_key_id.clone(),
100 blind_index,
101 })
102 }
103
104 pub fn decrypt_value(&self, value: &StoredEncryptedValue) -> Result<String, EncryptionError> {
106 let mut key_ids = self.keys.keys().cloned().collect::<Vec<_>>();
107 key_ids.sort();
108 if let Some(index) = key_ids.iter().position(|key_id| key_id == &value.key_id) {
109 let key_id = key_ids.remove(index);
110 key_ids.insert(0, key_id);
111 } else {
112 return Err(EncryptionError::MissingKey(value.key_id.clone()));
113 }
114
115 let mut last_error = None;
116 for key_id in key_ids {
117 let encryptor = self.encryptor(&key_id)?;
118 match encryptor.decrypt_and_verify(&value.ciphertext) {
119 Ok(bytes) => {
120 return String::from_utf8(bytes).map_err(|_| EncryptionError::InvalidUtf8);
121 }
122 Err(error) => last_error = Some(error),
123 }
124 }
125
126 match last_error {
127 Some(error) => Err(EncryptionError::Decrypt(error)),
128 None => Err(EncryptionError::MissingKey(value.key_id.clone())),
129 }
130 }
131
132 pub fn equality_tokens(&self, plaintext: &str) -> Result<Vec<String>, EncryptionError> {
134 let mut key_ids = self.keys.keys().cloned().collect::<Vec<_>>();
135 key_ids.sort();
136 if let Some(index) = key_ids
137 .iter()
138 .position(|key_id| key_id == &self.active_key_id)
139 {
140 let key_id = key_ids.remove(index);
141 key_ids.insert(0, key_id);
142 }
143
144 key_ids
145 .into_iter()
146 .map(|key_id| self.blind_index_for(&key_id, plaintext))
147 .collect()
148 }
149
150 fn encryptor(&self, key_id: &str) -> Result<MessageEncryptor, EncryptionError> {
151 let secret = self
152 .keys
153 .get(key_id)
154 .ok_or_else(|| EncryptionError::MissingKey(key_id.to_owned()))?;
155 MessageEncryptor::new(secret).map_err(EncryptionError::Encrypt)
156 }
157
158 fn blind_index_for(&self, key_id: &str, plaintext: &str) -> Result<String, EncryptionError> {
159 let secret = self
160 .keys
161 .get(key_id)
162 .ok_or_else(|| EncryptionError::MissingKey(key_id.to_owned()))?;
163 Ok(MessageVerifier::new(secret).generate(plaintext.as_bytes()))
164 }
165}
166
167pub trait EncryptedAttribute {
169 fn encrypted_attributes() -> &'static [EncryptedFieldConfig] {
171 &[]
172 }
173}
174
175#[cfg(test)]
176mod tests {
177 use std::collections::HashMap;
178 use std::sync::LazyLock;
179
180 use super::{
181 EncryptedAttribute, EncryptedFieldConfig, EncryptionError, EncryptionKeyRing, encrypts,
182 };
183
184 struct UserRecord;
185
186 impl EncryptedAttribute for UserRecord {
187 fn encrypted_attributes() -> &'static [EncryptedFieldConfig] {
188 static CONFIGS: LazyLock<Vec<EncryptedFieldConfig>> =
189 LazyLock::new(|| vec![encrypts("email"), encrypts("ssn").deterministic()]);
190 CONFIGS.as_slice()
191 }
192 }
193
194 fn keyring() -> EncryptionKeyRing {
195 EncryptionKeyRing::new(
196 "new",
197 HashMap::from([
198 ("new".to_owned(), [1_u8; 32]),
199 ("old".to_owned(), [2_u8; 32]),
200 ]),
201 )
202 }
203
204 #[test]
205 fn encrypts_builder_enables_deterministic_mode() {
206 let config = encrypts("email").deterministic();
207 assert!(config.deterministic);
208 assert_eq!(config.field, "email");
209 }
210
211 #[test]
212 fn encrypt_value_round_trips_plaintext() {
213 let stored = keyring()
214 .encrypt_value(&encrypts("email"), "alice@example.com")
215 .expect("encryption should succeed");
216 let plaintext = keyring()
217 .decrypt_value(&stored)
218 .expect("decryption should succeed");
219 assert_eq!(plaintext, "alice@example.com");
220 }
221
222 #[test]
223 fn deterministic_fields_emit_blind_indexes() {
224 let stored = keyring()
225 .encrypt_value(&encrypts("ssn").deterministic(), "123-45-6789")
226 .expect("encryption should succeed");
227 assert!(stored.blind_index.is_some());
228 }
229
230 #[test]
231 fn encrypted_attribute_metadata_preserves_field_order_and_flags() {
232 assert_eq!(
233 UserRecord::encrypted_attributes(),
234 &[encrypts("email"), encrypts("ssn").deterministic()]
235 );
236 }
237
238 #[test]
239 fn non_deterministic_fields_do_not_emit_blind_indexes() {
240 let stored = keyring()
241 .encrypt_value(&encrypts("email"), "alice@example.com")
242 .expect("encryption should succeed");
243
244 assert_eq!(stored.blind_index, None);
245 }
246
247 #[test]
248 fn deterministic_encryptions_keep_blind_index_stable_but_change_ciphertext() {
249 let config = encrypts("ssn").deterministic();
250 let first = keyring()
251 .encrypt_value(&config, "123-45-6789")
252 .expect("encryption should succeed");
253 let second = keyring()
254 .encrypt_value(&config, "123-45-6789")
255 .expect("encryption should succeed");
256
257 assert_eq!(first.blind_index, second.blind_index);
258 assert_ne!(first.ciphertext, second.ciphertext);
259 }
260
261 #[test]
262 fn equality_tokens_order_active_key_before_rotated_keys() {
263 let plaintext = "123-45-6789";
264 let active = keyring()
265 .encrypt_value(&encrypts("ssn").deterministic(), plaintext)
266 .expect("encryption should succeed");
267 let old_ring =
268 EncryptionKeyRing::new("old", HashMap::from([("old".to_owned(), [2_u8; 32])]));
269 let rotated = old_ring
270 .encrypt_value(&encrypts("ssn").deterministic(), plaintext)
271 .expect("encryption should succeed");
272
273 let tokens = keyring()
274 .equality_tokens(plaintext)
275 .expect("tokens should generate");
276
277 assert_eq!(
278 tokens,
279 vec![
280 active
281 .blind_index
282 .expect("deterministic field should emit a blind index"),
283 rotated
284 .blind_index
285 .expect("deterministic field should emit a blind index"),
286 ]
287 );
288 }
289
290 #[test]
291 fn encrypt_value_returns_missing_key_when_active_key_is_unconfigured() {
292 let keyring =
293 EncryptionKeyRing::new("missing", HashMap::from([("old".to_owned(), [2_u8; 32])]));
294
295 assert_eq!(
296 keyring.encrypt_value(&encrypts("email"), "alice@example.com"),
297 Err(EncryptionError::MissingKey("missing".to_owned()))
298 );
299 }
300
301 #[test]
302 fn tampered_ciphertext_returns_decrypt_error() {
303 let mut stored = keyring()
304 .encrypt_value(&encrypts("email"), "alice@example.com")
305 .expect("encryption should succeed");
306 stored.ciphertext.push_str("tampered");
307
308 assert!(matches!(
309 keyring().decrypt_value(&stored),
310 Err(EncryptionError::Decrypt(_))
311 ));
312 }
313
314 #[test]
315 fn ciphertext_is_non_empty_and_does_not_echo_plaintext() {
316 let plaintext = "alice@example.com";
317 let stored = keyring()
318 .encrypt_value(&encrypts("email"), plaintext)
319 .expect("encryption should succeed");
320
321 assert!(!stored.ciphertext.is_empty());
322 assert_ne!(stored.ciphertext, plaintext);
323 assert_eq!(stored.key_id, "new");
324 }
325
326 #[test]
327 fn equality_tokens_are_stable_for_same_plaintext() {
328 let first = keyring()
329 .equality_tokens("123-45-6789")
330 .expect("tokens should generate");
331 let second = keyring()
332 .equality_tokens("123-45-6789")
333 .expect("tokens should generate");
334 assert_eq!(first, second);
335 }
336
337 #[test]
338 fn equality_tokens_change_for_different_plaintext() {
339 let first = keyring()
340 .equality_tokens("alpha")
341 .expect("tokens should generate");
342 let second = keyring()
343 .equality_tokens("beta")
344 .expect("tokens should generate");
345 assert_ne!(first, second);
346 }
347
348 #[test]
349 fn rotated_keys_can_decrypt_old_ciphertext() {
350 let old_ring =
351 EncryptionKeyRing::new("old", HashMap::from([("old".to_owned(), [2_u8; 32])]));
352 let stored = old_ring
353 .encrypt_value(&encrypts("email"), "legacy@example.com")
354 .expect("encryption should succeed");
355
356 let plaintext = keyring()
357 .decrypt_value(&stored)
358 .expect("rotated keys should decrypt legacy values");
359 assert_eq!(plaintext, "legacy@example.com");
360 }
361
362 #[test]
363 fn missing_key_returns_error() {
364 let stored = super::StoredEncryptedValue {
365 ciphertext: "ciphertext".to_owned(),
366 key_id: "missing".to_owned(),
367 blind_index: None,
368 };
369
370 assert_eq!(
371 keyring().decrypt_value(&stored),
372 Err(EncryptionError::MissingKey("missing".to_owned()))
373 );
374 }
375
376 #[test]
377 fn trait_exposes_declared_encrypted_attributes() {
378 assert_eq!(UserRecord::encrypted_attributes().len(), 2);
379 assert!(UserRecord::encrypted_attributes()[1].deterministic);
380 }
381}