Skip to main content

rustrails_record/
encryption.rs

1use std::collections::HashMap;
2
3use rustrails_support::encryption::{EncryptorError, MessageEncryptor, MessageVerifier};
4
5/// Metadata describing an encrypted attribute.
6#[derive(Debug, Clone, PartialEq, Eq)]
7pub struct EncryptedFieldConfig {
8    /// The field name.
9    pub field: String,
10    /// Whether equality-query support should emit a deterministic blind index.
11    pub deterministic: bool,
12}
13
14impl EncryptedFieldConfig {
15    /// Creates encryption metadata for `field`.
16    #[must_use]
17    pub fn new(field: &str) -> Self {
18        Self {
19            field: field.to_owned(),
20            deterministic: false,
21        }
22    }
23
24    /// Enables deterministic equality-query support.
25    #[must_use]
26    pub fn deterministic(mut self) -> Self {
27        self.deterministic = true;
28        self
29    }
30}
31
32/// Declares an encrypted attribute.
33#[must_use]
34pub fn encrypts(field: &str) -> EncryptedFieldConfig {
35    EncryptedFieldConfig::new(field)
36}
37
38/// Stored encrypted data plus optional deterministic blind index.
39#[derive(Debug, Clone, PartialEq, Eq)]
40pub struct StoredEncryptedValue {
41    /// Encrypted ciphertext envelope.
42    pub ciphertext: String,
43    /// Identifier for the key used to encrypt the ciphertext.
44    pub key_id: String,
45    /// Deterministic equality token when enabled.
46    pub blind_index: Option<String>,
47}
48
49/// Errors returned by encrypted-attribute helpers.
50#[derive(Debug, thiserror::Error, Clone, PartialEq, Eq)]
51pub enum EncryptionError {
52    /// The configured key id does not exist.
53    #[error("encryption key not found: {0}")]
54    MissingKey(String),
55    /// Encryption failed.
56    #[error("encryption failed: {0}")]
57    Encrypt(#[from] EncryptorError),
58    /// Decryption failed.
59    #[error("decryption failed: {0}")]
60    Decrypt(EncryptorError),
61    /// Generated plaintext was not valid UTF-8.
62    #[error("decrypted plaintext is not valid utf-8")]
63    InvalidUtf8,
64}
65
66/// Keyring used for encryption, decryption, and key rotation.
67#[derive(Debug, Clone)]
68pub struct EncryptionKeyRing {
69    active_key_id: String,
70    keys: HashMap<String, [u8; 32]>,
71}
72
73impl EncryptionKeyRing {
74    /// Creates a keyring with an active key identifier.
75    #[must_use]
76    pub fn new(active_key_id: &str, keys: HashMap<String, [u8; 32]>) -> Self {
77        Self {
78            active_key_id: active_key_id.to_owned(),
79            keys,
80        }
81    }
82
83    /// Encrypts `plaintext` according to the provided field configuration.
84    pub fn encrypt_value(
85        &self,
86        config: &EncryptedFieldConfig,
87        plaintext: &str,
88    ) -> Result<StoredEncryptedValue, EncryptionError> {
89        let encryptor = self.encryptor(&self.active_key_id)?;
90        let ciphertext = encryptor.encrypt_and_sign(plaintext.as_bytes())?;
91        let blind_index = if config.deterministic {
92            Some(self.blind_index_for(&self.active_key_id, plaintext)?)
93        } else {
94            None
95        };
96
97        Ok(StoredEncryptedValue {
98            ciphertext,
99            key_id: self.active_key_id.clone(),
100            blind_index,
101        })
102    }
103
104    /// Decrypts an encrypted value, trying the recorded key first and then rotated keys.
105    pub fn decrypt_value(&self, value: &StoredEncryptedValue) -> Result<String, EncryptionError> {
106        let mut key_ids = self.keys.keys().cloned().collect::<Vec<_>>();
107        key_ids.sort();
108        if let Some(index) = key_ids.iter().position(|key_id| key_id == &value.key_id) {
109            let key_id = key_ids.remove(index);
110            key_ids.insert(0, key_id);
111        } else {
112            return Err(EncryptionError::MissingKey(value.key_id.clone()));
113        }
114
115        let mut last_error = None;
116        for key_id in key_ids {
117            let encryptor = self.encryptor(&key_id)?;
118            match encryptor.decrypt_and_verify(&value.ciphertext) {
119                Ok(bytes) => {
120                    return String::from_utf8(bytes).map_err(|_| EncryptionError::InvalidUtf8);
121                }
122                Err(error) => last_error = Some(error),
123            }
124        }
125
126        match last_error {
127            Some(error) => Err(EncryptionError::Decrypt(error)),
128            None => Err(EncryptionError::MissingKey(value.key_id.clone())),
129        }
130    }
131
132    /// Returns equality tokens for every configured key, newest key first.
133    pub fn equality_tokens(&self, plaintext: &str) -> Result<Vec<String>, EncryptionError> {
134        let mut key_ids = self.keys.keys().cloned().collect::<Vec<_>>();
135        key_ids.sort();
136        if let Some(index) = key_ids
137            .iter()
138            .position(|key_id| key_id == &self.active_key_id)
139        {
140            let key_id = key_ids.remove(index);
141            key_ids.insert(0, key_id);
142        }
143
144        key_ids
145            .into_iter()
146            .map(|key_id| self.blind_index_for(&key_id, plaintext))
147            .collect()
148    }
149
150    fn encryptor(&self, key_id: &str) -> Result<MessageEncryptor, EncryptionError> {
151        let secret = self
152            .keys
153            .get(key_id)
154            .ok_or_else(|| EncryptionError::MissingKey(key_id.to_owned()))?;
155        MessageEncryptor::new(secret).map_err(EncryptionError::Encrypt)
156    }
157
158    fn blind_index_for(&self, key_id: &str, plaintext: &str) -> Result<String, EncryptionError> {
159        let secret = self
160            .keys
161            .get(key_id)
162            .ok_or_else(|| EncryptionError::MissingKey(key_id.to_owned()))?;
163        Ok(MessageVerifier::new(secret).generate(plaintext.as_bytes()))
164    }
165}
166
167/// Trait implemented by records that declare encrypted attributes.
168pub trait EncryptedAttribute {
169    /// Returns encrypted-attribute metadata for the record type.
170    fn encrypted_attributes() -> &'static [EncryptedFieldConfig] {
171        &[]
172    }
173}
174
175#[cfg(test)]
176mod tests {
177    use std::collections::HashMap;
178    use std::sync::LazyLock;
179
180    use super::{
181        EncryptedAttribute, EncryptedFieldConfig, EncryptionError, EncryptionKeyRing, encrypts,
182    };
183
184    struct UserRecord;
185
186    impl EncryptedAttribute for UserRecord {
187        fn encrypted_attributes() -> &'static [EncryptedFieldConfig] {
188            static CONFIGS: LazyLock<Vec<EncryptedFieldConfig>> =
189                LazyLock::new(|| vec![encrypts("email"), encrypts("ssn").deterministic()]);
190            CONFIGS.as_slice()
191        }
192    }
193
194    fn keyring() -> EncryptionKeyRing {
195        EncryptionKeyRing::new(
196            "new",
197            HashMap::from([
198                ("new".to_owned(), [1_u8; 32]),
199                ("old".to_owned(), [2_u8; 32]),
200            ]),
201        )
202    }
203
204    #[test]
205    fn encrypts_builder_enables_deterministic_mode() {
206        let config = encrypts("email").deterministic();
207        assert!(config.deterministic);
208        assert_eq!(config.field, "email");
209    }
210
211    #[test]
212    fn encrypt_value_round_trips_plaintext() {
213        let stored = keyring()
214            .encrypt_value(&encrypts("email"), "alice@example.com")
215            .expect("encryption should succeed");
216        let plaintext = keyring()
217            .decrypt_value(&stored)
218            .expect("decryption should succeed");
219        assert_eq!(plaintext, "alice@example.com");
220    }
221
222    #[test]
223    fn deterministic_fields_emit_blind_indexes() {
224        let stored = keyring()
225            .encrypt_value(&encrypts("ssn").deterministic(), "123-45-6789")
226            .expect("encryption should succeed");
227        assert!(stored.blind_index.is_some());
228    }
229
230    #[test]
231    fn encrypted_attribute_metadata_preserves_field_order_and_flags() {
232        assert_eq!(
233            UserRecord::encrypted_attributes(),
234            &[encrypts("email"), encrypts("ssn").deterministic()]
235        );
236    }
237
238    #[test]
239    fn non_deterministic_fields_do_not_emit_blind_indexes() {
240        let stored = keyring()
241            .encrypt_value(&encrypts("email"), "alice@example.com")
242            .expect("encryption should succeed");
243
244        assert_eq!(stored.blind_index, None);
245    }
246
247    #[test]
248    fn deterministic_encryptions_keep_blind_index_stable_but_change_ciphertext() {
249        let config = encrypts("ssn").deterministic();
250        let first = keyring()
251            .encrypt_value(&config, "123-45-6789")
252            .expect("encryption should succeed");
253        let second = keyring()
254            .encrypt_value(&config, "123-45-6789")
255            .expect("encryption should succeed");
256
257        assert_eq!(first.blind_index, second.blind_index);
258        assert_ne!(first.ciphertext, second.ciphertext);
259    }
260
261    #[test]
262    fn equality_tokens_order_active_key_before_rotated_keys() {
263        let plaintext = "123-45-6789";
264        let active = keyring()
265            .encrypt_value(&encrypts("ssn").deterministic(), plaintext)
266            .expect("encryption should succeed");
267        let old_ring =
268            EncryptionKeyRing::new("old", HashMap::from([("old".to_owned(), [2_u8; 32])]));
269        let rotated = old_ring
270            .encrypt_value(&encrypts("ssn").deterministic(), plaintext)
271            .expect("encryption should succeed");
272
273        let tokens = keyring()
274            .equality_tokens(plaintext)
275            .expect("tokens should generate");
276
277        assert_eq!(
278            tokens,
279            vec![
280                active
281                    .blind_index
282                    .expect("deterministic field should emit a blind index"),
283                rotated
284                    .blind_index
285                    .expect("deterministic field should emit a blind index"),
286            ]
287        );
288    }
289
290    #[test]
291    fn encrypt_value_returns_missing_key_when_active_key_is_unconfigured() {
292        let keyring =
293            EncryptionKeyRing::new("missing", HashMap::from([("old".to_owned(), [2_u8; 32])]));
294
295        assert_eq!(
296            keyring.encrypt_value(&encrypts("email"), "alice@example.com"),
297            Err(EncryptionError::MissingKey("missing".to_owned()))
298        );
299    }
300
301    #[test]
302    fn tampered_ciphertext_returns_decrypt_error() {
303        let mut stored = keyring()
304            .encrypt_value(&encrypts("email"), "alice@example.com")
305            .expect("encryption should succeed");
306        stored.ciphertext.push_str("tampered");
307
308        assert!(matches!(
309            keyring().decrypt_value(&stored),
310            Err(EncryptionError::Decrypt(_))
311        ));
312    }
313
314    #[test]
315    fn ciphertext_is_non_empty_and_does_not_echo_plaintext() {
316        let plaintext = "alice@example.com";
317        let stored = keyring()
318            .encrypt_value(&encrypts("email"), plaintext)
319            .expect("encryption should succeed");
320
321        assert!(!stored.ciphertext.is_empty());
322        assert_ne!(stored.ciphertext, plaintext);
323        assert_eq!(stored.key_id, "new");
324    }
325
326    #[test]
327    fn equality_tokens_are_stable_for_same_plaintext() {
328        let first = keyring()
329            .equality_tokens("123-45-6789")
330            .expect("tokens should generate");
331        let second = keyring()
332            .equality_tokens("123-45-6789")
333            .expect("tokens should generate");
334        assert_eq!(first, second);
335    }
336
337    #[test]
338    fn equality_tokens_change_for_different_plaintext() {
339        let first = keyring()
340            .equality_tokens("alpha")
341            .expect("tokens should generate");
342        let second = keyring()
343            .equality_tokens("beta")
344            .expect("tokens should generate");
345        assert_ne!(first, second);
346    }
347
348    #[test]
349    fn rotated_keys_can_decrypt_old_ciphertext() {
350        let old_ring =
351            EncryptionKeyRing::new("old", HashMap::from([("old".to_owned(), [2_u8; 32])]));
352        let stored = old_ring
353            .encrypt_value(&encrypts("email"), "legacy@example.com")
354            .expect("encryption should succeed");
355
356        let plaintext = keyring()
357            .decrypt_value(&stored)
358            .expect("rotated keys should decrypt legacy values");
359        assert_eq!(plaintext, "legacy@example.com");
360    }
361
362    #[test]
363    fn missing_key_returns_error() {
364        let stored = super::StoredEncryptedValue {
365            ciphertext: "ciphertext".to_owned(),
366            key_id: "missing".to_owned(),
367            blind_index: None,
368        };
369
370        assert_eq!(
371            keyring().decrypt_value(&stored),
372            Err(EncryptionError::MissingKey("missing".to_owned()))
373        );
374    }
375
376    #[test]
377    fn trait_exposes_declared_encrypted_attributes() {
378        assert_eq!(UserRecord::encrypted_attributes().len(), 2);
379        assert!(UserRecord::encrypted_attributes()[1].deterministic);
380    }
381}