Skip to main content

shield_core/
ratchet.rs

1//! Forward secrecy through key ratcheting.
2//!
3//! Each message uses a new key derived from previous.
4//! Compromise of current key doesn't reveal past messages.
5//!
6//! Based on Signal's Double Ratchet (simplified symmetric version).
7
8// Crypto block counters are intentionally u32 - data >4GB would have other issues
9#![allow(clippy::cast_possible_truncation)]
10
11use ring::hmac;
12use subtle::ConstantTimeEq;
13use zeroize::{Zeroize, ZeroizeOnDrop};
14
15use crate::error::{Result, ShieldError};
16
17/// Ratcheting session for forward secrecy.
18///
19/// Each encrypt/decrypt advances the key chain,
20/// destroying previous keys automatically.
21///
22/// Chain keys are securely zeroized from memory when dropped.
23#[derive(Zeroize, ZeroizeOnDrop)]
24pub struct RatchetSession {
25    send_chain: [u8; 32],
26    recv_chain: [u8; 32],
27    #[zeroize(skip)]
28    send_counter: u64,
29    #[zeroize(skip)]
30    recv_counter: u64,
31}
32
33impl RatchetSession {
34    /// Create a new ratchet session from shared root key.
35    ///
36    /// # Arguments
37    /// * `root_key` - Shared secret from key exchange
38    /// * `is_initiator` - True if this party initiated the session
39    #[must_use]
40    pub fn new(root_key: &[u8; 32], is_initiator: bool) -> Self {
41        // Derive separate send/receive chains
42        let (send_label, recv_label) = if is_initiator {
43            (b"send", b"recv")
44        } else {
45            (b"recv", b"send")
46        };
47
48        let send_chain = derive_chain_key(root_key, send_label);
49        let recv_chain = derive_chain_key(root_key, recv_label);
50
51        Self {
52            send_chain,
53            recv_chain,
54            send_counter: 0,
55            recv_counter: 0,
56        }
57    }
58
59    /// Encrypt a message with forward secrecy.
60    ///
61    /// Advances the send chain - previous keys are destroyed.
62    pub fn encrypt(&mut self, plaintext: &[u8]) -> Result<Vec<u8>> {
63        // Ratchet send chain
64        let (new_chain, msg_key) = ratchet_chain(&self.send_chain);
65        self.send_chain = new_chain;
66
67        // Counter for ordering
68        let counter = self.send_counter;
69        self.send_counter += 1;
70
71        // Encrypt with message key
72        encrypt_with_key(&msg_key, plaintext, counter)
73    }
74
75    /// Decrypt a message with forward secrecy.
76    ///
77    /// Advances the receive chain only AFTER successful MAC and counter
78    /// verification. A forged packet will not desynchronize the session.
79    pub fn decrypt(&mut self, ciphertext: &[u8]) -> Result<Vec<u8>> {
80        // Speculatively compute next chain state without committing
81        let (new_chain, msg_key) = ratchet_chain(&self.recv_chain);
82
83        // Decrypt and verify MAC — if this fails, chain is untouched
84        let (plaintext, counter) = decrypt_with_key(&msg_key, ciphertext)?;
85
86        // Verify counter (replay protection) — if this fails, chain is untouched
87        if counter != self.recv_counter {
88            return Err(ShieldError::RatchetError(format!(
89                "out of order message: expected {}, got {}",
90                self.recv_counter, counter
91            )));
92        }
93
94        // All checks passed — now commit the chain advance
95        self.recv_chain = new_chain;
96        self.recv_counter += 1;
97
98        Ok(plaintext)
99    }
100
101    /// Get send counter (for diagnostics).
102    #[must_use]
103    pub fn send_counter(&self) -> u64 {
104        self.send_counter
105    }
106
107    /// Get receive counter (for diagnostics).
108    #[must_use]
109    pub fn recv_counter(&self) -> u64 {
110        self.recv_counter
111    }
112}
113
114/// Derive chain key from root and label using HMAC-SHA256 (keyed PRF).
115fn derive_chain_key(root: &[u8; 32], label: &[u8]) -> [u8; 32] {
116    let hmac_key = hmac::Key::new(hmac::HMAC_SHA256, root);
117    let tag = hmac::sign(&hmac_key, label);
118    let mut result = [0u8; 32];
119    result.copy_from_slice(&tag.as_ref()[..32]);
120    result
121}
122
123/// Ratchet chain forward using HMAC-SHA256, returning (`new_chain`, `message_key`).
124fn ratchet_chain(chain_key: &[u8; 32]) -> ([u8; 32], [u8; 32]) {
125    let hmac_key = hmac::Key::new(hmac::HMAC_SHA256, chain_key);
126
127    // New chain key = HMAC(chain_key, "chain")
128    let new_chain_tag = hmac::sign(&hmac_key, b"chain");
129    let mut new_chain = [0u8; 32];
130    new_chain.copy_from_slice(&new_chain_tag.as_ref()[..32]);
131
132    // Message key = HMAC(chain_key, "message")
133    let msg_tag = hmac::sign(&hmac_key, b"message");
134    let mut msg_key = [0u8; 32];
135    msg_key.copy_from_slice(&msg_tag.as_ref()[..32]);
136
137    (new_chain, msg_key)
138}
139
140/// Derive separated encryption and MAC subkeys from a message key using HMAC-SHA256.
141fn derive_msg_subkeys(key: &[u8; 32]) -> ([u8; 32], [u8; 32]) {
142    let hmac_key = hmac::Key::new(hmac::HMAC_SHA256, key);
143
144    let enc_tag = hmac::sign(&hmac_key, b"shield-ratchet-encrypt");
145    let mut enc_key = [0u8; 32];
146    enc_key.copy_from_slice(&enc_tag.as_ref()[..32]);
147
148    let mac_tag = hmac::sign(&hmac_key, b"shield-ratchet-authenticate");
149    let mut mac_key = [0u8; 32];
150    mac_key.copy_from_slice(&mac_tag.as_ref()[..32]);
151
152    (enc_key, mac_key)
153}
154
155/// Encrypt with message key (includes counter).
156fn encrypt_with_key(key: &[u8; 32], plaintext: &[u8], counter: u64) -> Result<Vec<u8>> {
157    let (enc_key, mac_key) = derive_msg_subkeys(key);
158
159    // Generate nonce
160    let nonce: [u8; 16] = crate::random::random_bytes()?;
161
162    // Counter header
163    let counter_bytes = counter.to_le_bytes();
164
165    // Data: counter || plaintext
166    let mut data = Vec::with_capacity(8 + plaintext.len());
167    data.extend_from_slice(&counter_bytes);
168    data.extend_from_slice(plaintext);
169
170    // Generate keystream using HMAC-SHA256 (keyed PRF) with enc_key
171    let num_blocks = data.len().div_ceil(32);
172    if u32::try_from(num_blocks).is_err() {
173        return Err(ShieldError::RatchetError(
174            "keystream too long: counter overflow".into(),
175        ));
176    }
177    let hmac_enc_key = hmac::Key::new(hmac::HMAC_SHA256, &enc_key);
178    let mut keystream = Vec::with_capacity(num_blocks * 32);
179    for i in 0..num_blocks {
180        let block_counter = (i as u32).to_le_bytes();
181        let mut block_data = Vec::with_capacity(nonce.len() + 4);
182        block_data.extend_from_slice(&nonce);
183        block_data.extend_from_slice(&block_counter);
184        let tag = hmac::sign(&hmac_enc_key, &block_data);
185        keystream.extend_from_slice(tag.as_ref());
186    }
187
188    // XOR encrypt
189    let ciphertext: Vec<u8> = data
190        .iter()
191        .zip(keystream.iter())
192        .map(|(p, k)| p ^ k)
193        .collect();
194
195    // HMAC with mac_key
196    let hmac_mac_key = hmac::Key::new(hmac::HMAC_SHA256, &mac_key);
197    let mut hmac_data = Vec::with_capacity(16 + ciphertext.len());
198    hmac_data.extend_from_slice(&nonce);
199    hmac_data.extend_from_slice(&ciphertext);
200    let tag = hmac::sign(&hmac_mac_key, &hmac_data);
201
202    // Format: nonce(16) || ciphertext || mac(16)
203    let mut result = Vec::with_capacity(16 + ciphertext.len() + 16);
204    result.extend_from_slice(&nonce);
205    result.extend_from_slice(&ciphertext);
206    result.extend_from_slice(&tag.as_ref()[..16]);
207
208    Ok(result)
209}
210
211/// Decrypt with message key, returns (plaintext, counter).
212fn decrypt_with_key(key: &[u8; 32], encrypted: &[u8]) -> Result<(Vec<u8>, u64)> {
213    if encrypted.len() < 40 {
214        return Err(ShieldError::RatchetError("ciphertext too short".into()));
215    }
216
217    let (enc_key, mac_key) = derive_msg_subkeys(key);
218
219    let nonce = &encrypted[..16];
220    let ciphertext = &encrypted[16..encrypted.len() - 16];
221    let mac = &encrypted[encrypted.len() - 16..];
222
223    // Verify MAC with mac_key
224    let hmac_mac_key = hmac::Key::new(hmac::HMAC_SHA256, &mac_key);
225    let mut hmac_data = Vec::with_capacity(16 + ciphertext.len());
226    hmac_data.extend_from_slice(nonce);
227    hmac_data.extend_from_slice(ciphertext);
228    let expected = hmac::sign(&hmac_mac_key, &hmac_data);
229
230    if mac.ct_eq(&expected.as_ref()[..16]).unwrap_u8() != 1 {
231        return Err(ShieldError::AuthenticationFailed);
232    }
233
234    // Generate keystream using HMAC-SHA256 (keyed PRF) with enc_key
235    let num_blocks = ciphertext.len().div_ceil(32);
236    if u32::try_from(num_blocks).is_err() {
237        return Err(ShieldError::RatchetError(
238            "keystream too long: counter overflow".into(),
239        ));
240    }
241    let hmac_enc_key = hmac::Key::new(hmac::HMAC_SHA256, &enc_key);
242    let mut keystream = Vec::with_capacity(num_blocks * 32);
243    for i in 0..num_blocks {
244        let block_counter = (i as u32).to_le_bytes();
245        let mut block_data = Vec::with_capacity(nonce.len() + 4);
246        block_data.extend_from_slice(nonce);
247        block_data.extend_from_slice(&block_counter);
248        let tag = hmac::sign(&hmac_enc_key, &block_data);
249        keystream.extend_from_slice(tag.as_ref());
250    }
251
252    // XOR decrypt
253    let decrypted: Vec<u8> = ciphertext
254        .iter()
255        .zip(keystream.iter())
256        .map(|(c, k)| c ^ k)
257        .collect();
258
259    // Parse counter
260    let counter = u64::from_le_bytes([
261        decrypted[0],
262        decrypted[1],
263        decrypted[2],
264        decrypted[3],
265        decrypted[4],
266        decrypted[5],
267        decrypted[6],
268        decrypted[7],
269    ]);
270
271    Ok((decrypted[8..].to_vec(), counter))
272}
273
274#[cfg(test)]
275mod tests {
276    use super::*;
277
278    #[test]
279    fn test_ratchet_roundtrip() {
280        let root = [0x42u8; 32];
281        let mut alice = RatchetSession::new(&root, true);
282        let mut bob = RatchetSession::new(&root, false);
283
284        let msg1 = b"Hello Bob!";
285        let enc1 = alice.encrypt(msg1).unwrap();
286        let dec1 = bob.decrypt(&enc1).unwrap();
287        assert_eq!(msg1.as_slice(), dec1.as_slice());
288
289        let msg2 = b"Second message";
290        let enc2 = alice.encrypt(msg2).unwrap();
291        let dec2 = bob.decrypt(&enc2).unwrap();
292        assert_eq!(msg2.as_slice(), dec2.as_slice());
293    }
294
295    #[test]
296    fn test_ratchet_counters() {
297        let root = [0x42u8; 32];
298        let mut alice = RatchetSession::new(&root, true);
299        let mut bob = RatchetSession::new(&root, false);
300
301        assert_eq!(alice.send_counter(), 0);
302        assert_eq!(bob.recv_counter(), 0);
303
304        let enc = alice.encrypt(b"test").unwrap();
305        assert_eq!(alice.send_counter(), 1);
306
307        bob.decrypt(&enc).unwrap();
308        assert_eq!(bob.recv_counter(), 1);
309    }
310
311    #[test]
312    fn test_ratchet_different_ciphertexts() {
313        let root = [0x42u8; 32];
314        let mut alice = RatchetSession::new(&root, true);
315
316        let enc1 = alice.encrypt(b"same message").unwrap();
317        let enc2 = alice.encrypt(b"same message").unwrap();
318
319        // Different ciphertext for same plaintext (forward secrecy)
320        assert_ne!(enc1, enc2);
321    }
322
323    #[test]
324    fn test_ratchet_survives_forged_packet() {
325        let root = [0x42u8; 32];
326        let mut alice = RatchetSession::new(&root, true);
327        let mut bob = RatchetSession::new(&root, false);
328
329        // Alice sends a legit message
330        let enc1 = alice.encrypt(b"first").unwrap();
331
332        // Attacker injects a forged packet — should fail MAC
333        let forged = vec![0xFFu8; 64];
334        assert!(bob.decrypt(&forged).is_err());
335
336        // Bob's chain should NOT be desynchronized — legit message still works
337        let dec1 = bob.decrypt(&enc1).unwrap();
338        assert_eq!(b"first".as_slice(), dec1.as_slice());
339
340        // Subsequent messages also work
341        let enc2 = alice.encrypt(b"second").unwrap();
342        let dec2 = bob.decrypt(&enc2).unwrap();
343        assert_eq!(b"second".as_slice(), dec2.as_slice());
344    }
345
346    #[test]
347    fn test_ratchet_replay_detection() {
348        let root = [0x42u8; 32];
349        let mut alice = RatchetSession::new(&root, true);
350        let mut bob = RatchetSession::new(&root, false);
351
352        // Send two messages
353        let _enc1 = alice.encrypt(b"first").unwrap();
354        let enc2 = alice.encrypt(b"second").unwrap();
355
356        // Try to decrypt second message first (out of order)
357        // This should fail because Bob expects counter 0, but gets counter 1
358        assert!(bob.decrypt(&enc2).is_err());
359    }
360}