shield_core/
ratchet.rs

1//! Forward secrecy through key ratcheting.
2//!
3//! Each message uses a new key derived from previous.
4//! Compromise of current key doesn't reveal past messages.
5//!
6//! Based on Signal's Double Ratchet (simplified symmetric version).
7
8// Crypto block counters are intentionally u32 - data >4GB would have other issues
9#![allow(clippy::cast_possible_truncation)]
10
11use ring::{digest, hmac, rand::{SecureRandom, SystemRandom}};
12use subtle::ConstantTimeEq;
13use zeroize::{Zeroize, ZeroizeOnDrop};
14
15use crate::error::{Result, ShieldError};
16
17/// Ratcheting session for forward secrecy.
18///
19/// Each encrypt/decrypt advances the key chain,
20/// destroying previous keys automatically.
21///
22/// Chain keys are securely zeroized from memory when dropped.
23#[derive(Zeroize, ZeroizeOnDrop)]
24pub struct RatchetSession {
25    send_chain: [u8; 32],
26    recv_chain: [u8; 32],
27    #[zeroize(skip)]
28    send_counter: u64,
29    #[zeroize(skip)]
30    recv_counter: u64,
31}
32
33impl RatchetSession {
34    /// Create a new ratchet session from shared root key.
35    ///
36    /// # Arguments
37    /// * `root_key` - Shared secret from key exchange
38    /// * `is_initiator` - True if this party initiated the session
39    #[must_use]
40    pub fn new(root_key: &[u8; 32], is_initiator: bool) -> Self {
41        // Derive separate send/receive chains
42        let (send_label, recv_label) = if is_initiator {
43            (b"send", b"recv")
44        } else {
45            (b"recv", b"send")
46        };
47
48        let send_chain = derive_chain_key(root_key, send_label);
49        let recv_chain = derive_chain_key(root_key, recv_label);
50
51        Self {
52            send_chain,
53            recv_chain,
54            send_counter: 0,
55            recv_counter: 0,
56        }
57    }
58
59    /// Encrypt a message with forward secrecy.
60    ///
61    /// Advances the send chain - previous keys are destroyed.
62    pub fn encrypt(&mut self, plaintext: &[u8]) -> Result<Vec<u8>> {
63        // Ratchet send chain
64        let (new_chain, msg_key) = ratchet_chain(&self.send_chain);
65        self.send_chain = new_chain;
66
67        // Counter for ordering
68        let counter = self.send_counter;
69        self.send_counter += 1;
70
71        // Encrypt with message key
72        encrypt_with_key(&msg_key, plaintext, counter)
73    }
74
75    /// Decrypt a message with forward secrecy.
76    ///
77    /// Advances the receive chain - previous keys are destroyed.
78    pub fn decrypt(&mut self, ciphertext: &[u8]) -> Result<Vec<u8>> {
79        // Ratchet receive chain
80        let (new_chain, msg_key) = ratchet_chain(&self.recv_chain);
81        self.recv_chain = new_chain;
82
83        // Decrypt with message key
84        let (plaintext, counter) = decrypt_with_key(&msg_key, ciphertext)?;
85
86        // Verify counter (replay protection)
87        if counter != self.recv_counter {
88            return Err(ShieldError::RatchetError(format!(
89                "out of order message: expected {}, got {}",
90                self.recv_counter, counter
91            )));
92        }
93        self.recv_counter += 1;
94
95        Ok(plaintext)
96    }
97
98    /// Get send counter (for diagnostics).
99    #[must_use]
100    pub fn send_counter(&self) -> u64 {
101        self.send_counter
102    }
103
104    /// Get receive counter (for diagnostics).
105    #[must_use]
106    pub fn recv_counter(&self) -> u64 {
107        self.recv_counter
108    }
109}
110
111/// Derive chain key from root and label.
112fn derive_chain_key(root: &[u8; 32], label: &[u8]) -> [u8; 32] {
113    let mut data = Vec::with_capacity(root.len() + label.len());
114    data.extend_from_slice(root);
115    data.extend_from_slice(label);
116
117    let hash = digest::digest(&digest::SHA256, &data);
118    let mut result = [0u8; 32];
119    result.copy_from_slice(hash.as_ref());
120    result
121}
122
123/// Ratchet chain forward, returning (`new_chain`, `message_key`).
124fn ratchet_chain(chain_key: &[u8; 32]) -> ([u8; 32], [u8; 32]) {
125    // New chain key
126    let mut chain_data = Vec::with_capacity(chain_key.len() + 5);
127    chain_data.extend_from_slice(chain_key);
128    chain_data.extend_from_slice(b"chain");
129    let new_chain_hash = digest::digest(&digest::SHA256, &chain_data);
130    let mut new_chain = [0u8; 32];
131    new_chain.copy_from_slice(new_chain_hash.as_ref());
132
133    // Message key
134    let mut msg_data = Vec::with_capacity(chain_key.len() + 7);
135    msg_data.extend_from_slice(chain_key);
136    msg_data.extend_from_slice(b"message");
137    let msg_hash = digest::digest(&digest::SHA256, &msg_data);
138    let mut msg_key = [0u8; 32];
139    msg_key.copy_from_slice(msg_hash.as_ref());
140
141    (new_chain, msg_key)
142}
143
144/// Encrypt with message key (includes counter).
145fn encrypt_with_key(key: &[u8; 32], plaintext: &[u8], counter: u64) -> Result<Vec<u8>> {
146    let rng = SystemRandom::new();
147
148    // Generate nonce
149    let mut nonce = [0u8; 16];
150    rng.fill(&mut nonce).map_err(|_| ShieldError::RandomFailed)?;
151
152    // Counter header
153    let counter_bytes = counter.to_le_bytes();
154
155    // Data: counter || plaintext
156    let mut data = Vec::with_capacity(8 + plaintext.len());
157    data.extend_from_slice(&counter_bytes);
158    data.extend_from_slice(plaintext);
159
160    // Generate keystream
161    let mut keystream = Vec::with_capacity(data.len().div_ceil(32) * 32);
162    for i in 0..data.len().div_ceil(32) {
163        let block_counter = (i as u32).to_le_bytes();
164        let mut hash_input = Vec::with_capacity(key.len() + nonce.len() + 4);
165        hash_input.extend_from_slice(key);
166        hash_input.extend_from_slice(&nonce);
167        hash_input.extend_from_slice(&block_counter);
168        let hash = digest::digest(&digest::SHA256, &hash_input);
169        keystream.extend_from_slice(hash.as_ref());
170    }
171
172    // XOR encrypt
173    let ciphertext: Vec<u8> = data
174        .iter()
175        .zip(keystream.iter())
176        .map(|(p, k)| p ^ k)
177        .collect();
178
179    // HMAC
180    let hmac_key = hmac::Key::new(hmac::HMAC_SHA256, key);
181    let mut hmac_data = Vec::with_capacity(16 + ciphertext.len());
182    hmac_data.extend_from_slice(&nonce);
183    hmac_data.extend_from_slice(&ciphertext);
184    let tag = hmac::sign(&hmac_key, &hmac_data);
185
186    // Format: nonce(16) || ciphertext || mac(16)
187    let mut result = Vec::with_capacity(16 + ciphertext.len() + 16);
188    result.extend_from_slice(&nonce);
189    result.extend_from_slice(&ciphertext);
190    result.extend_from_slice(&tag.as_ref()[..16]);
191
192    Ok(result)
193}
194
195/// Decrypt with message key, returns (plaintext, counter).
196fn decrypt_with_key(key: &[u8; 32], encrypted: &[u8]) -> Result<(Vec<u8>, u64)> {
197    if encrypted.len() < 40 {
198        return Err(ShieldError::RatchetError("ciphertext too short".into()));
199    }
200
201    let nonce = &encrypted[..16];
202    let ciphertext = &encrypted[16..encrypted.len() - 16];
203    let mac = &encrypted[encrypted.len() - 16..];
204
205    // Verify MAC
206    let hmac_key = hmac::Key::new(hmac::HMAC_SHA256, key);
207    let mut hmac_data = Vec::with_capacity(16 + ciphertext.len());
208    hmac_data.extend_from_slice(nonce);
209    hmac_data.extend_from_slice(ciphertext);
210    let expected = hmac::sign(&hmac_key, &hmac_data);
211
212    if mac.ct_eq(&expected.as_ref()[..16]).unwrap_u8() != 1 {
213        return Err(ShieldError::AuthenticationFailed);
214    }
215
216    // Generate keystream
217    let mut keystream = Vec::with_capacity(ciphertext.len().div_ceil(32) * 32);
218    for i in 0..ciphertext.len().div_ceil(32) {
219        let block_counter = (i as u32).to_le_bytes();
220        let mut hash_input = Vec::with_capacity(key.len() + nonce.len() + 4);
221        hash_input.extend_from_slice(key);
222        hash_input.extend_from_slice(nonce);
223        hash_input.extend_from_slice(&block_counter);
224        let hash = digest::digest(&digest::SHA256, &hash_input);
225        keystream.extend_from_slice(hash.as_ref());
226    }
227
228    // XOR decrypt
229    let decrypted: Vec<u8> = ciphertext
230        .iter()
231        .zip(keystream.iter())
232        .map(|(c, k)| c ^ k)
233        .collect();
234
235    // Parse counter
236    let counter = u64::from_le_bytes([
237        decrypted[0], decrypted[1], decrypted[2], decrypted[3],
238        decrypted[4], decrypted[5], decrypted[6], decrypted[7],
239    ]);
240
241    Ok((decrypted[8..].to_vec(), counter))
242}
243
244#[cfg(test)]
245mod tests {
246    use super::*;
247
248    #[test]
249    fn test_ratchet_roundtrip() {
250        let root = [0x42u8; 32];
251        let mut alice = RatchetSession::new(&root, true);
252        let mut bob = RatchetSession::new(&root, false);
253
254        let msg1 = b"Hello Bob!";
255        let enc1 = alice.encrypt(msg1).unwrap();
256        let dec1 = bob.decrypt(&enc1).unwrap();
257        assert_eq!(msg1.as_slice(), dec1.as_slice());
258
259        let msg2 = b"Second message";
260        let enc2 = alice.encrypt(msg2).unwrap();
261        let dec2 = bob.decrypt(&enc2).unwrap();
262        assert_eq!(msg2.as_slice(), dec2.as_slice());
263    }
264
265    #[test]
266    fn test_ratchet_counters() {
267        let root = [0x42u8; 32];
268        let mut alice = RatchetSession::new(&root, true);
269        let mut bob = RatchetSession::new(&root, false);
270
271        assert_eq!(alice.send_counter(), 0);
272        assert_eq!(bob.recv_counter(), 0);
273
274        let enc = alice.encrypt(b"test").unwrap();
275        assert_eq!(alice.send_counter(), 1);
276
277        bob.decrypt(&enc).unwrap();
278        assert_eq!(bob.recv_counter(), 1);
279    }
280
281    #[test]
282    fn test_ratchet_different_ciphertexts() {
283        let root = [0x42u8; 32];
284        let mut alice = RatchetSession::new(&root, true);
285
286        let enc1 = alice.encrypt(b"same message").unwrap();
287        let enc2 = alice.encrypt(b"same message").unwrap();
288
289        // Different ciphertext for same plaintext (forward secrecy)
290        assert_ne!(enc1, enc2);
291    }
292
293    #[test]
294    fn test_ratchet_replay_detection() {
295        let root = [0x42u8; 32];
296        let mut alice = RatchetSession::new(&root, true);
297        let mut bob = RatchetSession::new(&root, false);
298
299        // Send two messages
300        let _enc1 = alice.encrypt(b"first").unwrap();
301        let enc2 = alice.encrypt(b"second").unwrap();
302
303        // Try to decrypt second message first (out of order)
304        // This should fail because Bob expects counter 0, but gets counter 1
305        assert!(bob.decrypt(&enc2).is_err());
306    }
307}