Skip to main content

rustant_core/
encryption.rs

1//! Session encryption — AES-256-GCM for encrypting session data at rest.
2
3use aes_gcm::aead::{Aead, KeyInit};
4use aes_gcm::{Aes256Gcm, Nonce};
5use rand::RngCore;
6use rand::rngs::OsRng;
7
8/// Errors that can occur during encryption/decryption.
9#[derive(Debug, thiserror::Error)]
10pub enum EncryptionError {
11    #[error("Encryption failed: {0}")]
12    EncryptFailed(String),
13    #[error("Decryption failed: {0}")]
14    DecryptFailed(String),
15    #[error("Invalid key length: expected 32 bytes, got {0}")]
16    InvalidKeyLength(usize),
17    #[error("Data too short to contain nonce")]
18    DataTooShort,
19    #[error("Keyring error: {0}")]
20    KeyringError(String),
21}
22
23/// Encrypts and decrypts session data using AES-256-GCM.
24pub struct SessionEncryptor {
25    cipher: Aes256Gcm,
26}
27
28impl SessionEncryptor {
29    /// Create an encryptor from a raw 32-byte key.
30    pub fn from_key(key: &[u8; 32]) -> Self {
31        let cipher = Aes256Gcm::new_from_slice(key).expect("32-byte key is always valid");
32        Self { cipher }
33    }
34
35    /// Create an encryptor from the system keyring.
36    /// If no key exists, generates and stores a new one.
37    pub fn from_keyring() -> Result<Self, EncryptionError> {
38        let store = crate::credentials::KeyringCredentialStore::new();
39        let service = "rustant_session_encryption";
40
41        match crate::credentials::CredentialStore::get_key(&store, service) {
42            Ok(key_b64) => {
43                let key_bytes = base64_decode(&key_b64)?;
44                if key_bytes.len() != 32 {
45                    return Err(EncryptionError::InvalidKeyLength(key_bytes.len()));
46                }
47                let mut key = [0u8; 32];
48                key.copy_from_slice(&key_bytes);
49                Ok(Self::from_key(&key))
50            }
51            Err(_) => {
52                // Generate a new key
53                let mut key = [0u8; 32];
54                OsRng.fill_bytes(&mut key);
55                let key_b64 = base64_encode(&key);
56                crate::credentials::CredentialStore::store_key(&store, service, &key_b64).map_err(
57                    |e| EncryptionError::KeyringError(format!("Failed to store key: {}", e)),
58                )?;
59                Ok(Self::from_key(&key))
60            }
61        }
62    }
63
64    /// Encrypt plaintext data.
65    /// Returns nonce (12 bytes) prepended to ciphertext.
66    pub fn encrypt(&self, plaintext: &[u8]) -> Result<Vec<u8>, EncryptionError> {
67        let mut nonce_bytes = [0u8; 12];
68        OsRng.fill_bytes(&mut nonce_bytes);
69        let nonce = Nonce::from_slice(&nonce_bytes);
70
71        let ciphertext = self
72            .cipher
73            .encrypt(nonce, plaintext)
74            .map_err(|e| EncryptionError::EncryptFailed(e.to_string()))?;
75
76        // Prepend nonce to ciphertext
77        let mut result = Vec::with_capacity(12 + ciphertext.len());
78        result.extend_from_slice(&nonce_bytes);
79        result.extend_from_slice(&ciphertext);
80        Ok(result)
81    }
82
83    /// Decrypt data that was encrypted with `encrypt()`.
84    /// Expects nonce (12 bytes) prepended to ciphertext.
85    pub fn decrypt(&self, data: &[u8]) -> Result<Vec<u8>, EncryptionError> {
86        if data.len() < 12 {
87            return Err(EncryptionError::DataTooShort);
88        }
89
90        let (nonce_bytes, ciphertext) = data.split_at(12);
91        let nonce = Nonce::from_slice(nonce_bytes);
92
93        self.cipher
94            .decrypt(nonce, ciphertext)
95            .map_err(|e| EncryptionError::DecryptFailed(e.to_string()))
96    }
97}
98
99fn base64_encode(data: &[u8]) -> String {
100    use base64::Engine;
101    base64::engine::general_purpose::STANDARD.encode(data)
102}
103
104fn base64_decode(s: &str) -> Result<Vec<u8>, EncryptionError> {
105    use base64::Engine;
106    base64::engine::general_purpose::STANDARD
107        .decode(s)
108        .map_err(|e| EncryptionError::DecryptFailed(format!("Base64 decode error: {}", e)))
109}
110
111#[cfg(test)]
112mod tests {
113    use super::*;
114
115    fn test_key() -> [u8; 32] {
116        let mut key = [0u8; 32];
117        for (i, byte) in key.iter_mut().enumerate() {
118            *byte = i as u8;
119        }
120        key
121    }
122
123    #[test]
124    fn test_encrypt_decrypt_roundtrip() {
125        let encryptor = SessionEncryptor::from_key(&test_key());
126        let plaintext = b"Hello, Rustant session data!";
127
128        let encrypted = encryptor.encrypt(plaintext).unwrap();
129        assert_ne!(&encrypted, plaintext);
130        assert!(encrypted.len() > plaintext.len()); // nonce + tag overhead
131
132        let decrypted = encryptor.decrypt(&encrypted).unwrap();
133        assert_eq!(&decrypted, plaintext);
134    }
135
136    #[test]
137    fn test_encrypt_empty_data() {
138        let encryptor = SessionEncryptor::from_key(&test_key());
139        let encrypted = encryptor.encrypt(b"").unwrap();
140        let decrypted = encryptor.decrypt(&encrypted).unwrap();
141        assert_eq!(decrypted, b"");
142    }
143
144    #[test]
145    fn test_decrypt_wrong_key_fails() {
146        let encryptor1 = SessionEncryptor::from_key(&test_key());
147        let mut wrong_key = test_key();
148        wrong_key[0] = 255;
149        let encryptor2 = SessionEncryptor::from_key(&wrong_key);
150
151        let encrypted = encryptor1.encrypt(b"secret data").unwrap();
152        let result = encryptor2.decrypt(&encrypted);
153        assert!(result.is_err());
154    }
155
156    #[test]
157    fn test_decrypt_too_short_data() {
158        let encryptor = SessionEncryptor::from_key(&test_key());
159        let result = encryptor.decrypt(&[1, 2, 3]); // Less than 12 bytes
160        assert!(result.is_err());
161        match result.unwrap_err() {
162            EncryptionError::DataTooShort => {}
163            e => panic!("Expected DataTooShort, got: {:?}", e),
164        }
165    }
166
167    #[test]
168    fn test_decrypt_tampered_data_fails() {
169        let encryptor = SessionEncryptor::from_key(&test_key());
170        let mut encrypted = encryptor.encrypt(b"important data").unwrap();
171        // Tamper with ciphertext
172        if let Some(last) = encrypted.last_mut() {
173            *last ^= 0xFF;
174        }
175        let result = encryptor.decrypt(&encrypted);
176        assert!(result.is_err());
177    }
178
179    #[test]
180    fn test_different_encryptions_produce_different_output() {
181        let encryptor = SessionEncryptor::from_key(&test_key());
182        let plaintext = b"same data";
183        let enc1 = encryptor.encrypt(plaintext).unwrap();
184        let enc2 = encryptor.encrypt(plaintext).unwrap();
185        // Different nonces should produce different ciphertext
186        assert_ne!(enc1, enc2);
187        // But both should decrypt to the same plaintext
188        assert_eq!(encryptor.decrypt(&enc1).unwrap(), plaintext);
189        assert_eq!(encryptor.decrypt(&enc2).unwrap(), plaintext);
190    }
191
192    #[test]
193    fn test_large_data() {
194        let encryptor = SessionEncryptor::from_key(&test_key());
195        let large_data: Vec<u8> = (0..100_000).map(|i| (i % 256) as u8).collect();
196        let encrypted = encryptor.encrypt(&large_data).unwrap();
197        let decrypted = encryptor.decrypt(&encrypted).unwrap();
198        assert_eq!(decrypted, large_data);
199    }
200}