saorsa_pqc/pqc/
encryption.rs

1//! Hybrid encryption module combining ML-KEM-768 with AES-256-GCM
2//!
3//! This module provides quantum-resistant encryption by combining:
4//! - ML-KEM-768 for key encapsulation (post-quantum secure)
5//! - AES-256-GCM for symmetric encryption (classically secure)
6//!
7//! The hybrid approach ensures security against both classical and quantum attackers,
8//! following the "belt and suspenders" principle recommended by NIST.
9//!
10//! # Security Features
11//!
12//! - **Post-quantum KEM**: ML-KEM-768 provides ~192-bit quantum security
13//! - **Authenticated encryption**: AES-256-GCM provides confidentiality and integrity
14//! - **Forward secrecy**: Each message uses a unique ephemeral key
15//! - **Key separation**: Separate keys for encryption and authentication
16//! - **Domain separation**: Context-specific key derivation
17//!
18//! # Implementation Details
19//!
20//! - Uses ML-KEM-768 for key encapsulation (1088-byte ciphertext)
21//! - Derives AES-256 key using HKDF-SHA256
22//! - 96-bit nonces for AES-GCM (safe for 2^32 encryptions)
23//! - HKDF-SHA256 for proper key derivation (NIST SP 800-56C Rev. 2)
24//! - Constant-time operations where possible
25
26use crate::pqc::constant_time::ct_eq;
27use crate::pqc::types::{
28    MlKemCiphertext, MlKemPublicKey, MlKemSecretKey, PqcError, PqcResult, SharedSecret,
29};
30use crate::pqc::{ml_kem::MlKem768, MlKemOperations};
31use aes_gcm::{
32    aead::{Aead, KeyInit},
33    Aes256Gcm, Key, Nonce as AesNonce,
34};
35use hkdf::Hkdf;
36use rand::{thread_rng, RngCore};
37use sha2::{Digest, Sha256};
38use std::collections::HashMap;
39
40/// Wire format for encrypted messages
41///
42/// Contains all necessary components for decryption:
43/// - ML-KEM ciphertext for key encapsulation
44/// - AES-GCM ciphertext with authentication tag
45/// - Nonce for AES-GCM
46/// - Associated data hash for integrity
47#[derive(Debug, Clone)]
48pub struct EncryptedMessage {
49    /// ML-KEM-768 ciphertext (1088 bytes)
50    pub kem_ciphertext: MlKemCiphertext,
51    /// AES-GCM encrypted data with authentication tag
52    pub aes_ciphertext: Vec<u8>,
53    /// AES-GCM nonce (12 bytes)
54    pub nonce: [u8; 12],
55    /// SHA-256 hash of associated data for verification
56    pub aad_hash: [u8; 32],
57}
58
59/// Hybrid Public Key Encryption using ML-KEM and AES-GCM
60///
61/// Provides CCA2-secure public key encryption by combining:
62/// - ML-KEM-768 for key encapsulation
63/// - AES-256-GCM for data encryption
64/// - HKDF for key derivation
65pub struct HybridPublicKeyEncryption {
66    /// ML-KEM-768 instance for key encapsulation
67    ml_kem: MlKem768,
68}
69
70impl HybridPublicKeyEncryption {
71    /// Create a new hybrid encryption instance
72    #[must_use]
73    pub const fn new() -> Self {
74        Self {
75            ml_kem: MlKem768::new(),
76        }
77    }
78
79    /// Generate a new keypair for hybrid encryption
80    ///
81    /// # Errors
82    /// Returns `PqcError` if keypair generation fails
83    pub fn generate_keypair(&self) -> PqcResult<(MlKemPublicKey, MlKemSecretKey)> {
84        self.ml_kem.generate_keypair()
85    }
86
87    /// Encrypt a message using the hybrid scheme
88    ///
89    /// # Arguments
90    /// * `public_key` - Recipient's ML-KEM-768 public key
91    /// * `plaintext` - Message to encrypt
92    /// * `associated_data` - Additional authenticated data (not encrypted)
93    ///
94    /// # Returns
95    /// An `EncryptedMessage` containing all components needed for decryption
96    ///
97    /// # Errors
98    /// Returns `PqcError` if encapsulation or encryption fails
99    pub fn encrypt(
100        &self,
101        public_key: &MlKemPublicKey,
102        plaintext: &[u8],
103        associated_data: &[u8],
104    ) -> PqcResult<EncryptedMessage> {
105        // Step 1: Encapsulate to get shared secret
106        let (kem_ciphertext, shared_secret) = self.ml_kem.encapsulate(public_key)?;
107
108        // Step 2: Derive AES key from shared secret using HKDF
109        let hk = Hkdf::<Sha256>::new(None, shared_secret.as_bytes());
110        let mut aes_key_bytes = [0u8; 32];
111        hk.expand(b"aes-256-gcm-key", &mut aes_key_bytes)
112            .map_err(|_| PqcError::CryptoError("HKDF expansion failed".to_string()))?;
113
114        // Step 3: Generate random nonce
115        let mut nonce = [0u8; 12];
116        thread_rng().fill_bytes(&mut nonce);
117
118        // Step 4: Encrypt with AES-GCM
119        let key = Key::<Aes256Gcm>::from_slice(&aes_key_bytes);
120        let cipher = Aes256Gcm::new(key);
121        let nonce_obj = AesNonce::from_slice(&nonce);
122
123        let aes_ciphertext = cipher
124            .encrypt(nonce_obj, plaintext)
125            .map_err(|_| PqcError::EncryptionFailed("AES-GCM encryption failed".to_string()))?;
126
127        // Step 5: Hash associated data for integrity check
128        let mut hasher = Sha256::new();
129        hasher.update(associated_data);
130        let aad_hash: [u8; 32] = hasher.finalize().into();
131
132        Ok(EncryptedMessage {
133            kem_ciphertext,
134            aes_ciphertext,
135            nonce,
136            aad_hash,
137        })
138    }
139
140    /// Decrypt a message using the hybrid scheme
141    ///
142    /// # Arguments
143    /// * `secret_key` - Recipient's ML-KEM-768 secret key
144    /// * `encrypted_message` - The encrypted message to decrypt
145    /// * `associated_data` - Additional authenticated data for verification
146    ///
147    /// # Returns
148    /// The decrypted plaintext if successful
149    ///
150    /// # Errors
151    /// Returns `PqcError::DecryptionFailed` if AAD verification fails or AES-GCM decryption fails.
152    /// Returns `PqcError::CryptoError` if HKDF expansion fails.
153    pub fn decrypt(
154        &self,
155        secret_key: &MlKemSecretKey,
156        encrypted_message: &EncryptedMessage,
157        associated_data: &[u8],
158    ) -> PqcResult<Vec<u8>> {
159        // Step 1: Verify associated data hash
160        let mut hasher = Sha256::new();
161        hasher.update(associated_data);
162        let computed_hash: [u8; 32] = hasher.finalize().into();
163
164        if !ct_eq(&computed_hash, &encrypted_message.aad_hash) {
165            return Err(PqcError::DecryptionFailed(
166                "Associated data verification failed".to_string(),
167            ));
168        }
169
170        // Step 2: Decapsulate to recover shared secret
171        let shared_secret = self
172            .ml_kem
173            .decapsulate(secret_key, &encrypted_message.kem_ciphertext)?;
174
175        // Step 3: Derive AES key from shared secret
176        let hk = Hkdf::<Sha256>::new(None, shared_secret.as_bytes());
177        let mut aes_key_bytes = [0u8; 32];
178        hk.expand(b"aes-256-gcm-key", &mut aes_key_bytes)
179            .map_err(|_| PqcError::CryptoError("HKDF expansion failed".to_string()))?;
180
181        // Step 4: Decrypt with AES-GCM
182        let key = Key::<Aes256Gcm>::from_slice(&aes_key_bytes);
183        let cipher = Aes256Gcm::new(key);
184        let nonce_obj = AesNonce::from_slice(&encrypted_message.nonce);
185
186        let plaintext = cipher
187            .decrypt(nonce_obj, encrypted_message.aes_ciphertext.as_slice())
188            .map_err(|_| PqcError::DecryptionFailed("AES-GCM decryption failed".to_string()))?;
189
190        Ok(plaintext)
191    }
192}
193
194/// Session-based encryption for multiple messages
195///
196/// Provides efficient encryption for multiple messages to the same recipient
197/// by caching the shared secret and deriving per-message keys.
198pub struct EncryptionSession {
199    /// Shared secret for the session
200    shared_secret: SharedSecret,
201    /// Counter for message sequencing and key derivation
202    message_counter: u64,
203}
204
205impl EncryptionSession {
206    /// Create a new encryption session
207    ///
208    /// # Arguments
209    /// * `public_key` - Recipient's public key
210    ///
211    /// # Returns
212    /// A tuple of (session, KEM ciphertext) where the ciphertext must be sent to the recipient
213    ///
214    /// # Errors
215    /// Returns `PqcError` if ML-KEM encapsulation fails.
216    pub fn new(public_key: &MlKemPublicKey) -> PqcResult<(Self, MlKemCiphertext)> {
217        let ml_kem = MlKem768::new();
218        let (kem_ciphertext, shared_secret) = ml_kem.encapsulate(public_key)?;
219
220        Ok((
221            Self {
222                shared_secret,
223                message_counter: 0,
224            },
225            kem_ciphertext,
226        ))
227    }
228
229    /// Encrypt a message in the session
230    ///
231    /// Each message gets a unique key derived from the session secret and counter
232    ///
233    /// # Errors
234    /// Returns `PqcError::CryptoError` if HKDF expansion fails.
235    /// Returns `PqcError::EncryptionFailed` if AES-GCM encryption fails.
236    pub fn encrypt_message(&mut self, plaintext: &[u8]) -> PqcResult<Vec<u8>> {
237        // Derive per-message key
238        let mut key_material = Vec::new();
239        key_material.extend_from_slice(self.shared_secret.as_bytes());
240        key_material.extend_from_slice(&self.message_counter.to_be_bytes());
241
242        let hk = Hkdf::<Sha256>::new(None, &key_material);
243        let mut aes_key = [0u8; 32];
244        hk.expand(b"message-key", &mut aes_key)
245            .map_err(|_| PqcError::CryptoError("HKDF expansion failed".to_string()))?;
246
247        // Generate nonce from counter
248        let mut nonce = [0u8; 12];
249        let counter_bytes = self.message_counter.to_be_bytes();
250        nonce[4..12].copy_from_slice(&counter_bytes);
251
252        // Encrypt
253        let key = Key::<Aes256Gcm>::from_slice(&aes_key);
254        let cipher = Aes256Gcm::new(key);
255        let nonce_obj = AesNonce::from_slice(&nonce);
256
257        let ciphertext = cipher
258            .encrypt(nonce_obj, plaintext)
259            .map_err(|_| PqcError::EncryptionFailed("Session encryption failed".to_string()))?;
260
261        self.message_counter = self.message_counter.saturating_add(1);
262
263        // Prepend counter for decryption
264        let mut result = Vec::with_capacity(8_usize.saturating_add(ciphertext.len()));
265        result.extend_from_slice(&(self.message_counter.saturating_sub(1)).to_be_bytes());
266        result.extend_from_slice(&ciphertext);
267
268        Ok(result)
269    }
270}
271
272/// Decryption session for multiple messages
273pub struct DecryptionSession {
274    /// Shared secret for the session
275    shared_secret: SharedSecret,
276    /// Track received message counters to prevent replay attacks
277    received_counters: HashMap<u64, bool>,
278}
279
280impl DecryptionSession {
281    /// Create a new decryption session
282    ///
283    /// # Arguments
284    /// * `secret_key` - Recipient's secret key
285    /// * `kem_ciphertext` - KEM ciphertext from sender
286    ///
287    /// # Errors
288    /// Returns `PqcError` if ML-KEM decapsulation fails.
289    pub fn new(secret_key: &MlKemSecretKey, kem_ciphertext: &MlKemCiphertext) -> PqcResult<Self> {
290        let ml_kem = MlKem768::new();
291        let shared_secret = ml_kem.decapsulate(secret_key, kem_ciphertext)?;
292
293        Ok(Self {
294            shared_secret,
295            received_counters: HashMap::new(),
296        })
297    }
298
299    /// Decrypt a message in the session
300    ///
301    /// # Errors
302    /// Returns `PqcError::DecryptionFailed` for invalid ciphertext, counter format errors, or replay attacks.
303    /// Returns `PqcError::CryptoError` if HKDF expansion fails.
304    /// Returns `PqcError::DecryptionFailed` if AES-GCM decryption fails.
305    pub fn decrypt_message(&mut self, ciphertext: &[u8]) -> PqcResult<Vec<u8>> {
306        if ciphertext.len() < 8 {
307            return Err(PqcError::DecryptionFailed("Invalid ciphertext".to_string()));
308        }
309
310        // Extract counter
311        let counter_slice = ciphertext.get(..8).ok_or_else(|| {
312            PqcError::DecryptionFailed("Ciphertext too short for counter".to_string())
313        })?;
314        let counter_bytes: [u8; 8] = counter_slice
315            .try_into()
316            .map_err(|_| PqcError::DecryptionFailed("Invalid counter format".to_string()))?;
317        let counter = u64::from_be_bytes(counter_bytes);
318
319        // Check for replay
320        if self.received_counters.contains_key(&counter) {
321            return Err(PqcError::DecryptionFailed("Replay detected".to_string()));
322        }
323
324        // Derive per-message key
325        let mut key_material = Vec::new();
326        key_material.extend_from_slice(self.shared_secret.as_bytes());
327        key_material.extend_from_slice(&counter.to_be_bytes());
328
329        let hk = Hkdf::<Sha256>::new(None, &key_material);
330        let mut aes_key = [0u8; 32];
331        hk.expand(b"message-key", &mut aes_key)
332            .map_err(|_| PqcError::CryptoError("HKDF expansion failed".to_string()))?;
333
334        // Generate nonce from counter
335        let mut nonce = [0u8; 12];
336        nonce[4..].copy_from_slice(&counter.to_be_bytes());
337
338        // Decrypt
339        let key = Key::<Aes256Gcm>::from_slice(&aes_key);
340        let cipher = Aes256Gcm::new(key);
341        let nonce_obj = AesNonce::from_slice(&nonce);
342
343        let ciphertext_slice = ciphertext
344            .get(8..)
345            .ok_or_else(|| PqcError::DecryptionFailed("Ciphertext too short".to_string()))?;
346        let plaintext = cipher
347            .decrypt(nonce_obj, ciphertext_slice)
348            .map_err(|_| PqcError::DecryptionFailed("Session decryption failed".to_string()))?;
349
350        // Mark counter as used
351        self.received_counters.insert(counter, true);
352
353        Ok(plaintext)
354    }
355}
356
357impl Default for HybridPublicKeyEncryption {
358    fn default() -> Self {
359        Self::new()
360    }
361}
362
363#[cfg(test)]
364#[allow(clippy::unwrap_used, clippy::expect_used)]
365mod tests {
366    use super::*;
367
368    #[test]
369    fn test_encryption_decryption_roundtrip() {
370        let pke = HybridPublicKeyEncryption::new();
371
372        // Generate keypair for testing
373        let (public_key, secret_key) = pke
374            .ml_kem
375            .generate_keypair()
376            .expect("Key generation should succeed");
377
378        let plaintext = b"Hello, quantum-resistant world!";
379        let associated_data = b"test-context";
380
381        // Encrypt
382        let encrypted = pke
383            .encrypt(&public_key, plaintext, associated_data)
384            .expect("Encryption should succeed");
385
386        // Decrypt
387        let decrypted = pke
388            .decrypt(&secret_key, &encrypted, associated_data)
389            .expect("Decryption should succeed");
390
391        assert_eq!(plaintext.to_vec(), decrypted);
392    }
393
394    #[test]
395    fn test_wrong_aad_fails() {
396        let pke = HybridPublicKeyEncryption::new();
397
398        let (public_key, secret_key) = pke.ml_kem.generate_keypair().unwrap();
399        let plaintext = b"Test message";
400        let aad = b"correct-aad";
401        let wrong_aad = b"wrong-aad";
402
403        let encrypted = pke.encrypt(&public_key, plaintext, aad).unwrap();
404
405        let result = pke.decrypt(&secret_key, &encrypted, wrong_aad);
406        assert!(result.is_err());
407    }
408
409    #[test]
410    fn test_session_encryption() {
411        let ml_kem = MlKem768::new();
412        let (public_key, secret_key) = ml_kem.generate_keypair().unwrap();
413
414        // Create session
415        let (mut enc_session, kem_ct) = EncryptionSession::new(&public_key).unwrap();
416        let mut dec_session = DecryptionSession::new(&secret_key, &kem_ct).unwrap();
417
418        // Encrypt and decrypt multiple messages
419        for i in 0..10 {
420            let plaintext = format!("Message {}", i);
421            let encrypted = enc_session.encrypt_message(plaintext.as_bytes()).unwrap();
422            let decrypted = dec_session.decrypt_message(&encrypted).unwrap();
423            assert_eq!(plaintext.as_bytes(), decrypted);
424        }
425    }
426
427    #[test]
428    fn test_session_replay_protection() {
429        let ml_kem = MlKem768::new();
430        let (public_key, secret_key) = ml_kem.generate_keypair().unwrap();
431
432        let (mut enc_session, kem_ct) = EncryptionSession::new(&public_key).unwrap();
433        let mut dec_session = DecryptionSession::new(&secret_key, &kem_ct).unwrap();
434
435        let plaintext = b"Test";
436        let encrypted = enc_session.encrypt_message(plaintext).unwrap();
437
438        // First decryption should succeed
439        let decrypted = dec_session.decrypt_message(&encrypted).unwrap();
440        assert_eq!(plaintext.to_vec(), decrypted);
441
442        // Replay should fail
443        let replay_result = dec_session.decrypt_message(&encrypted);
444        assert!(replay_result.is_err());
445    }
446
447    #[test]
448    fn test_unique_ciphertexts() {
449        let pke = HybridPublicKeyEncryption::new();
450        let (public_key, _secret_key) = pke.ml_kem.generate_keypair().unwrap();
451
452        let plaintext = b"Same message";
453        let aad = b"same-aad";
454
455        let encrypted1 = pke.encrypt(&public_key, plaintext, aad).unwrap();
456        let encrypted2 = pke.encrypt(&public_key, plaintext, aad).unwrap();
457
458        // Same plaintext should produce different ciphertexts (due to randomness)
459        assert_ne!(encrypted1.aes_ciphertext, encrypted2.aes_ciphertext);
460        assert_ne!(encrypted1.nonce, encrypted2.nonce);
461    }
462}