1#![allow(clippy::cast_possible_truncation)]
10
11use ring::{
12 digest, hmac,
13 rand::{SecureRandom, SystemRandom},
14};
15use subtle::ConstantTimeEq;
16use zeroize::{Zeroize, ZeroizeOnDrop};
17
18use crate::error::{Result, ShieldError};
19
20#[derive(Zeroize, ZeroizeOnDrop)]
27pub struct RatchetSession {
28 send_chain: [u8; 32],
29 recv_chain: [u8; 32],
30 #[zeroize(skip)]
31 send_counter: u64,
32 #[zeroize(skip)]
33 recv_counter: u64,
34}
35
36impl RatchetSession {
37 #[must_use]
43 pub fn new(root_key: &[u8; 32], is_initiator: bool) -> Self {
44 let (send_label, recv_label) = if is_initiator {
46 (b"send", b"recv")
47 } else {
48 (b"recv", b"send")
49 };
50
51 let send_chain = derive_chain_key(root_key, send_label);
52 let recv_chain = derive_chain_key(root_key, recv_label);
53
54 Self {
55 send_chain,
56 recv_chain,
57 send_counter: 0,
58 recv_counter: 0,
59 }
60 }
61
62 pub fn encrypt(&mut self, plaintext: &[u8]) -> Result<Vec<u8>> {
66 let (new_chain, msg_key) = ratchet_chain(&self.send_chain);
68 self.send_chain = new_chain;
69
70 let counter = self.send_counter;
72 self.send_counter += 1;
73
74 encrypt_with_key(&msg_key, plaintext, counter)
76 }
77
78 pub fn decrypt(&mut self, ciphertext: &[u8]) -> Result<Vec<u8>> {
82 let (new_chain, msg_key) = ratchet_chain(&self.recv_chain);
84 self.recv_chain = new_chain;
85
86 let (plaintext, counter) = decrypt_with_key(&msg_key, ciphertext)?;
88
89 if counter != self.recv_counter {
91 return Err(ShieldError::RatchetError(format!(
92 "out of order message: expected {}, got {}",
93 self.recv_counter, counter
94 )));
95 }
96 self.recv_counter += 1;
97
98 Ok(plaintext)
99 }
100
101 #[must_use]
103 pub fn send_counter(&self) -> u64 {
104 self.send_counter
105 }
106
107 #[must_use]
109 pub fn recv_counter(&self) -> u64 {
110 self.recv_counter
111 }
112}
113
114fn derive_chain_key(root: &[u8; 32], label: &[u8]) -> [u8; 32] {
116 let mut data = Vec::with_capacity(root.len() + label.len());
117 data.extend_from_slice(root);
118 data.extend_from_slice(label);
119
120 let hash = digest::digest(&digest::SHA256, &data);
121 let mut result = [0u8; 32];
122 result.copy_from_slice(hash.as_ref());
123 result
124}
125
126fn ratchet_chain(chain_key: &[u8; 32]) -> ([u8; 32], [u8; 32]) {
128 let mut chain_data = Vec::with_capacity(chain_key.len() + 5);
130 chain_data.extend_from_slice(chain_key);
131 chain_data.extend_from_slice(b"chain");
132 let new_chain_hash = digest::digest(&digest::SHA256, &chain_data);
133 let mut new_chain = [0u8; 32];
134 new_chain.copy_from_slice(new_chain_hash.as_ref());
135
136 let mut msg_data = Vec::with_capacity(chain_key.len() + 7);
138 msg_data.extend_from_slice(chain_key);
139 msg_data.extend_from_slice(b"message");
140 let msg_hash = digest::digest(&digest::SHA256, &msg_data);
141 let mut msg_key = [0u8; 32];
142 msg_key.copy_from_slice(msg_hash.as_ref());
143
144 (new_chain, msg_key)
145}
146
147fn encrypt_with_key(key: &[u8; 32], plaintext: &[u8], counter: u64) -> Result<Vec<u8>> {
149 let rng = SystemRandom::new();
150
151 let mut nonce = [0u8; 16];
153 rng.fill(&mut nonce)
154 .map_err(|_| ShieldError::RandomFailed)?;
155
156 let counter_bytes = counter.to_le_bytes();
158
159 let mut data = Vec::with_capacity(8 + plaintext.len());
161 data.extend_from_slice(&counter_bytes);
162 data.extend_from_slice(plaintext);
163
164 let mut keystream = Vec::with_capacity(data.len().div_ceil(32) * 32);
166 for i in 0..data.len().div_ceil(32) {
167 let block_counter = (i as u32).to_le_bytes();
168 let mut hash_input = Vec::with_capacity(key.len() + nonce.len() + 4);
169 hash_input.extend_from_slice(key);
170 hash_input.extend_from_slice(&nonce);
171 hash_input.extend_from_slice(&block_counter);
172 let hash = digest::digest(&digest::SHA256, &hash_input);
173 keystream.extend_from_slice(hash.as_ref());
174 }
175
176 let ciphertext: Vec<u8> = data
178 .iter()
179 .zip(keystream.iter())
180 .map(|(p, k)| p ^ k)
181 .collect();
182
183 let hmac_key = hmac::Key::new(hmac::HMAC_SHA256, key);
185 let mut hmac_data = Vec::with_capacity(16 + ciphertext.len());
186 hmac_data.extend_from_slice(&nonce);
187 hmac_data.extend_from_slice(&ciphertext);
188 let tag = hmac::sign(&hmac_key, &hmac_data);
189
190 let mut result = Vec::with_capacity(16 + ciphertext.len() + 16);
192 result.extend_from_slice(&nonce);
193 result.extend_from_slice(&ciphertext);
194 result.extend_from_slice(&tag.as_ref()[..16]);
195
196 Ok(result)
197}
198
199fn decrypt_with_key(key: &[u8; 32], encrypted: &[u8]) -> Result<(Vec<u8>, u64)> {
201 if encrypted.len() < 40 {
202 return Err(ShieldError::RatchetError("ciphertext too short".into()));
203 }
204
205 let nonce = &encrypted[..16];
206 let ciphertext = &encrypted[16..encrypted.len() - 16];
207 let mac = &encrypted[encrypted.len() - 16..];
208
209 let hmac_key = hmac::Key::new(hmac::HMAC_SHA256, key);
211 let mut hmac_data = Vec::with_capacity(16 + ciphertext.len());
212 hmac_data.extend_from_slice(nonce);
213 hmac_data.extend_from_slice(ciphertext);
214 let expected = hmac::sign(&hmac_key, &hmac_data);
215
216 if mac.ct_eq(&expected.as_ref()[..16]).unwrap_u8() != 1 {
217 return Err(ShieldError::AuthenticationFailed);
218 }
219
220 let mut keystream = Vec::with_capacity(ciphertext.len().div_ceil(32) * 32);
222 for i in 0..ciphertext.len().div_ceil(32) {
223 let block_counter = (i as u32).to_le_bytes();
224 let mut hash_input = Vec::with_capacity(key.len() + nonce.len() + 4);
225 hash_input.extend_from_slice(key);
226 hash_input.extend_from_slice(nonce);
227 hash_input.extend_from_slice(&block_counter);
228 let hash = digest::digest(&digest::SHA256, &hash_input);
229 keystream.extend_from_slice(hash.as_ref());
230 }
231
232 let decrypted: Vec<u8> = ciphertext
234 .iter()
235 .zip(keystream.iter())
236 .map(|(c, k)| c ^ k)
237 .collect();
238
239 let counter = u64::from_le_bytes([
241 decrypted[0],
242 decrypted[1],
243 decrypted[2],
244 decrypted[3],
245 decrypted[4],
246 decrypted[5],
247 decrypted[6],
248 decrypted[7],
249 ]);
250
251 Ok((decrypted[8..].to_vec(), counter))
252}
253
254#[cfg(test)]
255mod tests {
256 use super::*;
257
258 #[test]
259 fn test_ratchet_roundtrip() {
260 let root = [0x42u8; 32];
261 let mut alice = RatchetSession::new(&root, true);
262 let mut bob = RatchetSession::new(&root, false);
263
264 let msg1 = b"Hello Bob!";
265 let enc1 = alice.encrypt(msg1).unwrap();
266 let dec1 = bob.decrypt(&enc1).unwrap();
267 assert_eq!(msg1.as_slice(), dec1.as_slice());
268
269 let msg2 = b"Second message";
270 let enc2 = alice.encrypt(msg2).unwrap();
271 let dec2 = bob.decrypt(&enc2).unwrap();
272 assert_eq!(msg2.as_slice(), dec2.as_slice());
273 }
274
275 #[test]
276 fn test_ratchet_counters() {
277 let root = [0x42u8; 32];
278 let mut alice = RatchetSession::new(&root, true);
279 let mut bob = RatchetSession::new(&root, false);
280
281 assert_eq!(alice.send_counter(), 0);
282 assert_eq!(bob.recv_counter(), 0);
283
284 let enc = alice.encrypt(b"test").unwrap();
285 assert_eq!(alice.send_counter(), 1);
286
287 bob.decrypt(&enc).unwrap();
288 assert_eq!(bob.recv_counter(), 1);
289 }
290
291 #[test]
292 fn test_ratchet_different_ciphertexts() {
293 let root = [0x42u8; 32];
294 let mut alice = RatchetSession::new(&root, true);
295
296 let enc1 = alice.encrypt(b"same message").unwrap();
297 let enc2 = alice.encrypt(b"same message").unwrap();
298
299 assert_ne!(enc1, enc2);
301 }
302
303 #[test]
304 fn test_ratchet_replay_detection() {
305 let root = [0x42u8; 32];
306 let mut alice = RatchetSession::new(&root, true);
307 let mut bob = RatchetSession::new(&root, false);
308
309 let _enc1 = alice.encrypt(b"first").unwrap();
311 let enc2 = alice.encrypt(b"second").unwrap();
312
313 assert!(bob.decrypt(&enc2).is_err());
316 }
317}