Skip to main content

sh_layer0/
encryption_engine.rs

1//! # Encryption Engine
2//!
3//! 安全加密引擎,提供生产级加密支持。
4//!
5//! ## 支持的算法
6//! - **AES-256-GCM**: 推荐,硬件加速
7//! - **ChaCha20-Poly1305**: 软件实现高效
8//!
9//! ## 密钥派生
10//! - Argon2id 密钥派生函数
11//!
12//! ## 功能
13//! - 对称加密/解密
14//! - 密钥派生 (KDF)
15//! - 密钥轮换
16//! - 安全密钥存储
17
18use aes_gcm::{
19    aead::{Aead, KeyInit, OsRng},
20    Aes256Gcm, Key, Nonce,
21};
22use argon2::{password_hash::rand_core::RngCore, Algorithm, Argon2, Params, Version};
23use base64::{engine::general_purpose::STANDARD as BASE64, Engine};
24use chacha20poly1305::{ChaCha20Poly1305, Key as ChaChaKey, Nonce as ChaChaNonce};
25use parking_lot::RwLock;
26use serde::{Deserialize, Serialize};
27use std::collections::HashMap;
28use std::sync::Arc;
29use thiserror::Error;
30use zeroize::Zeroizing;
31
32/// 加密错误类型
33#[derive(Debug, Error)]
34pub enum EncryptionError {
35    #[error("Encryption failed: {0}")]
36    EncryptionFailed(String),
37
38    #[error("Decryption failed: {0}")]
39    DecryptionFailed(String),
40
41    #[error("Key derivation failed: {0}")]
42    KeyDerivationFailed(String),
43
44    #[error("Invalid key length: expected {expected}, got {actual}")]
45    InvalidKeyLength { expected: usize, actual: usize },
46
47    #[error("Invalid nonce length: expected {expected}, got {actual}")]
48    InvalidNonceLength { expected: usize, actual: usize },
49
50    #[error("Key not found: {0}")]
51    KeyNotFound(String),
52
53    #[error("Invalid ciphertext format")]
54    InvalidCiphertextFormat,
55
56    #[error("Key rotation failed: {0}")]
57    KeyRotationFailed(String),
58}
59
60/// 加密算法类型
61#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
62pub enum EncryptionAlgorithm {
63    /// AES-256-GCM (推荐)
64    Aes256Gcm,
65    /// ChaCha20-Poly1305
66    ChaCha20Poly1305,
67}
68
69impl Default for EncryptionAlgorithm {
70    fn default() -> Self {
71        Self::Aes256Gcm
72    }
73}
74
75impl std::fmt::Display for EncryptionAlgorithm {
76    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
77        match self {
78            Self::Aes256Gcm => write!(f, "AES-256-GCM"),
79            Self::ChaCha20Poly1305 => write!(f, "ChaCha20-Poly1305"),
80        }
81    }
82}
83
84/// 加密密钥(安全存储)
85#[derive(Clone)]
86pub struct EncryptionKey {
87    /// 密钥数据(零化)
88    key_data: Arc<Zeroizing<Vec<u8>>>,
89    /// 算法类型
90    algorithm: EncryptionAlgorithm,
91    /// 密钥 ID
92    key_id: String,
93    /// 创建时间
94    created_at: chrono::DateTime<chrono::Utc>,
95}
96
97impl EncryptionKey {
98    /// 从原始字节创建密钥
99    pub fn from_bytes(
100        key_bytes: &[u8],
101        algorithm: EncryptionAlgorithm,
102    ) -> Result<Self, EncryptionError> {
103        let expected_len = Self::key_length(algorithm);
104        if key_bytes.len() != expected_len {
105            return Err(EncryptionError::InvalidKeyLength {
106                expected: expected_len,
107                actual: key_bytes.len(),
108            });
109        }
110
111        Ok(Self {
112            key_data: Arc::new(Zeroizing::new(key_bytes.to_vec())),
113            algorithm,
114            key_id: generate_key_id(),
115            created_at: chrono::Utc::now(),
116        })
117    }
118
119    /// 生成随机密钥
120    pub fn generate(algorithm: EncryptionAlgorithm) -> Result<Self, EncryptionError> {
121        let key_len = Self::key_length(algorithm);
122        let mut key_bytes = vec![0u8; key_len];
123        OsRng.fill_bytes(&mut key_bytes);
124
125        Self::from_bytes(&key_bytes, algorithm)
126    }
127
128    /// 从密码派生密钥
129    pub fn derive_from_password(
130        password: &str,
131        salt: &[u8],
132        algorithm: EncryptionAlgorithm,
133    ) -> Result<Self, EncryptionError> {
134        let key_len = Self::key_length(algorithm);
135        let params = Params::new(
136            Params::DEFAULT_M_COST,
137            Params::DEFAULT_T_COST,
138            Params::DEFAULT_P_COST,
139            Some(key_len),
140        )
141        .map_err(|e| EncryptionError::KeyDerivationFailed(e.to_string()))?;
142
143        let argon2 = Argon2::new(Algorithm::Argon2id, Version::V0x13, params);
144        let mut key_bytes = vec![0u8; key_len];
145
146        argon2
147            .hash_password_into(password.as_bytes(), salt, &mut key_bytes)
148            .map_err(|e| EncryptionError::KeyDerivationFailed(e.to_string()))?;
149
150        Self::from_bytes(&key_bytes, algorithm)
151    }
152
153    /// 获取密钥长度
154    pub fn key_length(algorithm: EncryptionAlgorithm) -> usize {
155        match algorithm {
156            EncryptionAlgorithm::Aes256Gcm => 32,
157            EncryptionAlgorithm::ChaCha20Poly1305 => 32,
158        }
159    }
160
161    /// 获取 Nonce 长度
162    pub fn nonce_length(algorithm: EncryptionAlgorithm) -> usize {
163        match algorithm {
164            EncryptionAlgorithm::Aes256Gcm => 12,
165            EncryptionAlgorithm::ChaCha20Poly1305 => 12,
166        }
167    }
168
169    /// 获取密钥数据
170    pub fn as_bytes(&self) -> &[u8] {
171        &self.key_data
172    }
173
174    /// 获取密钥 ID
175    pub fn key_id(&self) -> &str {
176        &self.key_id
177    }
178
179    /// 获取算法
180    pub fn algorithm(&self) -> EncryptionAlgorithm {
181        self.algorithm
182    }
183
184    /// 获取创建时间
185    pub fn created_at(&self) -> chrono::DateTime<chrono::Utc> {
186        self.created_at
187    }
188}
189
190impl std::fmt::Debug for EncryptionKey {
191    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
192        f.debug_struct("EncryptionKey")
193            .field("key_id", &self.key_id)
194            .field("algorithm", &self.algorithm)
195            .field("created_at", &self.created_at)
196            .field("key_data", &"<redacted>")
197            .finish()
198    }
199}
200
201/// 加密结果
202#[derive(Debug, Clone, Serialize, Deserialize)]
203pub struct EncryptedData {
204    /// 加密算法
205    pub algorithm: EncryptionAlgorithm,
206    /// Nonce(Base64)
207    pub nonce: String,
208    /// 密文(Base64)
209    pub ciphertext: String,
210    /// 密钥 ID
211    pub key_id: String,
212    /// 关联数据(可选)
213    #[serde(skip_serializing_if = "Option::is_none")]
214    pub associated_data: Option<String>,
215}
216
217impl EncryptedData {
218    /// 序列化为字符串
219    pub fn to_string(&self) -> Result<String, EncryptionError> {
220        serde_json::to_string(self).map_err(|e| EncryptionError::EncryptionFailed(e.to_string()))
221    }
222
223    /// 从字符串解析
224    pub fn from_string(s: &str) -> Result<Self, EncryptionError> {
225        serde_json::from_str(s).map_err(|e| EncryptionError::DecryptionFailed(e.to_string()))
226    }
227}
228
229/// 加密引擎配置
230#[derive(Debug, Clone)]
231pub struct EncryptionConfig {
232    /// 默认算法
233    pub default_algorithm: EncryptionAlgorithm,
234    /// 是否自动轮换密钥
235    pub auto_key_rotation: bool,
236    /// 密钥有效期(秒)
237    pub key_validity_secs: u64,
238    /// 最大密钥数量
239    pub max_keys: usize,
240}
241
242impl Default for EncryptionConfig {
243    fn default() -> Self {
244        Self {
245            default_algorithm: EncryptionAlgorithm::Aes256Gcm,
246            auto_key_rotation: false,
247            key_validity_secs: 86400 * 90, // 90 天
248            max_keys: 100,
249        }
250    }
251}
252
253/// 加密引擎
254pub struct EncryptionEngine {
255    /// 密钥存储
256    keys: RwLock<HashMap<String, EncryptionKey>>,
257    /// 默认密钥 ID
258    default_key_id: RwLock<Option<String>>,
259    /// 配置
260    config: EncryptionConfig,
261}
262
263impl EncryptionEngine {
264    /// 创建新的加密引擎
265    pub fn new() -> Self {
266        Self::with_config(EncryptionConfig::default())
267    }
268
269    /// 使用配置创建
270    pub fn with_config(config: EncryptionConfig) -> Self {
271        Self {
272            keys: RwLock::new(HashMap::new()),
273            default_key_id: RwLock::new(None),
274            config,
275        }
276    }
277
278    /// 生成并添加新密钥
279    pub fn generate_key(&self) -> Result<String, EncryptionError> {
280        self.generate_key_with_algorithm(self.config.default_algorithm)
281    }
282
283    /// 使用指定算法生成密钥
284    pub fn generate_key_with_algorithm(
285        &self,
286        algorithm: EncryptionAlgorithm,
287    ) -> Result<String, EncryptionError> {
288        let key = EncryptionKey::generate(algorithm)?;
289        let key_id = key.key_id().to_string();
290
291        // 检查密钥数量限制
292        let mut keys = self.keys.write();
293        if keys.len() >= self.config.max_keys {
294            // 移除最旧的密钥
295            if let Some((oldest_id, _)) = keys
296                .iter()
297                .min_by_key(|(_, k)| k.created_at())
298                .map(|(id, k)| (id.clone(), k.created_at()))
299            {
300                keys.remove(&oldest_id);
301            }
302        }
303
304        keys.insert(key_id.clone(), key);
305
306        // 设置为默认密钥
307        *self.default_key_id.write() = Some(key_id.clone());
308
309        tracing::info!("Generated new encryption key: {}", key_id);
310        Ok(key_id)
311    }
312
313    /// 从密码派生并添加密钥
314    pub fn derive_key_from_password(
315        &self,
316        password: &str,
317        salt: &[u8],
318    ) -> Result<String, EncryptionError> {
319        let key =
320            EncryptionKey::derive_from_password(password, salt, self.config.default_algorithm)?;
321        let key_id = key.key_id().to_string();
322
323        self.keys.write().insert(key_id.clone(), key);
324        *self.default_key_id.write() = Some(key_id.clone());
325
326        tracing::info!("Derived encryption key from password: {}", key_id);
327        Ok(key_id)
328    }
329
330    /// 添加已有密钥
331    pub fn add_key(&self, key: EncryptionKey) -> Result<String, EncryptionError> {
332        let key_id = key.key_id().to_string();
333        self.keys.write().insert(key_id.clone(), key);
334        *self.default_key_id.write() = Some(key_id.clone());
335        Ok(key_id)
336    }
337
338    /// 获取密钥
339    pub fn get_key(&self, key_id: &str) -> Result<EncryptionKey, EncryptionError> {
340        self.keys
341            .read()
342            .get(key_id)
343            .cloned()
344            .ok_or_else(|| EncryptionError::KeyNotFound(key_id.to_string()))
345    }
346
347    /// 删除密钥
348    pub fn remove_key(&self, key_id: &str) -> Result<bool, EncryptionError> {
349        let removed = self.keys.write().remove(key_id).is_some();
350
351        // 如果删除的是默认密钥,清除默认密钥设置
352        if removed {
353            let mut default_id = self.default_key_id.write();
354            if default_id.as_deref() == Some(key_id) {
355                *default_id = None;
356            }
357        }
358
359        Ok(removed)
360    }
361
362    /// 列出所有密钥 ID
363    pub fn list_keys(&self) -> Vec<String> {
364        self.keys.read().keys().cloned().collect()
365    }
366
367    /// 获取默认密钥
368    pub fn get_default_key(&self) -> Result<EncryptionKey, EncryptionError> {
369        let default_id = self
370            .default_key_id
371            .read()
372            .clone()
373            .ok_or_else(|| EncryptionError::KeyNotFound("default".to_string()))?;
374
375        self.get_key(&default_id)
376    }
377
378    /// 设置默认密钥
379    pub fn set_default_key(&self, key_id: &str) -> Result<(), EncryptionError> {
380        if !self.keys.read().contains_key(key_id) {
381            return Err(EncryptionError::KeyNotFound(key_id.to_string()));
382        }
383
384        *self.default_key_id.write() = Some(key_id.to_string());
385        Ok(())
386    }
387
388    /// 加密数据
389    pub fn encrypt(&self, plaintext: &[u8]) -> Result<EncryptedData, EncryptionError> {
390        let key = self.get_default_key()?;
391        self.encrypt_with_key(&key, plaintext)
392    }
393
394    /// 使用指定密钥加密
395    pub fn encrypt_with_key(
396        &self,
397        key: &EncryptionKey,
398        plaintext: &[u8],
399    ) -> Result<EncryptedData, EncryptionError> {
400        // 生成随机 Nonce
401        let nonce_len = EncryptionKey::nonce_length(key.algorithm);
402        let mut nonce_bytes = vec![0u8; nonce_len];
403        OsRng.fill_bytes(&mut nonce_bytes);
404
405        let ciphertext = match key.algorithm {
406            EncryptionAlgorithm::Aes256Gcm => {
407                let cipher_key = Key::<Aes256Gcm>::from_slice(key.as_bytes());
408                let cipher = Aes256Gcm::new(cipher_key);
409                let nonce = Nonce::from_slice(&nonce_bytes);
410
411                cipher
412                    .encrypt(nonce, plaintext)
413                    .map_err(|e| EncryptionError::EncryptionFailed(e.to_string()))?
414            }
415            EncryptionAlgorithm::ChaCha20Poly1305 => {
416                let cipher_key = ChaChaKey::from_slice(key.as_bytes());
417                let cipher = ChaCha20Poly1305::new(cipher_key);
418                let nonce = ChaChaNonce::from_slice(&nonce_bytes);
419
420                cipher
421                    .encrypt(nonce, plaintext)
422                    .map_err(|e| EncryptionError::EncryptionFailed(e.to_string()))?
423            }
424        };
425
426        Ok(EncryptedData {
427            algorithm: key.algorithm,
428            nonce: BASE64.encode(&nonce_bytes),
429            ciphertext: BASE64.encode(&ciphertext),
430            key_id: key.key_id().to_string(),
431            associated_data: None,
432        })
433    }
434
435    /// 加密字符串
436    pub fn encrypt_string(&self, plaintext: &str) -> Result<EncryptedData, EncryptionError> {
437        self.encrypt(plaintext.as_bytes())
438    }
439
440    /// 解密数据
441    pub fn decrypt(&self, encrypted: &EncryptedData) -> Result<Vec<u8>, EncryptionError> {
442        let key = self.get_key(&encrypted.key_id)?;
443        self.decrypt_with_key(&key, encrypted)
444    }
445
446    /// 使用指定密钥解密
447    pub fn decrypt_with_key(
448        &self,
449        key: &EncryptionKey,
450        encrypted: &EncryptedData,
451    ) -> Result<Vec<u8>, EncryptionError> {
452        // 解码 Base64
453        let nonce_bytes = BASE64
454            .decode(&encrypted.nonce)
455            .map_err(|_| EncryptionError::InvalidCiphertextFormat)?;
456
457        let ciphertext = BASE64
458            .decode(&encrypted.ciphertext)
459            .map_err(|_| EncryptionError::InvalidCiphertextFormat)?;
460
461        // 验证 Nonce 长度
462        let expected_nonce_len = EncryptionKey::nonce_length(key.algorithm);
463        if nonce_bytes.len() != expected_nonce_len {
464            return Err(EncryptionError::InvalidNonceLength {
465                expected: expected_nonce_len,
466                actual: nonce_bytes.len(),
467            });
468        }
469
470        let plaintext = match key.algorithm {
471            EncryptionAlgorithm::Aes256Gcm => {
472                let cipher_key = Key::<Aes256Gcm>::from_slice(key.as_bytes());
473                let cipher = Aes256Gcm::new(cipher_key);
474                let nonce = Nonce::from_slice(&nonce_bytes);
475
476                cipher
477                    .decrypt(nonce, ciphertext.as_slice())
478                    .map_err(|e| EncryptionError::DecryptionFailed(e.to_string()))?
479            }
480            EncryptionAlgorithm::ChaCha20Poly1305 => {
481                let cipher_key = ChaChaKey::from_slice(key.as_bytes());
482                let cipher = ChaCha20Poly1305::new(cipher_key);
483                let nonce = ChaChaNonce::from_slice(&nonce_bytes);
484
485                cipher
486                    .decrypt(nonce, ciphertext.as_slice())
487                    .map_err(|e| EncryptionError::DecryptionFailed(e.to_string()))?
488            }
489        };
490
491        Ok(plaintext)
492    }
493
494    /// 解密为字符串
495    pub fn decrypt_to_string(&self, encrypted: &EncryptedData) -> Result<String, EncryptionError> {
496        let plaintext = self.decrypt(encrypted)?;
497        String::from_utf8(plaintext).map_err(|e| EncryptionError::DecryptionFailed(e.to_string()))
498    }
499
500    /// 轮换密钥
501    pub fn rotate_key(&self, old_key_id: &str) -> Result<String, EncryptionError> {
502        // 检查旧密钥是否存在
503        if !self.keys.read().contains_key(old_key_id) {
504            return Err(EncryptionError::KeyNotFound(old_key_id.to_string()));
505        }
506
507        // 生成新密钥
508        let new_key_id = self.generate_key()?;
509
510        tracing::info!("Key rotated: {} -> {}", old_key_id, new_key_id);
511        Ok(new_key_id)
512    }
513
514    /// 获取需要轮换的密钥
515    pub fn get_keys_requiring_rotation(&self) -> Vec<String> {
516        let now = chrono::Utc::now();
517        let validity = chrono::Duration::seconds(self.config.key_validity_secs as i64);
518
519        self.keys
520            .read()
521            .iter()
522            .filter(|(_, key)| {
523                let age = now.signed_duration_since(key.created_at());
524                age > validity
525            })
526            .map(|(id, _)| id.clone())
527            .collect()
528    }
529
530    /// 导出密钥(Base64)
531    pub fn export_key(&self, key_id: &str) -> Result<String, EncryptionError> {
532        let key = self.get_key(key_id)?;
533        Ok(BASE64.encode(key.as_bytes()))
534    }
535
536    /// 导入密钥(Base64)
537    pub fn import_key(
538        &self,
539        key_b64: &str,
540        algorithm: EncryptionAlgorithm,
541    ) -> Result<String, EncryptionError> {
542        let key_bytes = BASE64
543            .decode(key_b64)
544            .map_err(|_| EncryptionError::InvalidCiphertextFormat)?;
545
546        let key = EncryptionKey::from_bytes(&key_bytes, algorithm)?;
547        self.add_key(key)
548    }
549
550    /// 密钥数量
551    pub fn key_count(&self) -> usize {
552        self.keys.read().len()
553    }
554
555    /// 是否有默认密钥
556    pub fn has_default_key(&self) -> bool {
557        self.default_key_id.read().is_some()
558    }
559}
560
561impl Default for EncryptionEngine {
562    fn default() -> Self {
563        Self::new()
564    }
565}
566
567/// 生成密钥 ID
568fn generate_key_id() -> String {
569    use uuid::Uuid;
570    format!("key_{}", Uuid::new_v4())
571}
572
573/// 生成盐值
574pub fn generate_salt() -> Vec<u8> {
575    let mut salt = vec![0u8; 16];
576    OsRng.fill_bytes(&mut salt);
577    salt
578}
579
580/// 从密码派生密钥(独立函数)
581pub fn derive_key_from_password(
582    password: &str,
583    salt: &[u8],
584    algorithm: EncryptionAlgorithm,
585) -> Result<EncryptionKey, EncryptionError> {
586    EncryptionKey::derive_from_password(password, salt, algorithm)
587}
588
589#[cfg(test)]
590mod tests {
591    use super::*;
592
593    #[test]
594    fn test_encryption_key_generation() {
595        let key = EncryptionKey::generate(EncryptionAlgorithm::Aes256Gcm).unwrap();
596        assert_eq!(key.as_bytes().len(), 32);
597        assert!(!key.key_id().is_empty());
598    }
599
600    #[test]
601    fn test_encryption_key_from_bytes() {
602        let key_bytes = vec![0u8; 32];
603        let key = EncryptionKey::from_bytes(&key_bytes, EncryptionAlgorithm::Aes256Gcm).unwrap();
604        assert_eq!(key.as_bytes().len(), 32);
605    }
606
607    #[test]
608    fn test_encryption_key_invalid_length() {
609        let key_bytes = vec![0u8; 16];
610        let result = EncryptionKey::from_bytes(&key_bytes, EncryptionAlgorithm::Aes256Gcm);
611        assert!(result.is_err());
612    }
613
614    #[test]
615    fn test_key_derivation() {
616        let salt = generate_salt();
617        let key =
618            derive_key_from_password("password123", &salt, EncryptionAlgorithm::Aes256Gcm).unwrap();
619        assert_eq!(key.as_bytes().len(), 32);
620    }
621
622    #[test]
623    fn test_aes_gcm_encryption_decryption() {
624        let engine = EncryptionEngine::new();
625        engine.generate_key().unwrap();
626
627        let plaintext = b"Hello, World!";
628        let encrypted = engine.encrypt(plaintext).unwrap();
629        let decrypted = engine.decrypt(&encrypted).unwrap();
630
631        assert_eq!(plaintext.to_vec(), decrypted);
632    }
633
634    #[test]
635    fn test_chacha_encryption_decryption() {
636        let mut config = EncryptionConfig::default();
637        config.default_algorithm = EncryptionAlgorithm::ChaCha20Poly1305;
638        let engine = EncryptionEngine::with_config(config);
639        engine.generate_key().unwrap();
640
641        let plaintext = b"Hello, ChaCha20!";
642        let encrypted = engine.encrypt(plaintext).unwrap();
643        let decrypted = engine.decrypt(&encrypted).unwrap();
644
645        assert_eq!(plaintext.to_vec(), decrypted);
646    }
647
648    #[test]
649    fn test_string_encryption() {
650        let engine = EncryptionEngine::new();
651        engine.generate_key().unwrap();
652
653        let plaintext = "Secret message";
654        let encrypted = engine.encrypt_string(plaintext).unwrap();
655        let decrypted = engine.decrypt_to_string(&encrypted).unwrap();
656
657        assert_eq!(plaintext, decrypted);
658    }
659
660    #[test]
661    fn test_multiple_keys() {
662        let engine = EncryptionEngine::new();
663        let key_id1 = engine.generate_key().unwrap();
664        let key_id2 = engine
665            .generate_key_with_algorithm(EncryptionAlgorithm::ChaCha20Poly1305)
666            .unwrap();
667
668        assert_eq!(engine.key_count(), 2);
669        assert!(engine.list_keys().contains(&key_id1));
670        assert!(engine.list_keys().contains(&key_id2));
671    }
672
673    #[test]
674    fn test_key_removal() {
675        let engine = EncryptionEngine::new();
676        let key_id = engine.generate_key().unwrap();
677
678        assert!(engine.remove_key(&key_id).unwrap());
679        assert!(!engine.list_keys().contains(&key_id));
680    }
681
682    #[test]
683    fn test_key_rotation() {
684        let engine = EncryptionEngine::new();
685        let old_key_id = engine.generate_key().unwrap();
686
687        let new_key_id = engine.rotate_key(&old_key_id).unwrap();
688
689        assert_ne!(old_key_id, new_key_id);
690        assert!(engine.list_keys().contains(&new_key_id));
691    }
692
693    #[test]
694    fn test_encrypted_data_serialization() {
695        let engine = EncryptionEngine::new();
696        engine.generate_key().unwrap();
697
698        let encrypted = engine.encrypt_string("test").unwrap();
699        let serialized = encrypted.to_string().unwrap();
700        let deserialized = EncryptedData::from_string(&serialized).unwrap();
701
702        assert_eq!(encrypted.key_id, deserialized.key_id);
703        assert_eq!(encrypted.nonce, deserialized.nonce);
704        assert_eq!(encrypted.ciphertext, deserialized.ciphertext);
705    }
706
707    #[test]
708    fn test_key_export_import() {
709        let engine = EncryptionEngine::new();
710        let key_id = engine.generate_key().unwrap();
711
712        let exported = engine.export_key(&key_id).unwrap();
713        engine.remove_key(&key_id).unwrap();
714
715        let imported_id = engine
716            .import_key(&exported, EncryptionAlgorithm::Aes256Gcm)
717            .unwrap();
718        assert!(engine.list_keys().contains(&imported_id));
719    }
720
721    #[test]
722    fn test_different_plaintexts_different_ciphertexts() {
723        let engine = EncryptionEngine::new();
724        engine.generate_key().unwrap();
725
726        let encrypted1 = engine.encrypt(b"test").unwrap();
727        let encrypted2 = engine.encrypt(b"test").unwrap();
728
729        // 即使明文相同,密文也应该不同(因为 Nonce 不同)
730        assert_ne!(encrypted1.ciphertext, encrypted2.ciphertext);
731        assert_ne!(encrypted1.nonce, encrypted2.nonce);
732    }
733
734    #[test]
735    fn test_large_data_encryption() {
736        let engine = EncryptionEngine::new();
737        engine.generate_key().unwrap();
738
739        let large_data = vec![0u8; 1_000_000]; // 1MB
740        let encrypted = engine.encrypt(&large_data).unwrap();
741        let decrypted = engine.decrypt(&encrypted).unwrap();
742
743        assert_eq!(large_data, decrypted);
744    }
745
746    #[test]
747    fn test_empty_data_encryption() {
748        let engine = EncryptionEngine::new();
749        engine.generate_key().unwrap();
750
751        let encrypted = engine.encrypt(b"").unwrap();
752        let decrypted = engine.decrypt(&encrypted).unwrap();
753
754        assert!(decrypted.is_empty());
755    }
756
757    #[test]
758    fn test_algorithm_display() {
759        assert_eq!(format!("{}", EncryptionAlgorithm::Aes256Gcm), "AES-256-GCM");
760        assert_eq!(
761            format!("{}", EncryptionAlgorithm::ChaCha20Poly1305),
762            "ChaCha20-Poly1305"
763        );
764    }
765
766    #[test]
767    fn test_encryption_key_debug_redaction() {
768        let key = EncryptionKey::generate(EncryptionAlgorithm::Aes256Gcm).unwrap();
769        let debug_output = format!("{:?}", key);
770
771        // 确保调试输出不包含密钥数据
772        assert!(debug_output.contains("<redacted>"));
773        assert!(!debug_output.contains(&BASE64.encode(key.as_bytes())));
774    }
775
776    #[test]
777    fn test_default_key() {
778        let engine = EncryptionEngine::new();
779        assert!(!engine.has_default_key());
780
781        engine.generate_key().unwrap();
782        assert!(engine.has_default_key());
783    }
784
785    #[test]
786    fn test_set_default_key() {
787        let engine = EncryptionEngine::new();
788        let key_id1 = engine.generate_key().unwrap();
789        let key_id2 = engine.generate_key().unwrap();
790
791        // 第二个生成的密钥应该是默认的
792        let default_key = engine.get_default_key().unwrap();
793        assert_eq!(default_key.key_id(), key_id2);
794
795        // 设置第一个密钥为默认
796        engine.set_default_key(&key_id1).unwrap();
797        let default_key = engine.get_default_key().unwrap();
798        assert_eq!(default_key.key_id(), key_id1);
799    }
800
801    #[test]
802    fn test_wrong_algorithm_decryption() {
803        let engine = EncryptionEngine::new();
804        let aes_key = engine
805            .generate_key_with_algorithm(EncryptionAlgorithm::Aes256Gcm)
806            .unwrap();
807        let chacha_key = engine
808            .generate_key_with_algorithm(EncryptionAlgorithm::ChaCha20Poly1305)
809            .unwrap();
810
811        let aes_encrypted = engine
812            .encrypt_with_key(&engine.get_key(&aes_key).unwrap(), b"test")
813            .unwrap();
814
815        // 尝试用错误的密钥解密应该失败
816        let result = engine.decrypt_with_key(&engine.get_key(&chacha_key).unwrap(), &aes_encrypted);
817        assert!(result.is_err());
818    }
819
820    #[test]
821    fn test_concurrent_encryption() {
822        use std::sync::Arc;
823        use std::thread;
824
825        let engine = Arc::new(EncryptionEngine::new());
826        engine.generate_key().unwrap();
827
828        let mut handles = vec![];
829
830        for i in 0..10 {
831            let e = Arc::clone(&engine);
832            handles.push(thread::spawn(move || {
833                let data = format!("message_{}", i);
834                let encrypted = e.encrypt_string(&data).unwrap();
835                let decrypted = e.decrypt_to_string(&encrypted).unwrap();
836                assert_eq!(data, decrypted);
837            }));
838        }
839
840        for h in handles {
841            h.join().unwrap();
842        }
843    }
844}