1use crate::pqc::types::{
27 MlKemCiphertext, MlKemPublicKey, MlKemSecretKey, PqcError, PqcResult, SharedSecret,
28};
29use crate::pqc::{ml_kem::MlKem768, MlKemOperations};
30use aes_gcm::{
31 aead::{Aead, KeyInit},
32 Aes256Gcm, Key, Nonce as AesNonce,
33};
34use hkdf::Hkdf;
35use rand::{thread_rng, RngCore};
36use sha2::{Digest, Sha256};
37use std::collections::HashMap;
38
39#[derive(Debug, Clone)]
47pub struct EncryptedMessage {
48 pub kem_ciphertext: MlKemCiphertext,
50 pub aes_ciphertext: Vec<u8>,
52 pub nonce: [u8; 12],
54 pub aad_hash: [u8; 32],
56}
57
58pub struct HybridPublicKeyEncryption {
65 ml_kem: MlKem768,
67}
68
69impl HybridPublicKeyEncryption {
70 #[must_use]
72 pub const fn new() -> Self {
73 Self {
74 ml_kem: MlKem768::new(),
75 }
76 }
77
78 pub fn generate_keypair(&self) -> PqcResult<(MlKemPublicKey, MlKemSecretKey)> {
83 self.ml_kem.generate_keypair()
84 }
85
86 pub fn encrypt(
99 &self,
100 public_key: &MlKemPublicKey,
101 plaintext: &[u8],
102 associated_data: &[u8],
103 ) -> PqcResult<EncryptedMessage> {
104 let (kem_ciphertext, shared_secret) = self.ml_kem.encapsulate(public_key)?;
106
107 let hk = Hkdf::<Sha256>::new(None, shared_secret.as_bytes());
109 let mut aes_key_bytes = [0u8; 32];
110 hk.expand(b"aes-256-gcm-key", &mut aes_key_bytes)
111 .map_err(|_| PqcError::CryptoError("HKDF expansion failed".to_string()))?;
112
113 let mut nonce = [0u8; 12];
115 thread_rng().fill_bytes(&mut nonce);
116
117 let key = Key::<Aes256Gcm>::from_slice(&aes_key_bytes);
119 let cipher = Aes256Gcm::new(key);
120 let nonce_obj = AesNonce::from_slice(&nonce);
121
122 let aes_ciphertext = cipher
123 .encrypt(nonce_obj, plaintext)
124 .map_err(|_| PqcError::EncryptionFailed("AES-GCM encryption failed".to_string()))?;
125
126 let mut hasher = Sha256::new();
128 hasher.update(associated_data);
129 let aad_hash: [u8; 32] = hasher.finalize().into();
130
131 Ok(EncryptedMessage {
132 kem_ciphertext,
133 aes_ciphertext,
134 nonce,
135 aad_hash,
136 })
137 }
138
139 pub fn decrypt(
149 &self,
150 secret_key: &MlKemSecretKey,
151 encrypted_message: &EncryptedMessage,
152 associated_data: &[u8],
153 ) -> PqcResult<Vec<u8>> {
154 let mut hasher = Sha256::new();
156 hasher.update(associated_data);
157 let computed_hash: [u8; 32] = hasher.finalize().into();
158
159 if computed_hash != encrypted_message.aad_hash {
160 return Err(PqcError::DecryptionFailed(
161 "Associated data verification failed".to_string(),
162 ));
163 }
164
165 let shared_secret = self
167 .ml_kem
168 .decapsulate(secret_key, &encrypted_message.kem_ciphertext)?;
169
170 let hk = Hkdf::<Sha256>::new(None, shared_secret.as_bytes());
172 let mut aes_key_bytes = [0u8; 32];
173 hk.expand(b"aes-256-gcm-key", &mut aes_key_bytes)
174 .map_err(|_| PqcError::CryptoError("HKDF expansion failed".to_string()))?;
175
176 let key = Key::<Aes256Gcm>::from_slice(&aes_key_bytes);
178 let cipher = Aes256Gcm::new(key);
179 let nonce_obj = AesNonce::from_slice(&encrypted_message.nonce);
180
181 let plaintext = cipher
182 .decrypt(nonce_obj, encrypted_message.aes_ciphertext.as_slice())
183 .map_err(|_| PqcError::DecryptionFailed("AES-GCM decryption failed".to_string()))?;
184
185 Ok(plaintext)
186 }
187}
188
189pub struct EncryptionSession {
194 shared_secret: SharedSecret,
196 message_counter: u64,
198}
199
200impl EncryptionSession {
201 pub fn new(public_key: &MlKemPublicKey) -> PqcResult<(Self, MlKemCiphertext)> {
209 let ml_kem = MlKem768::new();
210 let (kem_ciphertext, shared_secret) = ml_kem.encapsulate(public_key)?;
211
212 Ok((
213 Self {
214 shared_secret,
215 message_counter: 0,
216 },
217 kem_ciphertext,
218 ))
219 }
220
221 pub fn encrypt_message(&mut self, plaintext: &[u8]) -> PqcResult<Vec<u8>> {
225 let mut key_material = Vec::new();
227 key_material.extend_from_slice(self.shared_secret.as_bytes());
228 key_material.extend_from_slice(&self.message_counter.to_be_bytes());
229
230 let hk = Hkdf::<Sha256>::new(None, &key_material);
231 let mut aes_key = [0u8; 32];
232 hk.expand(b"message-key", &mut aes_key)
233 .map_err(|_| PqcError::CryptoError("HKDF expansion failed".to_string()))?;
234
235 let mut nonce = [0u8; 12];
237 nonce[4..].copy_from_slice(&self.message_counter.to_be_bytes());
238
239 let key = Key::<Aes256Gcm>::from_slice(&aes_key);
241 let cipher = Aes256Gcm::new(key);
242 let nonce_obj = AesNonce::from_slice(&nonce);
243
244 let ciphertext = cipher
245 .encrypt(nonce_obj, plaintext)
246 .map_err(|_| PqcError::EncryptionFailed("Session encryption failed".to_string()))?;
247
248 self.message_counter += 1;
249
250 let mut result = Vec::with_capacity(8 + ciphertext.len());
252 result.extend_from_slice(&(self.message_counter - 1).to_be_bytes());
253 result.extend_from_slice(&ciphertext);
254
255 Ok(result)
256 }
257}
258
259pub struct DecryptionSession {
261 shared_secret: SharedSecret,
263 received_counters: HashMap<u64, bool>,
265}
266
267impl DecryptionSession {
268 pub fn new(secret_key: &MlKemSecretKey, kem_ciphertext: &MlKemCiphertext) -> PqcResult<Self> {
274 let ml_kem = MlKem768::new();
275 let shared_secret = ml_kem.decapsulate(secret_key, kem_ciphertext)?;
276
277 Ok(Self {
278 shared_secret,
279 received_counters: HashMap::new(),
280 })
281 }
282
283 pub fn decrypt_message(&mut self, ciphertext: &[u8]) -> PqcResult<Vec<u8>> {
285 if ciphertext.len() < 8 {
286 return Err(PqcError::DecryptionFailed("Invalid ciphertext".to_string()));
287 }
288
289 let counter = u64::from_be_bytes(ciphertext[..8].try_into().unwrap());
291
292 if self.received_counters.contains_key(&counter) {
294 return Err(PqcError::DecryptionFailed("Replay detected".to_string()));
295 }
296
297 let mut key_material = Vec::new();
299 key_material.extend_from_slice(self.shared_secret.as_bytes());
300 key_material.extend_from_slice(&counter.to_be_bytes());
301
302 let hk = Hkdf::<Sha256>::new(None, &key_material);
303 let mut aes_key = [0u8; 32];
304 hk.expand(b"message-key", &mut aes_key)
305 .map_err(|_| PqcError::CryptoError("HKDF expansion failed".to_string()))?;
306
307 let mut nonce = [0u8; 12];
309 nonce[4..].copy_from_slice(&counter.to_be_bytes());
310
311 let key = Key::<Aes256Gcm>::from_slice(&aes_key);
313 let cipher = Aes256Gcm::new(key);
314 let nonce_obj = AesNonce::from_slice(&nonce);
315
316 let plaintext = cipher
317 .decrypt(nonce_obj, &ciphertext[8..])
318 .map_err(|_| PqcError::DecryptionFailed("Session decryption failed".to_string()))?;
319
320 self.received_counters.insert(counter, true);
322
323 Ok(plaintext)
324 }
325}
326
327impl Default for HybridPublicKeyEncryption {
328 fn default() -> Self {
329 Self::new()
330 }
331}
332
333#[cfg(test)]
334mod tests {
335 use super::*;
336
337 #[test]
338 fn test_encryption_decryption_roundtrip() {
339 let pke = HybridPublicKeyEncryption::new();
340
341 let (public_key, secret_key) = pke
343 .ml_kem
344 .generate_keypair()
345 .expect("Key generation should succeed");
346
347 let plaintext = b"Hello, quantum-resistant world!";
348 let associated_data = b"test-context";
349
350 let encrypted = pke
352 .encrypt(&public_key, plaintext, associated_data)
353 .expect("Encryption should succeed");
354
355 let decrypted = pke
357 .decrypt(&secret_key, &encrypted, associated_data)
358 .expect("Decryption should succeed");
359
360 assert_eq!(plaintext.to_vec(), decrypted);
361 }
362
363 #[test]
364 fn test_wrong_aad_fails() {
365 let pke = HybridPublicKeyEncryption::new();
366
367 let (public_key, secret_key) = pke.ml_kem.generate_keypair().unwrap();
368 let plaintext = b"Test message";
369 let aad = b"correct-aad";
370 let wrong_aad = b"wrong-aad";
371
372 let encrypted = pke.encrypt(&public_key, plaintext, aad).unwrap();
373
374 let result = pke.decrypt(&secret_key, &encrypted, wrong_aad);
375 assert!(result.is_err());
376 }
377
378 #[test]
379 fn test_session_encryption() {
380 let ml_kem = MlKem768::new();
381 let (public_key, secret_key) = ml_kem.generate_keypair().unwrap();
382
383 let (mut enc_session, kem_ct) = EncryptionSession::new(&public_key).unwrap();
385 let mut dec_session = DecryptionSession::new(&secret_key, &kem_ct).unwrap();
386
387 for i in 0..10 {
389 let plaintext = format!("Message {}", i);
390 let encrypted = enc_session.encrypt_message(plaintext.as_bytes()).unwrap();
391 let decrypted = dec_session.decrypt_message(&encrypted).unwrap();
392 assert_eq!(plaintext.as_bytes(), decrypted);
393 }
394 }
395
396 #[test]
397 fn test_session_replay_protection() {
398 let ml_kem = MlKem768::new();
399 let (public_key, secret_key) = ml_kem.generate_keypair().unwrap();
400
401 let (mut enc_session, kem_ct) = EncryptionSession::new(&public_key).unwrap();
402 let mut dec_session = DecryptionSession::new(&secret_key, &kem_ct).unwrap();
403
404 let plaintext = b"Test";
405 let encrypted = enc_session.encrypt_message(plaintext).unwrap();
406
407 let decrypted = dec_session.decrypt_message(&encrypted).unwrap();
409 assert_eq!(plaintext.to_vec(), decrypted);
410
411 let replay_result = dec_session.decrypt_message(&encrypted);
413 assert!(replay_result.is_err());
414 }
415
416 #[test]
417 fn test_unique_ciphertexts() {
418 let pke = HybridPublicKeyEncryption::new();
419 let (public_key, _secret_key) = pke.ml_kem.generate_keypair().unwrap();
420
421 let plaintext = b"Same message";
422 let aad = b"same-aad";
423
424 let encrypted1 = pke.encrypt(&public_key, plaintext, aad).unwrap();
425 let encrypted2 = pke.encrypt(&public_key, plaintext, aad).unwrap();
426
427 assert_ne!(encrypted1.aes_ciphertext, encrypted2.aes_ciphertext);
429 assert_ne!(encrypted1.nonce, encrypted2.nonce);
430 }
431}