sochdb_storage/
encryption.rs1use aes_gcm_siv::{
35 aead::{Aead, KeyInit, OsRng},
36 Aes256GcmSiv, Nonce,
37};
38use rand::RngCore;
39use zeroize::Zeroize;
40
41const ENCRYPTION_VERSION: u8 = 1;
43const NONCE_SIZE: usize = 12;
45const HEADER_SIZE: usize = 1 + NONCE_SIZE;
47
48pub struct EncryptionEngine {
53 cipher: Aes256GcmSiv,
54 enabled: bool,
56}
57
58impl EncryptionEngine {
59 pub fn new(key: &[u8; 32]) -> Self {
64 let cipher = Aes256GcmSiv::new_from_slice(key)
65 .expect("AES-256-GCM-SIV key must be 32 bytes");
66 Self {
67 cipher,
68 enabled: true,
69 }
70 }
71
72 pub fn disabled() -> Self {
76 let key = [0u8; 32];
78 let cipher = Aes256GcmSiv::new_from_slice(&key)
79 .expect("AES-256-GCM-SIV key must be 32 bytes");
80 Self {
81 cipher,
82 enabled: false,
83 }
84 }
85
86 pub fn is_enabled(&self) -> bool {
88 self.enabled
89 }
90
91 pub fn encrypt(&self, plaintext: &[u8]) -> Result<Vec<u8>, EncryptionError> {
100 if !self.enabled {
101 return Ok(plaintext.to_vec());
102 }
103
104 let mut nonce_bytes = [0u8; NONCE_SIZE];
106 OsRng.fill_bytes(&mut nonce_bytes);
107 let nonce = Nonce::from_slice(&nonce_bytes);
108
109 let ciphertext = self
110 .cipher
111 .encrypt(nonce, plaintext)
112 .map_err(|_| EncryptionError::EncryptFailed)?;
113
114 let mut output = Vec::with_capacity(HEADER_SIZE + ciphertext.len());
116 output.push(ENCRYPTION_VERSION);
117 output.extend_from_slice(&nonce_bytes);
118 output.extend_from_slice(&ciphertext);
119
120 Ok(output)
121 }
122
123 pub fn decrypt(&self, encrypted: &[u8]) -> Result<Vec<u8>, EncryptionError> {
127 if !self.enabled {
128 return Ok(encrypted.to_vec());
129 }
130
131 if encrypted.len() < HEADER_SIZE + 16 {
132 return Err(EncryptionError::InvalidFormat(
133 "Data too short for encrypted block".into(),
134 ));
135 }
136
137 let version = encrypted[0];
138 if version != ENCRYPTION_VERSION {
139 return Err(EncryptionError::UnsupportedVersion(version));
140 }
141
142 let nonce = Nonce::from_slice(&encrypted[1..HEADER_SIZE]);
143 let ciphertext = &encrypted[HEADER_SIZE..];
144
145 self.cipher
146 .decrypt(nonce, ciphertext)
147 .map_err(|_| EncryptionError::DecryptFailed)
148 }
149
150 pub fn encrypt_in_place(&self, buffer: &mut Vec<u8>) -> Result<(), EncryptionError> {
155 if !self.enabled {
156 return Ok(());
157 }
158
159 let encrypted = self.encrypt(buffer)?;
160 *buffer = encrypted;
161 Ok(())
162 }
163}
164
165#[derive(Debug)]
167pub enum EncryptionError {
168 EncryptFailed,
170 DecryptFailed,
172 InvalidFormat(String),
174 UnsupportedVersion(u8),
176}
177
178impl std::fmt::Display for EncryptionError {
179 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
180 match self {
181 EncryptionError::EncryptFailed => write!(f, "Encryption failed"),
182 EncryptionError::DecryptFailed => {
183 write!(f, "Decryption failed (wrong key or tampered data)")
184 }
185 EncryptionError::InvalidFormat(msg) => write!(f, "Invalid format: {}", msg),
186 EncryptionError::UnsupportedVersion(v) => {
187 write!(f, "Unsupported encryption version: {}", v)
188 }
189 }
190 }
191}
192
193impl std::error::Error for EncryptionError {}
194
195pub fn generate_key() -> [u8; 32] {
201 let mut key = [0u8; 32];
202 OsRng.fill_bytes(&mut key);
203 key
204}
205
206#[derive(Zeroize)]
208#[zeroize(drop)]
209pub struct EncryptionKey {
210 bytes: [u8; 32],
211}
212
213impl EncryptionKey {
214 pub fn new(bytes: [u8; 32]) -> Self {
215 Self { bytes }
216 }
217
218 pub fn as_bytes(&self) -> &[u8; 32] {
219 &self.bytes
220 }
221}
222
223#[cfg(test)]
224mod tests {
225 use super::*;
226
227 #[test]
228 fn test_encrypt_decrypt_roundtrip() {
229 let key = generate_key();
230 let engine = EncryptionEngine::new(&key);
231
232 let plaintext = b"Hello, SochDB enterprise encryption!";
233 let encrypted = engine.encrypt(plaintext).unwrap();
234
235 assert!(encrypted.len() > plaintext.len());
237 assert_eq!(encrypted[0], ENCRYPTION_VERSION);
238
239 let decrypted = engine.decrypt(&encrypted).unwrap();
240 assert_eq!(decrypted, plaintext);
241 }
242
243 #[test]
244 fn test_encrypt_empty() {
245 let key = generate_key();
246 let engine = EncryptionEngine::new(&key);
247
248 let encrypted = engine.encrypt(b"").unwrap();
249 let decrypted = engine.decrypt(&encrypted).unwrap();
250 assert!(decrypted.is_empty());
251 }
252
253 #[test]
254 fn test_encrypt_large_block() {
255 let key = generate_key();
256 let engine = EncryptionEngine::new(&key);
257
258 let plaintext: Vec<u8> = (0..1_000_000).map(|i| (i % 256) as u8).collect();
260 let encrypted = engine.encrypt(&plaintext).unwrap();
261 let decrypted = engine.decrypt(&encrypted).unwrap();
262 assert_eq!(decrypted, plaintext);
263 }
264
265 #[test]
266 fn test_wrong_key_fails() {
267 let key1 = generate_key();
268 let key2 = generate_key();
269 let engine1 = EncryptionEngine::new(&key1);
270 let engine2 = EncryptionEngine::new(&key2);
271
272 let encrypted = engine1.encrypt(b"secret data").unwrap();
273 let result = engine2.decrypt(&encrypted);
274 assert!(result.is_err());
275 }
276
277 #[test]
278 fn test_tampered_data_fails() {
279 let key = generate_key();
280 let engine = EncryptionEngine::new(&key);
281
282 let mut encrypted = engine.encrypt(b"important data").unwrap();
283 let last = encrypted.len() - 1;
285 encrypted[last] ^= 0xFF;
286
287 let result = engine.decrypt(&encrypted);
288 assert!(result.is_err());
289 }
290
291 #[test]
292 fn test_disabled_passthrough() {
293 let engine = EncryptionEngine::disabled();
294
295 let plaintext = b"no encryption here";
296 let encrypted = engine.encrypt(plaintext).unwrap();
297 assert_eq!(encrypted, plaintext);
298
299 let decrypted = engine.decrypt(&encrypted).unwrap();
300 assert_eq!(decrypted, plaintext);
301 }
302
303 #[test]
304 fn test_unique_nonces() {
305 let key = generate_key();
306 let engine = EncryptionEngine::new(&key);
307
308 let enc1 = engine.encrypt(b"same plaintext").unwrap();
309 let enc2 = engine.encrypt(b"same plaintext").unwrap();
310
311 assert_ne!(enc1[1..13], enc2[1..13]);
313 assert_ne!(enc1, enc2);
315 }
316
317 #[test]
318 fn test_invalid_format() {
319 let key = generate_key();
320 let engine = EncryptionEngine::new(&key);
321
322 assert!(engine.decrypt(&[1, 2, 3]).is_err());
324 let mut fake = vec![99u8; 50];
326 assert!(engine.decrypt(&fake).is_err());
327 }
328
329 #[test]
330 fn test_key_zeroize() {
331 let mut key = EncryptionKey::new(generate_key());
332 assert_ne!(key.as_bytes(), &[0u8; 32]);
333 drop(key);
334 }
337
338 #[test]
339 fn test_encrypt_in_place() {
340 let key = generate_key();
341 let engine = EncryptionEngine::new(&key);
342
343 let original = b"WAL entry payload".to_vec();
344 let mut buffer = original.clone();
345 engine.encrypt_in_place(&mut buffer).unwrap();
346
347 assert_ne!(buffer, original);
348 let decrypted = engine.decrypt(&buffer).unwrap();
349 assert_eq!(decrypted, original);
350 }
351}