1use crate::pqc::constant_time::ct_eq;
27use crate::pqc::types::{
28 MlKemCiphertext, MlKemPublicKey, MlKemSecretKey, PqcError, PqcResult, SharedSecret,
29};
30use crate::pqc::{ml_kem::MlKem768, MlKemOperations};
31use aes_gcm::{
32 aead::{Aead, KeyInit},
33 Aes256Gcm, Key, Nonce as AesNonce,
34};
35use hkdf::Hkdf;
36use rand::{thread_rng, RngCore};
37use sha2::{Digest, Sha256};
38use std::collections::HashMap;
39
40#[derive(Debug, Clone)]
48pub struct EncryptedMessage {
49 pub kem_ciphertext: MlKemCiphertext,
51 pub aes_ciphertext: Vec<u8>,
53 pub nonce: [u8; 12],
55 pub aad_hash: [u8; 32],
57}
58
59pub struct HybridPublicKeyEncryption {
66 ml_kem: MlKem768,
68}
69
70impl HybridPublicKeyEncryption {
71 #[must_use]
73 pub const fn new() -> Self {
74 Self {
75 ml_kem: MlKem768::new(),
76 }
77 }
78
79 pub fn generate_keypair(&self) -> PqcResult<(MlKemPublicKey, MlKemSecretKey)> {
84 self.ml_kem.generate_keypair()
85 }
86
87 pub fn encrypt(
100 &self,
101 public_key: &MlKemPublicKey,
102 plaintext: &[u8],
103 associated_data: &[u8],
104 ) -> PqcResult<EncryptedMessage> {
105 let (kem_ciphertext, shared_secret) = self.ml_kem.encapsulate(public_key)?;
107
108 let hk = Hkdf::<Sha256>::new(None, shared_secret.as_bytes());
110 let mut aes_key_bytes = [0u8; 32];
111 hk.expand(b"aes-256-gcm-key", &mut aes_key_bytes)
112 .map_err(|_| PqcError::CryptoError("HKDF expansion failed".to_string()))?;
113
114 let mut nonce = [0u8; 12];
116 thread_rng().fill_bytes(&mut nonce);
117
118 let key = Key::<Aes256Gcm>::from_slice(&aes_key_bytes);
120 let cipher = Aes256Gcm::new(key);
121 let nonce_obj = AesNonce::from_slice(&nonce);
122
123 let aes_ciphertext = cipher
124 .encrypt(nonce_obj, plaintext)
125 .map_err(|_| PqcError::EncryptionFailed("AES-GCM encryption failed".to_string()))?;
126
127 let mut hasher = Sha256::new();
129 hasher.update(associated_data);
130 let aad_hash: [u8; 32] = hasher.finalize().into();
131
132 Ok(EncryptedMessage {
133 kem_ciphertext,
134 aes_ciphertext,
135 nonce,
136 aad_hash,
137 })
138 }
139
140 pub fn decrypt(
154 &self,
155 secret_key: &MlKemSecretKey,
156 encrypted_message: &EncryptedMessage,
157 associated_data: &[u8],
158 ) -> PqcResult<Vec<u8>> {
159 let mut hasher = Sha256::new();
161 hasher.update(associated_data);
162 let computed_hash: [u8; 32] = hasher.finalize().into();
163
164 if !ct_eq(&computed_hash, &encrypted_message.aad_hash) {
165 return Err(PqcError::DecryptionFailed(
166 "Associated data verification failed".to_string(),
167 ));
168 }
169
170 let shared_secret = self
172 .ml_kem
173 .decapsulate(secret_key, &encrypted_message.kem_ciphertext)?;
174
175 let hk = Hkdf::<Sha256>::new(None, shared_secret.as_bytes());
177 let mut aes_key_bytes = [0u8; 32];
178 hk.expand(b"aes-256-gcm-key", &mut aes_key_bytes)
179 .map_err(|_| PqcError::CryptoError("HKDF expansion failed".to_string()))?;
180
181 let key = Key::<Aes256Gcm>::from_slice(&aes_key_bytes);
183 let cipher = Aes256Gcm::new(key);
184 let nonce_obj = AesNonce::from_slice(&encrypted_message.nonce);
185
186 let plaintext = cipher
187 .decrypt(nonce_obj, encrypted_message.aes_ciphertext.as_slice())
188 .map_err(|_| PqcError::DecryptionFailed("AES-GCM decryption failed".to_string()))?;
189
190 Ok(plaintext)
191 }
192}
193
194pub struct EncryptionSession {
199 shared_secret: SharedSecret,
201 message_counter: u64,
203}
204
205impl EncryptionSession {
206 pub fn new(public_key: &MlKemPublicKey) -> PqcResult<(Self, MlKemCiphertext)> {
217 let ml_kem = MlKem768::new();
218 let (kem_ciphertext, shared_secret) = ml_kem.encapsulate(public_key)?;
219
220 Ok((
221 Self {
222 shared_secret,
223 message_counter: 0,
224 },
225 kem_ciphertext,
226 ))
227 }
228
229 pub fn encrypt_message(&mut self, plaintext: &[u8]) -> PqcResult<Vec<u8>> {
237 let mut key_material = Vec::new();
239 key_material.extend_from_slice(self.shared_secret.as_bytes());
240 key_material.extend_from_slice(&self.message_counter.to_be_bytes());
241
242 let hk = Hkdf::<Sha256>::new(None, &key_material);
243 let mut aes_key = [0u8; 32];
244 hk.expand(b"message-key", &mut aes_key)
245 .map_err(|_| PqcError::CryptoError("HKDF expansion failed".to_string()))?;
246
247 let mut nonce = [0u8; 12];
249 let counter_bytes = self.message_counter.to_be_bytes();
250 nonce[4..12].copy_from_slice(&counter_bytes);
251
252 let key = Key::<Aes256Gcm>::from_slice(&aes_key);
254 let cipher = Aes256Gcm::new(key);
255 let nonce_obj = AesNonce::from_slice(&nonce);
256
257 let ciphertext = cipher
258 .encrypt(nonce_obj, plaintext)
259 .map_err(|_| PqcError::EncryptionFailed("Session encryption failed".to_string()))?;
260
261 self.message_counter = self.message_counter.saturating_add(1);
262
263 let mut result = Vec::with_capacity(8_usize.saturating_add(ciphertext.len()));
265 result.extend_from_slice(&(self.message_counter.saturating_sub(1)).to_be_bytes());
266 result.extend_from_slice(&ciphertext);
267
268 Ok(result)
269 }
270}
271
272pub struct DecryptionSession {
274 shared_secret: SharedSecret,
276 received_counters: HashMap<u64, bool>,
278}
279
280impl DecryptionSession {
281 pub fn new(secret_key: &MlKemSecretKey, kem_ciphertext: &MlKemCiphertext) -> PqcResult<Self> {
290 let ml_kem = MlKem768::new();
291 let shared_secret = ml_kem.decapsulate(secret_key, kem_ciphertext)?;
292
293 Ok(Self {
294 shared_secret,
295 received_counters: HashMap::new(),
296 })
297 }
298
299 pub fn decrypt_message(&mut self, ciphertext: &[u8]) -> PqcResult<Vec<u8>> {
306 if ciphertext.len() < 8 {
307 return Err(PqcError::DecryptionFailed("Invalid ciphertext".to_string()));
308 }
309
310 let counter_slice = ciphertext.get(..8).ok_or_else(|| {
312 PqcError::DecryptionFailed("Ciphertext too short for counter".to_string())
313 })?;
314 let counter_bytes: [u8; 8] = counter_slice
315 .try_into()
316 .map_err(|_| PqcError::DecryptionFailed("Invalid counter format".to_string()))?;
317 let counter = u64::from_be_bytes(counter_bytes);
318
319 if self.received_counters.contains_key(&counter) {
321 return Err(PqcError::DecryptionFailed("Replay detected".to_string()));
322 }
323
324 let mut key_material = Vec::new();
326 key_material.extend_from_slice(self.shared_secret.as_bytes());
327 key_material.extend_from_slice(&counter.to_be_bytes());
328
329 let hk = Hkdf::<Sha256>::new(None, &key_material);
330 let mut aes_key = [0u8; 32];
331 hk.expand(b"message-key", &mut aes_key)
332 .map_err(|_| PqcError::CryptoError("HKDF expansion failed".to_string()))?;
333
334 let mut nonce = [0u8; 12];
336 nonce[4..].copy_from_slice(&counter.to_be_bytes());
337
338 let key = Key::<Aes256Gcm>::from_slice(&aes_key);
340 let cipher = Aes256Gcm::new(key);
341 let nonce_obj = AesNonce::from_slice(&nonce);
342
343 let ciphertext_slice = ciphertext
344 .get(8..)
345 .ok_or_else(|| PqcError::DecryptionFailed("Ciphertext too short".to_string()))?;
346 let plaintext = cipher
347 .decrypt(nonce_obj, ciphertext_slice)
348 .map_err(|_| PqcError::DecryptionFailed("Session decryption failed".to_string()))?;
349
350 self.received_counters.insert(counter, true);
352
353 Ok(plaintext)
354 }
355}
356
357impl Default for HybridPublicKeyEncryption {
358 fn default() -> Self {
359 Self::new()
360 }
361}
362
363#[cfg(test)]
364#[allow(clippy::unwrap_used, clippy::expect_used)]
365mod tests {
366 use super::*;
367
368 #[test]
369 fn test_encryption_decryption_roundtrip() {
370 let pke = HybridPublicKeyEncryption::new();
371
372 let (public_key, secret_key) = pke
374 .ml_kem
375 .generate_keypair()
376 .expect("Key generation should succeed");
377
378 let plaintext = b"Hello, quantum-resistant world!";
379 let associated_data = b"test-context";
380
381 let encrypted = pke
383 .encrypt(&public_key, plaintext, associated_data)
384 .expect("Encryption should succeed");
385
386 let decrypted = pke
388 .decrypt(&secret_key, &encrypted, associated_data)
389 .expect("Decryption should succeed");
390
391 assert_eq!(plaintext.to_vec(), decrypted);
392 }
393
394 #[test]
395 fn test_wrong_aad_fails() {
396 let pke = HybridPublicKeyEncryption::new();
397
398 let (public_key, secret_key) = pke.ml_kem.generate_keypair().unwrap();
399 let plaintext = b"Test message";
400 let aad = b"correct-aad";
401 let wrong_aad = b"wrong-aad";
402
403 let encrypted = pke.encrypt(&public_key, plaintext, aad).unwrap();
404
405 let result = pke.decrypt(&secret_key, &encrypted, wrong_aad);
406 assert!(result.is_err());
407 }
408
409 #[test]
410 fn test_session_encryption() {
411 let ml_kem = MlKem768::new();
412 let (public_key, secret_key) = ml_kem.generate_keypair().unwrap();
413
414 let (mut enc_session, kem_ct) = EncryptionSession::new(&public_key).unwrap();
416 let mut dec_session = DecryptionSession::new(&secret_key, &kem_ct).unwrap();
417
418 for i in 0..10 {
420 let plaintext = format!("Message {}", i);
421 let encrypted = enc_session.encrypt_message(plaintext.as_bytes()).unwrap();
422 let decrypted = dec_session.decrypt_message(&encrypted).unwrap();
423 assert_eq!(plaintext.as_bytes(), decrypted);
424 }
425 }
426
427 #[test]
428 fn test_session_replay_protection() {
429 let ml_kem = MlKem768::new();
430 let (public_key, secret_key) = ml_kem.generate_keypair().unwrap();
431
432 let (mut enc_session, kem_ct) = EncryptionSession::new(&public_key).unwrap();
433 let mut dec_session = DecryptionSession::new(&secret_key, &kem_ct).unwrap();
434
435 let plaintext = b"Test";
436 let encrypted = enc_session.encrypt_message(plaintext).unwrap();
437
438 let decrypted = dec_session.decrypt_message(&encrypted).unwrap();
440 assert_eq!(plaintext.to_vec(), decrypted);
441
442 let replay_result = dec_session.decrypt_message(&encrypted);
444 assert!(replay_result.is_err());
445 }
446
447 #[test]
448 fn test_unique_ciphertexts() {
449 let pke = HybridPublicKeyEncryption::new();
450 let (public_key, _secret_key) = pke.ml_kem.generate_keypair().unwrap();
451
452 let plaintext = b"Same message";
453 let aad = b"same-aad";
454
455 let encrypted1 = pke.encrypt(&public_key, plaintext, aad).unwrap();
456 let encrypted2 = pke.encrypt(&public_key, plaintext, aad).unwrap();
457
458 assert_ne!(encrypted1.aes_ciphertext, encrypted2.aes_ciphertext);
460 assert_ne!(encrypted1.nonce, encrypted2.nonce);
461 }
462}