1#![allow(clippy::cast_possible_truncation)]
7
8use ring::hmac;
9use std::collections::HashMap;
10use subtle::ConstantTimeEq;
11use zeroize::Zeroize;
12
13use crate::error::{Result, ShieldError};
14
15pub struct KeyRotationManager {
17 keys: HashMap<u32, [u8; 32]>,
18 current_version: u32,
19}
20
21impl KeyRotationManager {
22 #[must_use]
24 pub fn new(key: [u8; 32], version: u32) -> Self {
25 let mut keys = HashMap::new();
26 keys.insert(version, key);
27
28 Self {
29 keys,
30 current_version: version,
31 }
32 }
33
34 #[must_use]
36 pub fn current_version(&self) -> u32 {
37 self.current_version
38 }
39
40 #[must_use]
42 pub fn versions(&self) -> Vec<u32> {
43 let mut v: Vec<u32> = self.keys.keys().copied().collect();
44 v.sort_unstable();
45 v
46 }
47
48 pub fn add_key(&mut self, key: [u8; 32], version: u32) -> Result<()> {
50 if self.keys.contains_key(&version) {
51 return Err(ShieldError::VersionExists(version));
52 }
53 self.keys.insert(version, key);
54 Ok(())
55 }
56
57 pub fn rotate(&mut self, new_key: [u8; 32], new_version: Option<u32>) -> Result<u32> {
59 let version = new_version.unwrap_or(self.current_version + 1);
60 if version <= self.current_version {
61 return Err(ShieldError::InvalidVersion);
62 }
63
64 self.keys.insert(version, new_key);
65 self.current_version = version;
66 Ok(version)
67 }
68
69 pub fn encrypt(&self, plaintext: &[u8]) -> Result<Vec<u8>> {
73 let key = self
74 .keys
75 .get(&self.current_version)
76 .ok_or(ShieldError::UnknownVersion(self.current_version))?;
77 let (enc_key, mac_key) = derive_rotation_subkeys(key);
78 let nonce: [u8; 16] = crate::random::random_bytes()?;
79
80 let keystream = generate_keystream(&enc_key, &nonce, plaintext.len())?;
82 let ciphertext: Vec<u8> = plaintext
83 .iter()
84 .zip(keystream.iter())
85 .map(|(p, k)| p ^ k)
86 .collect();
87
88 let version_bytes = self.current_version.to_le_bytes();
90
91 let hmac_signing_key = hmac::Key::new(hmac::HMAC_SHA256, &mac_key);
93 let mut hmac_data = Vec::with_capacity(4 + 16 + ciphertext.len());
94 hmac_data.extend_from_slice(&version_bytes);
95 hmac_data.extend_from_slice(&nonce);
96 hmac_data.extend_from_slice(&ciphertext);
97 let tag = hmac::sign(&hmac_signing_key, &hmac_data);
98
99 let mut result = Vec::with_capacity(4 + 16 + ciphertext.len() + 16);
101 result.extend_from_slice(&version_bytes);
102 result.extend_from_slice(&nonce);
103 result.extend_from_slice(&ciphertext);
104 result.extend_from_slice(&tag.as_ref()[..16]);
105
106 Ok(result)
107 }
108
109 pub fn decrypt(&self, encrypted: &[u8]) -> Result<Vec<u8>> {
111 if encrypted.len() < 36 {
112 return Err(ShieldError::CiphertextTooShort {
113 expected: 36,
114 actual: encrypted.len(),
115 });
116 }
117
118 let version = u32::from_le_bytes(encrypted[..4].try_into().map_err(|_| {
120 ShieldError::CiphertextTooShort {
121 expected: 4,
122 actual: encrypted.len(),
123 }
124 })?);
125 let nonce = &encrypted[4..20];
126 let ciphertext = &encrypted[20..encrypted.len() - 16];
127 let mac = &encrypted[encrypted.len() - 16..];
128
129 let key = self
131 .keys
132 .get(&version)
133 .ok_or(ShieldError::UnknownVersion(version))?;
134 let (enc_key, mac_key) = derive_rotation_subkeys(key);
135
136 let hmac_signing_key = hmac::Key::new(hmac::HMAC_SHA256, &mac_key);
138 let expected_tag = hmac::sign(&hmac_signing_key, &encrypted[..encrypted.len() - 16]);
139
140 if mac.ct_eq(&expected_tag.as_ref()[..16]).unwrap_u8() != 1 {
141 return Err(ShieldError::AuthenticationFailed);
142 }
143
144 let keystream = generate_keystream(&enc_key, nonce, ciphertext.len())?;
146 let plaintext: Vec<u8> = ciphertext
147 .iter()
148 .zip(keystream.iter())
149 .map(|(c, k)| c ^ k)
150 .collect();
151
152 Ok(plaintext)
153 }
154
155 pub fn prune_old_keys(&mut self, keep_versions: usize) -> Vec<u32> {
157 let mut versions = self.versions();
158 versions.reverse(); let mut to_keep: std::collections::HashSet<u32> =
161 versions.iter().take(keep_versions).copied().collect();
162 to_keep.insert(self.current_version);
163
164 let mut pruned = Vec::new();
165 for v in self.keys.keys().copied().collect::<Vec<_>>() {
166 if !to_keep.contains(&v) {
167 if let Some(key) = self.keys.get_mut(&v) {
169 key.zeroize();
170 }
171 self.keys.remove(&v);
172 pruned.push(v);
173 }
174 }
175
176 pruned
177 }
178
179 pub fn re_encrypt(&self, encrypted: &[u8]) -> Result<Vec<u8>> {
181 let plaintext = self.decrypt(encrypted)?;
182 self.encrypt(&plaintext)
183 }
184}
185
186impl Drop for KeyRotationManager {
187 fn drop(&mut self) {
188 for key in self.keys.values_mut() {
189 key.zeroize();
190 }
191 }
192}
193
194fn derive_rotation_subkeys(key: &[u8; 32]) -> ([u8; 32], [u8; 32]) {
196 let hmac_key = hmac::Key::new(hmac::HMAC_SHA256, key);
197
198 let enc_tag = hmac::sign(&hmac_key, b"shield-rotation-encrypt");
199 let mut enc_key = [0u8; 32];
200 enc_key.copy_from_slice(&enc_tag.as_ref()[..32]);
201
202 let mac_tag = hmac::sign(&hmac_key, b"shield-rotation-authenticate");
203 let mut mac_key = [0u8; 32];
204 mac_key.copy_from_slice(&mac_tag.as_ref()[..32]);
205
206 (enc_key, mac_key)
207}
208
209fn generate_keystream(key: &[u8], nonce: &[u8], length: usize) -> Result<Vec<u8>> {
211 let num_blocks = length.div_ceil(32);
212 if u32::try_from(num_blocks).is_err() {
213 return Err(ShieldError::StreamError(
214 "keystream too long: counter overflow".into(),
215 ));
216 }
217 let mut keystream = Vec::with_capacity(num_blocks * 32);
218 let hmac_key = hmac::Key::new(hmac::HMAC_SHA256, key);
219
220 for i in 0..num_blocks {
221 let counter = (i as u32).to_le_bytes();
222 let mut data = Vec::with_capacity(nonce.len() + 4);
223 data.extend_from_slice(nonce);
224 data.extend_from_slice(&counter);
225
226 let tag = hmac::sign(&hmac_key, &data);
227 keystream.extend_from_slice(tag.as_ref());
228 }
229
230 keystream.truncate(length);
231 Ok(keystream)
232}
233
234#[cfg(test)]
235mod tests {
236 use super::*;
237
238 #[test]
239 fn test_encrypt_decrypt() {
240 let key = [42u8; 32];
241 let manager = KeyRotationManager::new(key, 1);
242 let plaintext = b"Hello, Rotation!";
243
244 let encrypted = manager.encrypt(plaintext).unwrap();
245 let decrypted = manager.decrypt(&encrypted).unwrap();
246
247 assert_eq!(plaintext.as_slice(), decrypted.as_slice());
248 }
249
250 #[test]
251 fn test_version_embedded() {
252 let key = [42u8; 32];
253 let manager = KeyRotationManager::new(key, 5);
254 let encrypted = manager.encrypt(b"test").unwrap();
255
256 let version = u32::from_le_bytes(encrypted[..4].try_into().unwrap());
257 assert_eq!(version, 5);
258 }
259
260 #[test]
261 fn test_rotate() {
262 let key1 = [1u8; 32];
263 let mut manager = KeyRotationManager::new(key1, 1);
264 let encrypted1 = manager.encrypt(b"message 1").unwrap();
265
266 let key2 = [2u8; 32];
267 manager.rotate(key2, None).unwrap();
268 assert_eq!(manager.current_version(), 2);
269
270 let encrypted2 = manager.encrypt(b"message 2").unwrap();
271
272 assert_eq!(manager.decrypt(&encrypted1).unwrap(), b"message 1");
274 assert_eq!(manager.decrypt(&encrypted2).unwrap(), b"message 2");
275 }
276
277 #[test]
278 fn test_prune_old_keys() {
279 let mut manager = KeyRotationManager::new([1u8; 32], 1);
280 manager.rotate([2u8; 32], None).unwrap();
281 manager.rotate([3u8; 32], None).unwrap();
282 manager.rotate([4u8; 32], None).unwrap();
283
284 let encrypted = manager.encrypt(b"test").unwrap();
285 let pruned = manager.prune_old_keys(2);
286
287 assert!(!pruned.is_empty());
288 assert_eq!(manager.decrypt(&encrypted).unwrap(), b"test");
289 }
290
291 #[test]
292 fn test_re_encrypt() {
293 let mut manager = KeyRotationManager::new([1u8; 32], 1);
294 let encrypted = manager.encrypt(b"original").unwrap();
295
296 manager.rotate([2u8; 32], None).unwrap();
297 let re_encrypted = manager.re_encrypt(&encrypted).unwrap();
298
299 let version = u32::from_le_bytes(re_encrypted[..4].try_into().unwrap());
300 assert_eq!(version, 2);
301 assert_eq!(manager.decrypt(&re_encrypted).unwrap(), b"original");
302 }
303
304 #[test]
305 fn test_unknown_version() {
306 let manager = KeyRotationManager::new([1u8; 32], 1);
307 let mut encrypted = manager.encrypt(b"test").unwrap();
308
309 encrypted[0] = 99;
311 assert!(manager.decrypt(&encrypted).is_err());
312 }
313}