shield_core/
rotation.rs

1//! Key rotation with version tagging.
2//!
3//! Manages multiple key versions for seamless rotation.
4
5// Crypto block counters are intentionally u32 - data >4GB would have other issues
6#![allow(clippy::cast_possible_truncation)]
7
8use ring::hmac;
9use ring::rand::{SecureRandom, SystemRandom};
10use std::collections::HashMap;
11use subtle::ConstantTimeEq;
12
13use crate::error::{Result, ShieldError};
14
15/// Key rotation manager.
16pub struct KeyRotationManager {
17    keys: HashMap<u32, [u8; 32]>,
18    current_version: u32,
19}
20
21impl KeyRotationManager {
22    /// Create new manager with initial key.
23    #[must_use] 
24    pub fn new(key: [u8; 32], version: u32) -> Self {
25        let mut keys = HashMap::new();
26        keys.insert(version, key);
27
28        Self {
29            keys,
30            current_version: version,
31        }
32    }
33
34    /// Get current key version.
35    #[must_use] 
36    pub fn current_version(&self) -> u32 {
37        self.current_version
38    }
39
40    /// Get all available versions.
41    #[must_use] 
42    pub fn versions(&self) -> Vec<u32> {
43        let mut v: Vec<u32> = self.keys.keys().copied().collect();
44        v.sort_unstable();
45        v
46    }
47
48    /// Add a historical key.
49    pub fn add_key(&mut self, key: [u8; 32], version: u32) -> Result<()> {
50        if self.keys.contains_key(&version) {
51            return Err(ShieldError::VersionExists(version));
52        }
53        self.keys.insert(version, key);
54        Ok(())
55    }
56
57    /// Rotate to new key.
58    pub fn rotate(&mut self, new_key: [u8; 32], new_version: Option<u32>) -> Result<u32> {
59        let version = new_version.unwrap_or(self.current_version + 1);
60        if version <= self.current_version {
61            return Err(ShieldError::InvalidVersion);
62        }
63
64        self.keys.insert(version, new_key);
65        self.current_version = version;
66        Ok(version)
67    }
68
69    /// Encrypt with current key.
70    ///
71    /// Format: version(4) || nonce(16) || ciphertext || mac(16)
72    pub fn encrypt(&self, plaintext: &[u8]) -> Result<Vec<u8>> {
73        let key = self.keys.get(&self.current_version).unwrap();
74        let rng = SystemRandom::new();
75
76        let mut nonce = [0u8; 16];
77        rng.fill(&mut nonce).map_err(|_| ShieldError::RandomFailed)?;
78
79        // Generate keystream
80        let keystream = generate_keystream(key, &nonce, plaintext.len());
81        let ciphertext: Vec<u8> = plaintext
82            .iter()
83            .zip(keystream.iter())
84            .map(|(p, k)| p ^ k)
85            .collect();
86
87        // Version bytes
88        let version_bytes = self.current_version.to_le_bytes();
89
90        // HMAC over version || nonce || ciphertext
91        let hmac_key = hmac::Key::new(hmac::HMAC_SHA256, key);
92        let mut hmac_data = Vec::with_capacity(4 + 16 + ciphertext.len());
93        hmac_data.extend_from_slice(&version_bytes);
94        hmac_data.extend_from_slice(&nonce);
95        hmac_data.extend_from_slice(&ciphertext);
96        let tag = hmac::sign(&hmac_key, &hmac_data);
97
98        // Result: version || nonce || ciphertext || mac
99        let mut result = Vec::with_capacity(4 + 16 + ciphertext.len() + 16);
100        result.extend_from_slice(&version_bytes);
101        result.extend_from_slice(&nonce);
102        result.extend_from_slice(&ciphertext);
103        result.extend_from_slice(&tag.as_ref()[..16]);
104
105        Ok(result)
106    }
107
108    /// Decrypt with appropriate key version.
109    pub fn decrypt(&self, encrypted: &[u8]) -> Result<Vec<u8>> {
110        if encrypted.len() < 36 {
111            return Err(ShieldError::CiphertextTooShort {
112                expected: 36,
113                actual: encrypted.len(),
114            });
115        }
116
117        // Parse components
118        let version = u32::from_le_bytes(encrypted[..4].try_into().unwrap());
119        let nonce = &encrypted[4..20];
120        let ciphertext = &encrypted[20..encrypted.len() - 16];
121        let mac = &encrypted[encrypted.len() - 16..];
122
123        // Get key for version
124        let key = self
125            .keys
126            .get(&version)
127            .ok_or(ShieldError::UnknownVersion(version))?;
128
129        // Verify MAC
130        let hmac_key = hmac::Key::new(hmac::HMAC_SHA256, key);
131        let expected_tag = hmac::sign(&hmac_key, &encrypted[..encrypted.len() - 16]);
132
133        if mac.ct_eq(&expected_tag.as_ref()[..16]).unwrap_u8() != 1 {
134            return Err(ShieldError::AuthenticationFailed);
135        }
136
137        // Decrypt
138        let keystream = generate_keystream(key, nonce, ciphertext.len());
139        let plaintext: Vec<u8> = ciphertext
140            .iter()
141            .zip(keystream.iter())
142            .map(|(c, k)| c ^ k)
143            .collect();
144
145        Ok(plaintext)
146    }
147
148    /// Prune old keys, keeping specified number of most recent.
149    pub fn prune_old_keys(&mut self, keep_versions: usize) -> Vec<u32> {
150        let mut versions = self.versions();
151        versions.reverse(); // Most recent first
152
153        let mut to_keep: std::collections::HashSet<u32> =
154            versions.iter().take(keep_versions).copied().collect();
155        to_keep.insert(self.current_version);
156
157        let mut pruned = Vec::new();
158        for v in self.keys.keys().copied().collect::<Vec<_>>() {
159            if !to_keep.contains(&v) {
160                self.keys.remove(&v);
161                pruned.push(v);
162            }
163        }
164
165        pruned
166    }
167
168    /// Re-encrypt data with current key.
169    pub fn re_encrypt(&self, encrypted: &[u8]) -> Result<Vec<u8>> {
170        let plaintext = self.decrypt(encrypted)?;
171        self.encrypt(&plaintext)
172    }
173}
174
175/// Generate keystream using SHA256.
176fn generate_keystream(key: &[u8], nonce: &[u8], length: usize) -> Vec<u8> {
177    let mut keystream = Vec::with_capacity(length.div_ceil(32) * 32);
178    let num_blocks = length.div_ceil(32);
179
180    for i in 0..num_blocks {
181        let counter = (i as u32).to_le_bytes();
182        let mut data = Vec::with_capacity(key.len() + nonce.len() + 4);
183        data.extend_from_slice(key);
184        data.extend_from_slice(nonce);
185        data.extend_from_slice(&counter);
186
187        let hash = ring::digest::digest(&ring::digest::SHA256, &data);
188        keystream.extend_from_slice(hash.as_ref());
189    }
190
191    keystream.truncate(length);
192    keystream
193}
194
195#[cfg(test)]
196mod tests {
197    use super::*;
198
199    #[test]
200    fn test_encrypt_decrypt() {
201        let key = [42u8; 32];
202        let manager = KeyRotationManager::new(key, 1);
203        let plaintext = b"Hello, Rotation!";
204
205        let encrypted = manager.encrypt(plaintext).unwrap();
206        let decrypted = manager.decrypt(&encrypted).unwrap();
207
208        assert_eq!(plaintext.as_slice(), decrypted.as_slice());
209    }
210
211    #[test]
212    fn test_version_embedded() {
213        let key = [42u8; 32];
214        let manager = KeyRotationManager::new(key, 5);
215        let encrypted = manager.encrypt(b"test").unwrap();
216
217        let version = u32::from_le_bytes(encrypted[..4].try_into().unwrap());
218        assert_eq!(version, 5);
219    }
220
221    #[test]
222    fn test_rotate() {
223        let key1 = [1u8; 32];
224        let mut manager = KeyRotationManager::new(key1, 1);
225        let encrypted1 = manager.encrypt(b"message 1").unwrap();
226
227        let key2 = [2u8; 32];
228        manager.rotate(key2, None).unwrap();
229        assert_eq!(manager.current_version(), 2);
230
231        let encrypted2 = manager.encrypt(b"message 2").unwrap();
232
233        // Both decrypt
234        assert_eq!(manager.decrypt(&encrypted1).unwrap(), b"message 1");
235        assert_eq!(manager.decrypt(&encrypted2).unwrap(), b"message 2");
236    }
237
238    #[test]
239    fn test_prune_old_keys() {
240        let mut manager = KeyRotationManager::new([1u8; 32], 1);
241        manager.rotate([2u8; 32], None).unwrap();
242        manager.rotate([3u8; 32], None).unwrap();
243        manager.rotate([4u8; 32], None).unwrap();
244
245        let encrypted = manager.encrypt(b"test").unwrap();
246        let pruned = manager.prune_old_keys(2);
247
248        assert!(!pruned.is_empty());
249        assert_eq!(manager.decrypt(&encrypted).unwrap(), b"test");
250    }
251
252    #[test]
253    fn test_re_encrypt() {
254        let mut manager = KeyRotationManager::new([1u8; 32], 1);
255        let encrypted = manager.encrypt(b"original").unwrap();
256
257        manager.rotate([2u8; 32], None).unwrap();
258        let re_encrypted = manager.re_encrypt(&encrypted).unwrap();
259
260        let version = u32::from_le_bytes(re_encrypted[..4].try_into().unwrap());
261        assert_eq!(version, 2);
262        assert_eq!(manager.decrypt(&re_encrypted).unwrap(), b"original");
263    }
264
265    #[test]
266    fn test_unknown_version() {
267        let manager = KeyRotationManager::new([1u8; 32], 1);
268        let mut encrypted = manager.encrypt(b"test").unwrap();
269
270        // Corrupt version
271        encrypted[0] = 99;
272        assert!(manager.decrypt(&encrypted).is_err());
273    }
274}