Skip to main content

shield_core/
rotation.rs

1//! Key rotation with version tagging.
2//!
3//! Manages multiple key versions for seamless rotation.
4
5// Crypto block counters are intentionally u32 - data >4GB would have other issues
6#![allow(clippy::cast_possible_truncation)]
7
8use ring::hmac;
9use std::collections::HashMap;
10use subtle::ConstantTimeEq;
11use zeroize::Zeroize;
12
13use crate::error::{Result, ShieldError};
14
15/// Key rotation manager.
16pub struct KeyRotationManager {
17    keys: HashMap<u32, [u8; 32]>,
18    current_version: u32,
19}
20
21impl KeyRotationManager {
22    /// Create new manager with initial key.
23    #[must_use]
24    pub fn new(key: [u8; 32], version: u32) -> Self {
25        let mut keys = HashMap::new();
26        keys.insert(version, key);
27
28        Self {
29            keys,
30            current_version: version,
31        }
32    }
33
34    /// Get current key version.
35    #[must_use]
36    pub fn current_version(&self) -> u32 {
37        self.current_version
38    }
39
40    /// Get all available versions.
41    #[must_use]
42    pub fn versions(&self) -> Vec<u32> {
43        let mut v: Vec<u32> = self.keys.keys().copied().collect();
44        v.sort_unstable();
45        v
46    }
47
48    /// Add a historical key.
49    pub fn add_key(&mut self, key: [u8; 32], version: u32) -> Result<()> {
50        if self.keys.contains_key(&version) {
51            return Err(ShieldError::VersionExists(version));
52        }
53        self.keys.insert(version, key);
54        Ok(())
55    }
56
57    /// Rotate to new key.
58    pub fn rotate(&mut self, new_key: [u8; 32], new_version: Option<u32>) -> Result<u32> {
59        let version = new_version.unwrap_or(self.current_version + 1);
60        if version <= self.current_version {
61            return Err(ShieldError::InvalidVersion);
62        }
63
64        self.keys.insert(version, new_key);
65        self.current_version = version;
66        Ok(version)
67    }
68
69    /// Encrypt with current key.
70    ///
71    /// Format: version(4) || nonce(16) || ciphertext || mac(16)
72    pub fn encrypt(&self, plaintext: &[u8]) -> Result<Vec<u8>> {
73        let key = self
74            .keys
75            .get(&self.current_version)
76            .ok_or(ShieldError::UnknownVersion(self.current_version))?;
77        let (enc_key, mac_key) = derive_rotation_subkeys(key);
78        let nonce: [u8; 16] = crate::random::random_bytes()?;
79
80        // Generate keystream with enc_key
81        let keystream = generate_keystream(&enc_key, &nonce, plaintext.len())?;
82        let ciphertext: Vec<u8> = plaintext
83            .iter()
84            .zip(keystream.iter())
85            .map(|(p, k)| p ^ k)
86            .collect();
87
88        // Version bytes
89        let version_bytes = self.current_version.to_le_bytes();
90
91        // HMAC over version || nonce || ciphertext with mac_key
92        let hmac_signing_key = hmac::Key::new(hmac::HMAC_SHA256, &mac_key);
93        let mut hmac_data = Vec::with_capacity(4 + 16 + ciphertext.len());
94        hmac_data.extend_from_slice(&version_bytes);
95        hmac_data.extend_from_slice(&nonce);
96        hmac_data.extend_from_slice(&ciphertext);
97        let tag = hmac::sign(&hmac_signing_key, &hmac_data);
98
99        // Result: version || nonce || ciphertext || mac
100        let mut result = Vec::with_capacity(4 + 16 + ciphertext.len() + 16);
101        result.extend_from_slice(&version_bytes);
102        result.extend_from_slice(&nonce);
103        result.extend_from_slice(&ciphertext);
104        result.extend_from_slice(&tag.as_ref()[..16]);
105
106        Ok(result)
107    }
108
109    /// Decrypt with appropriate key version.
110    pub fn decrypt(&self, encrypted: &[u8]) -> Result<Vec<u8>> {
111        if encrypted.len() < 36 {
112            return Err(ShieldError::CiphertextTooShort {
113                expected: 36,
114                actual: encrypted.len(),
115            });
116        }
117
118        // Parse components
119        let version = u32::from_le_bytes(encrypted[..4].try_into().map_err(|_| {
120            ShieldError::CiphertextTooShort {
121                expected: 4,
122                actual: encrypted.len(),
123            }
124        })?);
125        let nonce = &encrypted[4..20];
126        let ciphertext = &encrypted[20..encrypted.len() - 16];
127        let mac = &encrypted[encrypted.len() - 16..];
128
129        // Get key for version
130        let key = self
131            .keys
132            .get(&version)
133            .ok_or(ShieldError::UnknownVersion(version))?;
134        let (enc_key, mac_key) = derive_rotation_subkeys(key);
135
136        // Verify MAC with mac_key
137        let hmac_signing_key = hmac::Key::new(hmac::HMAC_SHA256, &mac_key);
138        let expected_tag = hmac::sign(&hmac_signing_key, &encrypted[..encrypted.len() - 16]);
139
140        if mac.ct_eq(&expected_tag.as_ref()[..16]).unwrap_u8() != 1 {
141            return Err(ShieldError::AuthenticationFailed);
142        }
143
144        // Decrypt with enc_key
145        let keystream = generate_keystream(&enc_key, nonce, ciphertext.len())?;
146        let plaintext: Vec<u8> = ciphertext
147            .iter()
148            .zip(keystream.iter())
149            .map(|(c, k)| c ^ k)
150            .collect();
151
152        Ok(plaintext)
153    }
154
155    /// Prune old keys, keeping specified number of most recent.
156    pub fn prune_old_keys(&mut self, keep_versions: usize) -> Vec<u32> {
157        let mut versions = self.versions();
158        versions.reverse(); // Most recent first
159
160        let mut to_keep: std::collections::HashSet<u32> =
161            versions.iter().take(keep_versions).copied().collect();
162        to_keep.insert(self.current_version);
163
164        let mut pruned = Vec::new();
165        for v in self.keys.keys().copied().collect::<Vec<_>>() {
166            if !to_keep.contains(&v) {
167                // Zeroize key material before removing from map
168                if let Some(key) = self.keys.get_mut(&v) {
169                    key.zeroize();
170                }
171                self.keys.remove(&v);
172                pruned.push(v);
173            }
174        }
175
176        pruned
177    }
178
179    /// Re-encrypt data with current key.
180    pub fn re_encrypt(&self, encrypted: &[u8]) -> Result<Vec<u8>> {
181        let plaintext = self.decrypt(encrypted)?;
182        self.encrypt(&plaintext)
183    }
184}
185
186impl Drop for KeyRotationManager {
187    fn drop(&mut self) {
188        for key in self.keys.values_mut() {
189            key.zeroize();
190        }
191    }
192}
193
194/// Derive separated encryption and MAC subkeys from a rotation key using HMAC-SHA256.
195fn derive_rotation_subkeys(key: &[u8; 32]) -> ([u8; 32], [u8; 32]) {
196    let hmac_key = hmac::Key::new(hmac::HMAC_SHA256, key);
197
198    let enc_tag = hmac::sign(&hmac_key, b"shield-rotation-encrypt");
199    let mut enc_key = [0u8; 32];
200    enc_key.copy_from_slice(&enc_tag.as_ref()[..32]);
201
202    let mac_tag = hmac::sign(&hmac_key, b"shield-rotation-authenticate");
203    let mut mac_key = [0u8; 32];
204    mac_key.copy_from_slice(&mac_tag.as_ref()[..32]);
205
206    (enc_key, mac_key)
207}
208
209/// Generate keystream using HMAC-SHA256 (keyed PRF).
210fn generate_keystream(key: &[u8], nonce: &[u8], length: usize) -> Result<Vec<u8>> {
211    let num_blocks = length.div_ceil(32);
212    if u32::try_from(num_blocks).is_err() {
213        return Err(ShieldError::StreamError(
214            "keystream too long: counter overflow".into(),
215        ));
216    }
217    let mut keystream = Vec::with_capacity(num_blocks * 32);
218    let hmac_key = hmac::Key::new(hmac::HMAC_SHA256, key);
219
220    for i in 0..num_blocks {
221        let counter = (i as u32).to_le_bytes();
222        let mut data = Vec::with_capacity(nonce.len() + 4);
223        data.extend_from_slice(nonce);
224        data.extend_from_slice(&counter);
225
226        let tag = hmac::sign(&hmac_key, &data);
227        keystream.extend_from_slice(tag.as_ref());
228    }
229
230    keystream.truncate(length);
231    Ok(keystream)
232}
233
234#[cfg(test)]
235mod tests {
236    use super::*;
237
238    #[test]
239    fn test_encrypt_decrypt() {
240        let key = [42u8; 32];
241        let manager = KeyRotationManager::new(key, 1);
242        let plaintext = b"Hello, Rotation!";
243
244        let encrypted = manager.encrypt(plaintext).unwrap();
245        let decrypted = manager.decrypt(&encrypted).unwrap();
246
247        assert_eq!(plaintext.as_slice(), decrypted.as_slice());
248    }
249
250    #[test]
251    fn test_version_embedded() {
252        let key = [42u8; 32];
253        let manager = KeyRotationManager::new(key, 5);
254        let encrypted = manager.encrypt(b"test").unwrap();
255
256        let version = u32::from_le_bytes(encrypted[..4].try_into().unwrap());
257        assert_eq!(version, 5);
258    }
259
260    #[test]
261    fn test_rotate() {
262        let key1 = [1u8; 32];
263        let mut manager = KeyRotationManager::new(key1, 1);
264        let encrypted1 = manager.encrypt(b"message 1").unwrap();
265
266        let key2 = [2u8; 32];
267        manager.rotate(key2, None).unwrap();
268        assert_eq!(manager.current_version(), 2);
269
270        let encrypted2 = manager.encrypt(b"message 2").unwrap();
271
272        // Both decrypt
273        assert_eq!(manager.decrypt(&encrypted1).unwrap(), b"message 1");
274        assert_eq!(manager.decrypt(&encrypted2).unwrap(), b"message 2");
275    }
276
277    #[test]
278    fn test_prune_old_keys() {
279        let mut manager = KeyRotationManager::new([1u8; 32], 1);
280        manager.rotate([2u8; 32], None).unwrap();
281        manager.rotate([3u8; 32], None).unwrap();
282        manager.rotate([4u8; 32], None).unwrap();
283
284        let encrypted = manager.encrypt(b"test").unwrap();
285        let pruned = manager.prune_old_keys(2);
286
287        assert!(!pruned.is_empty());
288        assert_eq!(manager.decrypt(&encrypted).unwrap(), b"test");
289    }
290
291    #[test]
292    fn test_re_encrypt() {
293        let mut manager = KeyRotationManager::new([1u8; 32], 1);
294        let encrypted = manager.encrypt(b"original").unwrap();
295
296        manager.rotate([2u8; 32], None).unwrap();
297        let re_encrypted = manager.re_encrypt(&encrypted).unwrap();
298
299        let version = u32::from_le_bytes(re_encrypted[..4].try_into().unwrap());
300        assert_eq!(version, 2);
301        assert_eq!(manager.decrypt(&re_encrypted).unwrap(), b"original");
302    }
303
304    #[test]
305    fn test_unknown_version() {
306        let manager = KeyRotationManager::new([1u8; 32], 1);
307        let mut encrypted = manager.encrypt(b"test").unwrap();
308
309        // Corrupt version
310        encrypted[0] = 99;
311        assert!(manager.decrypt(&encrypted).is_err());
312    }
313}