Skip to main content

shadowforge_lib/domain/crypto/
mod.rs

1//! ML-KEM-1024, ML-DSA-87, Argon2id, AES-256-GCM, and secure zeroing.
2//!
3//! All functions are pure — no I/O, no file system, no network. Each
4//! function that needs randomness accepts a CSPRNG as a parameter so it
5//! can be exercised with a seeded RNG in tests.
6
7use aes_gcm::aead::Aead;
8use aes_gcm::{Aes256Gcm, Key as AesKey, KeyInit, Nonce};
9use argon2::password_hash::SaltString;
10use argon2::{Argon2, Params, PasswordHasher};
11use bytes::{BufMut, Bytes, BytesMut};
12use ml_dsa::{
13    EncodedSignature, EncodedVerifyingKey, KeyGen, MlDsa87, VerifyingKey, signature::Keypair,
14};
15use ml_kem::{
16    Decapsulate, DecapsulationKey1024, Encapsulate as _, EncapsulationKey1024, Kem as _, Key,
17    KeyExport as _, MlKem1024,
18};
19use rand_core::CryptoRng;
20use zeroize::Zeroize;
21
22use crate::domain::errors::CryptoError;
23use crate::domain::types::{KeyPair, Payload, Signature};
24
25/// Expected byte-length of an ML-KEM-1024 seed (secret key stored form).
26const KEM_SEED_LEN: usize = 64;
27/// Expected byte-length of an ML-KEM-1024 encapsulation (public) key.
28const KEM_EK_LEN: usize = 1568;
29/// Expected byte-length of an ML-DSA-87 seed (secret key stored form).
30const DSA_SEED_LEN: usize = 32;
31/// Expected byte-length of an ML-DSA-87 verifying (public) key.
32const DSA_VK_LEN: usize = 2592;
33/// AES-256 key length in bytes.
34const AES_KEY_LEN: usize = 32;
35/// AES-GCM nonce length in bytes.
36const AES_NONCE_LEN: usize = 12;
37/// Argon2id salt length in bytes.
38const ARGON2_SALT_LEN: usize = 32;
39
40// ─── ML-KEM-1024 (NIST FIPS 203) ─────────────────────────────────────────────
41
42/// Generate an ML-KEM-1024 key pair using the provided CSPRNG.
43///
44/// The returned [`KeyPair`] stores the 64-byte compact seed as `secret_key`
45/// and the 1568-byte encapsulation key as `public_key`. Both fields are
46/// zeroized on drop.
47///
48/// # Errors
49/// Returns [`CryptoError::KeyGenFailed`] if the freshly generated key does
50/// not carry a recoverable seed (should never occur in practice).
51pub fn generate_kem_keypair(rng: &mut impl CryptoRng) -> Result<KeyPair, CryptoError> {
52    let (dk, ek) = MlKem1024::generate_keypair_from_rng(rng);
53    let seed = dk.to_seed().ok_or_else(|| CryptoError::KeyGenFailed {
54        reason: "freshly generated key has no seed".into(),
55    })?;
56    let ek_bytes = ek.to_bytes();
57    Ok(KeyPair {
58        public_key: (ek_bytes.as_ref() as &[u8]).to_vec(),
59        secret_key: (seed.as_ref() as &[u8]).to_vec(),
60    })
61}
62
63/// Encapsulate a shared secret for the holder of `public_key`.
64///
65/// Returns `(ciphertext, shared_secret)` — both as raw bytes.
66/// Ciphertext is 1568 bytes; shared secret is 32 bytes.
67///
68/// # Errors
69/// Returns [`CryptoError::InvalidKeyLength`] if `public_key` is not 1568
70/// bytes, or [`CryptoError::EncapsulationFailed`] if the key bytes are
71/// otherwise invalid.
72pub fn encapsulate_kem(
73    public_key: &[u8],
74    rng: &mut impl CryptoRng,
75) -> Result<(Bytes, Bytes), CryptoError> {
76    if public_key.len() != KEM_EK_LEN {
77        return Err(CryptoError::InvalidKeyLength {
78            expected: KEM_EK_LEN,
79            got: public_key.len(),
80        });
81    }
82    let key_arr: Key<EncapsulationKey1024> =
83        public_key
84            .try_into()
85            .map_err(|_| CryptoError::InvalidKeyLength {
86                expected: KEM_EK_LEN,
87                got: public_key.len(),
88            })?;
89    let ek = EncapsulationKey1024::new(&key_arr).map_err(|_| CryptoError::EncapsulationFailed {
90        reason: "invalid encapsulation key".into(),
91    })?;
92    let (ct, ss) = ek.encapsulate_with_rng(rng);
93    Ok((
94        Bytes::copy_from_slice(ct.as_ref() as &[u8]),
95        Bytes::copy_from_slice(ss.as_ref() as &[u8]),
96    ))
97}
98
99/// Decapsulate a shared secret using `secret_key` (the 64-byte seed) and
100/// `ciphertext`.
101///
102/// ML-KEM uses implicit rejection — an invalid ciphertext yields a
103/// pseudo-random (but different) shared secret rather than an error.
104///
105/// # Errors
106/// Returns [`CryptoError::InvalidKeyLength`] if `secret_key` is not 64
107/// bytes. Returns [`CryptoError::DecapsulationFailed`] if `ciphertext`
108/// has the wrong length.
109pub fn decapsulate_kem(secret_key: &[u8], ciphertext: &[u8]) -> Result<Bytes, CryptoError> {
110    if secret_key.len() != KEM_SEED_LEN {
111        return Err(CryptoError::InvalidKeyLength {
112            expected: KEM_SEED_LEN,
113            got: secret_key.len(),
114        });
115    }
116    let seed: ml_kem::Seed = secret_key
117        .try_into()
118        .map_err(|_| CryptoError::InvalidKeyLength {
119            expected: KEM_SEED_LEN,
120            got: secret_key.len(),
121        })?;
122    let dk = DecapsulationKey1024::from_seed(seed);
123    let ss = dk
124        .decapsulate_slice(ciphertext)
125        .map_err(|_| CryptoError::DecapsulationFailed {
126            reason: format!("ciphertext length {} is invalid", ciphertext.len()),
127        })?;
128    Ok(Bytes::copy_from_slice(ss.as_ref() as &[u8]))
129}
130
131// ─── ML-DSA-87 (NIST FIPS 204) ───────────────────────────────────────────────
132
133/// Generate an ML-DSA-87 key pair using the provided CSPRNG.
134///
135/// The returned [`KeyPair`] stores the 32-byte seed as `secret_key` and
136/// the 2592-byte verifying key as `public_key`.
137///
138/// # Errors
139/// This function currently always succeeds; the `Result` is kept for API
140/// uniformity with [`generate_kem_keypair`].
141pub fn generate_dsa_keypair(rng: &mut impl CryptoRng) -> Result<KeyPair, CryptoError> {
142    let signing_key = MlDsa87::key_gen(rng);
143    let mut seed = signing_key.to_seed();
144    let vk_encoded: EncodedVerifyingKey<MlDsa87> = signing_key.verifying_key().encode();
145    let public_key = (vk_encoded.as_ref() as &[u8]).to_vec();
146    let secret_key = (seed.as_ref() as &[u8]).to_vec();
147    seed.zeroize();
148    Ok(KeyPair {
149        public_key,
150        secret_key,
151    })
152}
153
154/// Sign `message` with the ML-DSA-87 secret key (32-byte seed).
155///
156/// Signing is deterministic — no per-call randomness required.
157///
158/// # Errors
159/// Returns [`CryptoError::InvalidKeyLength`] if `secret_key` is not 32
160/// bytes. Returns [`CryptoError::SigningFailed`] if the deterministic
161/// signing operation fails.
162pub fn sign_dsa(secret_key: &[u8], message: &[u8]) -> Result<Signature, CryptoError> {
163    if secret_key.len() != DSA_SEED_LEN {
164        return Err(CryptoError::InvalidKeyLength {
165            expected: DSA_SEED_LEN,
166            got: secret_key.len(),
167        });
168    }
169    let mut seed_arr: ml_dsa::B32 =
170        secret_key
171            .try_into()
172            .map_err(|_| CryptoError::InvalidKeyLength {
173                expected: DSA_SEED_LEN,
174                got: secret_key.len(),
175            })?;
176    let signing_key = MlDsa87::from_seed(&seed_arr);
177    seed_arr.zeroize();
178    let ml_sig = signing_key
179        .signing_key()
180        .sign_deterministic(message, b"")
181        .map_err(|e| CryptoError::SigningFailed {
182            reason: e.to_string(),
183        })?;
184    let encoded: EncodedSignature<MlDsa87> = ml_sig.encode();
185    Ok(Signature(Bytes::copy_from_slice(encoded.as_ref())))
186}
187
188/// Verify that `sig` is a valid ML-DSA-87 signature over `message` by
189/// `public_key`.
190///
191/// Returns `Ok(true)` for a valid signature, `Ok(false)` for an invalid one.
192///
193/// # Errors
194/// Returns [`CryptoError::InvalidKeyLength`] if `public_key` is not 2592
195/// bytes. Returns [`CryptoError::VerificationFailed`] if the signature
196/// bytes are malformed.
197pub fn verify_dsa(public_key: &[u8], message: &[u8], sig: &Signature) -> Result<bool, CryptoError> {
198    if public_key.len() != DSA_VK_LEN {
199        return Err(CryptoError::InvalidKeyLength {
200            expected: DSA_VK_LEN,
201            got: public_key.len(),
202        });
203    }
204    let enc_vk: EncodedVerifyingKey<MlDsa87> =
205        public_key
206            .try_into()
207            .map_err(|_| CryptoError::InvalidKeyLength {
208                expected: DSA_VK_LEN,
209                got: public_key.len(),
210            })?;
211    let vk = VerifyingKey::<MlDsa87>::decode(&enc_vk);
212
213    let enc_sig: EncodedSignature<MlDsa87> =
214        sig.0
215            .as_ref()
216            .try_into()
217            .map_err(|_| CryptoError::VerificationFailed {
218                reason: "invalid signature length".into(),
219            })?;
220    let ml_sig = ml_dsa::Signature::<MlDsa87>::decode(&enc_sig).ok_or_else(|| {
221        CryptoError::VerificationFailed {
222            reason: "malformed signature bytes".into(),
223        }
224    })?;
225
226    Ok(vk.verify_with_context(message, b"", &ml_sig))
227}
228
229// ─── AES-256-GCM Symmetric Encryption ────────────────────────────────────────
230
231/// Encrypt `plaintext` with AES-256-GCM using `key` and `nonce`.
232///
233/// Returns the ciphertext with authentication tag appended.
234///
235/// # Errors
236/// Returns [`CryptoError::InvalidKeyLength`] if `key` is not 32 bytes,
237/// [`CryptoError::InvalidNonceLength`] if `nonce` is not 12 bytes, or
238/// [`CryptoError::EncryptionFailed`] if encryption fails.
239pub fn encrypt_aes_gcm(key: &[u8], nonce: &[u8], plaintext: &[u8]) -> Result<Bytes, CryptoError> {
240    if key.len() != AES_KEY_LEN {
241        return Err(CryptoError::InvalidKeyLength {
242            expected: AES_KEY_LEN,
243            got: key.len(),
244        });
245    }
246    if nonce.len() != AES_NONCE_LEN {
247        return Err(CryptoError::InvalidNonceLength {
248            expected: AES_NONCE_LEN,
249            got: nonce.len(),
250        });
251    }
252
253    let aes_key = AesKey::<Aes256Gcm>::from_slice(key);
254    let cipher = Aes256Gcm::new(aes_key);
255    let aes_nonce = Nonce::from_slice(nonce);
256
257    let ciphertext =
258        cipher
259            .encrypt(aes_nonce, plaintext)
260            .map_err(|e| CryptoError::EncryptionFailed {
261                reason: e.to_string(),
262            })?;
263
264    Ok(Bytes::from(ciphertext))
265}
266
267/// Decrypt and authenticate `ciphertext` with AES-256-GCM using `key` and `nonce`.
268///
269/// # Errors
270/// Returns [`CryptoError::InvalidKeyLength`] if `key` is not 32 bytes,
271/// [`CryptoError::InvalidNonceLength`] if `nonce` is not 12 bytes, or
272/// [`CryptoError::DecryptionFailed`] if decryption or authentication fails.
273pub fn decrypt_aes_gcm(key: &[u8], nonce: &[u8], ciphertext: &[u8]) -> Result<Bytes, CryptoError> {
274    if key.len() != AES_KEY_LEN {
275        return Err(CryptoError::InvalidKeyLength {
276            expected: AES_KEY_LEN,
277            got: key.len(),
278        });
279    }
280    if nonce.len() != AES_NONCE_LEN {
281        return Err(CryptoError::InvalidNonceLength {
282            expected: AES_NONCE_LEN,
283            got: nonce.len(),
284        });
285    }
286
287    let aes_key = AesKey::<Aes256Gcm>::from_slice(key);
288    let cipher = Aes256Gcm::new(aes_key);
289    let aes_nonce = Nonce::from_slice(nonce);
290
291    let plaintext =
292        cipher
293            .decrypt(aes_nonce, ciphertext)
294            .map_err(|e| CryptoError::DecryptionFailed {
295                reason: e.to_string(),
296            })?;
297
298    Ok(Bytes::from(plaintext))
299}
300
301// ─── Argon2id Key Derivation ─────────────────────────────────────────────────
302
303/// Derive a key from `password` and `salt` using Argon2id.
304///
305/// Uses Argon2id with `time_cost=3`, `mem_cost=65536` KiB (64 MiB), `parallelism=4`.
306///
307/// # Errors
308/// Returns [`CryptoError::KdfFailed`] if key derivation fails or if
309/// `output_len` is invalid.
310pub fn derive_key(password: &[u8], salt: &[u8], output_len: usize) -> Result<Bytes, CryptoError> {
311    if salt.len() != ARGON2_SALT_LEN {
312        return Err(CryptoError::KdfFailed {
313            reason: format!("salt must be {} bytes, got {}", ARGON2_SALT_LEN, salt.len()),
314        });
315    }
316
317    // Argon2id parameters: time_cost=3, mem_cost=65536 KiB, parallelism=4
318    let params =
319        Params::new(65536, 3, 4, Some(output_len)).map_err(|e| CryptoError::KdfFailed {
320            reason: e.to_string(),
321        })?;
322
323    let argon2 = Argon2::new(argon2::Algorithm::Argon2id, argon2::Version::V0x13, params);
324
325    // Create a SaltString from the provided salt
326    let salt_str = SaltString::encode_b64(salt).map_err(|e| CryptoError::KdfFailed {
327        reason: e.to_string(),
328    })?;
329
330    let hash = argon2
331        .hash_password(password, &salt_str)
332        .map_err(|e| CryptoError::KdfFailed {
333            reason: e.to_string(),
334        })?;
335
336    // Extract the derived key from the hash
337    let hash_output = hash.hash.ok_or_else(|| CryptoError::KdfFailed {
338        reason: "no hash output".into(),
339    })?;
340
341    Ok(Bytes::copy_from_slice(hash_output.as_bytes()))
342}
343
344// ─── Full Encryption Pipeline ────────────────────────────────────────────────
345
346/// Encrypt a payload using the full hybrid cryptosystem pipeline.
347///
348/// Pipeline: ML-KEM-1024 key encapsulation → derive AES-256-GCM key from
349/// shared secret → encrypt payload → sign (KEM ciphertext || AES ciphertext)
350/// with ML-DSA-87.
351///
352/// Output format (all length-prefixed as u32 big-endian):
353/// ```text
354/// [kem_ct_len][kem_ct][nonce][sym_ct_len][sym_ct][sig_len][sig]
355/// ```
356///
357/// # Errors
358/// Returns [`CryptoError`] variants for any cryptographic operation failure.
359pub fn encrypt_payload(
360    kem_public_key: &[u8],
361    dsa_secret_key: &[u8],
362    payload: &Payload,
363    rng: &mut impl CryptoRng,
364) -> Result<Bytes, CryptoError> {
365    // 1. KEM encapsulation
366    let (kem_ct, shared_secret) = encapsulate_kem(kem_public_key, rng)?;
367
368    // 2. Derive AES key from shared secret
369    let mut salt = vec![0u8; ARGON2_SALT_LEN];
370    rng.fill_bytes(&mut salt);
371    let aes_key_bytes = derive_key(shared_secret.as_ref(), &salt, AES_KEY_LEN)?;
372
373    // 3. Generate nonce
374    let mut nonce = vec![0u8; AES_NONCE_LEN];
375    rng.fill_bytes(&mut nonce);
376
377    // 4. Encrypt payload
378    let sym_ct = encrypt_aes_gcm(&aes_key_bytes, &nonce, payload.as_bytes())?;
379
380    // 5. Sign (kem_ct || salt || nonce || sym_ct)
381    let mut message_to_sign = BytesMut::new();
382    message_to_sign.put(kem_ct.as_ref());
383    message_to_sign.put_slice(&salt);
384    message_to_sign.put_slice(&nonce);
385    message_to_sign.put(sym_ct.as_ref());
386
387    let signature = sign_dsa(dsa_secret_key, &message_to_sign)?;
388
389    // 6. Build final output: [kem_ct_len][kem_ct][salt][nonce][sym_ct_len][sym_ct][sig_len][sig]
390    let mut output = BytesMut::new();
391    #[expect(
392        clippy::cast_possible_truncation,
393        reason = "ML-KEM-1024 ciphertext is 1568 bytes"
394    )]
395    output.put_u32(kem_ct.len() as u32);
396    output.put(kem_ct);
397    output.put_slice(&salt);
398    output.put_slice(&nonce);
399    #[expect(
400        clippy::cast_possible_truncation,
401        reason = "payload sizes are bounded by protocol"
402    )]
403    output.put_u32(sym_ct.len() as u32);
404    output.put(sym_ct);
405    #[expect(
406        clippy::cast_possible_truncation,
407        reason = "ML-DSA-87 signature is 4595 bytes"
408    )]
409    output.put_u32(signature.0.len() as u32);
410    output.put(signature.0);
411
412    Ok(output.freeze())
413}
414
415/// Decrypt a payload using the full hybrid cryptosystem pipeline.
416///
417/// Reverses [`encrypt_payload`]: verify signature → KEM decapsulation →
418/// derive AES-256-GCM key → decrypt payload.
419///
420/// # Errors
421/// Returns [`CryptoError`] variants for any cryptographic operation failure,
422/// including signature verification failure.
423pub fn decrypt_payload(
424    kem_secret_key: &[u8],
425    dsa_public_key: &[u8],
426    encrypted: &[u8],
427) -> Result<Payload, CryptoError> {
428    let mut cursor = encrypted;
429    let truncated = |field: &str| CryptoError::DecryptionFailed {
430        reason: format!("truncated {field}"),
431    };
432
433    // 1. Parse KEM ciphertext
434    let kem_ct_len = {
435        let b = cursor.get(..4).ok_or_else(|| truncated("kem_ct_len"))?;
436        let arr = <[u8; 4]>::try_from(b).map_err(|_| truncated("kem_ct_len"))?;
437        cursor = cursor.get(4..).ok_or_else(|| truncated("kem_ct_len"))?;
438        u32::from_be_bytes(arr) as usize
439    };
440    let kem_ct = cursor
441        .get(..kem_ct_len)
442        .ok_or_else(|| truncated("kem_ct"))?;
443    cursor = cursor
444        .get(kem_ct_len..)
445        .ok_or_else(|| truncated("kem_ct"))?;
446
447    // 2. Parse salt
448    let salt = cursor
449        .get(..ARGON2_SALT_LEN)
450        .ok_or_else(|| truncated("salt"))?;
451    cursor = cursor
452        .get(ARGON2_SALT_LEN..)
453        .ok_or_else(|| truncated("salt"))?;
454
455    // 3. Parse nonce
456    let nonce = cursor
457        .get(..AES_NONCE_LEN)
458        .ok_or_else(|| truncated("nonce"))?;
459    cursor = cursor
460        .get(AES_NONCE_LEN..)
461        .ok_or_else(|| truncated("nonce"))?;
462
463    // 4. Parse symmetric ciphertext
464    let sym_ct_len = {
465        let b = cursor.get(..4).ok_or_else(|| truncated("sym_ct_len"))?;
466        let arr = <[u8; 4]>::try_from(b).map_err(|_| truncated("sym_ct_len"))?;
467        cursor = cursor.get(4..).ok_or_else(|| truncated("sym_ct_len"))?;
468        u32::from_be_bytes(arr) as usize
469    };
470    let sym_ct = cursor
471        .get(..sym_ct_len)
472        .ok_or_else(|| truncated("sym_ct"))?;
473    cursor = cursor
474        .get(sym_ct_len..)
475        .ok_or_else(|| truncated("sym_ct"))?;
476
477    // 5. Parse signature
478    let sig_len = {
479        let b = cursor.get(..4).ok_or_else(|| truncated("sig_len"))?;
480        let arr = <[u8; 4]>::try_from(b).map_err(|_| truncated("sig_len"))?;
481        cursor = cursor.get(4..).ok_or_else(|| truncated("sig_len"))?;
482        u32::from_be_bytes(arr) as usize
483    };
484    let sig_bytes = cursor.get(..sig_len).ok_or_else(|| truncated("sig"))?;
485    let signature = Signature(Bytes::copy_from_slice(sig_bytes));
486
487    // 6. Verify signature over (kem_ct || salt || nonce || sym_ct)
488    let mut message_to_verify = BytesMut::new();
489    message_to_verify.put_slice(kem_ct);
490    message_to_verify.put_slice(salt);
491    message_to_verify.put_slice(nonce);
492    message_to_verify.put_slice(sym_ct);
493
494    let sig_valid = verify_dsa(dsa_public_key, &message_to_verify, &signature)?;
495    if !sig_valid {
496        return Err(CryptoError::DecryptionFailed {
497            reason: "signature verification failed".into(),
498        });
499    }
500
501    // 7. KEM decapsulation
502    let shared_secret = decapsulate_kem(kem_secret_key, kem_ct)?;
503
504    // 8. Derive AES key from shared secret
505    let aes_key_bytes = derive_key(shared_secret.as_ref(), salt, AES_KEY_LEN)?;
506
507    // 9. Decrypt payload
508    let plaintext = decrypt_aes_gcm(&aes_key_bytes, nonce, sym_ct)?;
509
510    Ok(Payload::from_bytes(plaintext.to_vec()))
511}
512
513// ─── Tests ────────────────────────────────────────────────────────────────────
514
515#[cfg(test)]
516mod tests {
517    use bytes::Bytes;
518    use rand_chacha::ChaCha20Rng;
519    use rand_core::SeedableRng;
520    use subtle::ConstantTimeEq;
521
522    use super::*;
523
524    type TestResult = Result<(), Box<dyn std::error::Error>>;
525
526    fn rng() -> ChaCha20Rng {
527        ChaCha20Rng::from_rng(&mut rand::rng())
528    }
529
530    // ─── ML-KEM ───────────────────────────────────────────────────────────────
531
532    /// KEM round-trip: shared secrets from encapsulate and decapsulate must match.
533    #[test]
534    fn test_kem_roundtrip() -> TestResult {
535        let kp = generate_kem_keypair(&mut rng())?;
536        let (ct, ss_send) = encapsulate_kem(&kp.public_key, &mut rng())?;
537        let ss_recv = decapsulate_kem(&kp.secret_key, &ct)?;
538
539        let eq = ss_send.as_ref().ct_eq(ss_recv.as_ref()).unwrap_u8();
540        assert_eq!(eq, 1u8, "shared secrets must match");
541        Ok(())
542    }
543
544    /// Invalid ciphertext must not produce the correct shared secret
545    /// (ML-KEM uses implicit rejection — not an error, but a different key).
546    #[test]
547    fn test_kem_wrong_ciphertext_differs() -> TestResult {
548        let kp = generate_kem_keypair(&mut rng())?;
549        let (ct, ss_good) = encapsulate_kem(&kp.public_key, &mut rng())?;
550        // Flip first byte to corrupt ciphertext
551        let mut ct_vec = ct.to_vec();
552        let first = ct_vec.first_mut().ok_or("empty ciphertext")?;
553        *first ^= 0xFF;
554        let ss_bad = decapsulate_kem(&kp.secret_key, &ct_vec)?;
555
556        let eq = ss_good.as_ref().ct_eq(ss_bad.as_ref()).unwrap_u8();
557        assert_eq!(
558            eq, 0u8,
559            "corrupted ciphertext must yield a different shared secret"
560        );
561        Ok(())
562    }
563
564    /// Wrong public key length must return `InvalidKeyLength`.
565    #[test]
566    fn test_kem_bad_pubkey_length() {
567        let result = encapsulate_kem(&[0u8; 42], &mut rng());
568        assert!(matches!(result, Err(CryptoError::InvalidKeyLength { .. })));
569    }
570
571    /// Wrong secret key length must return `InvalidKeyLength`.
572    #[test]
573    fn test_kem_bad_seckey_length() {
574        let ct = Bytes::from(vec![0u8; 1568]);
575        let result = decapsulate_kem(&[0u8; 42], &ct);
576        assert!(matches!(result, Err(CryptoError::InvalidKeyLength { .. })));
577    }
578
579    /// Key pair byte sizes must match ML-KEM-1024 FIPS 203 specification.
580    #[test]
581    fn test_kem_keypair_sizes() -> TestResult {
582        let kp = generate_kem_keypair(&mut rng())?;
583        assert_eq!(
584            kp.secret_key.len(),
585            KEM_SEED_LEN,
586            "KEM seed must be 64 bytes"
587        );
588        assert_eq!(
589            kp.public_key.len(),
590            KEM_EK_LEN,
591            "KEM enc key must be 1568 bytes"
592        );
593        Ok(())
594    }
595
596    // ─── ML-DSA ───────────────────────────────────────────────────────────────
597
598    /// DSA round-trip: sign then verify must return `true`.
599    #[test]
600    fn test_dsa_roundtrip() -> TestResult {
601        let kp = generate_dsa_keypair(&mut rng())?;
602        let msg = b"the quick brown fox jumps over the lazy dog";
603        let sig = sign_dsa(&kp.secret_key, msg)?;
604        let ok = verify_dsa(&kp.public_key, msg, &sig)?;
605        assert!(ok, "valid signature must verify");
606        Ok(())
607    }
608
609    /// Tampered signature must not verify.
610    #[test]
611    fn test_dsa_tamper() -> TestResult {
612        let kp = generate_dsa_keypair(&mut rng())?;
613        let msg = b"the quick brown fox jumps over the lazy dog";
614        let sig = sign_dsa(&kp.secret_key, msg)?;
615        let mut sig_bytes = sig.0.to_vec();
616        let first = sig_bytes.first_mut().ok_or("empty signature")?;
617        *first ^= 0xFF;
618        let tampered_sig = Signature(Bytes::from(sig_bytes));
619        let result = verify_dsa(&kp.public_key, msg, &tampered_sig);
620        assert!(
621            matches!(result, Ok(false) | Err(_)),
622            "tampered signature must not verify"
623        );
624        Ok(())
625    }
626
627    /// Signature under a different public key must not verify.
628    #[test]
629    fn test_dsa_wrong_key() -> TestResult {
630        let kp1 = generate_dsa_keypair(&mut rng())?;
631        let kp2 = generate_dsa_keypair(&mut rng())?;
632        let msg = b"the quick brown fox jumps over the lazy dog";
633        let sig = sign_dsa(&kp1.secret_key, msg)?;
634        let result = verify_dsa(&kp2.public_key, msg, &sig);
635        assert!(
636            matches!(result, Ok(false) | Err(_)),
637            "sig must not verify under a different key"
638        );
639        Ok(())
640    }
641
642    /// Key pair byte sizes must match ML-DSA-87 FIPS 204 specification.
643    #[test]
644    fn test_dsa_keypair_sizes() -> TestResult {
645        let kp = generate_dsa_keypair(&mut rng())?;
646        assert_eq!(
647            kp.secret_key.len(),
648            DSA_SEED_LEN,
649            "DSA seed must be 32 bytes"
650        );
651        assert_eq!(
652            kp.public_key.len(),
653            DSA_VK_LEN,
654            "DSA verifying key must be 2592 bytes"
655        );
656        Ok(())
657    }
658
659    // ─── AES-256-GCM ──────────────────────────────────────────────────────────
660
661    /// AES-256-GCM round-trip: encrypt then decrypt yields original plaintext.
662    #[test]
663    fn test_aes_roundtrip() -> TestResult {
664        let key = vec![0u8; AES_KEY_LEN];
665        let nonce = vec![1u8; AES_NONCE_LEN];
666        let plaintext = b"the quick brown fox";
667        let ciphertext = encrypt_aes_gcm(&key, &nonce, plaintext)?;
668        let recovered = decrypt_aes_gcm(&key, &nonce, &ciphertext)?;
669        assert_eq!(recovered.as_ref(), plaintext);
670        Ok(())
671    }
672
673    /// Tampered ciphertext must fail authentication.
674    #[test]
675    fn test_aes_tamper() -> TestResult {
676        let key = vec![0u8; AES_KEY_LEN];
677        let nonce = vec![1u8; AES_NONCE_LEN];
678        let plaintext = b"the quick brown fox";
679        let mut ciphertext = encrypt_aes_gcm(&key, &nonce, plaintext)?.to_vec();
680        let first = ciphertext.first_mut().ok_or("empty ciphertext")?;
681        *first ^= 0xFF;
682        let result = decrypt_aes_gcm(&key, &nonce, &ciphertext);
683        assert!(result.is_err(), "tampered ciphertext must fail to decrypt");
684        Ok(())
685    }
686
687    /// Wrong key length must return `InvalidKeyLength`.
688    #[test]
689    fn test_aes_bad_key_length() {
690        let key = vec![0u8; 16]; // Wrong length
691        let nonce = vec![1u8; AES_NONCE_LEN];
692        let plaintext = b"test";
693        let result = encrypt_aes_gcm(&key, &nonce, plaintext);
694        assert!(matches!(result, Err(CryptoError::InvalidKeyLength { .. })));
695    }
696
697    /// Wrong nonce length must return `InvalidNonceLength`.
698    #[test]
699    fn test_aes_bad_nonce_length() {
700        let key = vec![0u8; AES_KEY_LEN];
701        let nonce = vec![1u8; 8]; // Wrong length
702        let plaintext = b"test";
703        let result = encrypt_aes_gcm(&key, &nonce, plaintext);
704        assert!(matches!(
705            result,
706            Err(CryptoError::InvalidNonceLength { .. })
707        ));
708    }
709
710    // ─── Argon2id KDF ─────────────────────────────────────────────────────────
711
712    /// Argon2id must produce deterministic output for same password + salt.
713    #[test]
714    fn test_kdf_deterministic() -> TestResult {
715        let password = b"password123";
716        let salt = vec![0u8; ARGON2_SALT_LEN];
717        let key1 = derive_key(password, &salt, AES_KEY_LEN)?;
718        let key2 = derive_key(password, &salt, AES_KEY_LEN)?;
719        assert_eq!(key1.as_ref(), key2.as_ref(), "KDF must be deterministic");
720        Ok(())
721    }
722
723    /// Argon2id must produce different output for different passwords.
724    #[test]
725    fn test_kdf_different_passwords() -> TestResult {
726        let salt = vec![0u8; ARGON2_SALT_LEN];
727        let key1 = derive_key(b"password1", &salt, AES_KEY_LEN)?;
728        let key2 = derive_key(b"password2", &salt, AES_KEY_LEN)?;
729        assert_ne!(
730            key1.as_ref(),
731            key2.as_ref(),
732            "different passwords must yield different keys"
733        );
734        Ok(())
735    }
736
737    /// Argon2id must produce different output for different salts.
738    #[test]
739    fn test_kdf_different_salts() -> TestResult {
740        let password = b"password123";
741        let salt1 = vec![0u8; ARGON2_SALT_LEN];
742        let salt2 = vec![1u8; ARGON2_SALT_LEN];
743        let key1 = derive_key(password, &salt1, AES_KEY_LEN)?;
744        let key2 = derive_key(password, &salt2, AES_KEY_LEN)?;
745        assert_ne!(
746            key1.as_ref(),
747            key2.as_ref(),
748            "different salts must yield different keys"
749        );
750        Ok(())
751    }
752
753    // ─── Full Pipeline ────────────────────────────────────────────────────────
754
755    /// Full pipeline round-trip: encrypt then decrypt yields original payload.
756    #[test]
757    fn test_pipeline_roundtrip() -> TestResult {
758        let kem_kp = generate_kem_keypair(&mut rng())?;
759        let dsa_kp = generate_dsa_keypair(&mut rng())?;
760        let payload = crate::domain::types::Payload::from_bytes(b"secret message".to_vec());
761
762        let encrypted =
763            encrypt_payload(&kem_kp.public_key, &dsa_kp.secret_key, &payload, &mut rng())?;
764
765        let recovered = decrypt_payload(&kem_kp.secret_key, &dsa_kp.public_key, &encrypted)?;
766
767        assert_eq!(recovered.as_bytes(), payload.as_bytes());
768        Ok(())
769    }
770
771    /// Full pipeline with tampered ciphertext must fail signature verification.
772    #[test]
773    fn test_pipeline_tamper() -> TestResult {
774        let kem_kp = generate_kem_keypair(&mut rng())?;
775        let dsa_kp = generate_dsa_keypair(&mut rng())?;
776        let payload = crate::domain::types::Payload::from_bytes(b"secret message".to_vec());
777
778        let mut encrypted =
779            encrypt_payload(&kem_kp.public_key, &dsa_kp.secret_key, &payload, &mut rng())?.to_vec();
780
781        // Tamper with a byte in the middle
782        let mid = encrypted.len() / 2;
783        let byte = encrypted.get_mut(mid).ok_or("empty encrypted data")?;
784        *byte ^= 0xFF;
785
786        let result = decrypt_payload(&kem_kp.secret_key, &dsa_kp.public_key, &encrypted);
787        assert!(result.is_err(), "tampered payload must fail to decrypt");
788        Ok(())
789    }
790
791    /// Full pipeline with wrong DSA key must fail signature verification.
792    #[test]
793    fn test_pipeline_wrong_dsa_key() -> TestResult {
794        let kem_kp = generate_kem_keypair(&mut rng())?;
795        let dsa_kp1 = generate_dsa_keypair(&mut rng())?;
796        let dsa_kp2 = generate_dsa_keypair(&mut rng())?;
797        let payload = crate::domain::types::Payload::from_bytes(b"secret message".to_vec());
798
799        let encrypted = encrypt_payload(
800            &kem_kp.public_key,
801            &dsa_kp1.secret_key,
802            &payload,
803            &mut rng(),
804        )?;
805
806        let result = decrypt_payload(&kem_kp.secret_key, &dsa_kp2.public_key, &encrypted);
807        assert!(result.is_err(), "wrong DSA key must fail verification");
808        Ok(())
809    }
810
811    // ─── Additional edge-case coverage ────────────────────────────────────
812
813    /// KDF with wrong salt length must return `KdfFailed`.
814    #[test]
815    fn test_kdf_bad_salt_length() {
816        let result = derive_key(b"password", &[0u8; 16], AES_KEY_LEN);
817        assert!(matches!(result, Err(CryptoError::KdfFailed { .. })));
818    }
819
820    /// DSA sign with wrong secret key length must return `InvalidKeyLength`.
821    #[test]
822    fn test_dsa_sign_bad_key_length() {
823        let result = sign_dsa(&[0u8; 16], b"message");
824        assert!(matches!(result, Err(CryptoError::InvalidKeyLength { .. })));
825    }
826
827    /// DSA verify with wrong public key length must return `InvalidKeyLength`.
828    #[test]
829    fn test_dsa_verify_bad_pubkey_length() {
830        let sig = Signature(Bytes::from(vec![0u8; 64]));
831        let result = verify_dsa(&[0u8; 16], b"message", &sig);
832        assert!(matches!(result, Err(CryptoError::InvalidKeyLength { .. })));
833    }
834
835    /// KEM decapsulate with invalid ciphertext length must return `DecapsulationFailed`.
836    #[test]
837    fn test_kem_bad_ciphertext_length() -> TestResult {
838        let kp = generate_kem_keypair(&mut rng())?;
839        let result = decapsulate_kem(&kp.secret_key, &[0u8; 42]);
840        assert!(matches!(
841            result,
842            Err(CryptoError::DecapsulationFailed { .. })
843        ));
844        Ok(())
845    }
846
847    /// AES decrypt with bad key length must return `InvalidKeyLength`.
848    #[test]
849    fn test_aes_decrypt_bad_key_length() {
850        let result = decrypt_aes_gcm(&[0u8; 16], &[0u8; AES_NONCE_LEN], &[0u8; 32]);
851        assert!(matches!(result, Err(CryptoError::InvalidKeyLength { .. })));
852    }
853
854    /// AES decrypt with bad nonce length must return `InvalidNonceLength`.
855    #[test]
856    fn test_aes_decrypt_bad_nonce_length() {
857        let result = decrypt_aes_gcm(&[0u8; AES_KEY_LEN], &[0u8; 8], &[0u8; 32]);
858        assert!(matches!(
859            result,
860            Err(CryptoError::InvalidNonceLength { .. })
861        ));
862    }
863
864    /// Decrypt pipeline with truncated input must return `DecryptionFailed`.
865    #[test]
866    fn test_decrypt_pipeline_truncated_empty() {
867        let result = decrypt_payload(&[0u8; KEM_SEED_LEN], &[0u8; DSA_VK_LEN], &[]);
868        assert!(matches!(result, Err(CryptoError::DecryptionFailed { .. })));
869    }
870
871    /// Decrypt pipeline with truncated input after `kem_ct_len` must return `DecryptionFailed`.
872    #[test]
873    fn test_decrypt_pipeline_truncated_after_header() {
874        let result = decrypt_payload(&[0u8; KEM_SEED_LEN], &[0u8; DSA_VK_LEN], &[0u8; 8]);
875        assert!(matches!(result, Err(CryptoError::DecryptionFailed { .. })));
876    }
877
878    /// DSA verify with malformed signature (wrong length) must return `VerificationFailed`.
879    #[test]
880    fn test_dsa_verify_bad_sig_length() -> TestResult {
881        let kp = generate_dsa_keypair(&mut rng())?;
882        let bad_sig = Signature(Bytes::from(vec![0u8; 10])); // Too short
883        let result = verify_dsa(&kp.public_key, b"message", &bad_sig);
884        assert!(
885            matches!(result, Err(CryptoError::VerificationFailed { .. })),
886            "expected VerificationFailed, got {result:?}"
887        );
888        Ok(())
889    }
890
891    /// Empty plaintext encrypts and decrypts correctly.
892    #[test]
893    fn test_aes_empty_plaintext() -> TestResult {
894        let key = vec![0u8; AES_KEY_LEN];
895        let nonce = vec![1u8; AES_NONCE_LEN];
896        let ciphertext = encrypt_aes_gcm(&key, &nonce, &[])?;
897        let recovered = decrypt_aes_gcm(&key, &nonce, &ciphertext)?;
898        assert!(recovered.is_empty());
899        Ok(())
900    }
901
902    /// Full pipeline with empty payload.
903    #[test]
904    fn test_pipeline_empty_payload() -> TestResult {
905        let kem_kp = generate_kem_keypair(&mut rng())?;
906        let dsa_kp = generate_dsa_keypair(&mut rng())?;
907        let payload = crate::domain::types::Payload::from_bytes(Vec::new());
908
909        let encrypted =
910            encrypt_payload(&kem_kp.public_key, &dsa_kp.secret_key, &payload, &mut rng())?;
911        let recovered = decrypt_payload(&kem_kp.secret_key, &dsa_kp.public_key, &encrypted)?;
912        assert!(recovered.is_empty());
913        Ok(())
914    }
915}