saorsa_pqc/pqc/
encryption.rs

1//! Hybrid encryption module combining ML-KEM-768 with AES-256-GCM
2//!
3//! This module provides quantum-resistant encryption by combining:
4//! - ML-KEM-768 for key encapsulation (post-quantum secure)
5//! - AES-256-GCM for symmetric encryption (classically secure)
6//!
7//! The hybrid approach ensures security against both classical and quantum attackers,
8//! following the "belt and suspenders" principle recommended by NIST.
9//!
10//! # Security Features
11//!
12//! - **Post-quantum KEM**: ML-KEM-768 provides ~192-bit quantum security
13//! - **Authenticated encryption**: AES-256-GCM provides confidentiality and integrity
14//! - **Forward secrecy**: Each message uses a unique ephemeral key
15//! - **Key separation**: Separate keys for encryption and authentication
16//! - **Domain separation**: Context-specific key derivation
17//!
18//! # Implementation Details
19//!
20//! - Uses ML-KEM-768 for key encapsulation (1088-byte ciphertext)
21//! - Derives AES-256 key using HKDF-SHA256
22//! - 96-bit nonces for AES-GCM (safe for 2^32 encryptions)
23//! - HKDF-SHA256 for proper key derivation (NIST SP 800-56C Rev. 2)
24//! - Constant-time operations where possible
25
26use crate::pqc::types::{
27    MlKemCiphertext, MlKemPublicKey, MlKemSecretKey, PqcError, PqcResult, SharedSecret,
28};
29use crate::pqc::{ml_kem::MlKem768, MlKemOperations};
30use aes_gcm::{
31    aead::{Aead, KeyInit},
32    Aes256Gcm, Key, Nonce as AesNonce,
33};
34use hkdf::Hkdf;
35use rand::{thread_rng, RngCore};
36use sha2::{Digest, Sha256};
37use std::collections::HashMap;
38
39/// Wire format for encrypted messages
40///
41/// Contains all necessary components for decryption:
42/// - ML-KEM ciphertext for key encapsulation
43/// - AES-GCM ciphertext with authentication tag
44/// - Nonce for AES-GCM
45/// - Associated data hash for integrity
46#[derive(Debug, Clone)]
47pub struct EncryptedMessage {
48    /// ML-KEM-768 ciphertext (1088 bytes)
49    pub kem_ciphertext: MlKemCiphertext,
50    /// AES-GCM encrypted data with authentication tag
51    pub aes_ciphertext: Vec<u8>,
52    /// AES-GCM nonce (12 bytes)
53    pub nonce: [u8; 12],
54    /// SHA-256 hash of associated data for verification
55    pub aad_hash: [u8; 32],
56}
57
58/// Hybrid Public Key Encryption using ML-KEM and AES-GCM
59///
60/// Provides CCA2-secure public key encryption by combining:
61/// - ML-KEM-768 for key encapsulation
62/// - AES-256-GCM for data encryption
63/// - HKDF for key derivation
64pub struct HybridPublicKeyEncryption {
65    /// ML-KEM-768 instance for key encapsulation
66    ml_kem: MlKem768,
67}
68
69impl HybridPublicKeyEncryption {
70    /// Create a new hybrid encryption instance
71    #[must_use]
72    pub const fn new() -> Self {
73        Self {
74            ml_kem: MlKem768::new(),
75        }
76    }
77
78    /// Generate a new keypair for hybrid encryption
79    ///
80    /// # Errors
81    /// Returns `PqcError` if keypair generation fails
82    pub fn generate_keypair(&self) -> PqcResult<(MlKemPublicKey, MlKemSecretKey)> {
83        self.ml_kem.generate_keypair()
84    }
85
86    /// Encrypt a message using the hybrid scheme
87    ///
88    /// # Arguments
89    /// * `public_key` - Recipient's ML-KEM-768 public key
90    /// * `plaintext` - Message to encrypt
91    /// * `associated_data` - Additional authenticated data (not encrypted)
92    ///
93    /// # Returns
94    /// An `EncryptedMessage` containing all components needed for decryption
95    ///
96    /// # Errors
97    /// Returns `PqcError` if encapsulation or encryption fails
98    pub fn encrypt(
99        &self,
100        public_key: &MlKemPublicKey,
101        plaintext: &[u8],
102        associated_data: &[u8],
103    ) -> PqcResult<EncryptedMessage> {
104        // Step 1: Encapsulate to get shared secret
105        let (kem_ciphertext, shared_secret) = self.ml_kem.encapsulate(public_key)?;
106
107        // Step 2: Derive AES key from shared secret using HKDF
108        let hk = Hkdf::<Sha256>::new(None, shared_secret.as_bytes());
109        let mut aes_key_bytes = [0u8; 32];
110        hk.expand(b"aes-256-gcm-key", &mut aes_key_bytes)
111            .map_err(|_| PqcError::CryptoError("HKDF expansion failed".to_string()))?;
112
113        // Step 3: Generate random nonce
114        let mut nonce = [0u8; 12];
115        thread_rng().fill_bytes(&mut nonce);
116
117        // Step 4: Encrypt with AES-GCM
118        let key = Key::<Aes256Gcm>::from_slice(&aes_key_bytes);
119        let cipher = Aes256Gcm::new(key);
120        let nonce_obj = AesNonce::from_slice(&nonce);
121
122        let aes_ciphertext = cipher
123            .encrypt(nonce_obj, plaintext)
124            .map_err(|_| PqcError::EncryptionFailed("AES-GCM encryption failed".to_string()))?;
125
126        // Step 5: Hash associated data for integrity check
127        let mut hasher = Sha256::new();
128        hasher.update(associated_data);
129        let aad_hash: [u8; 32] = hasher.finalize().into();
130
131        Ok(EncryptedMessage {
132            kem_ciphertext,
133            aes_ciphertext,
134            nonce,
135            aad_hash,
136        })
137    }
138
139    /// Decrypt a message using the hybrid scheme
140    ///
141    /// # Arguments
142    /// * `secret_key` - Recipient's ML-KEM-768 secret key
143    /// * `encrypted_message` - The encrypted message to decrypt
144    /// * `associated_data` - Additional authenticated data for verification
145    ///
146    /// # Returns
147    /// The decrypted plaintext if successful
148    pub fn decrypt(
149        &self,
150        secret_key: &MlKemSecretKey,
151        encrypted_message: &EncryptedMessage,
152        associated_data: &[u8],
153    ) -> PqcResult<Vec<u8>> {
154        // Step 1: Verify associated data hash
155        let mut hasher = Sha256::new();
156        hasher.update(associated_data);
157        let computed_hash: [u8; 32] = hasher.finalize().into();
158
159        if computed_hash != encrypted_message.aad_hash {
160            return Err(PqcError::DecryptionFailed(
161                "Associated data verification failed".to_string(),
162            ));
163        }
164
165        // Step 2: Decapsulate to recover shared secret
166        let shared_secret = self
167            .ml_kem
168            .decapsulate(secret_key, &encrypted_message.kem_ciphertext)?;
169
170        // Step 3: Derive AES key from shared secret
171        let hk = Hkdf::<Sha256>::new(None, shared_secret.as_bytes());
172        let mut aes_key_bytes = [0u8; 32];
173        hk.expand(b"aes-256-gcm-key", &mut aes_key_bytes)
174            .map_err(|_| PqcError::CryptoError("HKDF expansion failed".to_string()))?;
175
176        // Step 4: Decrypt with AES-GCM
177        let key = Key::<Aes256Gcm>::from_slice(&aes_key_bytes);
178        let cipher = Aes256Gcm::new(key);
179        let nonce_obj = AesNonce::from_slice(&encrypted_message.nonce);
180
181        let plaintext = cipher
182            .decrypt(nonce_obj, encrypted_message.aes_ciphertext.as_slice())
183            .map_err(|_| PqcError::DecryptionFailed("AES-GCM decryption failed".to_string()))?;
184
185        Ok(plaintext)
186    }
187}
188
189/// Session-based encryption for multiple messages
190///
191/// Provides efficient encryption for multiple messages to the same recipient
192/// by caching the shared secret and deriving per-message keys.
193pub struct EncryptionSession {
194    /// Shared secret for the session
195    shared_secret: SharedSecret,
196    /// Counter for message sequencing and key derivation
197    message_counter: u64,
198}
199
200impl EncryptionSession {
201    /// Create a new encryption session
202    ///
203    /// # Arguments
204    /// * `public_key` - Recipient's public key
205    ///
206    /// # Returns
207    /// A tuple of (session, KEM ciphertext) where the ciphertext must be sent to the recipient
208    pub fn new(public_key: &MlKemPublicKey) -> PqcResult<(Self, MlKemCiphertext)> {
209        let ml_kem = MlKem768::new();
210        let (kem_ciphertext, shared_secret) = ml_kem.encapsulate(public_key)?;
211
212        Ok((
213            Self {
214                shared_secret,
215                message_counter: 0,
216            },
217            kem_ciphertext,
218        ))
219    }
220
221    /// Encrypt a message in the session
222    ///
223    /// Each message gets a unique key derived from the session secret and counter
224    pub fn encrypt_message(&mut self, plaintext: &[u8]) -> PqcResult<Vec<u8>> {
225        // Derive per-message key
226        let mut key_material = Vec::new();
227        key_material.extend_from_slice(self.shared_secret.as_bytes());
228        key_material.extend_from_slice(&self.message_counter.to_be_bytes());
229
230        let hk = Hkdf::<Sha256>::new(None, &key_material);
231        let mut aes_key = [0u8; 32];
232        hk.expand(b"message-key", &mut aes_key)
233            .map_err(|_| PqcError::CryptoError("HKDF expansion failed".to_string()))?;
234
235        // Generate nonce from counter
236        let mut nonce = [0u8; 12];
237        nonce[4..].copy_from_slice(&self.message_counter.to_be_bytes());
238
239        // Encrypt
240        let key = Key::<Aes256Gcm>::from_slice(&aes_key);
241        let cipher = Aes256Gcm::new(key);
242        let nonce_obj = AesNonce::from_slice(&nonce);
243
244        let ciphertext = cipher
245            .encrypt(nonce_obj, plaintext)
246            .map_err(|_| PqcError::EncryptionFailed("Session encryption failed".to_string()))?;
247
248        self.message_counter += 1;
249
250        // Prepend counter for decryption
251        let mut result = Vec::with_capacity(8 + ciphertext.len());
252        result.extend_from_slice(&(self.message_counter - 1).to_be_bytes());
253        result.extend_from_slice(&ciphertext);
254
255        Ok(result)
256    }
257}
258
259/// Decryption session for multiple messages
260pub struct DecryptionSession {
261    /// Shared secret for the session
262    shared_secret: SharedSecret,
263    /// Track received message counters to prevent replay attacks
264    received_counters: HashMap<u64, bool>,
265}
266
267impl DecryptionSession {
268    /// Create a new decryption session
269    ///
270    /// # Arguments
271    /// * `secret_key` - Recipient's secret key
272    /// * `kem_ciphertext` - KEM ciphertext from sender
273    pub fn new(secret_key: &MlKemSecretKey, kem_ciphertext: &MlKemCiphertext) -> PqcResult<Self> {
274        let ml_kem = MlKem768::new();
275        let shared_secret = ml_kem.decapsulate(secret_key, kem_ciphertext)?;
276
277        Ok(Self {
278            shared_secret,
279            received_counters: HashMap::new(),
280        })
281    }
282
283    /// Decrypt a message in the session
284    pub fn decrypt_message(&mut self, ciphertext: &[u8]) -> PqcResult<Vec<u8>> {
285        if ciphertext.len() < 8 {
286            return Err(PqcError::DecryptionFailed("Invalid ciphertext".to_string()));
287        }
288
289        // Extract counter
290        let counter = u64::from_be_bytes(ciphertext[..8].try_into().unwrap());
291
292        // Check for replay
293        if self.received_counters.contains_key(&counter) {
294            return Err(PqcError::DecryptionFailed("Replay detected".to_string()));
295        }
296
297        // Derive per-message key
298        let mut key_material = Vec::new();
299        key_material.extend_from_slice(self.shared_secret.as_bytes());
300        key_material.extend_from_slice(&counter.to_be_bytes());
301
302        let hk = Hkdf::<Sha256>::new(None, &key_material);
303        let mut aes_key = [0u8; 32];
304        hk.expand(b"message-key", &mut aes_key)
305            .map_err(|_| PqcError::CryptoError("HKDF expansion failed".to_string()))?;
306
307        // Generate nonce from counter
308        let mut nonce = [0u8; 12];
309        nonce[4..].copy_from_slice(&counter.to_be_bytes());
310
311        // Decrypt
312        let key = Key::<Aes256Gcm>::from_slice(&aes_key);
313        let cipher = Aes256Gcm::new(key);
314        let nonce_obj = AesNonce::from_slice(&nonce);
315
316        let plaintext = cipher
317            .decrypt(nonce_obj, &ciphertext[8..])
318            .map_err(|_| PqcError::DecryptionFailed("Session decryption failed".to_string()))?;
319
320        // Mark counter as used
321        self.received_counters.insert(counter, true);
322
323        Ok(plaintext)
324    }
325}
326
327impl Default for HybridPublicKeyEncryption {
328    fn default() -> Self {
329        Self::new()
330    }
331}
332
333#[cfg(test)]
334mod tests {
335    use super::*;
336
337    #[test]
338    fn test_encryption_decryption_roundtrip() {
339        let pke = HybridPublicKeyEncryption::new();
340
341        // Generate keypair for testing
342        let (public_key, secret_key) = pke
343            .ml_kem
344            .generate_keypair()
345            .expect("Key generation should succeed");
346
347        let plaintext = b"Hello, quantum-resistant world!";
348        let associated_data = b"test-context";
349
350        // Encrypt
351        let encrypted = pke
352            .encrypt(&public_key, plaintext, associated_data)
353            .expect("Encryption should succeed");
354
355        // Decrypt
356        let decrypted = pke
357            .decrypt(&secret_key, &encrypted, associated_data)
358            .expect("Decryption should succeed");
359
360        assert_eq!(plaintext.to_vec(), decrypted);
361    }
362
363    #[test]
364    fn test_wrong_aad_fails() {
365        let pke = HybridPublicKeyEncryption::new();
366
367        let (public_key, secret_key) = pke.ml_kem.generate_keypair().unwrap();
368        let plaintext = b"Test message";
369        let aad = b"correct-aad";
370        let wrong_aad = b"wrong-aad";
371
372        let encrypted = pke.encrypt(&public_key, plaintext, aad).unwrap();
373
374        let result = pke.decrypt(&secret_key, &encrypted, wrong_aad);
375        assert!(result.is_err());
376    }
377
378    #[test]
379    fn test_session_encryption() {
380        let ml_kem = MlKem768::new();
381        let (public_key, secret_key) = ml_kem.generate_keypair().unwrap();
382
383        // Create session
384        let (mut enc_session, kem_ct) = EncryptionSession::new(&public_key).unwrap();
385        let mut dec_session = DecryptionSession::new(&secret_key, &kem_ct).unwrap();
386
387        // Encrypt and decrypt multiple messages
388        for i in 0..10 {
389            let plaintext = format!("Message {}", i);
390            let encrypted = enc_session.encrypt_message(plaintext.as_bytes()).unwrap();
391            let decrypted = dec_session.decrypt_message(&encrypted).unwrap();
392            assert_eq!(plaintext.as_bytes(), decrypted);
393        }
394    }
395
396    #[test]
397    fn test_session_replay_protection() {
398        let ml_kem = MlKem768::new();
399        let (public_key, secret_key) = ml_kem.generate_keypair().unwrap();
400
401        let (mut enc_session, kem_ct) = EncryptionSession::new(&public_key).unwrap();
402        let mut dec_session = DecryptionSession::new(&secret_key, &kem_ct).unwrap();
403
404        let plaintext = b"Test";
405        let encrypted = enc_session.encrypt_message(plaintext).unwrap();
406
407        // First decryption should succeed
408        let decrypted = dec_session.decrypt_message(&encrypted).unwrap();
409        assert_eq!(plaintext.to_vec(), decrypted);
410
411        // Replay should fail
412        let replay_result = dec_session.decrypt_message(&encrypted);
413        assert!(replay_result.is_err());
414    }
415
416    #[test]
417    fn test_unique_ciphertexts() {
418        let pke = HybridPublicKeyEncryption::new();
419        let (public_key, _secret_key) = pke.ml_kem.generate_keypair().unwrap();
420
421        let plaintext = b"Same message";
422        let aad = b"same-aad";
423
424        let encrypted1 = pke.encrypt(&public_key, plaintext, aad).unwrap();
425        let encrypted2 = pke.encrypt(&public_key, plaintext, aad).unwrap();
426
427        // Same plaintext should produce different ciphertexts (due to randomness)
428        assert_ne!(encrypted1.aes_ciphertext, encrypted2.aes_ciphertext);
429        assert_ne!(encrypted1.nonce, encrypted2.nonce);
430    }
431}