1use crate::pqc::types::{
27 MlKemCiphertext, MlKemPublicKey, MlKemSecretKey, PqcError, PqcResult, SharedSecret,
28};
29use crate::pqc::{ml_kem::MlKem768, MlKemOperations};
30use aes_gcm::{
31 aead::{Aead, KeyInit},
32 Aes256Gcm, Key, Nonce as AesNonce,
33};
34use hkdf::Hkdf;
35use rand::{thread_rng, RngCore};
36use sha2::{Digest, Sha256};
37use std::collections::HashMap;
38
39#[derive(Debug, Clone)]
47pub struct EncryptedMessage {
48 pub kem_ciphertext: MlKemCiphertext,
50 pub aes_ciphertext: Vec<u8>,
52 pub nonce: [u8; 12],
54 pub aad_hash: [u8; 32],
56}
57
58pub struct HybridPublicKeyEncryption {
65 ml_kem: MlKem768,
67}
68
69impl HybridPublicKeyEncryption {
70 #[must_use]
72 pub const fn new() -> Self {
73 Self {
74 ml_kem: MlKem768::new(),
75 }
76 }
77
78 pub fn generate_keypair(&self) -> PqcResult<(MlKemPublicKey, MlKemSecretKey)> {
83 self.ml_kem.generate_keypair()
84 }
85
86 pub fn encrypt(
99 &self,
100 public_key: &MlKemPublicKey,
101 plaintext: &[u8],
102 associated_data: &[u8],
103 ) -> PqcResult<EncryptedMessage> {
104 let (kem_ciphertext, shared_secret) = self.ml_kem.encapsulate(public_key)?;
106
107 let hk = Hkdf::<Sha256>::new(None, shared_secret.as_bytes());
109 let mut aes_key_bytes = [0u8; 32];
110 hk.expand(b"aes-256-gcm-key", &mut aes_key_bytes)
111 .map_err(|_| PqcError::CryptoError("HKDF expansion failed".to_string()))?;
112
113 let mut nonce = [0u8; 12];
115 thread_rng().fill_bytes(&mut nonce);
116
117 let key = Key::<Aes256Gcm>::from_slice(&aes_key_bytes);
119 let cipher = Aes256Gcm::new(key);
120 let nonce_obj = AesNonce::from_slice(&nonce);
121
122 let aes_ciphertext = cipher
123 .encrypt(nonce_obj, plaintext)
124 .map_err(|_| PqcError::EncryptionFailed("AES-GCM encryption failed".to_string()))?;
125
126 let mut hasher = Sha256::new();
128 hasher.update(associated_data);
129 let aad_hash: [u8; 32] = hasher.finalize().into();
130
131 Ok(EncryptedMessage {
132 kem_ciphertext,
133 aes_ciphertext,
134 nonce,
135 aad_hash,
136 })
137 }
138
139 pub fn decrypt(
153 &self,
154 secret_key: &MlKemSecretKey,
155 encrypted_message: &EncryptedMessage,
156 associated_data: &[u8],
157 ) -> PqcResult<Vec<u8>> {
158 let mut hasher = Sha256::new();
160 hasher.update(associated_data);
161 let computed_hash: [u8; 32] = hasher.finalize().into();
162
163 if computed_hash != encrypted_message.aad_hash {
164 return Err(PqcError::DecryptionFailed(
165 "Associated data verification failed".to_string(),
166 ));
167 }
168
169 let shared_secret = self
171 .ml_kem
172 .decapsulate(secret_key, &encrypted_message.kem_ciphertext)?;
173
174 let hk = Hkdf::<Sha256>::new(None, shared_secret.as_bytes());
176 let mut aes_key_bytes = [0u8; 32];
177 hk.expand(b"aes-256-gcm-key", &mut aes_key_bytes)
178 .map_err(|_| PqcError::CryptoError("HKDF expansion failed".to_string()))?;
179
180 let key = Key::<Aes256Gcm>::from_slice(&aes_key_bytes);
182 let cipher = Aes256Gcm::new(key);
183 let nonce_obj = AesNonce::from_slice(&encrypted_message.nonce);
184
185 let plaintext = cipher
186 .decrypt(nonce_obj, encrypted_message.aes_ciphertext.as_slice())
187 .map_err(|_| PqcError::DecryptionFailed("AES-GCM decryption failed".to_string()))?;
188
189 Ok(plaintext)
190 }
191}
192
193pub struct EncryptionSession {
198 shared_secret: SharedSecret,
200 message_counter: u64,
202}
203
204impl EncryptionSession {
205 pub fn new(public_key: &MlKemPublicKey) -> PqcResult<(Self, MlKemCiphertext)> {
216 let ml_kem = MlKem768::new();
217 let (kem_ciphertext, shared_secret) = ml_kem.encapsulate(public_key)?;
218
219 Ok((
220 Self {
221 shared_secret,
222 message_counter: 0,
223 },
224 kem_ciphertext,
225 ))
226 }
227
228 pub fn encrypt_message(&mut self, plaintext: &[u8]) -> PqcResult<Vec<u8>> {
236 let mut key_material = Vec::new();
238 key_material.extend_from_slice(self.shared_secret.as_bytes());
239 key_material.extend_from_slice(&self.message_counter.to_be_bytes());
240
241 let hk = Hkdf::<Sha256>::new(None, &key_material);
242 let mut aes_key = [0u8; 32];
243 hk.expand(b"message-key", &mut aes_key)
244 .map_err(|_| PqcError::CryptoError("HKDF expansion failed".to_string()))?;
245
246 let mut nonce = [0u8; 12];
248 let counter_bytes = self.message_counter.to_be_bytes();
249 nonce[4..12].copy_from_slice(&counter_bytes);
250
251 let key = Key::<Aes256Gcm>::from_slice(&aes_key);
253 let cipher = Aes256Gcm::new(key);
254 let nonce_obj = AesNonce::from_slice(&nonce);
255
256 let ciphertext = cipher
257 .encrypt(nonce_obj, plaintext)
258 .map_err(|_| PqcError::EncryptionFailed("Session encryption failed".to_string()))?;
259
260 self.message_counter = self.message_counter.saturating_add(1);
261
262 let mut result = Vec::with_capacity(8_usize.saturating_add(ciphertext.len()));
264 result.extend_from_slice(&(self.message_counter.saturating_sub(1)).to_be_bytes());
265 result.extend_from_slice(&ciphertext);
266
267 Ok(result)
268 }
269}
270
271pub struct DecryptionSession {
273 shared_secret: SharedSecret,
275 received_counters: HashMap<u64, bool>,
277}
278
279impl DecryptionSession {
280 pub fn new(secret_key: &MlKemSecretKey, kem_ciphertext: &MlKemCiphertext) -> PqcResult<Self> {
289 let ml_kem = MlKem768::new();
290 let shared_secret = ml_kem.decapsulate(secret_key, kem_ciphertext)?;
291
292 Ok(Self {
293 shared_secret,
294 received_counters: HashMap::new(),
295 })
296 }
297
298 pub fn decrypt_message(&mut self, ciphertext: &[u8]) -> PqcResult<Vec<u8>> {
305 if ciphertext.len() < 8 {
306 return Err(PqcError::DecryptionFailed("Invalid ciphertext".to_string()));
307 }
308
309 let counter_slice = ciphertext.get(..8).ok_or_else(|| {
311 PqcError::DecryptionFailed("Ciphertext too short for counter".to_string())
312 })?;
313 let counter_bytes: [u8; 8] = counter_slice
314 .try_into()
315 .map_err(|_| PqcError::DecryptionFailed("Invalid counter format".to_string()))?;
316 let counter = u64::from_be_bytes(counter_bytes);
317
318 if self.received_counters.contains_key(&counter) {
320 return Err(PqcError::DecryptionFailed("Replay detected".to_string()));
321 }
322
323 let mut key_material = Vec::new();
325 key_material.extend_from_slice(self.shared_secret.as_bytes());
326 key_material.extend_from_slice(&counter.to_be_bytes());
327
328 let hk = Hkdf::<Sha256>::new(None, &key_material);
329 let mut aes_key = [0u8; 32];
330 hk.expand(b"message-key", &mut aes_key)
331 .map_err(|_| PqcError::CryptoError("HKDF expansion failed".to_string()))?;
332
333 let mut nonce = [0u8; 12];
335 nonce[4..].copy_from_slice(&counter.to_be_bytes());
336
337 let key = Key::<Aes256Gcm>::from_slice(&aes_key);
339 let cipher = Aes256Gcm::new(key);
340 let nonce_obj = AesNonce::from_slice(&nonce);
341
342 let ciphertext_slice = ciphertext
343 .get(8..)
344 .ok_or_else(|| PqcError::DecryptionFailed("Ciphertext too short".to_string()))?;
345 let plaintext = cipher
346 .decrypt(nonce_obj, ciphertext_slice)
347 .map_err(|_| PqcError::DecryptionFailed("Session decryption failed".to_string()))?;
348
349 self.received_counters.insert(counter, true);
351
352 Ok(plaintext)
353 }
354}
355
356impl Default for HybridPublicKeyEncryption {
357 fn default() -> Self {
358 Self::new()
359 }
360}
361
362#[cfg(test)]
363#[allow(clippy::unwrap_used, clippy::expect_used)]
364mod tests {
365 use super::*;
366
367 #[test]
368 fn test_encryption_decryption_roundtrip() {
369 let pke = HybridPublicKeyEncryption::new();
370
371 let (public_key, secret_key) = pke
373 .ml_kem
374 .generate_keypair()
375 .expect("Key generation should succeed");
376
377 let plaintext = b"Hello, quantum-resistant world!";
378 let associated_data = b"test-context";
379
380 let encrypted = pke
382 .encrypt(&public_key, plaintext, associated_data)
383 .expect("Encryption should succeed");
384
385 let decrypted = pke
387 .decrypt(&secret_key, &encrypted, associated_data)
388 .expect("Decryption should succeed");
389
390 assert_eq!(plaintext.to_vec(), decrypted);
391 }
392
393 #[test]
394 fn test_wrong_aad_fails() {
395 let pke = HybridPublicKeyEncryption::new();
396
397 let (public_key, secret_key) = pke.ml_kem.generate_keypair().unwrap();
398 let plaintext = b"Test message";
399 let aad = b"correct-aad";
400 let wrong_aad = b"wrong-aad";
401
402 let encrypted = pke.encrypt(&public_key, plaintext, aad).unwrap();
403
404 let result = pke.decrypt(&secret_key, &encrypted, wrong_aad);
405 assert!(result.is_err());
406 }
407
408 #[test]
409 fn test_session_encryption() {
410 let ml_kem = MlKem768::new();
411 let (public_key, secret_key) = ml_kem.generate_keypair().unwrap();
412
413 let (mut enc_session, kem_ct) = EncryptionSession::new(&public_key).unwrap();
415 let mut dec_session = DecryptionSession::new(&secret_key, &kem_ct).unwrap();
416
417 for i in 0..10 {
419 let plaintext = format!("Message {}", i);
420 let encrypted = enc_session.encrypt_message(plaintext.as_bytes()).unwrap();
421 let decrypted = dec_session.decrypt_message(&encrypted).unwrap();
422 assert_eq!(plaintext.as_bytes(), decrypted);
423 }
424 }
425
426 #[test]
427 fn test_session_replay_protection() {
428 let ml_kem = MlKem768::new();
429 let (public_key, secret_key) = ml_kem.generate_keypair().unwrap();
430
431 let (mut enc_session, kem_ct) = EncryptionSession::new(&public_key).unwrap();
432 let mut dec_session = DecryptionSession::new(&secret_key, &kem_ct).unwrap();
433
434 let plaintext = b"Test";
435 let encrypted = enc_session.encrypt_message(plaintext).unwrap();
436
437 let decrypted = dec_session.decrypt_message(&encrypted).unwrap();
439 assert_eq!(plaintext.to_vec(), decrypted);
440
441 let replay_result = dec_session.decrypt_message(&encrypted);
443 assert!(replay_result.is_err());
444 }
445
446 #[test]
447 fn test_unique_ciphertexts() {
448 let pke = HybridPublicKeyEncryption::new();
449 let (public_key, _secret_key) = pke.ml_kem.generate_keypair().unwrap();
450
451 let plaintext = b"Same message";
452 let aad = b"same-aad";
453
454 let encrypted1 = pke.encrypt(&public_key, plaintext, aad).unwrap();
455 let encrypted2 = pke.encrypt(&public_key, plaintext, aad).unwrap();
456
457 assert_ne!(encrypted1.aes_ciphertext, encrypted2.aes_ciphertext);
459 assert_ne!(encrypted1.nonce, encrypted2.nonce);
460 }
461}