rustant_core/
encryption.rs1use aes_gcm::aead::{Aead, KeyInit};
4use aes_gcm::{Aes256Gcm, Nonce};
5use rand::RngCore;
6use rand::rngs::OsRng;
7
8#[derive(Debug, thiserror::Error)]
10pub enum EncryptionError {
11 #[error("Encryption failed: {0}")]
12 EncryptFailed(String),
13 #[error("Decryption failed: {0}")]
14 DecryptFailed(String),
15 #[error("Invalid key length: expected 32 bytes, got {0}")]
16 InvalidKeyLength(usize),
17 #[error("Data too short to contain nonce")]
18 DataTooShort,
19 #[error("Keyring error: {0}")]
20 KeyringError(String),
21}
22
23pub struct SessionEncryptor {
25 cipher: Aes256Gcm,
26}
27
28impl SessionEncryptor {
29 pub fn from_key(key: &[u8; 32]) -> Self {
31 let cipher = Aes256Gcm::new_from_slice(key).expect("32-byte key is always valid");
32 Self { cipher }
33 }
34
35 pub fn from_keyring() -> Result<Self, EncryptionError> {
38 let store = crate::credentials::KeyringCredentialStore::new();
39 let service = "rustant_session_encryption";
40
41 match crate::credentials::CredentialStore::get_key(&store, service) {
42 Ok(key_b64) => {
43 let key_bytes = base64_decode(&key_b64)?;
44 if key_bytes.len() != 32 {
45 return Err(EncryptionError::InvalidKeyLength(key_bytes.len()));
46 }
47 let mut key = [0u8; 32];
48 key.copy_from_slice(&key_bytes);
49 Ok(Self::from_key(&key))
50 }
51 Err(_) => {
52 let mut key = [0u8; 32];
54 OsRng.fill_bytes(&mut key);
55 let key_b64 = base64_encode(&key);
56 crate::credentials::CredentialStore::store_key(&store, service, &key_b64).map_err(
57 |e| EncryptionError::KeyringError(format!("Failed to store key: {}", e)),
58 )?;
59 Ok(Self::from_key(&key))
60 }
61 }
62 }
63
64 pub fn encrypt(&self, plaintext: &[u8]) -> Result<Vec<u8>, EncryptionError> {
67 let mut nonce_bytes = [0u8; 12];
68 OsRng.fill_bytes(&mut nonce_bytes);
69 let nonce = Nonce::from_slice(&nonce_bytes);
70
71 let ciphertext = self
72 .cipher
73 .encrypt(nonce, plaintext)
74 .map_err(|e| EncryptionError::EncryptFailed(e.to_string()))?;
75
76 let mut result = Vec::with_capacity(12 + ciphertext.len());
78 result.extend_from_slice(&nonce_bytes);
79 result.extend_from_slice(&ciphertext);
80 Ok(result)
81 }
82
83 pub fn decrypt(&self, data: &[u8]) -> Result<Vec<u8>, EncryptionError> {
86 if data.len() < 12 {
87 return Err(EncryptionError::DataTooShort);
88 }
89
90 let (nonce_bytes, ciphertext) = data.split_at(12);
91 let nonce = Nonce::from_slice(nonce_bytes);
92
93 self.cipher
94 .decrypt(nonce, ciphertext)
95 .map_err(|e| EncryptionError::DecryptFailed(e.to_string()))
96 }
97}
98
99fn base64_encode(data: &[u8]) -> String {
100 use base64::Engine;
101 base64::engine::general_purpose::STANDARD.encode(data)
102}
103
104fn base64_decode(s: &str) -> Result<Vec<u8>, EncryptionError> {
105 use base64::Engine;
106 base64::engine::general_purpose::STANDARD
107 .decode(s)
108 .map_err(|e| EncryptionError::DecryptFailed(format!("Base64 decode error: {}", e)))
109}
110
111#[cfg(test)]
112mod tests {
113 use super::*;
114
115 fn test_key() -> [u8; 32] {
116 let mut key = [0u8; 32];
117 for (i, byte) in key.iter_mut().enumerate() {
118 *byte = i as u8;
119 }
120 key
121 }
122
123 #[test]
124 fn test_encrypt_decrypt_roundtrip() {
125 let encryptor = SessionEncryptor::from_key(&test_key());
126 let plaintext = b"Hello, Rustant session data!";
127
128 let encrypted = encryptor.encrypt(plaintext).unwrap();
129 assert_ne!(&encrypted, plaintext);
130 assert!(encrypted.len() > plaintext.len()); let decrypted = encryptor.decrypt(&encrypted).unwrap();
133 assert_eq!(&decrypted, plaintext);
134 }
135
136 #[test]
137 fn test_encrypt_empty_data() {
138 let encryptor = SessionEncryptor::from_key(&test_key());
139 let encrypted = encryptor.encrypt(b"").unwrap();
140 let decrypted = encryptor.decrypt(&encrypted).unwrap();
141 assert_eq!(decrypted, b"");
142 }
143
144 #[test]
145 fn test_decrypt_wrong_key_fails() {
146 let encryptor1 = SessionEncryptor::from_key(&test_key());
147 let mut wrong_key = test_key();
148 wrong_key[0] = 255;
149 let encryptor2 = SessionEncryptor::from_key(&wrong_key);
150
151 let encrypted = encryptor1.encrypt(b"secret data").unwrap();
152 let result = encryptor2.decrypt(&encrypted);
153 assert!(result.is_err());
154 }
155
156 #[test]
157 fn test_decrypt_too_short_data() {
158 let encryptor = SessionEncryptor::from_key(&test_key());
159 let result = encryptor.decrypt(&[1, 2, 3]); assert!(result.is_err());
161 match result.unwrap_err() {
162 EncryptionError::DataTooShort => {}
163 e => panic!("Expected DataTooShort, got: {:?}", e),
164 }
165 }
166
167 #[test]
168 fn test_decrypt_tampered_data_fails() {
169 let encryptor = SessionEncryptor::from_key(&test_key());
170 let mut encrypted = encryptor.encrypt(b"important data").unwrap();
171 if let Some(last) = encrypted.last_mut() {
173 *last ^= 0xFF;
174 }
175 let result = encryptor.decrypt(&encrypted);
176 assert!(result.is_err());
177 }
178
179 #[test]
180 fn test_different_encryptions_produce_different_output() {
181 let encryptor = SessionEncryptor::from_key(&test_key());
182 let plaintext = b"same data";
183 let enc1 = encryptor.encrypt(plaintext).unwrap();
184 let enc2 = encryptor.encrypt(plaintext).unwrap();
185 assert_ne!(enc1, enc2);
187 assert_eq!(encryptor.decrypt(&enc1).unwrap(), plaintext);
189 assert_eq!(encryptor.decrypt(&enc2).unwrap(), plaintext);
190 }
191
192 #[test]
193 fn test_large_data() {
194 let encryptor = SessionEncryptor::from_key(&test_key());
195 let large_data: Vec<u8> = (0..100_000).map(|i| (i % 256) as u8).collect();
196 let encrypted = encryptor.encrypt(&large_data).unwrap();
197 let decrypted = encryptor.decrypt(&encrypted).unwrap();
198 assert_eq!(decrypted, large_data);
199 }
200}