saorsa_pqc/pqc/
encryption.rs

1//! Hybrid encryption module combining ML-KEM-768 with AES-256-GCM
2//!
3//! This module provides quantum-resistant encryption by combining:
4//! - ML-KEM-768 for key encapsulation (post-quantum secure)
5//! - AES-256-GCM for symmetric encryption (classically secure)
6//!
7//! The hybrid approach ensures security against both classical and quantum attackers,
8//! following the "belt and suspenders" principle recommended by NIST.
9//!
10//! # Security Features
11//!
12//! - **Post-quantum KEM**: ML-KEM-768 provides ~192-bit quantum security
13//! - **Authenticated encryption**: AES-256-GCM provides confidentiality and integrity
14//! - **Forward secrecy**: Each message uses a unique ephemeral key
15//! - **Key separation**: Separate keys for encryption and authentication
16//! - **Domain separation**: Context-specific key derivation
17//!
18//! # Implementation Details
19//!
20//! - Uses ML-KEM-768 for key encapsulation (1088-byte ciphertext)
21//! - Derives AES-256 key using HKDF-SHA256
22//! - 96-bit nonces for AES-GCM (safe for 2^32 encryptions)
23//! - HKDF-SHA256 for proper key derivation (NIST SP 800-56C Rev. 2)
24//! - Constant-time operations where possible
25
26use crate::pqc::types::{
27    MlKemCiphertext, MlKemPublicKey, MlKemSecretKey, PqcError, PqcResult, SharedSecret,
28};
29use crate::pqc::{ml_kem::MlKem768, MlKemOperations};
30use aes_gcm::{
31    aead::{Aead, KeyInit},
32    Aes256Gcm, Key, Nonce as AesNonce,
33};
34use hkdf::Hkdf;
35use rand::{thread_rng, RngCore};
36use sha2::{Digest, Sha256};
37use std::collections::HashMap;
38
39/// Wire format for encrypted messages
40///
41/// Contains all necessary components for decryption:
42/// - ML-KEM ciphertext for key encapsulation
43/// - AES-GCM ciphertext with authentication tag
44/// - Nonce for AES-GCM
45/// - Associated data hash for integrity
46#[derive(Debug, Clone)]
47pub struct EncryptedMessage {
48    /// ML-KEM-768 ciphertext (1088 bytes)
49    pub kem_ciphertext: MlKemCiphertext,
50    /// AES-GCM encrypted data with authentication tag
51    pub aes_ciphertext: Vec<u8>,
52    /// AES-GCM nonce (12 bytes)
53    pub nonce: [u8; 12],
54    /// SHA-256 hash of associated data for verification
55    pub aad_hash: [u8; 32],
56}
57
58/// Hybrid Public Key Encryption using ML-KEM and AES-GCM
59///
60/// Provides CCA2-secure public key encryption by combining:
61/// - ML-KEM-768 for key encapsulation
62/// - AES-256-GCM for data encryption
63/// - HKDF for key derivation
64pub struct HybridPublicKeyEncryption {
65    /// ML-KEM-768 instance for key encapsulation
66    ml_kem: MlKem768,
67}
68
69impl HybridPublicKeyEncryption {
70    /// Create a new hybrid encryption instance
71    #[must_use]
72    pub const fn new() -> Self {
73        Self {
74            ml_kem: MlKem768::new(),
75        }
76    }
77
78    /// Generate a new keypair for hybrid encryption
79    ///
80    /// # Errors
81    /// Returns `PqcError` if keypair generation fails
82    pub fn generate_keypair(&self) -> PqcResult<(MlKemPublicKey, MlKemSecretKey)> {
83        self.ml_kem.generate_keypair()
84    }
85
86    /// Encrypt a message using the hybrid scheme
87    ///
88    /// # Arguments
89    /// * `public_key` - Recipient's ML-KEM-768 public key
90    /// * `plaintext` - Message to encrypt
91    /// * `associated_data` - Additional authenticated data (not encrypted)
92    ///
93    /// # Returns
94    /// An `EncryptedMessage` containing all components needed for decryption
95    ///
96    /// # Errors
97    /// Returns `PqcError` if encapsulation or encryption fails
98    pub fn encrypt(
99        &self,
100        public_key: &MlKemPublicKey,
101        plaintext: &[u8],
102        associated_data: &[u8],
103    ) -> PqcResult<EncryptedMessage> {
104        // Step 1: Encapsulate to get shared secret
105        let (kem_ciphertext, shared_secret) = self.ml_kem.encapsulate(public_key)?;
106
107        // Step 2: Derive AES key from shared secret using HKDF
108        let hk = Hkdf::<Sha256>::new(None, shared_secret.as_bytes());
109        let mut aes_key_bytes = [0u8; 32];
110        hk.expand(b"aes-256-gcm-key", &mut aes_key_bytes)
111            .map_err(|_| PqcError::CryptoError("HKDF expansion failed".to_string()))?;
112
113        // Step 3: Generate random nonce
114        let mut nonce = [0u8; 12];
115        thread_rng().fill_bytes(&mut nonce);
116
117        // Step 4: Encrypt with AES-GCM
118        let key = Key::<Aes256Gcm>::from_slice(&aes_key_bytes);
119        let cipher = Aes256Gcm::new(key);
120        let nonce_obj = AesNonce::from_slice(&nonce);
121
122        let aes_ciphertext = cipher
123            .encrypt(nonce_obj, plaintext)
124            .map_err(|_| PqcError::EncryptionFailed("AES-GCM encryption failed".to_string()))?;
125
126        // Step 5: Hash associated data for integrity check
127        let mut hasher = Sha256::new();
128        hasher.update(associated_data);
129        let aad_hash: [u8; 32] = hasher.finalize().into();
130
131        Ok(EncryptedMessage {
132            kem_ciphertext,
133            aes_ciphertext,
134            nonce,
135            aad_hash,
136        })
137    }
138
139    /// Decrypt a message using the hybrid scheme
140    ///
141    /// # Arguments
142    /// * `secret_key` - Recipient's ML-KEM-768 secret key
143    /// * `encrypted_message` - The encrypted message to decrypt
144    /// * `associated_data` - Additional authenticated data for verification
145    ///
146    /// # Returns
147    /// The decrypted plaintext if successful
148    ///
149    /// # Errors
150    /// Returns `PqcError::DecryptionFailed` if AAD verification fails or AES-GCM decryption fails.
151    /// Returns `PqcError::CryptoError` if HKDF expansion fails.
152    pub fn decrypt(
153        &self,
154        secret_key: &MlKemSecretKey,
155        encrypted_message: &EncryptedMessage,
156        associated_data: &[u8],
157    ) -> PqcResult<Vec<u8>> {
158        // Step 1: Verify associated data hash
159        let mut hasher = Sha256::new();
160        hasher.update(associated_data);
161        let computed_hash: [u8; 32] = hasher.finalize().into();
162
163        if computed_hash != encrypted_message.aad_hash {
164            return Err(PqcError::DecryptionFailed(
165                "Associated data verification failed".to_string(),
166            ));
167        }
168
169        // Step 2: Decapsulate to recover shared secret
170        let shared_secret = self
171            .ml_kem
172            .decapsulate(secret_key, &encrypted_message.kem_ciphertext)?;
173
174        // Step 3: Derive AES key from shared secret
175        let hk = Hkdf::<Sha256>::new(None, shared_secret.as_bytes());
176        let mut aes_key_bytes = [0u8; 32];
177        hk.expand(b"aes-256-gcm-key", &mut aes_key_bytes)
178            .map_err(|_| PqcError::CryptoError("HKDF expansion failed".to_string()))?;
179
180        // Step 4: Decrypt with AES-GCM
181        let key = Key::<Aes256Gcm>::from_slice(&aes_key_bytes);
182        let cipher = Aes256Gcm::new(key);
183        let nonce_obj = AesNonce::from_slice(&encrypted_message.nonce);
184
185        let plaintext = cipher
186            .decrypt(nonce_obj, encrypted_message.aes_ciphertext.as_slice())
187            .map_err(|_| PqcError::DecryptionFailed("AES-GCM decryption failed".to_string()))?;
188
189        Ok(plaintext)
190    }
191}
192
193/// Session-based encryption for multiple messages
194///
195/// Provides efficient encryption for multiple messages to the same recipient
196/// by caching the shared secret and deriving per-message keys.
197pub struct EncryptionSession {
198    /// Shared secret for the session
199    shared_secret: SharedSecret,
200    /// Counter for message sequencing and key derivation
201    message_counter: u64,
202}
203
204impl EncryptionSession {
205    /// Create a new encryption session
206    ///
207    /// # Arguments
208    /// * `public_key` - Recipient's public key
209    ///
210    /// # Returns
211    /// A tuple of (session, KEM ciphertext) where the ciphertext must be sent to the recipient
212    ///
213    /// # Errors
214    /// Returns `PqcError` if ML-KEM encapsulation fails.
215    pub fn new(public_key: &MlKemPublicKey) -> PqcResult<(Self, MlKemCiphertext)> {
216        let ml_kem = MlKem768::new();
217        let (kem_ciphertext, shared_secret) = ml_kem.encapsulate(public_key)?;
218
219        Ok((
220            Self {
221                shared_secret,
222                message_counter: 0,
223            },
224            kem_ciphertext,
225        ))
226    }
227
228    /// Encrypt a message in the session
229    ///
230    /// Each message gets a unique key derived from the session secret and counter
231    ///
232    /// # Errors
233    /// Returns `PqcError::CryptoError` if HKDF expansion fails.
234    /// Returns `PqcError::EncryptionFailed` if AES-GCM encryption fails.
235    pub fn encrypt_message(&mut self, plaintext: &[u8]) -> PqcResult<Vec<u8>> {
236        // Derive per-message key
237        let mut key_material = Vec::new();
238        key_material.extend_from_slice(self.shared_secret.as_bytes());
239        key_material.extend_from_slice(&self.message_counter.to_be_bytes());
240
241        let hk = Hkdf::<Sha256>::new(None, &key_material);
242        let mut aes_key = [0u8; 32];
243        hk.expand(b"message-key", &mut aes_key)
244            .map_err(|_| PqcError::CryptoError("HKDF expansion failed".to_string()))?;
245
246        // Generate nonce from counter
247        let mut nonce = [0u8; 12];
248        let counter_bytes = self.message_counter.to_be_bytes();
249        nonce[4..12].copy_from_slice(&counter_bytes);
250
251        // Encrypt
252        let key = Key::<Aes256Gcm>::from_slice(&aes_key);
253        let cipher = Aes256Gcm::new(key);
254        let nonce_obj = AesNonce::from_slice(&nonce);
255
256        let ciphertext = cipher
257            .encrypt(nonce_obj, plaintext)
258            .map_err(|_| PqcError::EncryptionFailed("Session encryption failed".to_string()))?;
259
260        self.message_counter = self.message_counter.saturating_add(1);
261
262        // Prepend counter for decryption
263        let mut result = Vec::with_capacity(8_usize.saturating_add(ciphertext.len()));
264        result.extend_from_slice(&(self.message_counter.saturating_sub(1)).to_be_bytes());
265        result.extend_from_slice(&ciphertext);
266
267        Ok(result)
268    }
269}
270
271/// Decryption session for multiple messages
272pub struct DecryptionSession {
273    /// Shared secret for the session
274    shared_secret: SharedSecret,
275    /// Track received message counters to prevent replay attacks
276    received_counters: HashMap<u64, bool>,
277}
278
279impl DecryptionSession {
280    /// Create a new decryption session
281    ///
282    /// # Arguments
283    /// * `secret_key` - Recipient's secret key
284    /// * `kem_ciphertext` - KEM ciphertext from sender
285    ///
286    /// # Errors
287    /// Returns `PqcError` if ML-KEM decapsulation fails.
288    pub fn new(secret_key: &MlKemSecretKey, kem_ciphertext: &MlKemCiphertext) -> PqcResult<Self> {
289        let ml_kem = MlKem768::new();
290        let shared_secret = ml_kem.decapsulate(secret_key, kem_ciphertext)?;
291
292        Ok(Self {
293            shared_secret,
294            received_counters: HashMap::new(),
295        })
296    }
297
298    /// Decrypt a message in the session
299    ///
300    /// # Errors
301    /// Returns `PqcError::DecryptionFailed` for invalid ciphertext, counter format errors, or replay attacks.
302    /// Returns `PqcError::CryptoError` if HKDF expansion fails.
303    /// Returns `PqcError::DecryptionFailed` if AES-GCM decryption fails.
304    pub fn decrypt_message(&mut self, ciphertext: &[u8]) -> PqcResult<Vec<u8>> {
305        if ciphertext.len() < 8 {
306            return Err(PqcError::DecryptionFailed("Invalid ciphertext".to_string()));
307        }
308
309        // Extract counter
310        let counter_slice = ciphertext.get(..8).ok_or_else(|| {
311            PqcError::DecryptionFailed("Ciphertext too short for counter".to_string())
312        })?;
313        let counter_bytes: [u8; 8] = counter_slice
314            .try_into()
315            .map_err(|_| PqcError::DecryptionFailed("Invalid counter format".to_string()))?;
316        let counter = u64::from_be_bytes(counter_bytes);
317
318        // Check for replay
319        if self.received_counters.contains_key(&counter) {
320            return Err(PqcError::DecryptionFailed("Replay detected".to_string()));
321        }
322
323        // Derive per-message key
324        let mut key_material = Vec::new();
325        key_material.extend_from_slice(self.shared_secret.as_bytes());
326        key_material.extend_from_slice(&counter.to_be_bytes());
327
328        let hk = Hkdf::<Sha256>::new(None, &key_material);
329        let mut aes_key = [0u8; 32];
330        hk.expand(b"message-key", &mut aes_key)
331            .map_err(|_| PqcError::CryptoError("HKDF expansion failed".to_string()))?;
332
333        // Generate nonce from counter
334        let mut nonce = [0u8; 12];
335        nonce[4..].copy_from_slice(&counter.to_be_bytes());
336
337        // Decrypt
338        let key = Key::<Aes256Gcm>::from_slice(&aes_key);
339        let cipher = Aes256Gcm::new(key);
340        let nonce_obj = AesNonce::from_slice(&nonce);
341
342        let ciphertext_slice = ciphertext
343            .get(8..)
344            .ok_or_else(|| PqcError::DecryptionFailed("Ciphertext too short".to_string()))?;
345        let plaintext = cipher
346            .decrypt(nonce_obj, ciphertext_slice)
347            .map_err(|_| PqcError::DecryptionFailed("Session decryption failed".to_string()))?;
348
349        // Mark counter as used
350        self.received_counters.insert(counter, true);
351
352        Ok(plaintext)
353    }
354}
355
356impl Default for HybridPublicKeyEncryption {
357    fn default() -> Self {
358        Self::new()
359    }
360}
361
362#[cfg(test)]
363#[allow(clippy::unwrap_used, clippy::expect_used)]
364mod tests {
365    use super::*;
366
367    #[test]
368    fn test_encryption_decryption_roundtrip() {
369        let pke = HybridPublicKeyEncryption::new();
370
371        // Generate keypair for testing
372        let (public_key, secret_key) = pke
373            .ml_kem
374            .generate_keypair()
375            .expect("Key generation should succeed");
376
377        let plaintext = b"Hello, quantum-resistant world!";
378        let associated_data = b"test-context";
379
380        // Encrypt
381        let encrypted = pke
382            .encrypt(&public_key, plaintext, associated_data)
383            .expect("Encryption should succeed");
384
385        // Decrypt
386        let decrypted = pke
387            .decrypt(&secret_key, &encrypted, associated_data)
388            .expect("Decryption should succeed");
389
390        assert_eq!(plaintext.to_vec(), decrypted);
391    }
392
393    #[test]
394    fn test_wrong_aad_fails() {
395        let pke = HybridPublicKeyEncryption::new();
396
397        let (public_key, secret_key) = pke.ml_kem.generate_keypair().unwrap();
398        let plaintext = b"Test message";
399        let aad = b"correct-aad";
400        let wrong_aad = b"wrong-aad";
401
402        let encrypted = pke.encrypt(&public_key, plaintext, aad).unwrap();
403
404        let result = pke.decrypt(&secret_key, &encrypted, wrong_aad);
405        assert!(result.is_err());
406    }
407
408    #[test]
409    fn test_session_encryption() {
410        let ml_kem = MlKem768::new();
411        let (public_key, secret_key) = ml_kem.generate_keypair().unwrap();
412
413        // Create session
414        let (mut enc_session, kem_ct) = EncryptionSession::new(&public_key).unwrap();
415        let mut dec_session = DecryptionSession::new(&secret_key, &kem_ct).unwrap();
416
417        // Encrypt and decrypt multiple messages
418        for i in 0..10 {
419            let plaintext = format!("Message {}", i);
420            let encrypted = enc_session.encrypt_message(plaintext.as_bytes()).unwrap();
421            let decrypted = dec_session.decrypt_message(&encrypted).unwrap();
422            assert_eq!(plaintext.as_bytes(), decrypted);
423        }
424    }
425
426    #[test]
427    fn test_session_replay_protection() {
428        let ml_kem = MlKem768::new();
429        let (public_key, secret_key) = ml_kem.generate_keypair().unwrap();
430
431        let (mut enc_session, kem_ct) = EncryptionSession::new(&public_key).unwrap();
432        let mut dec_session = DecryptionSession::new(&secret_key, &kem_ct).unwrap();
433
434        let plaintext = b"Test";
435        let encrypted = enc_session.encrypt_message(plaintext).unwrap();
436
437        // First decryption should succeed
438        let decrypted = dec_session.decrypt_message(&encrypted).unwrap();
439        assert_eq!(plaintext.to_vec(), decrypted);
440
441        // Replay should fail
442        let replay_result = dec_session.decrypt_message(&encrypted);
443        assert!(replay_result.is_err());
444    }
445
446    #[test]
447    fn test_unique_ciphertexts() {
448        let pke = HybridPublicKeyEncryption::new();
449        let (public_key, _secret_key) = pke.ml_kem.generate_keypair().unwrap();
450
451        let plaintext = b"Same message";
452        let aad = b"same-aad";
453
454        let encrypted1 = pke.encrypt(&public_key, plaintext, aad).unwrap();
455        let encrypted2 = pke.encrypt(&public_key, plaintext, aad).unwrap();
456
457        // Same plaintext should produce different ciphertexts (due to randomness)
458        assert_ne!(encrypted1.aes_ciphertext, encrypted2.aes_ciphertext);
459        assert_ne!(encrypted1.nonce, encrypted2.nonce);
460    }
461}