1#![allow(clippy::cast_possible_truncation)]
10
11use ring::hmac;
12use subtle::ConstantTimeEq;
13use zeroize::{Zeroize, ZeroizeOnDrop};
14
15use crate::error::{Result, ShieldError};
16
17#[derive(Zeroize, ZeroizeOnDrop)]
24pub struct RatchetSession {
25 send_chain: [u8; 32],
26 recv_chain: [u8; 32],
27 #[zeroize(skip)]
28 send_counter: u64,
29 #[zeroize(skip)]
30 recv_counter: u64,
31}
32
33impl RatchetSession {
34 #[must_use]
40 pub fn new(root_key: &[u8; 32], is_initiator: bool) -> Self {
41 let (send_label, recv_label) = if is_initiator {
43 (b"send", b"recv")
44 } else {
45 (b"recv", b"send")
46 };
47
48 let send_chain = derive_chain_key(root_key, send_label);
49 let recv_chain = derive_chain_key(root_key, recv_label);
50
51 Self {
52 send_chain,
53 recv_chain,
54 send_counter: 0,
55 recv_counter: 0,
56 }
57 }
58
59 pub fn encrypt(&mut self, plaintext: &[u8]) -> Result<Vec<u8>> {
63 let (new_chain, msg_key) = ratchet_chain(&self.send_chain);
65 self.send_chain = new_chain;
66
67 let counter = self.send_counter;
69 self.send_counter += 1;
70
71 encrypt_with_key(&msg_key, plaintext, counter)
73 }
74
75 pub fn decrypt(&mut self, ciphertext: &[u8]) -> Result<Vec<u8>> {
80 let (new_chain, msg_key) = ratchet_chain(&self.recv_chain);
82
83 let (plaintext, counter) = decrypt_with_key(&msg_key, ciphertext)?;
85
86 if counter != self.recv_counter {
88 return Err(ShieldError::RatchetError(format!(
89 "out of order message: expected {}, got {}",
90 self.recv_counter, counter
91 )));
92 }
93
94 self.recv_chain = new_chain;
96 self.recv_counter += 1;
97
98 Ok(plaintext)
99 }
100
101 #[must_use]
103 pub fn send_counter(&self) -> u64 {
104 self.send_counter
105 }
106
107 #[must_use]
109 pub fn recv_counter(&self) -> u64 {
110 self.recv_counter
111 }
112}
113
114fn derive_chain_key(root: &[u8; 32], label: &[u8]) -> [u8; 32] {
116 let hmac_key = hmac::Key::new(hmac::HMAC_SHA256, root);
117 let tag = hmac::sign(&hmac_key, label);
118 let mut result = [0u8; 32];
119 result.copy_from_slice(&tag.as_ref()[..32]);
120 result
121}
122
123fn ratchet_chain(chain_key: &[u8; 32]) -> ([u8; 32], [u8; 32]) {
125 let hmac_key = hmac::Key::new(hmac::HMAC_SHA256, chain_key);
126
127 let new_chain_tag = hmac::sign(&hmac_key, b"chain");
129 let mut new_chain = [0u8; 32];
130 new_chain.copy_from_slice(&new_chain_tag.as_ref()[..32]);
131
132 let msg_tag = hmac::sign(&hmac_key, b"message");
134 let mut msg_key = [0u8; 32];
135 msg_key.copy_from_slice(&msg_tag.as_ref()[..32]);
136
137 (new_chain, msg_key)
138}
139
140fn derive_msg_subkeys(key: &[u8; 32]) -> ([u8; 32], [u8; 32]) {
142 let hmac_key = hmac::Key::new(hmac::HMAC_SHA256, key);
143
144 let enc_tag = hmac::sign(&hmac_key, b"shield-ratchet-encrypt");
145 let mut enc_key = [0u8; 32];
146 enc_key.copy_from_slice(&enc_tag.as_ref()[..32]);
147
148 let mac_tag = hmac::sign(&hmac_key, b"shield-ratchet-authenticate");
149 let mut mac_key = [0u8; 32];
150 mac_key.copy_from_slice(&mac_tag.as_ref()[..32]);
151
152 (enc_key, mac_key)
153}
154
155fn encrypt_with_key(key: &[u8; 32], plaintext: &[u8], counter: u64) -> Result<Vec<u8>> {
157 let (enc_key, mac_key) = derive_msg_subkeys(key);
158
159 let nonce: [u8; 16] = crate::random::random_bytes()?;
161
162 let counter_bytes = counter.to_le_bytes();
164
165 let mut data = Vec::with_capacity(8 + plaintext.len());
167 data.extend_from_slice(&counter_bytes);
168 data.extend_from_slice(plaintext);
169
170 let num_blocks = data.len().div_ceil(32);
172 if u32::try_from(num_blocks).is_err() {
173 return Err(ShieldError::RatchetError(
174 "keystream too long: counter overflow".into(),
175 ));
176 }
177 let hmac_enc_key = hmac::Key::new(hmac::HMAC_SHA256, &enc_key);
178 let mut keystream = Vec::with_capacity(num_blocks * 32);
179 for i in 0..num_blocks {
180 let block_counter = (i as u32).to_le_bytes();
181 let mut block_data = Vec::with_capacity(nonce.len() + 4);
182 block_data.extend_from_slice(&nonce);
183 block_data.extend_from_slice(&block_counter);
184 let tag = hmac::sign(&hmac_enc_key, &block_data);
185 keystream.extend_from_slice(tag.as_ref());
186 }
187
188 let ciphertext: Vec<u8> = data
190 .iter()
191 .zip(keystream.iter())
192 .map(|(p, k)| p ^ k)
193 .collect();
194
195 let hmac_mac_key = hmac::Key::new(hmac::HMAC_SHA256, &mac_key);
197 let mut hmac_data = Vec::with_capacity(16 + ciphertext.len());
198 hmac_data.extend_from_slice(&nonce);
199 hmac_data.extend_from_slice(&ciphertext);
200 let tag = hmac::sign(&hmac_mac_key, &hmac_data);
201
202 let mut result = Vec::with_capacity(16 + ciphertext.len() + 16);
204 result.extend_from_slice(&nonce);
205 result.extend_from_slice(&ciphertext);
206 result.extend_from_slice(&tag.as_ref()[..16]);
207
208 Ok(result)
209}
210
211fn decrypt_with_key(key: &[u8; 32], encrypted: &[u8]) -> Result<(Vec<u8>, u64)> {
213 if encrypted.len() < 40 {
214 return Err(ShieldError::RatchetError("ciphertext too short".into()));
215 }
216
217 let (enc_key, mac_key) = derive_msg_subkeys(key);
218
219 let nonce = &encrypted[..16];
220 let ciphertext = &encrypted[16..encrypted.len() - 16];
221 let mac = &encrypted[encrypted.len() - 16..];
222
223 let hmac_mac_key = hmac::Key::new(hmac::HMAC_SHA256, &mac_key);
225 let mut hmac_data = Vec::with_capacity(16 + ciphertext.len());
226 hmac_data.extend_from_slice(nonce);
227 hmac_data.extend_from_slice(ciphertext);
228 let expected = hmac::sign(&hmac_mac_key, &hmac_data);
229
230 if mac.ct_eq(&expected.as_ref()[..16]).unwrap_u8() != 1 {
231 return Err(ShieldError::AuthenticationFailed);
232 }
233
234 let num_blocks = ciphertext.len().div_ceil(32);
236 if u32::try_from(num_blocks).is_err() {
237 return Err(ShieldError::RatchetError(
238 "keystream too long: counter overflow".into(),
239 ));
240 }
241 let hmac_enc_key = hmac::Key::new(hmac::HMAC_SHA256, &enc_key);
242 let mut keystream = Vec::with_capacity(num_blocks * 32);
243 for i in 0..num_blocks {
244 let block_counter = (i as u32).to_le_bytes();
245 let mut block_data = Vec::with_capacity(nonce.len() + 4);
246 block_data.extend_from_slice(nonce);
247 block_data.extend_from_slice(&block_counter);
248 let tag = hmac::sign(&hmac_enc_key, &block_data);
249 keystream.extend_from_slice(tag.as_ref());
250 }
251
252 let decrypted: Vec<u8> = ciphertext
254 .iter()
255 .zip(keystream.iter())
256 .map(|(c, k)| c ^ k)
257 .collect();
258
259 let counter = u64::from_le_bytes([
261 decrypted[0],
262 decrypted[1],
263 decrypted[2],
264 decrypted[3],
265 decrypted[4],
266 decrypted[5],
267 decrypted[6],
268 decrypted[7],
269 ]);
270
271 Ok((decrypted[8..].to_vec(), counter))
272}
273
274#[cfg(test)]
275mod tests {
276 use super::*;
277
278 #[test]
279 fn test_ratchet_roundtrip() {
280 let root = [0x42u8; 32];
281 let mut alice = RatchetSession::new(&root, true);
282 let mut bob = RatchetSession::new(&root, false);
283
284 let msg1 = b"Hello Bob!";
285 let enc1 = alice.encrypt(msg1).unwrap();
286 let dec1 = bob.decrypt(&enc1).unwrap();
287 assert_eq!(msg1.as_slice(), dec1.as_slice());
288
289 let msg2 = b"Second message";
290 let enc2 = alice.encrypt(msg2).unwrap();
291 let dec2 = bob.decrypt(&enc2).unwrap();
292 assert_eq!(msg2.as_slice(), dec2.as_slice());
293 }
294
295 #[test]
296 fn test_ratchet_counters() {
297 let root = [0x42u8; 32];
298 let mut alice = RatchetSession::new(&root, true);
299 let mut bob = RatchetSession::new(&root, false);
300
301 assert_eq!(alice.send_counter(), 0);
302 assert_eq!(bob.recv_counter(), 0);
303
304 let enc = alice.encrypt(b"test").unwrap();
305 assert_eq!(alice.send_counter(), 1);
306
307 bob.decrypt(&enc).unwrap();
308 assert_eq!(bob.recv_counter(), 1);
309 }
310
311 #[test]
312 fn test_ratchet_different_ciphertexts() {
313 let root = [0x42u8; 32];
314 let mut alice = RatchetSession::new(&root, true);
315
316 let enc1 = alice.encrypt(b"same message").unwrap();
317 let enc2 = alice.encrypt(b"same message").unwrap();
318
319 assert_ne!(enc1, enc2);
321 }
322
323 #[test]
324 fn test_ratchet_survives_forged_packet() {
325 let root = [0x42u8; 32];
326 let mut alice = RatchetSession::new(&root, true);
327 let mut bob = RatchetSession::new(&root, false);
328
329 let enc1 = alice.encrypt(b"first").unwrap();
331
332 let forged = vec![0xFFu8; 64];
334 assert!(bob.decrypt(&forged).is_err());
335
336 let dec1 = bob.decrypt(&enc1).unwrap();
338 assert_eq!(b"first".as_slice(), dec1.as_slice());
339
340 let enc2 = alice.encrypt(b"second").unwrap();
342 let dec2 = bob.decrypt(&enc2).unwrap();
343 assert_eq!(b"second".as_slice(), dec2.as_slice());
344 }
345
346 #[test]
347 fn test_ratchet_replay_detection() {
348 let root = [0x42u8; 32];
349 let mut alice = RatchetSession::new(&root, true);
350 let mut bob = RatchetSession::new(&root, false);
351
352 let _enc1 = alice.encrypt(b"first").unwrap();
354 let enc2 = alice.encrypt(b"second").unwrap();
355
356 assert!(bob.decrypt(&enc2).is_err());
359 }
360}