1use aes_gcm::{
8 aead::{Aead, AeadInPlace, KeyInit, Payload},
9 Aes256Gcm, Nonce, Tag,
10};
11use rand::RngCore;
12use crate::{CryptoError, CryptoResult};
13
14const NONCE_SIZE: usize = 12;
15const TAG_SIZE: usize = 16;
16
17pub const AAD_COMMIT: &[u8] = b"void:commit:v1";
22pub const AAD_METADATA: &[u8] = b"void:metadata:v1";
24pub const AAD_SHARD: &[u8] = b"void:shard:v1";
26pub const AAD_INDEX: &[u8] = b"void:index:v1";
28pub const AAD_STASH: &[u8] = b"void:stash:v1";
30pub const AAD_STAGED: &[u8] = b"void:staged:v1";
32pub const AAD_MANIFEST: &[u8] = b"void:manifest:v1";
34pub const AAD_SHARD_KEY: &[u8] = b"void:shard-key:v1";
36pub const AAD_REPO_MANIFEST: &[u8] = b"void:repo-manifest:v1";
38
39pub fn encrypt(key: &[u8; 32], plaintext: &[u8], aad: &[u8]) -> CryptoResult<Vec<u8>> {
52 let cipher =
53 Aes256Gcm::new_from_slice(key).map_err(|e| CryptoError::Encryption(e.to_string()))?;
54
55 let mut nonce_bytes = [0u8; NONCE_SIZE];
56 rand::thread_rng().fill_bytes(&mut nonce_bytes);
57 let nonce = Nonce::from_slice(&nonce_bytes);
58
59 let payload = Payload {
60 msg: plaintext,
61 aad,
62 };
63
64 let ciphertext = cipher
65 .encrypt(nonce, payload)
66 .map_err(|e| CryptoError::Encryption(e.to_string()))?;
67
68 let mut output = Vec::with_capacity(NONCE_SIZE + ciphertext.len());
69 output.extend_from_slice(&nonce_bytes);
70 output.extend_from_slice(&ciphertext);
71
72 Ok(output)
73}
74
75pub fn decrypt(key: &[u8; 32], ciphertext: &[u8], aad: &[u8]) -> CryptoResult<Vec<u8>> {
81 if ciphertext.len() < NONCE_SIZE + TAG_SIZE {
82 return Err(CryptoError::Decryption("ciphertext too short".into()));
83 }
84
85 let cipher =
86 Aes256Gcm::new_from_slice(key).map_err(|e| CryptoError::Decryption(e.to_string()))?;
87
88 let nonce = Nonce::from_slice(&ciphertext[..NONCE_SIZE]);
89 let encrypted = &ciphertext[NONCE_SIZE..];
90
91 let payload = Payload {
92 msg: encrypted,
93 aad,
94 };
95
96 cipher
97 .decrypt(nonce, payload)
98 .map_err(|e| CryptoError::Decryption(e.to_string()))
99}
100
101pub fn decrypt_to_vec(key: &[u8; 32], ciphertext: &[u8], aad: &[u8]) -> CryptoResult<Vec<u8>> {
106 if ciphertext.len() < NONCE_SIZE + TAG_SIZE {
107 return Err(CryptoError::Decryption("ciphertext too short".into()));
108 }
109
110 let cipher =
111 Aes256Gcm::new_from_slice(key).map_err(|e| CryptoError::Decryption(e.to_string()))?;
112
113 let nonce = Nonce::from_slice(&ciphertext[..NONCE_SIZE]);
114 let encrypted = &ciphertext[NONCE_SIZE..];
115 let (body, tag_bytes) = encrypted.split_at(encrypted.len() - TAG_SIZE);
116 let tag = Tag::from_slice(tag_bytes);
117
118 let mut output = body.to_vec();
119
120 cipher
121 .decrypt_in_place_detached(nonce, aad, &mut output, tag)
122 .map_err(|e| CryptoError::Decryption(e.to_string()))?;
123
124 Ok(output)
125}
126
127pub fn decrypt_and_parse<T>(key: &[u8; 32], ciphertext: &[u8], aad: &[u8]) -> CryptoResult<T>
131where
132 T: serde::de::DeserializeOwned,
133{
134 let plaintext = decrypt(key, ciphertext, aad)?;
135 ciborium::from_reader(&plaintext[..])
136 .map_err(|e| CryptoError::Serialization(format!("CBOR deserialization failed: {e}")))
137}
138
139pub fn wrap_shard_key(
143 content_key: &crate::ContentKey,
144 shard_key: &[u8; 32],
145) -> CryptoResult<crate::WrappedKey> {
146 let bytes = encrypt(content_key.as_bytes(), shard_key, AAD_SHARD_KEY)?;
147 Ok(crate::WrappedKey::from_bytes(bytes))
148}
149
150pub fn unwrap_shard_key(
152 content_key: &crate::ContentKey,
153 wrapped: &crate::WrappedKey,
154) -> CryptoResult<[u8; 32]> {
155 let plain = decrypt(content_key.as_bytes(), wrapped.as_bytes(), AAD_SHARD_KEY)?;
156 plain
157 .try_into()
158 .map_err(|_| CryptoError::Decryption("unwrapped shard key is not 32 bytes".into()))
159}
160
161#[cfg(test)]
162mod tests {
163 use super::*;
164
165 const TEST_AAD: &[u8] = b"void:test:v1";
166
167 #[test]
168 fn encrypt_decrypt_roundtrip() {
169 let key = [0x42u8; 32];
170 let plaintext = b"hello, void!";
171
172 let ciphertext = encrypt(&key, plaintext, TEST_AAD).unwrap();
173 let decrypted = decrypt(&key, &ciphertext, TEST_AAD).unwrap();
174
175 assert_eq!(decrypted, plaintext);
176 }
177
178 #[test]
179 fn decrypt_wrong_key_fails() {
180 let key1 = [0x42u8; 32];
181 let key2 = [0x43u8; 32];
182 let plaintext = b"secret";
183
184 let ciphertext = encrypt(&key1, plaintext, TEST_AAD).unwrap();
185 let result = decrypt(&key2, &ciphertext, TEST_AAD);
186
187 assert!(result.is_err());
188 }
189
190 #[test]
191 fn decrypt_wrong_aad_fails() {
192 let key = [0x42u8; 32];
193 let plaintext = b"secret";
194
195 let ciphertext = encrypt(&key, plaintext, b"void:commit:v1").unwrap();
196 let result = decrypt(&key, &ciphertext, b"void:shard:v1");
197
198 assert!(result.is_err(), "Decryption with wrong AAD should fail");
199 }
200
201 #[test]
202 fn decrypt_tampered_fails() {
203 let key = [0x42u8; 32];
204 let plaintext = b"secret";
205
206 let mut ciphertext = encrypt(&key, plaintext, TEST_AAD).unwrap();
207 if let Some(byte) = ciphertext.last_mut() {
208 *byte ^= 0xff;
209 }
210
211 let result = decrypt(&key, &ciphertext, TEST_AAD);
212 assert!(result.is_err());
213 }
214
215 #[test]
216 fn ciphertext_too_short_fails() {
217 let key = [0x42u8; 32];
218 let short = vec![0u8; 10];
219
220 let result = decrypt(&key, &short, TEST_AAD);
221 assert!(result.is_err());
222 }
223
224 #[test]
225 fn decrypt_to_vec_roundtrip() {
226 let key = [0x42u8; 32];
227 let plaintext = b"hello, void!";
228
229 let ciphertext = encrypt(&key, plaintext, TEST_AAD).unwrap();
230 let decrypted = decrypt_to_vec(&key, &ciphertext, TEST_AAD).unwrap();
231
232 assert_eq!(&decrypted, plaintext);
233 }
234
235 #[test]
236 fn decrypt_and_parse_with_simple_type() {
237 #[derive(serde::Serialize, serde::Deserialize, Debug, PartialEq)]
238 struct TestData {
239 value: u64,
240 name: String,
241 }
242
243 let key = [0x42u8; 32];
244 let data = TestData {
245 value: 42,
246 name: "test".to_string(),
247 };
248
249 let mut serialized = Vec::new();
250 ciborium::into_writer(&data, &mut serialized).unwrap();
251 let ciphertext = encrypt(&key, &serialized, TEST_AAD).unwrap();
252
253 let parsed: TestData = decrypt_and_parse(&key, &ciphertext, TEST_AAD).unwrap();
254
255 assert_eq!(parsed, data);
256 }
257}
258
259#[cfg(test)]
260mod proptests {
261 use super::*;
262 use proptest::prelude::*;
263
264 const TEST_AAD: &[u8] = b"void:proptest:v1";
265
266 proptest! {
267 #[test]
268 fn encrypt_decrypt_roundtrip(
269 key in prop::array::uniform32(any::<u8>()),
270 plaintext in prop::collection::vec(any::<u8>(), 0..10000)
271 ) {
272 let ciphertext = encrypt(&key, &plaintext, TEST_AAD).expect("encryption should succeed");
273 let decrypted = decrypt(&key, &ciphertext, TEST_AAD).expect("decryption should succeed");
274 prop_assert_eq!(decrypted, plaintext);
275 }
276
277 #[test]
278 fn single_bit_flip_detected(
279 key in prop::array::uniform32(any::<u8>()),
280 plaintext in prop::collection::vec(any::<u8>(), 1..1000),
281 flip_position in any::<usize>(),
282 flip_bit in 0u8..8
283 ) {
284 let ciphertext = encrypt(&key, &plaintext, TEST_AAD).expect("encryption should succeed");
285 let mut tampered = ciphertext.clone();
286 let pos = flip_position % tampered.len();
287 tampered[pos] ^= 1 << flip_bit;
288 let result = decrypt(&key, &tampered, TEST_AAD);
289 prop_assert!(result.is_err());
290 }
291
292 #[test]
293 fn ciphertext_expansion(
294 key in prop::array::uniform32(any::<u8>()),
295 plaintext in prop::collection::vec(any::<u8>(), 0..5000)
296 ) {
297 let ciphertext = encrypt(&key, &plaintext, TEST_AAD).expect("encryption should succeed");
298 let expected_len = plaintext.len() + 12 + 16;
299 prop_assert_eq!(ciphertext.len(), expected_len);
300 }
301
302 #[test]
303 fn decrypt_to_vec_matches_regular(
304 key in prop::array::uniform32(any::<u8>()),
305 plaintext in prop::collection::vec(any::<u8>(), 0..5000)
306 ) {
307 let ciphertext = encrypt(&key, &plaintext, TEST_AAD).expect("encryption should succeed");
308 let regular = decrypt(&key, &ciphertext, TEST_AAD).expect("regular decryption should succeed");
309 let via_to_vec = decrypt_to_vec(&key, &ciphertext, TEST_AAD).expect("decrypt_to_vec should succeed");
310 prop_assert_eq!(regular, via_to_vec);
311 }
312
313 #[test]
314 fn wrong_aad_fails(
315 key in prop::array::uniform32(any::<u8>()),
316 plaintext in prop::collection::vec(any::<u8>(), 1..1000)
317 ) {
318 let ciphertext = encrypt(&key, &plaintext, b"void:commit:v1").expect("encryption should succeed");
319 let result = decrypt(&key, &ciphertext, b"void:shard:v1");
320 prop_assert!(result.is_err());
321 }
322 }
323}