rust_crypto_utils/
keywrap.rs

1//! Key wrapping module for secure key storage v2.0
2//!
3//! Provides AES key wrapping (RFC 3394) for secure key hierarchies.
4
5use aes_gcm::{
6    aead::{Aead, KeyInit, OsRng},
7    Aes256Gcm, Nonce,
8};
9use rand::RngCore;
10use serde::{Deserialize, Serialize};
11use thiserror::Error;
12use zeroize::{Zeroize, ZeroizeOnDrop};
13
14/// Key wrapping errors
15#[derive(Error, Debug)]
16pub enum KeyWrapError {
17    #[error("Wrapping failed: {0}")]
18    WrapFailed(String),
19
20    #[error("Unwrapping failed: {0}")]
21    UnwrapFailed(String),
22
23    #[error("Invalid key length: expected {expected}, got {actual}")]
24    InvalidKeyLength { expected: usize, actual: usize },
25
26    #[error("Invalid wrapped key format")]
27    InvalidFormat,
28}
29
30/// Key Encryption Key (KEK) for wrapping/unwrapping
31#[derive(Zeroize, ZeroizeOnDrop)]
32pub struct KeyEncryptionKey {
33    key: Vec<u8>,
34}
35
36impl KeyEncryptionKey {
37    /// Generate a new random KEK
38    pub fn generate() -> Self {
39        let mut key = vec![0u8; 32];
40        OsRng.fill_bytes(&mut key);
41        Self { key }
42    }
43
44    /// Create from existing bytes
45    pub fn from_bytes(bytes: &[u8]) -> Result<Self, KeyWrapError> {
46        if bytes.len() != 32 {
47            return Err(KeyWrapError::InvalidKeyLength {
48                expected: 32,
49                actual: bytes.len(),
50            });
51        }
52        Ok(Self {
53            key: bytes.to_vec(),
54        })
55    }
56
57    /// Get key bytes
58    pub fn as_bytes(&self) -> &[u8] {
59        &self.key
60    }
61}
62
63/// Wrapped key structure
64#[derive(Clone, Serialize, Deserialize)]
65pub struct WrappedKey {
66    pub ciphertext: Vec<u8>,
67    pub nonce: [u8; 12],
68    pub key_id: String,
69    pub algorithm: String,
70    pub wrapped_at: chrono::DateTime<chrono::Utc>,
71}
72
73impl WrappedKey {
74    /// Get wrapped key as hex
75    pub fn to_hex(&self) -> String {
76        hex::encode(&self.ciphertext)
77    }
78
79    /// Get nonce as hex
80    pub fn nonce_hex(&self) -> String {
81        hex::encode(self.nonce)
82    }
83
84    /// Export as JSON
85    pub fn to_json(&self) -> Result<String, serde_json::Error> {
86        serde_json::to_string_pretty(self)
87    }
88}
89
90/// Key wrapper using AES-256-GCM
91pub struct KeyWrapper {
92    cipher: Aes256Gcm,
93}
94
95impl KeyWrapper {
96    /// Create a new key wrapper with the given KEK
97    pub fn new(kek: &KeyEncryptionKey) -> Result<Self, KeyWrapError> {
98        let cipher = Aes256Gcm::new_from_slice(kek.as_bytes())
99            .map_err(|e| KeyWrapError::WrapFailed(e.to_string()))?;
100        Ok(Self { cipher })
101    }
102
103    /// Wrap a key
104    pub fn wrap(&self, key: &[u8], key_id: &str) -> Result<WrappedKey, KeyWrapError> {
105        let mut nonce_bytes = [0u8; 12];
106        OsRng.fill_bytes(&mut nonce_bytes);
107        let nonce = Nonce::from_slice(&nonce_bytes);
108
109        let ciphertext = self
110            .cipher
111            .encrypt(nonce, key)
112            .map_err(|e| KeyWrapError::WrapFailed(e.to_string()))?;
113
114        Ok(WrappedKey {
115            ciphertext,
116            nonce: nonce_bytes,
117            key_id: key_id.to_string(),
118            algorithm: "AES-256-GCM".to_string(),
119            wrapped_at: chrono::Utc::now(),
120        })
121    }
122
123    /// Unwrap a key
124    pub fn unwrap(&self, wrapped: &WrappedKey) -> Result<Vec<u8>, KeyWrapError> {
125        let nonce = Nonce::from_slice(&wrapped.nonce);
126
127        let plaintext = self
128            .cipher
129            .decrypt(nonce, wrapped.ciphertext.as_ref())
130            .map_err(|e| KeyWrapError::UnwrapFailed(e.to_string()))?;
131
132        Ok(plaintext)
133    }
134
135    /// Rewrap a key with a new KEK
136    pub fn rewrap(
137        &self,
138        wrapped: &WrappedKey,
139        new_wrapper: &KeyWrapper,
140    ) -> Result<WrappedKey, KeyWrapError> {
141        let key = self.unwrap(wrapped)?;
142        new_wrapper.wrap(&key, &wrapped.key_id)
143    }
144}
145
146/// Key hierarchy manager for multi-level key wrapping
147pub struct KeyHierarchy {
148    master_wrapper: KeyWrapper,
149    level_keys: Vec<KeyEncryptionKey>,
150}
151
152impl KeyHierarchy {
153    /// Create a new key hierarchy with a master KEK
154    pub fn new(master_kek: KeyEncryptionKey) -> Result<Self, KeyWrapError> {
155        let master_wrapper = KeyWrapper::new(&master_kek)?;
156        Ok(Self {
157            master_wrapper,
158            level_keys: Vec::new(),
159        })
160    }
161
162    /// Add a new level to the hierarchy
163    pub fn add_level(&mut self) -> Result<WrappedKey, KeyWrapError> {
164        let level_kek = KeyEncryptionKey::generate();
165        let level_id = format!("level-{}", self.level_keys.len());
166        let wrapped = self.master_wrapper.wrap(level_kek.as_bytes(), &level_id)?;
167        self.level_keys.push(level_kek);
168        Ok(wrapped)
169    }
170
171    /// Get wrapper for a specific level
172    pub fn get_level_wrapper(&self, level: usize) -> Result<KeyWrapper, KeyWrapError> {
173        let kek = self
174            .level_keys
175            .get(level)
176            .ok_or(KeyWrapError::InvalidFormat)?;
177        KeyWrapper::new(kek)
178    }
179
180    /// Wrap a data key at a specific level
181    pub fn wrap_data_key(&self, key: &[u8], level: usize, key_id: &str) -> Result<WrappedKey, KeyWrapError> {
182        let wrapper = self.get_level_wrapper(level)?;
183        wrapper.wrap(key, key_id)
184    }
185
186    /// Unwrap a data key at a specific level
187    pub fn unwrap_data_key(&self, wrapped: &WrappedKey, level: usize) -> Result<Vec<u8>, KeyWrapError> {
188        let wrapper = self.get_level_wrapper(level)?;
189        wrapper.unwrap(wrapped)
190    }
191}
192
193#[cfg(test)]
194mod tests {
195    use super::*;
196
197    #[test]
198    fn test_wrap_unwrap() {
199        let kek = KeyEncryptionKey::generate();
200        let wrapper = KeyWrapper::new(&kek).unwrap();
201
202        let data_key = vec![0u8; 32]; // 256-bit key
203        let wrapped = wrapper.wrap(&data_key, "key-001").unwrap();
204        let unwrapped = wrapper.unwrap(&wrapped).unwrap();
205
206        assert_eq!(data_key, unwrapped);
207    }
208
209    #[test]
210    fn test_wrapped_key_metadata() {
211        let kek = KeyEncryptionKey::generate();
212        let wrapper = KeyWrapper::new(&kek).unwrap();
213
214        let data_key = vec![0u8; 32];
215        let wrapped = wrapper.wrap(&data_key, "my-key").unwrap();
216
217        assert_eq!(wrapped.key_id, "my-key");
218        assert_eq!(wrapped.algorithm, "AES-256-GCM");
219    }
220
221    #[test]
222    fn test_wrong_kek_fails() {
223        let kek1 = KeyEncryptionKey::generate();
224        let kek2 = KeyEncryptionKey::generate();
225
226        let wrapper1 = KeyWrapper::new(&kek1).unwrap();
227        let wrapper2 = KeyWrapper::new(&kek2).unwrap();
228
229        let data_key = vec![0u8; 32];
230        let wrapped = wrapper1.wrap(&data_key, "key-001").unwrap();
231        let result = wrapper2.unwrap(&wrapped);
232
233        assert!(result.is_err());
234    }
235
236    #[test]
237    fn test_rewrap() {
238        let kek1 = KeyEncryptionKey::generate();
239        let kek2 = KeyEncryptionKey::generate();
240
241        let wrapper1 = KeyWrapper::new(&kek1).unwrap();
242        let wrapper2 = KeyWrapper::new(&kek2).unwrap();
243
244        let data_key = vec![0u8; 32];
245        let wrapped1 = wrapper1.wrap(&data_key, "key-001").unwrap();
246        let wrapped2 = wrapper1.rewrap(&wrapped1, &wrapper2).unwrap();
247
248        let unwrapped = wrapper2.unwrap(&wrapped2).unwrap();
249        assert_eq!(data_key, unwrapped);
250    }
251
252    #[test]
253    fn test_key_hierarchy() {
254        let master_kek = KeyEncryptionKey::generate();
255        let mut hierarchy = KeyHierarchy::new(master_kek).unwrap();
256
257        // Add two levels
258        hierarchy.add_level().unwrap();
259        hierarchy.add_level().unwrap();
260
261        // Wrap a data key at level 0
262        let data_key = vec![42u8; 32];
263        let wrapped = hierarchy.wrap_data_key(&data_key, 0, "data-key-001").unwrap();
264
265        // Unwrap it
266        let unwrapped = hierarchy.unwrap_data_key(&wrapped, 0).unwrap();
267        assert_eq!(data_key, unwrapped);
268    }
269
270    #[test]
271    fn test_wrapped_key_json() {
272        let kek = KeyEncryptionKey::generate();
273        let wrapper = KeyWrapper::new(&kek).unwrap();
274
275        let wrapped = wrapper.wrap(&[0u8; 32], "test-key").unwrap();
276        let json = wrapped.to_json().unwrap();
277
278        assert!(json.contains("test-key"));
279        assert!(json.contains("AES-256-GCM"));
280    }
281
282    #[test]
283    fn test_invalid_kek_length() {
284        let result = KeyEncryptionKey::from_bytes(&[0u8; 16]); // Too short
285        assert!(result.is_err());
286    }
287}