Skip to main content

tryaudex_core/
keystore.rs

1use aes_gcm::{
2    aead::{Aead, KeyInit},
3    Aes256Gcm, Nonce,
4};
5use serde::{Deserialize, Serialize};
6
7use crate::error::{AvError, Result};
8
9const KEYRING_SERVICE: &str = "audex";
10const KEYRING_USER: &str = "credential-encryption-key";
11const NONCE_LEN: usize = 12;
12
13static CACHED_KEY: std::sync::OnceLock<[u8; 32]> = std::sync::OnceLock::new();
14
15/// Encrypted blob stored on disk instead of plaintext JSON.
16#[derive(Debug, Serialize, Deserialize)]
17pub struct EncryptedBlob {
18    /// Base64-encoded nonce
19    pub nonce: String,
20    /// Base64-encoded ciphertext
21    pub ciphertext: String,
22}
23
24/// Get or create the encryption key from the OS keychain.
25/// Falls back to a machine-derived key if keychain is unavailable.
26/// Key is cached per-process to ensure consistency.
27fn get_or_create_key() -> [u8; 32] {
28    *CACHED_KEY.get_or_init(|| {
29        // Try OS keychain first
30        if let Ok(key) = load_key_from_keychain() {
31            return key;
32        }
33
34        // Try to generate and store a new key
35        if let Ok(key) = generate_and_store_key() {
36            return key;
37        }
38
39        // Fallback: derive from machine-specific data
40        derive_fallback_key()
41    })
42}
43
44fn load_key_from_keychain() -> std::result::Result<[u8; 32], ()> {
45    let entry = keyring::Entry::new(KEYRING_SERVICE, KEYRING_USER).map_err(|_| ())?;
46    let secret = entry.get_password().map_err(|_| ())?;
47    use base64::Engine;
48    let bytes = base64::engine::general_purpose::STANDARD
49        .decode(&secret)
50        .map_err(|_| ())?;
51    if bytes.len() != 32 {
52        return Err(());
53    }
54    let mut key = [0u8; 32];
55    key.copy_from_slice(&bytes);
56    Ok(key)
57}
58
59fn generate_and_store_key() -> std::result::Result<[u8; 32], ()> {
60    use rand::Rng;
61    let mut key = [0u8; 32];
62    rand::rng().fill(&mut key);
63
64    use base64::Engine;
65    let encoded = base64::engine::general_purpose::STANDARD.encode(key);
66
67    let entry = keyring::Entry::new(KEYRING_SERVICE, KEYRING_USER).map_err(|_| ())?;
68    entry.set_password(&encoded).map_err(|_| ())?;
69
70    Ok(key)
71}
72
73/// Derive a deterministic key from machine-specific data when keychain is unavailable.
74fn derive_fallback_key() -> [u8; 32] {
75    use sha2::Digest;
76
77    let mut hasher = sha2::Sha256::new();
78    hasher.update(b"audex-fallback-key-v1");
79
80    // Mix in username
81    if let Ok(user) = std::env::var("USER").or_else(|_| std::env::var("LOGNAME")) {
82        hasher.update(user.as_bytes());
83    }
84
85    // Mix in home directory
86    if let Some(home) = dirs::home_dir() {
87        hasher.update(home.to_string_lossy().as_bytes());
88    }
89
90    // Mix in machine-id if available (Linux)
91    if let Ok(machine_id) = std::fs::read_to_string("/etc/machine-id") {
92        hasher.update(machine_id.trim().as_bytes());
93    }
94
95    hasher.finalize().into()
96}
97
98/// Encrypt data using AES-256-GCM with a key from the OS keychain.
99pub fn encrypt(plaintext: &[u8]) -> Result<EncryptedBlob> {
100    let key = get_or_create_key();
101    let cipher = Aes256Gcm::new_from_slice(&key)
102        .map_err(|e| AvError::InvalidPolicy(format!("Encryption init failed: {}", e)))?;
103
104    use rand::Rng;
105    let mut nonce_bytes = [0u8; NONCE_LEN];
106    rand::rng().fill(&mut nonce_bytes);
107    let nonce = Nonce::from_slice(&nonce_bytes);
108
109    let ciphertext = cipher
110        .encrypt(nonce, plaintext)
111        .map_err(|e| AvError::InvalidPolicy(format!("Encryption failed: {}", e)))?;
112
113    use base64::Engine;
114    Ok(EncryptedBlob {
115        nonce: base64::engine::general_purpose::STANDARD.encode(nonce_bytes),
116        ciphertext: base64::engine::general_purpose::STANDARD.encode(ciphertext),
117    })
118}
119
120/// Decrypt data using AES-256-GCM with a key from the OS keychain.
121pub fn decrypt(blob: &EncryptedBlob) -> Result<Vec<u8>> {
122    let key = get_or_create_key();
123    let cipher = Aes256Gcm::new_from_slice(&key)
124        .map_err(|e| AvError::InvalidPolicy(format!("Decryption init failed: {}", e)))?;
125
126    use base64::Engine;
127    let nonce_bytes = base64::engine::general_purpose::STANDARD
128        .decode(&blob.nonce)
129        .map_err(|e| AvError::InvalidPolicy(format!("Invalid nonce: {}", e)))?;
130    let ciphertext = base64::engine::general_purpose::STANDARD
131        .decode(&blob.ciphertext)
132        .map_err(|e| AvError::InvalidPolicy(format!("Invalid ciphertext: {}", e)))?;
133
134    if nonce_bytes.len() != NONCE_LEN {
135        return Err(AvError::InvalidPolicy("Invalid nonce length".to_string()));
136    }
137
138    let nonce = Nonce::from_slice(&nonce_bytes);
139    cipher
140        .decrypt(nonce, ciphertext.as_ref())
141        .map_err(|e| AvError::InvalidPolicy(format!("Decryption failed: {}", e)))
142}
143
144/// Encrypt a serializable value and write to a file.
145pub fn encrypt_to_file<T: Serialize>(path: &std::path::Path, value: &T) -> Result<()> {
146    let json = serde_json::to_vec(value)?;
147    let blob = encrypt(&json)?;
148    let blob_json = serde_json::to_string(&blob)?;
149    std::fs::write(path, blob_json)?;
150    Ok(())
151}
152
153/// Read and decrypt a value from a file.
154pub fn decrypt_from_file<T: for<'de> Deserialize<'de>>(
155    path: &std::path::Path,
156) -> Result<Option<T>> {
157    if !path.exists() {
158        return Ok(None);
159    }
160    let blob_json = std::fs::read_to_string(path)?;
161
162    // Try to decrypt as encrypted blob
163    if let Ok(blob) = serde_json::from_str::<EncryptedBlob>(&blob_json) {
164        if !blob.nonce.is_empty() && !blob.ciphertext.is_empty() {
165            let plaintext = decrypt(&blob)?;
166            let value: T = serde_json::from_slice(&plaintext)?;
167            return Ok(Some(value));
168        }
169    }
170
171    // Fallback: try reading as plaintext JSON (migration from unencrypted)
172    match serde_json::from_str::<T>(&blob_json) {
173        Ok(value) => Ok(Some(value)),
174        Err(e) => Err(AvError::InvalidPolicy(format!(
175            "Failed to read cached data: {}",
176            e
177        ))),
178    }
179}
180
181#[cfg(test)]
182mod tests {
183    use super::*;
184
185    #[test]
186    fn test_encrypt_decrypt_roundtrip() {
187        let plaintext = b"secret credential data here";
188        let blob = encrypt(plaintext).unwrap();
189
190        // Verify it's not plaintext
191        assert_ne!(blob.ciphertext.as_bytes(), plaintext);
192
193        let decrypted = decrypt(&blob).unwrap();
194        assert_eq!(decrypted, plaintext);
195    }
196
197    #[test]
198    fn test_encrypt_decrypt_json_value() {
199        use serde_json::json;
200        let value = json!({
201            "access_key_id": "AKIAIOSFODNN7EXAMPLE",
202            "secret_access_key": "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY",
203            "session_token": "FwoGZXIvYXdzEBY..."
204        });
205
206        let json_bytes = serde_json::to_vec(&value).unwrap();
207        let blob = encrypt(&json_bytes).unwrap();
208        let decrypted = decrypt(&blob).unwrap();
209        let parsed: serde_json::Value = serde_json::from_slice(&decrypted).unwrap();
210
211        assert_eq!(parsed["access_key_id"], "AKIAIOSFODNN7EXAMPLE");
212    }
213
214    #[test]
215    fn test_different_encryptions_differ() {
216        let plaintext = b"same data";
217        let blob1 = encrypt(plaintext).unwrap();
218        let blob2 = encrypt(plaintext).unwrap();
219        // Different nonces should produce different ciphertexts
220        assert_ne!(blob1.ciphertext, blob2.ciphertext);
221        assert_ne!(blob1.nonce, blob2.nonce);
222    }
223
224    #[test]
225    fn test_tampered_ciphertext_fails() {
226        let blob = encrypt(b"secret").unwrap();
227        let mut tampered = blob;
228        // Flip a character in the ciphertext
229        let mut chars: Vec<char> = tampered.ciphertext.chars().collect();
230        if let Some(c) = chars.get_mut(5) {
231            *c = if *c == 'A' { 'B' } else { 'A' };
232        }
233        tampered.ciphertext = chars.into_iter().collect();
234        assert!(decrypt(&tampered).is_err());
235    }
236
237    #[test]
238    fn test_encrypted_blob_serialization() {
239        let blob = encrypt(b"test data").unwrap();
240        let json = serde_json::to_string(&blob).unwrap();
241        let parsed: EncryptedBlob = serde_json::from_str(&json).unwrap();
242        assert_eq!(parsed.nonce, blob.nonce);
243        assert_eq!(parsed.ciphertext, blob.ciphertext);
244
245        let decrypted = decrypt(&parsed).unwrap();
246        assert_eq!(decrypted, b"test data");
247    }
248
249    #[test]
250    fn test_fallback_key_deterministic() {
251        let key1 = derive_fallback_key();
252        let key2 = derive_fallback_key();
253        assert_eq!(key1, key2);
254    }
255}