shield_core/
ratchet.rs

1//! Forward secrecy through key ratcheting.
2//!
3//! Each message uses a new key derived from previous.
4//! Compromise of current key doesn't reveal past messages.
5//!
6//! Based on Signal's Double Ratchet (simplified symmetric version).
7
8// Crypto block counters are intentionally u32 - data >4GB would have other issues
9#![allow(clippy::cast_possible_truncation)]
10
11use ring::{
12    digest, hmac,
13    rand::{SecureRandom, SystemRandom},
14};
15use subtle::ConstantTimeEq;
16use zeroize::{Zeroize, ZeroizeOnDrop};
17
18use crate::error::{Result, ShieldError};
19
20/// Ratcheting session for forward secrecy.
21///
22/// Each encrypt/decrypt advances the key chain,
23/// destroying previous keys automatically.
24///
25/// Chain keys are securely zeroized from memory when dropped.
26#[derive(Zeroize, ZeroizeOnDrop)]
27pub struct RatchetSession {
28    send_chain: [u8; 32],
29    recv_chain: [u8; 32],
30    #[zeroize(skip)]
31    send_counter: u64,
32    #[zeroize(skip)]
33    recv_counter: u64,
34}
35
36impl RatchetSession {
37    /// Create a new ratchet session from shared root key.
38    ///
39    /// # Arguments
40    /// * `root_key` - Shared secret from key exchange
41    /// * `is_initiator` - True if this party initiated the session
42    #[must_use]
43    pub fn new(root_key: &[u8; 32], is_initiator: bool) -> Self {
44        // Derive separate send/receive chains
45        let (send_label, recv_label) = if is_initiator {
46            (b"send", b"recv")
47        } else {
48            (b"recv", b"send")
49        };
50
51        let send_chain = derive_chain_key(root_key, send_label);
52        let recv_chain = derive_chain_key(root_key, recv_label);
53
54        Self {
55            send_chain,
56            recv_chain,
57            send_counter: 0,
58            recv_counter: 0,
59        }
60    }
61
62    /// Encrypt a message with forward secrecy.
63    ///
64    /// Advances the send chain - previous keys are destroyed.
65    pub fn encrypt(&mut self, plaintext: &[u8]) -> Result<Vec<u8>> {
66        // Ratchet send chain
67        let (new_chain, msg_key) = ratchet_chain(&self.send_chain);
68        self.send_chain = new_chain;
69
70        // Counter for ordering
71        let counter = self.send_counter;
72        self.send_counter += 1;
73
74        // Encrypt with message key
75        encrypt_with_key(&msg_key, plaintext, counter)
76    }
77
78    /// Decrypt a message with forward secrecy.
79    ///
80    /// Advances the receive chain - previous keys are destroyed.
81    pub fn decrypt(&mut self, ciphertext: &[u8]) -> Result<Vec<u8>> {
82        // Ratchet receive chain
83        let (new_chain, msg_key) = ratchet_chain(&self.recv_chain);
84        self.recv_chain = new_chain;
85
86        // Decrypt with message key
87        let (plaintext, counter) = decrypt_with_key(&msg_key, ciphertext)?;
88
89        // Verify counter (replay protection)
90        if counter != self.recv_counter {
91            return Err(ShieldError::RatchetError(format!(
92                "out of order message: expected {}, got {}",
93                self.recv_counter, counter
94            )));
95        }
96        self.recv_counter += 1;
97
98        Ok(plaintext)
99    }
100
101    /// Get send counter (for diagnostics).
102    #[must_use]
103    pub fn send_counter(&self) -> u64 {
104        self.send_counter
105    }
106
107    /// Get receive counter (for diagnostics).
108    #[must_use]
109    pub fn recv_counter(&self) -> u64 {
110        self.recv_counter
111    }
112}
113
114/// Derive chain key from root and label.
115fn derive_chain_key(root: &[u8; 32], label: &[u8]) -> [u8; 32] {
116    let mut data = Vec::with_capacity(root.len() + label.len());
117    data.extend_from_slice(root);
118    data.extend_from_slice(label);
119
120    let hash = digest::digest(&digest::SHA256, &data);
121    let mut result = [0u8; 32];
122    result.copy_from_slice(hash.as_ref());
123    result
124}
125
126/// Ratchet chain forward, returning (`new_chain`, `message_key`).
127fn ratchet_chain(chain_key: &[u8; 32]) -> ([u8; 32], [u8; 32]) {
128    // New chain key
129    let mut chain_data = Vec::with_capacity(chain_key.len() + 5);
130    chain_data.extend_from_slice(chain_key);
131    chain_data.extend_from_slice(b"chain");
132    let new_chain_hash = digest::digest(&digest::SHA256, &chain_data);
133    let mut new_chain = [0u8; 32];
134    new_chain.copy_from_slice(new_chain_hash.as_ref());
135
136    // Message key
137    let mut msg_data = Vec::with_capacity(chain_key.len() + 7);
138    msg_data.extend_from_slice(chain_key);
139    msg_data.extend_from_slice(b"message");
140    let msg_hash = digest::digest(&digest::SHA256, &msg_data);
141    let mut msg_key = [0u8; 32];
142    msg_key.copy_from_slice(msg_hash.as_ref());
143
144    (new_chain, msg_key)
145}
146
147/// Encrypt with message key (includes counter).
148fn encrypt_with_key(key: &[u8; 32], plaintext: &[u8], counter: u64) -> Result<Vec<u8>> {
149    let rng = SystemRandom::new();
150
151    // Generate nonce
152    let mut nonce = [0u8; 16];
153    rng.fill(&mut nonce)
154        .map_err(|_| ShieldError::RandomFailed)?;
155
156    // Counter header
157    let counter_bytes = counter.to_le_bytes();
158
159    // Data: counter || plaintext
160    let mut data = Vec::with_capacity(8 + plaintext.len());
161    data.extend_from_slice(&counter_bytes);
162    data.extend_from_slice(plaintext);
163
164    // Generate keystream
165    let mut keystream = Vec::with_capacity(data.len().div_ceil(32) * 32);
166    for i in 0..data.len().div_ceil(32) {
167        let block_counter = (i as u32).to_le_bytes();
168        let mut hash_input = Vec::with_capacity(key.len() + nonce.len() + 4);
169        hash_input.extend_from_slice(key);
170        hash_input.extend_from_slice(&nonce);
171        hash_input.extend_from_slice(&block_counter);
172        let hash = digest::digest(&digest::SHA256, &hash_input);
173        keystream.extend_from_slice(hash.as_ref());
174    }
175
176    // XOR encrypt
177    let ciphertext: Vec<u8> = data
178        .iter()
179        .zip(keystream.iter())
180        .map(|(p, k)| p ^ k)
181        .collect();
182
183    // HMAC
184    let hmac_key = hmac::Key::new(hmac::HMAC_SHA256, key);
185    let mut hmac_data = Vec::with_capacity(16 + ciphertext.len());
186    hmac_data.extend_from_slice(&nonce);
187    hmac_data.extend_from_slice(&ciphertext);
188    let tag = hmac::sign(&hmac_key, &hmac_data);
189
190    // Format: nonce(16) || ciphertext || mac(16)
191    let mut result = Vec::with_capacity(16 + ciphertext.len() + 16);
192    result.extend_from_slice(&nonce);
193    result.extend_from_slice(&ciphertext);
194    result.extend_from_slice(&tag.as_ref()[..16]);
195
196    Ok(result)
197}
198
199/// Decrypt with message key, returns (plaintext, counter).
200fn decrypt_with_key(key: &[u8; 32], encrypted: &[u8]) -> Result<(Vec<u8>, u64)> {
201    if encrypted.len() < 40 {
202        return Err(ShieldError::RatchetError("ciphertext too short".into()));
203    }
204
205    let nonce = &encrypted[..16];
206    let ciphertext = &encrypted[16..encrypted.len() - 16];
207    let mac = &encrypted[encrypted.len() - 16..];
208
209    // Verify MAC
210    let hmac_key = hmac::Key::new(hmac::HMAC_SHA256, key);
211    let mut hmac_data = Vec::with_capacity(16 + ciphertext.len());
212    hmac_data.extend_from_slice(nonce);
213    hmac_data.extend_from_slice(ciphertext);
214    let expected = hmac::sign(&hmac_key, &hmac_data);
215
216    if mac.ct_eq(&expected.as_ref()[..16]).unwrap_u8() != 1 {
217        return Err(ShieldError::AuthenticationFailed);
218    }
219
220    // Generate keystream
221    let mut keystream = Vec::with_capacity(ciphertext.len().div_ceil(32) * 32);
222    for i in 0..ciphertext.len().div_ceil(32) {
223        let block_counter = (i as u32).to_le_bytes();
224        let mut hash_input = Vec::with_capacity(key.len() + nonce.len() + 4);
225        hash_input.extend_from_slice(key);
226        hash_input.extend_from_slice(nonce);
227        hash_input.extend_from_slice(&block_counter);
228        let hash = digest::digest(&digest::SHA256, &hash_input);
229        keystream.extend_from_slice(hash.as_ref());
230    }
231
232    // XOR decrypt
233    let decrypted: Vec<u8> = ciphertext
234        .iter()
235        .zip(keystream.iter())
236        .map(|(c, k)| c ^ k)
237        .collect();
238
239    // Parse counter
240    let counter = u64::from_le_bytes([
241        decrypted[0],
242        decrypted[1],
243        decrypted[2],
244        decrypted[3],
245        decrypted[4],
246        decrypted[5],
247        decrypted[6],
248        decrypted[7],
249    ]);
250
251    Ok((decrypted[8..].to_vec(), counter))
252}
253
254#[cfg(test)]
255mod tests {
256    use super::*;
257
258    #[test]
259    fn test_ratchet_roundtrip() {
260        let root = [0x42u8; 32];
261        let mut alice = RatchetSession::new(&root, true);
262        let mut bob = RatchetSession::new(&root, false);
263
264        let msg1 = b"Hello Bob!";
265        let enc1 = alice.encrypt(msg1).unwrap();
266        let dec1 = bob.decrypt(&enc1).unwrap();
267        assert_eq!(msg1.as_slice(), dec1.as_slice());
268
269        let msg2 = b"Second message";
270        let enc2 = alice.encrypt(msg2).unwrap();
271        let dec2 = bob.decrypt(&enc2).unwrap();
272        assert_eq!(msg2.as_slice(), dec2.as_slice());
273    }
274
275    #[test]
276    fn test_ratchet_counters() {
277        let root = [0x42u8; 32];
278        let mut alice = RatchetSession::new(&root, true);
279        let mut bob = RatchetSession::new(&root, false);
280
281        assert_eq!(alice.send_counter(), 0);
282        assert_eq!(bob.recv_counter(), 0);
283
284        let enc = alice.encrypt(b"test").unwrap();
285        assert_eq!(alice.send_counter(), 1);
286
287        bob.decrypt(&enc).unwrap();
288        assert_eq!(bob.recv_counter(), 1);
289    }
290
291    #[test]
292    fn test_ratchet_different_ciphertexts() {
293        let root = [0x42u8; 32];
294        let mut alice = RatchetSession::new(&root, true);
295
296        let enc1 = alice.encrypt(b"same message").unwrap();
297        let enc2 = alice.encrypt(b"same message").unwrap();
298
299        // Different ciphertext for same plaintext (forward secrecy)
300        assert_ne!(enc1, enc2);
301    }
302
303    #[test]
304    fn test_ratchet_replay_detection() {
305        let root = [0x42u8; 32];
306        let mut alice = RatchetSession::new(&root, true);
307        let mut bob = RatchetSession::new(&root, false);
308
309        // Send two messages
310        let _enc1 = alice.encrypt(b"first").unwrap();
311        let enc2 = alice.encrypt(b"second").unwrap();
312
313        // Try to decrypt second message first (out of order)
314        // This should fail because Bob expects counter 0, but gets counter 1
315        assert!(bob.decrypt(&enc2).is_err());
316    }
317}