Skip to main content

void_crypto/
aead.rs

1//! AES-256-GCM authenticated encryption with AAD support.
2//!
3//! All encryption uses Additional Authenticated Data (AAD) to bind ciphertext
4//! to its intended purpose. This prevents type confusion attacks where one
5//! ciphertext type is misinterpreted as another.
6
7use aes_gcm::{
8    aead::{Aead, AeadInPlace, KeyInit, Payload},
9    Aes256Gcm, Nonce, Tag,
10};
11use rand::RngCore;
12use crate::{CryptoError, CryptoResult};
13
14const NONCE_SIZE: usize = 12;
15const TAG_SIZE: usize = 16;
16
17// AAD constants for different object types.
18// These prevent type confusion attacks (treating shard as commit, etc.)
19
20/// AAD for commit objects.
21pub const AAD_COMMIT: &[u8] = b"void:commit:v1";
22/// AAD for metadata bundle objects.
23pub const AAD_METADATA: &[u8] = b"void:metadata:v1";
24/// AAD for shard objects.
25pub const AAD_SHARD: &[u8] = b"void:shard:v1";
26/// AAD for index objects.
27pub const AAD_INDEX: &[u8] = b"void:index:v1";
28/// AAD for stash objects.
29pub const AAD_STASH: &[u8] = b"void:stash:v1";
30/// AAD for staged content blobs.
31pub const AAD_STAGED: &[u8] = b"void:staged:v1";
32/// AAD for tree manifest objects.
33pub const AAD_MANIFEST: &[u8] = b"void:manifest:v1";
34/// AAD for shard key wrapping (content_key wraps shard_key).
35pub const AAD_SHARD_KEY: &[u8] = b"void:shard-key:v1";
36/// AAD for repo manifest (collaboration manifest JSON).
37pub const AAD_REPO_MANIFEST: &[u8] = b"void:repo-manifest:v1";
38
39/// Encrypts plaintext using AES-256-GCM with Additional Authenticated Data (AAD).
40///
41/// The AAD is authenticated but not encrypted - it binds the ciphertext to a
42/// specific context (e.g., object type) without being included in the output.
43///
44/// Returns: nonce (12 bytes) || ciphertext || tag (16 bytes)
45///
46/// # Security
47///
48/// Uses a random 12-byte nonce. Total ciphertext overhead is 28 bytes
49/// (12 nonce + 16 tag). Always use the appropriate `AAD_*` constant
50/// for the object type being encrypted.
51pub fn encrypt(key: &[u8; 32], plaintext: &[u8], aad: &[u8]) -> CryptoResult<Vec<u8>> {
52    let cipher =
53        Aes256Gcm::new_from_slice(key).map_err(|e| CryptoError::Encryption(e.to_string()))?;
54
55    let mut nonce_bytes = [0u8; NONCE_SIZE];
56    rand::thread_rng().fill_bytes(&mut nonce_bytes);
57    let nonce = Nonce::from_slice(&nonce_bytes);
58
59    let payload = Payload {
60        msg: plaintext,
61        aad,
62    };
63
64    let ciphertext = cipher
65        .encrypt(nonce, payload)
66        .map_err(|e| CryptoError::Encryption(e.to_string()))?;
67
68    let mut output = Vec::with_capacity(NONCE_SIZE + ciphertext.len());
69    output.extend_from_slice(&nonce_bytes);
70    output.extend_from_slice(&ciphertext);
71
72    Ok(output)
73}
74
75/// Decrypts ciphertext using AES-256-GCM with Additional Authenticated Data (AAD).
76///
77/// The AAD must match what was used during encryption, otherwise decryption fails.
78///
79/// Expects input format: nonce (12 bytes) || ciphertext || tag (16 bytes)
80pub fn decrypt(key: &[u8; 32], ciphertext: &[u8], aad: &[u8]) -> CryptoResult<Vec<u8>> {
81    if ciphertext.len() < NONCE_SIZE + TAG_SIZE {
82        return Err(CryptoError::Decryption("ciphertext too short".into()));
83    }
84
85    let cipher =
86        Aes256Gcm::new_from_slice(key).map_err(|e| CryptoError::Decryption(e.to_string()))?;
87
88    let nonce = Nonce::from_slice(&ciphertext[..NONCE_SIZE]);
89    let encrypted = &ciphertext[NONCE_SIZE..];
90
91    let payload = Payload {
92        msg: encrypted,
93        aad,
94    };
95
96    cipher
97        .decrypt(nonce, payload)
98        .map_err(|e| CryptoError::Decryption(e.to_string()))
99}
100
101/// Decrypts ciphertext into a byte buffer.
102///
103/// Functionally identical to `decrypt()` but uses in-place decryption
104/// (avoids an extra allocation for the tag-stripped copy).
105pub fn decrypt_to_vec(key: &[u8; 32], ciphertext: &[u8], aad: &[u8]) -> CryptoResult<Vec<u8>> {
106    if ciphertext.len() < NONCE_SIZE + TAG_SIZE {
107        return Err(CryptoError::Decryption("ciphertext too short".into()));
108    }
109
110    let cipher =
111        Aes256Gcm::new_from_slice(key).map_err(|e| CryptoError::Decryption(e.to_string()))?;
112
113    let nonce = Nonce::from_slice(&ciphertext[..NONCE_SIZE]);
114    let encrypted = &ciphertext[NONCE_SIZE..];
115    let (body, tag_bytes) = encrypted.split_at(encrypted.len() - TAG_SIZE);
116    let tag = Tag::from_slice(tag_bytes);
117
118    let mut output = body.to_vec();
119
120    cipher
121        .decrypt_in_place_detached(nonce, aad, &mut output, tag)
122        .map_err(|e| CryptoError::Decryption(e.to_string()))?;
123
124    Ok(output)
125}
126
127/// Decrypt and parse a CBOR-encoded type.
128///
129/// Decrypts the ciphertext, then deserializes the plaintext as CBOR.
130pub fn decrypt_and_parse<T>(key: &[u8; 32], ciphertext: &[u8], aad: &[u8]) -> CryptoResult<T>
131where
132    T: serde::de::DeserializeOwned,
133{
134    let plaintext = decrypt(key, ciphertext, aad)?;
135    ciborium::from_reader(&plaintext[..])
136        .map_err(|e| CryptoError::Serialization(format!("CBOR deserialization failed: {e}")))
137}
138
139/// Wrap a shard key under a content key using AES-256-GCM.
140///
141/// Returns a `WrappedKey` containing nonce || ciphertext || tag.
142pub fn wrap_shard_key(
143    content_key: &crate::ContentKey,
144    shard_key: &[u8; 32],
145) -> CryptoResult<crate::WrappedKey> {
146    let bytes = encrypt(content_key.as_bytes(), shard_key, AAD_SHARD_KEY)?;
147    Ok(crate::WrappedKey::from_bytes(bytes))
148}
149
150/// Unwrap a shard key that was wrapped under a content key.
151pub fn unwrap_shard_key(
152    content_key: &crate::ContentKey,
153    wrapped: &crate::WrappedKey,
154) -> CryptoResult<[u8; 32]> {
155    let plain = decrypt(content_key.as_bytes(), wrapped.as_bytes(), AAD_SHARD_KEY)?;
156    plain
157        .try_into()
158        .map_err(|_| CryptoError::Decryption("unwrapped shard key is not 32 bytes".into()))
159}
160
161#[cfg(test)]
162mod tests {
163    use super::*;
164
165    const TEST_AAD: &[u8] = b"void:test:v1";
166
167    #[test]
168    fn encrypt_decrypt_roundtrip() {
169        let key = [0x42u8; 32];
170        let plaintext = b"hello, void!";
171
172        let ciphertext = encrypt(&key, plaintext, TEST_AAD).unwrap();
173        let decrypted = decrypt(&key, &ciphertext, TEST_AAD).unwrap();
174
175        assert_eq!(decrypted, plaintext);
176    }
177
178    #[test]
179    fn decrypt_wrong_key_fails() {
180        let key1 = [0x42u8; 32];
181        let key2 = [0x43u8; 32];
182        let plaintext = b"secret";
183
184        let ciphertext = encrypt(&key1, plaintext, TEST_AAD).unwrap();
185        let result = decrypt(&key2, &ciphertext, TEST_AAD);
186
187        assert!(result.is_err());
188    }
189
190    #[test]
191    fn decrypt_wrong_aad_fails() {
192        let key = [0x42u8; 32];
193        let plaintext = b"secret";
194
195        let ciphertext = encrypt(&key, plaintext, b"void:commit:v1").unwrap();
196        let result = decrypt(&key, &ciphertext, b"void:shard:v1");
197
198        assert!(result.is_err(), "Decryption with wrong AAD should fail");
199    }
200
201    #[test]
202    fn decrypt_tampered_fails() {
203        let key = [0x42u8; 32];
204        let plaintext = b"secret";
205
206        let mut ciphertext = encrypt(&key, plaintext, TEST_AAD).unwrap();
207        if let Some(byte) = ciphertext.last_mut() {
208            *byte ^= 0xff;
209        }
210
211        let result = decrypt(&key, &ciphertext, TEST_AAD);
212        assert!(result.is_err());
213    }
214
215    #[test]
216    fn ciphertext_too_short_fails() {
217        let key = [0x42u8; 32];
218        let short = vec![0u8; 10];
219
220        let result = decrypt(&key, &short, TEST_AAD);
221        assert!(result.is_err());
222    }
223
224    #[test]
225    fn decrypt_to_vec_roundtrip() {
226        let key = [0x42u8; 32];
227        let plaintext = b"hello, void!";
228
229        let ciphertext = encrypt(&key, plaintext, TEST_AAD).unwrap();
230        let decrypted = decrypt_to_vec(&key, &ciphertext, TEST_AAD).unwrap();
231
232        assert_eq!(&decrypted, plaintext);
233    }
234
235    #[test]
236    fn decrypt_and_parse_with_simple_type() {
237        #[derive(serde::Serialize, serde::Deserialize, Debug, PartialEq)]
238        struct TestData {
239            value: u64,
240            name: String,
241        }
242
243        let key = [0x42u8; 32];
244        let data = TestData {
245            value: 42,
246            name: "test".to_string(),
247        };
248
249        let mut serialized = Vec::new();
250        ciborium::into_writer(&data, &mut serialized).unwrap();
251        let ciphertext = encrypt(&key, &serialized, TEST_AAD).unwrap();
252
253        let parsed: TestData = decrypt_and_parse(&key, &ciphertext, TEST_AAD).unwrap();
254
255        assert_eq!(parsed, data);
256    }
257}
258
259#[cfg(test)]
260mod proptests {
261    use super::*;
262    use proptest::prelude::*;
263
264    const TEST_AAD: &[u8] = b"void:proptest:v1";
265
266    proptest! {
267        #[test]
268        fn encrypt_decrypt_roundtrip(
269            key in prop::array::uniform32(any::<u8>()),
270            plaintext in prop::collection::vec(any::<u8>(), 0..10000)
271        ) {
272            let ciphertext = encrypt(&key, &plaintext, TEST_AAD).expect("encryption should succeed");
273            let decrypted = decrypt(&key, &ciphertext, TEST_AAD).expect("decryption should succeed");
274            prop_assert_eq!(decrypted, plaintext);
275        }
276
277        #[test]
278        fn single_bit_flip_detected(
279            key in prop::array::uniform32(any::<u8>()),
280            plaintext in prop::collection::vec(any::<u8>(), 1..1000),
281            flip_position in any::<usize>(),
282            flip_bit in 0u8..8
283        ) {
284            let ciphertext = encrypt(&key, &plaintext, TEST_AAD).expect("encryption should succeed");
285            let mut tampered = ciphertext.clone();
286            let pos = flip_position % tampered.len();
287            tampered[pos] ^= 1 << flip_bit;
288            let result = decrypt(&key, &tampered, TEST_AAD);
289            prop_assert!(result.is_err());
290        }
291
292        #[test]
293        fn ciphertext_expansion(
294            key in prop::array::uniform32(any::<u8>()),
295            plaintext in prop::collection::vec(any::<u8>(), 0..5000)
296        ) {
297            let ciphertext = encrypt(&key, &plaintext, TEST_AAD).expect("encryption should succeed");
298            let expected_len = plaintext.len() + 12 + 16;
299            prop_assert_eq!(ciphertext.len(), expected_len);
300        }
301
302        #[test]
303        fn decrypt_to_vec_matches_regular(
304            key in prop::array::uniform32(any::<u8>()),
305            plaintext in prop::collection::vec(any::<u8>(), 0..5000)
306        ) {
307            let ciphertext = encrypt(&key, &plaintext, TEST_AAD).expect("encryption should succeed");
308            let regular = decrypt(&key, &ciphertext, TEST_AAD).expect("regular decryption should succeed");
309            let via_to_vec = decrypt_to_vec(&key, &ciphertext, TEST_AAD).expect("decrypt_to_vec should succeed");
310            prop_assert_eq!(regular, via_to_vec);
311        }
312
313        #[test]
314        fn wrong_aad_fails(
315            key in prop::array::uniform32(any::<u8>()),
316            plaintext in prop::collection::vec(any::<u8>(), 1..1000)
317        ) {
318            let ciphertext = encrypt(&key, &plaintext, b"void:commit:v1").expect("encryption should succeed");
319            let result = decrypt(&key, &ciphertext, b"void:shard:v1");
320            prop_assert!(result.is_err());
321        }
322    }
323}