1use aes_gcm_siv::{
35 Aes256GcmSiv, Nonce,
36 aead::{Aead, KeyInit, OsRng, Payload},
37};
38use hkdf::Hkdf;
39use rand::RngCore;
40use sha2::Sha256;
41use sochdb_core::SochDBError;
42use zeroize::Zeroize;
43
44const ENCRYPTION_VERSION: u8 = 1;
46const NONCE_SIZE: usize = 12;
48const HEADER_SIZE: usize = 1 + NONCE_SIZE;
50
51pub struct EncryptionEngine {
56 cipher: Aes256GcmSiv,
57 enabled: bool,
59}
60
61impl EncryptionEngine {
62 pub fn new(key: &[u8; 32]) -> Self {
67 let cipher =
68 Aes256GcmSiv::new_from_slice(key).expect("AES-256-GCM-SIV key must be 32 bytes");
69 Self {
70 cipher,
71 enabled: true,
72 }
73 }
74
75 pub fn from_key(key: &EncryptionKey) -> Self {
80 Self::new(key.as_bytes())
81 }
82
83 pub fn disabled() -> Self {
87 let key = [0u8; 32];
89 let cipher =
90 Aes256GcmSiv::new_from_slice(&key).expect("AES-256-GCM-SIV key must be 32 bytes");
91 Self {
92 cipher,
93 enabled: false,
94 }
95 }
96
97 pub fn is_enabled(&self) -> bool {
99 self.enabled
100 }
101
102 pub fn encrypt(&self, plaintext: &[u8]) -> Result<Vec<u8>, EncryptionError> {
111 self.encrypt_with_aad(plaintext, &[])
112 }
113
114 pub fn encrypt_with_aad(
128 &self,
129 plaintext: &[u8],
130 aad: &[u8],
131 ) -> Result<Vec<u8>, EncryptionError> {
132 if !self.enabled {
133 return Ok(plaintext.to_vec());
136 }
137
138 let mut nonce_bytes = [0u8; NONCE_SIZE];
140 OsRng.fill_bytes(&mut nonce_bytes);
141 let nonce = Nonce::from_slice(&nonce_bytes);
142
143 let ciphertext = self
144 .cipher
145 .encrypt(
146 nonce,
147 Payload {
148 msg: plaintext,
149 aad,
150 },
151 )
152 .map_err(|_| EncryptionError::EncryptFailed)?;
153
154 let mut output = Vec::with_capacity(HEADER_SIZE + ciphertext.len());
156 output.push(ENCRYPTION_VERSION);
157 output.extend_from_slice(&nonce_bytes);
158 output.extend_from_slice(&ciphertext);
159
160 Ok(output)
161 }
162
163 pub fn decrypt(&self, encrypted: &[u8]) -> Result<Vec<u8>, EncryptionError> {
167 self.decrypt_with_aad(encrypted, &[])
168 }
169
170 pub fn decrypt_with_aad(
177 &self,
178 encrypted: &[u8],
179 aad: &[u8],
180 ) -> Result<Vec<u8>, EncryptionError> {
181 if !self.enabled {
182 return Ok(encrypted.to_vec());
183 }
184
185 if encrypted.len() < HEADER_SIZE + 16 {
186 return Err(EncryptionError::InvalidFormat(
187 "Data too short for encrypted block".into(),
188 ));
189 }
190
191 let version = encrypted[0];
192 if version != ENCRYPTION_VERSION {
193 return Err(EncryptionError::UnsupportedVersion(version));
194 }
195
196 let nonce = Nonce::from_slice(&encrypted[1..HEADER_SIZE]);
197 let ciphertext = &encrypted[HEADER_SIZE..];
198
199 self.cipher
200 .decrypt(
201 nonce,
202 Payload {
203 msg: ciphertext,
204 aad,
205 },
206 )
207 .map_err(|_| EncryptionError::DecryptFailed)
208 }
209
210 pub fn encrypt_in_place(&self, buffer: &mut Vec<u8>) -> Result<(), EncryptionError> {
215 if !self.enabled {
216 return Ok(());
217 }
218
219 let encrypted = self.encrypt(buffer)?;
220 *buffer = encrypted;
221 Ok(())
222 }
223}
224
225pub fn derive_subkey(ikm: &[u8], salt: &[u8], info: &[u8]) -> EncryptionKey {
232 let hk = Hkdf::<Sha256>::new(Some(salt), ikm);
233 let mut okm = [0u8; 32];
234 hk.expand(info, &mut okm)
235 .expect("HKDF expand of 32 bytes never fails");
236 let key = EncryptionKey::new(okm);
237 okm.zeroize();
238 key
239}
240
241#[derive(Debug)]
243pub enum EncryptionError {
244 EncryptFailed,
246 DecryptFailed,
248 InvalidFormat(String),
250 UnsupportedVersion(u8),
252}
253
254impl std::fmt::Display for EncryptionError {
255 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
256 match self {
257 EncryptionError::EncryptFailed => write!(f, "Encryption failed"),
258 EncryptionError::DecryptFailed => {
259 write!(f, "Decryption failed (wrong key or tampered data)")
260 }
261 EncryptionError::InvalidFormat(msg) => write!(f, "Invalid format: {}", msg),
262 EncryptionError::UnsupportedVersion(v) => {
263 write!(f, "Unsupported encryption version: {}", v)
264 }
265 }
266 }
267}
268
269impl EncryptionError {
270 pub fn is_integrity_failure(&self) -> bool {
277 matches!(
278 self,
279 EncryptionError::DecryptFailed
280 | EncryptionError::InvalidFormat(_)
281 | EncryptionError::UnsupportedVersion(_)
282 )
283 }
284}
285
286impl std::error::Error for EncryptionError {}
287
288impl From<EncryptionError> for SochDBError {
289 fn from(e: EncryptionError) -> Self {
290 SochDBError::Encryption(e.to_string())
293 }
294}
295
296pub fn generate_key() -> [u8; 32] {
302 let mut key = [0u8; 32];
303 OsRng.fill_bytes(&mut key);
304 key
305}
306
307#[derive(Zeroize)]
309#[zeroize(drop)]
310pub struct EncryptionKey {
311 bytes: [u8; 32],
312}
313
314impl EncryptionKey {
315 pub fn new(bytes: [u8; 32]) -> Self {
316 Self { bytes }
317 }
318
319 pub fn as_bytes(&self) -> &[u8; 32] {
320 &self.bytes
321 }
322}
323
324#[cfg(test)]
325mod tests {
326 use super::*;
327
328 #[test]
329 fn test_encrypt_decrypt_roundtrip() {
330 let key = generate_key();
331 let engine = EncryptionEngine::new(&key);
332
333 let plaintext = b"Hello, SochDB enterprise encryption!";
334 let encrypted = engine.encrypt(plaintext).unwrap();
335
336 assert!(encrypted.len() > plaintext.len());
338 assert_eq!(encrypted[0], ENCRYPTION_VERSION);
339
340 let decrypted = engine.decrypt(&encrypted).unwrap();
341 assert_eq!(decrypted, plaintext);
342 }
343
344 #[test]
345 fn test_encrypt_empty() {
346 let key = generate_key();
347 let engine = EncryptionEngine::new(&key);
348
349 let encrypted = engine.encrypt(b"").unwrap();
350 let decrypted = engine.decrypt(&encrypted).unwrap();
351 assert!(decrypted.is_empty());
352 }
353
354 #[test]
355 fn test_encrypt_large_block() {
356 let key = generate_key();
357 let engine = EncryptionEngine::new(&key);
358
359 let plaintext: Vec<u8> = (0..1_000_000).map(|i| (i % 256) as u8).collect();
361 let encrypted = engine.encrypt(&plaintext).unwrap();
362 let decrypted = engine.decrypt(&encrypted).unwrap();
363 assert_eq!(decrypted, plaintext);
364 }
365
366 #[test]
367 fn test_wrong_key_fails() {
368 let key1 = generate_key();
369 let key2 = generate_key();
370 let engine1 = EncryptionEngine::new(&key1);
371 let engine2 = EncryptionEngine::new(&key2);
372
373 let encrypted = engine1.encrypt(b"secret data").unwrap();
374 let result = engine2.decrypt(&encrypted);
375 assert!(result.is_err());
376 }
377
378 #[test]
379 fn test_tampered_data_fails() {
380 let key = generate_key();
381 let engine = EncryptionEngine::new(&key);
382
383 let mut encrypted = engine.encrypt(b"important data").unwrap();
384 let last = encrypted.len() - 1;
386 encrypted[last] ^= 0xFF;
387
388 let result = engine.decrypt(&encrypted);
389 assert!(result.is_err());
390 }
391
392 #[test]
393 fn test_disabled_passthrough() {
394 let engine = EncryptionEngine::disabled();
395
396 let plaintext = b"no encryption here";
397 let encrypted = engine.encrypt(plaintext).unwrap();
398 assert_eq!(encrypted, plaintext);
399
400 let decrypted = engine.decrypt(&encrypted).unwrap();
401 assert_eq!(decrypted, plaintext);
402 }
403
404 #[test]
405 fn test_unique_nonces() {
406 let key = generate_key();
407 let engine = EncryptionEngine::new(&key);
408
409 let enc1 = engine.encrypt(b"same plaintext").unwrap();
410 let enc2 = engine.encrypt(b"same plaintext").unwrap();
411
412 assert_ne!(enc1[1..13], enc2[1..13]);
414 assert_ne!(enc1, enc2);
416 }
417
418 #[test]
419 fn test_invalid_format() {
420 let key = generate_key();
421 let engine = EncryptionEngine::new(&key);
422
423 assert!(engine.decrypt(&[1, 2, 3]).is_err());
425 let fake = vec![99u8; 50];
427 assert!(engine.decrypt(&fake).is_err());
428 }
429
430 #[test]
431 fn test_key_zeroize() {
432 let key = EncryptionKey::new(generate_key());
433 assert_ne!(key.as_bytes(), &[0u8; 32]);
434 drop(key);
435 }
438
439 #[test]
440 fn test_aad_roundtrip() {
441 let key = generate_key();
442 let engine = EncryptionEngine::new(&key);
443 let aad = b"v1|db-uuid|epoch=0|lsn=42";
444
445 let ct = engine.encrypt_with_aad(b"payload", aad).unwrap();
446 let pt = engine.decrypt_with_aad(&ct, aad).unwrap();
447 assert_eq!(pt, b"payload");
448 }
449
450 #[test]
451 fn test_aad_mismatch_fails_like_wrong_key() {
452 let key = generate_key();
453 let engine = EncryptionEngine::new(&key);
454
455 let ct = engine
458 .encrypt_with_aad(b"committed record", b"...|lsn=42")
459 .unwrap();
460 let err = engine.decrypt_with_aad(&ct, b"...|lsn=43").unwrap_err();
461 assert!(matches!(err, EncryptionError::DecryptFailed));
462 assert!(err.is_integrity_failure());
463 }
464
465 #[test]
466 fn test_no_aad_is_not_same_as_some_aad() {
467 let key = generate_key();
468 let engine = EncryptionEngine::new(&key);
469 let ct = engine.encrypt_with_aad(b"x", b"bound").unwrap();
470 assert!(engine.decrypt(&ct).is_err());
472 let ct2 = engine.encrypt(b"x").unwrap();
474 assert_eq!(engine.decrypt(&ct2).unwrap(), b"x");
475 }
476
477 #[test]
478 fn test_integrity_failure_classification() {
479 assert!(EncryptionError::DecryptFailed.is_integrity_failure());
480 assert!(EncryptionError::UnsupportedVersion(9).is_integrity_failure());
481 assert!(EncryptionError::InvalidFormat("x".into()).is_integrity_failure());
482 assert!(!EncryptionError::EncryptFailed.is_integrity_failure());
484 }
485
486 #[test]
487 fn test_hkdf_deterministic_and_salt_separated() {
488 let kek = b"operator-supplied-kek-material";
489 let salt_a = [1u8; 16];
490 let salt_b = [2u8; 16];
491
492 let k1 = derive_subkey(kek, &salt_a, b"sochdb/dek/v1");
494 let k2 = derive_subkey(kek, &salt_a, b"sochdb/dek/v1");
495 assert_eq!(k1.as_bytes(), k2.as_bytes());
496
497 let k3 = derive_subkey(kek, &salt_b, b"sochdb/dek/v1");
500 assert_ne!(k1.as_bytes(), k3.as_bytes());
501
502 let k4 = derive_subkey(kek, &salt_a, b"sochdb/wrap/v1");
504 assert_ne!(k1.as_bytes(), k4.as_bytes());
505
506 let engine = EncryptionEngine::from_key(&k1);
508 let ct = engine.encrypt(b"hi").unwrap();
509 assert_eq!(engine.decrypt(&ct).unwrap(), b"hi");
510 }
511
512 #[test]
513 fn test_encrypt_in_place() {
514 let key = generate_key();
515 let engine = EncryptionEngine::new(&key);
516
517 let original = b"WAL entry payload".to_vec();
518 let mut buffer = original.clone();
519 engine.encrypt_in_place(&mut buffer).unwrap();
520
521 assert_ne!(buffer, original);
522 let decrypted = engine.decrypt(&buffer).unwrap();
523 assert_eq!(decrypted, original);
524 }
525}