1#![allow(clippy::cast_possible_truncation)]
7
8use ring::hmac;
9use ring::rand::{SecureRandom, SystemRandom};
10use std::collections::HashMap;
11use subtle::ConstantTimeEq;
12
13use crate::error::{Result, ShieldError};
14
15pub struct KeyRotationManager {
17 keys: HashMap<u32, [u8; 32]>,
18 current_version: u32,
19}
20
21impl KeyRotationManager {
22 #[must_use]
24 pub fn new(key: [u8; 32], version: u32) -> Self {
25 let mut keys = HashMap::new();
26 keys.insert(version, key);
27
28 Self {
29 keys,
30 current_version: version,
31 }
32 }
33
34 #[must_use]
36 pub fn current_version(&self) -> u32 {
37 self.current_version
38 }
39
40 #[must_use]
42 pub fn versions(&self) -> Vec<u32> {
43 let mut v: Vec<u32> = self.keys.keys().copied().collect();
44 v.sort_unstable();
45 v
46 }
47
48 pub fn add_key(&mut self, key: [u8; 32], version: u32) -> Result<()> {
50 if self.keys.contains_key(&version) {
51 return Err(ShieldError::VersionExists(version));
52 }
53 self.keys.insert(version, key);
54 Ok(())
55 }
56
57 pub fn rotate(&mut self, new_key: [u8; 32], new_version: Option<u32>) -> Result<u32> {
59 let version = new_version.unwrap_or(self.current_version + 1);
60 if version <= self.current_version {
61 return Err(ShieldError::InvalidVersion);
62 }
63
64 self.keys.insert(version, new_key);
65 self.current_version = version;
66 Ok(version)
67 }
68
69 pub fn encrypt(&self, plaintext: &[u8]) -> Result<Vec<u8>> {
73 let key = self.keys.get(&self.current_version).unwrap();
74 let rng = SystemRandom::new();
75
76 let mut nonce = [0u8; 16];
77 rng.fill(&mut nonce).map_err(|_| ShieldError::RandomFailed)?;
78
79 let keystream = generate_keystream(key, &nonce, plaintext.len());
81 let ciphertext: Vec<u8> = plaintext
82 .iter()
83 .zip(keystream.iter())
84 .map(|(p, k)| p ^ k)
85 .collect();
86
87 let version_bytes = self.current_version.to_le_bytes();
89
90 let hmac_key = hmac::Key::new(hmac::HMAC_SHA256, key);
92 let mut hmac_data = Vec::with_capacity(4 + 16 + ciphertext.len());
93 hmac_data.extend_from_slice(&version_bytes);
94 hmac_data.extend_from_slice(&nonce);
95 hmac_data.extend_from_slice(&ciphertext);
96 let tag = hmac::sign(&hmac_key, &hmac_data);
97
98 let mut result = Vec::with_capacity(4 + 16 + ciphertext.len() + 16);
100 result.extend_from_slice(&version_bytes);
101 result.extend_from_slice(&nonce);
102 result.extend_from_slice(&ciphertext);
103 result.extend_from_slice(&tag.as_ref()[..16]);
104
105 Ok(result)
106 }
107
108 pub fn decrypt(&self, encrypted: &[u8]) -> Result<Vec<u8>> {
110 if encrypted.len() < 36 {
111 return Err(ShieldError::CiphertextTooShort {
112 expected: 36,
113 actual: encrypted.len(),
114 });
115 }
116
117 let version = u32::from_le_bytes(encrypted[..4].try_into().unwrap());
119 let nonce = &encrypted[4..20];
120 let ciphertext = &encrypted[20..encrypted.len() - 16];
121 let mac = &encrypted[encrypted.len() - 16..];
122
123 let key = self
125 .keys
126 .get(&version)
127 .ok_or(ShieldError::UnknownVersion(version))?;
128
129 let hmac_key = hmac::Key::new(hmac::HMAC_SHA256, key);
131 let expected_tag = hmac::sign(&hmac_key, &encrypted[..encrypted.len() - 16]);
132
133 if mac.ct_eq(&expected_tag.as_ref()[..16]).unwrap_u8() != 1 {
134 return Err(ShieldError::AuthenticationFailed);
135 }
136
137 let keystream = generate_keystream(key, nonce, ciphertext.len());
139 let plaintext: Vec<u8> = ciphertext
140 .iter()
141 .zip(keystream.iter())
142 .map(|(c, k)| c ^ k)
143 .collect();
144
145 Ok(plaintext)
146 }
147
148 pub fn prune_old_keys(&mut self, keep_versions: usize) -> Vec<u32> {
150 let mut versions = self.versions();
151 versions.reverse(); let mut to_keep: std::collections::HashSet<u32> =
154 versions.iter().take(keep_versions).copied().collect();
155 to_keep.insert(self.current_version);
156
157 let mut pruned = Vec::new();
158 for v in self.keys.keys().copied().collect::<Vec<_>>() {
159 if !to_keep.contains(&v) {
160 self.keys.remove(&v);
161 pruned.push(v);
162 }
163 }
164
165 pruned
166 }
167
168 pub fn re_encrypt(&self, encrypted: &[u8]) -> Result<Vec<u8>> {
170 let plaintext = self.decrypt(encrypted)?;
171 self.encrypt(&plaintext)
172 }
173}
174
175fn generate_keystream(key: &[u8], nonce: &[u8], length: usize) -> Vec<u8> {
177 let mut keystream = Vec::with_capacity(length.div_ceil(32) * 32);
178 let num_blocks = length.div_ceil(32);
179
180 for i in 0..num_blocks {
181 let counter = (i as u32).to_le_bytes();
182 let mut data = Vec::with_capacity(key.len() + nonce.len() + 4);
183 data.extend_from_slice(key);
184 data.extend_from_slice(nonce);
185 data.extend_from_slice(&counter);
186
187 let hash = ring::digest::digest(&ring::digest::SHA256, &data);
188 keystream.extend_from_slice(hash.as_ref());
189 }
190
191 keystream.truncate(length);
192 keystream
193}
194
195#[cfg(test)]
196mod tests {
197 use super::*;
198
199 #[test]
200 fn test_encrypt_decrypt() {
201 let key = [42u8; 32];
202 let manager = KeyRotationManager::new(key, 1);
203 let plaintext = b"Hello, Rotation!";
204
205 let encrypted = manager.encrypt(plaintext).unwrap();
206 let decrypted = manager.decrypt(&encrypted).unwrap();
207
208 assert_eq!(plaintext.as_slice(), decrypted.as_slice());
209 }
210
211 #[test]
212 fn test_version_embedded() {
213 let key = [42u8; 32];
214 let manager = KeyRotationManager::new(key, 5);
215 let encrypted = manager.encrypt(b"test").unwrap();
216
217 let version = u32::from_le_bytes(encrypted[..4].try_into().unwrap());
218 assert_eq!(version, 5);
219 }
220
221 #[test]
222 fn test_rotate() {
223 let key1 = [1u8; 32];
224 let mut manager = KeyRotationManager::new(key1, 1);
225 let encrypted1 = manager.encrypt(b"message 1").unwrap();
226
227 let key2 = [2u8; 32];
228 manager.rotate(key2, None).unwrap();
229 assert_eq!(manager.current_version(), 2);
230
231 let encrypted2 = manager.encrypt(b"message 2").unwrap();
232
233 assert_eq!(manager.decrypt(&encrypted1).unwrap(), b"message 1");
235 assert_eq!(manager.decrypt(&encrypted2).unwrap(), b"message 2");
236 }
237
238 #[test]
239 fn test_prune_old_keys() {
240 let mut manager = KeyRotationManager::new([1u8; 32], 1);
241 manager.rotate([2u8; 32], None).unwrap();
242 manager.rotate([3u8; 32], None).unwrap();
243 manager.rotate([4u8; 32], None).unwrap();
244
245 let encrypted = manager.encrypt(b"test").unwrap();
246 let pruned = manager.prune_old_keys(2);
247
248 assert!(!pruned.is_empty());
249 assert_eq!(manager.decrypt(&encrypted).unwrap(), b"test");
250 }
251
252 #[test]
253 fn test_re_encrypt() {
254 let mut manager = KeyRotationManager::new([1u8; 32], 1);
255 let encrypted = manager.encrypt(b"original").unwrap();
256
257 manager.rotate([2u8; 32], None).unwrap();
258 let re_encrypted = manager.re_encrypt(&encrypted).unwrap();
259
260 let version = u32::from_le_bytes(re_encrypted[..4].try_into().unwrap());
261 assert_eq!(version, 2);
262 assert_eq!(manager.decrypt(&re_encrypted).unwrap(), b"original");
263 }
264
265 #[test]
266 fn test_unknown_version() {
267 let manager = KeyRotationManager::new([1u8; 32], 1);
268 let mut encrypted = manager.encrypt(b"test").unwrap();
269
270 encrypted[0] = 99;
272 assert!(manager.decrypt(&encrypted).is_err());
273 }
274}