1#![allow(clippy::cast_possible_truncation)]
10
11use ring::{digest, hmac, rand::{SecureRandom, SystemRandom}};
12use subtle::ConstantTimeEq;
13use zeroize::{Zeroize, ZeroizeOnDrop};
14
15use crate::error::{Result, ShieldError};
16
17#[derive(Zeroize, ZeroizeOnDrop)]
24pub struct RatchetSession {
25 send_chain: [u8; 32],
26 recv_chain: [u8; 32],
27 #[zeroize(skip)]
28 send_counter: u64,
29 #[zeroize(skip)]
30 recv_counter: u64,
31}
32
33impl RatchetSession {
34 #[must_use]
40 pub fn new(root_key: &[u8; 32], is_initiator: bool) -> Self {
41 let (send_label, recv_label) = if is_initiator {
43 (b"send", b"recv")
44 } else {
45 (b"recv", b"send")
46 };
47
48 let send_chain = derive_chain_key(root_key, send_label);
49 let recv_chain = derive_chain_key(root_key, recv_label);
50
51 Self {
52 send_chain,
53 recv_chain,
54 send_counter: 0,
55 recv_counter: 0,
56 }
57 }
58
59 pub fn encrypt(&mut self, plaintext: &[u8]) -> Result<Vec<u8>> {
63 let (new_chain, msg_key) = ratchet_chain(&self.send_chain);
65 self.send_chain = new_chain;
66
67 let counter = self.send_counter;
69 self.send_counter += 1;
70
71 encrypt_with_key(&msg_key, plaintext, counter)
73 }
74
75 pub fn decrypt(&mut self, ciphertext: &[u8]) -> Result<Vec<u8>> {
79 let (new_chain, msg_key) = ratchet_chain(&self.recv_chain);
81 self.recv_chain = new_chain;
82
83 let (plaintext, counter) = decrypt_with_key(&msg_key, ciphertext)?;
85
86 if counter != self.recv_counter {
88 return Err(ShieldError::RatchetError(format!(
89 "out of order message: expected {}, got {}",
90 self.recv_counter, counter
91 )));
92 }
93 self.recv_counter += 1;
94
95 Ok(plaintext)
96 }
97
98 #[must_use]
100 pub fn send_counter(&self) -> u64 {
101 self.send_counter
102 }
103
104 #[must_use]
106 pub fn recv_counter(&self) -> u64 {
107 self.recv_counter
108 }
109}
110
111fn derive_chain_key(root: &[u8; 32], label: &[u8]) -> [u8; 32] {
113 let mut data = Vec::with_capacity(root.len() + label.len());
114 data.extend_from_slice(root);
115 data.extend_from_slice(label);
116
117 let hash = digest::digest(&digest::SHA256, &data);
118 let mut result = [0u8; 32];
119 result.copy_from_slice(hash.as_ref());
120 result
121}
122
123fn ratchet_chain(chain_key: &[u8; 32]) -> ([u8; 32], [u8; 32]) {
125 let mut chain_data = Vec::with_capacity(chain_key.len() + 5);
127 chain_data.extend_from_slice(chain_key);
128 chain_data.extend_from_slice(b"chain");
129 let new_chain_hash = digest::digest(&digest::SHA256, &chain_data);
130 let mut new_chain = [0u8; 32];
131 new_chain.copy_from_slice(new_chain_hash.as_ref());
132
133 let mut msg_data = Vec::with_capacity(chain_key.len() + 7);
135 msg_data.extend_from_slice(chain_key);
136 msg_data.extend_from_slice(b"message");
137 let msg_hash = digest::digest(&digest::SHA256, &msg_data);
138 let mut msg_key = [0u8; 32];
139 msg_key.copy_from_slice(msg_hash.as_ref());
140
141 (new_chain, msg_key)
142}
143
144fn encrypt_with_key(key: &[u8; 32], plaintext: &[u8], counter: u64) -> Result<Vec<u8>> {
146 let rng = SystemRandom::new();
147
148 let mut nonce = [0u8; 16];
150 rng.fill(&mut nonce).map_err(|_| ShieldError::RandomFailed)?;
151
152 let counter_bytes = counter.to_le_bytes();
154
155 let mut data = Vec::with_capacity(8 + plaintext.len());
157 data.extend_from_slice(&counter_bytes);
158 data.extend_from_slice(plaintext);
159
160 let mut keystream = Vec::with_capacity(data.len().div_ceil(32) * 32);
162 for i in 0..data.len().div_ceil(32) {
163 let block_counter = (i as u32).to_le_bytes();
164 let mut hash_input = Vec::with_capacity(key.len() + nonce.len() + 4);
165 hash_input.extend_from_slice(key);
166 hash_input.extend_from_slice(&nonce);
167 hash_input.extend_from_slice(&block_counter);
168 let hash = digest::digest(&digest::SHA256, &hash_input);
169 keystream.extend_from_slice(hash.as_ref());
170 }
171
172 let ciphertext: Vec<u8> = data
174 .iter()
175 .zip(keystream.iter())
176 .map(|(p, k)| p ^ k)
177 .collect();
178
179 let hmac_key = hmac::Key::new(hmac::HMAC_SHA256, key);
181 let mut hmac_data = Vec::with_capacity(16 + ciphertext.len());
182 hmac_data.extend_from_slice(&nonce);
183 hmac_data.extend_from_slice(&ciphertext);
184 let tag = hmac::sign(&hmac_key, &hmac_data);
185
186 let mut result = Vec::with_capacity(16 + ciphertext.len() + 16);
188 result.extend_from_slice(&nonce);
189 result.extend_from_slice(&ciphertext);
190 result.extend_from_slice(&tag.as_ref()[..16]);
191
192 Ok(result)
193}
194
195fn decrypt_with_key(key: &[u8; 32], encrypted: &[u8]) -> Result<(Vec<u8>, u64)> {
197 if encrypted.len() < 40 {
198 return Err(ShieldError::RatchetError("ciphertext too short".into()));
199 }
200
201 let nonce = &encrypted[..16];
202 let ciphertext = &encrypted[16..encrypted.len() - 16];
203 let mac = &encrypted[encrypted.len() - 16..];
204
205 let hmac_key = hmac::Key::new(hmac::HMAC_SHA256, key);
207 let mut hmac_data = Vec::with_capacity(16 + ciphertext.len());
208 hmac_data.extend_from_slice(nonce);
209 hmac_data.extend_from_slice(ciphertext);
210 let expected = hmac::sign(&hmac_key, &hmac_data);
211
212 if mac.ct_eq(&expected.as_ref()[..16]).unwrap_u8() != 1 {
213 return Err(ShieldError::AuthenticationFailed);
214 }
215
216 let mut keystream = Vec::with_capacity(ciphertext.len().div_ceil(32) * 32);
218 for i in 0..ciphertext.len().div_ceil(32) {
219 let block_counter = (i as u32).to_le_bytes();
220 let mut hash_input = Vec::with_capacity(key.len() + nonce.len() + 4);
221 hash_input.extend_from_slice(key);
222 hash_input.extend_from_slice(nonce);
223 hash_input.extend_from_slice(&block_counter);
224 let hash = digest::digest(&digest::SHA256, &hash_input);
225 keystream.extend_from_slice(hash.as_ref());
226 }
227
228 let decrypted: Vec<u8> = ciphertext
230 .iter()
231 .zip(keystream.iter())
232 .map(|(c, k)| c ^ k)
233 .collect();
234
235 let counter = u64::from_le_bytes([
237 decrypted[0], decrypted[1], decrypted[2], decrypted[3],
238 decrypted[4], decrypted[5], decrypted[6], decrypted[7],
239 ]);
240
241 Ok((decrypted[8..].to_vec(), counter))
242}
243
244#[cfg(test)]
245mod tests {
246 use super::*;
247
248 #[test]
249 fn test_ratchet_roundtrip() {
250 let root = [0x42u8; 32];
251 let mut alice = RatchetSession::new(&root, true);
252 let mut bob = RatchetSession::new(&root, false);
253
254 let msg1 = b"Hello Bob!";
255 let enc1 = alice.encrypt(msg1).unwrap();
256 let dec1 = bob.decrypt(&enc1).unwrap();
257 assert_eq!(msg1.as_slice(), dec1.as_slice());
258
259 let msg2 = b"Second message";
260 let enc2 = alice.encrypt(msg2).unwrap();
261 let dec2 = bob.decrypt(&enc2).unwrap();
262 assert_eq!(msg2.as_slice(), dec2.as_slice());
263 }
264
265 #[test]
266 fn test_ratchet_counters() {
267 let root = [0x42u8; 32];
268 let mut alice = RatchetSession::new(&root, true);
269 let mut bob = RatchetSession::new(&root, false);
270
271 assert_eq!(alice.send_counter(), 0);
272 assert_eq!(bob.recv_counter(), 0);
273
274 let enc = alice.encrypt(b"test").unwrap();
275 assert_eq!(alice.send_counter(), 1);
276
277 bob.decrypt(&enc).unwrap();
278 assert_eq!(bob.recv_counter(), 1);
279 }
280
281 #[test]
282 fn test_ratchet_different_ciphertexts() {
283 let root = [0x42u8; 32];
284 let mut alice = RatchetSession::new(&root, true);
285
286 let enc1 = alice.encrypt(b"same message").unwrap();
287 let enc2 = alice.encrypt(b"same message").unwrap();
288
289 assert_ne!(enc1, enc2);
291 }
292
293 #[test]
294 fn test_ratchet_replay_detection() {
295 let root = [0x42u8; 32];
296 let mut alice = RatchetSession::new(&root, true);
297 let mut bob = RatchetSession::new(&root, false);
298
299 let _enc1 = alice.encrypt(b"first").unwrap();
301 let enc2 = alice.encrypt(b"second").unwrap();
302
303 assert!(bob.decrypt(&enc2).is_err());
306 }
307}