Skip to main content

tryaudex_core/
keystore.rs

1use aes_gcm::{
2    aead::{Aead, KeyInit},
3    Aes256Gcm, Nonce,
4};
5use serde::{Deserialize, Serialize};
6
7use crate::error::{AvError, Result};
8
9const KEYRING_SERVICE: &str = "audex";
10const KEYRING_USER: &str = "credential-encryption-key";
11const NONCE_LEN: usize = 12;
12
13static CACHED_KEY: std::sync::OnceLock<[u8; 32]> = std::sync::OnceLock::new();
14
15/// Encrypted blob stored on disk instead of plaintext JSON.
16#[derive(Debug, Serialize, Deserialize)]
17pub struct EncryptedBlob {
18    /// Base64-encoded nonce
19    pub nonce: String,
20    /// Base64-encoded ciphertext
21    pub ciphertext: String,
22}
23
24/// Get or create the encryption key from the OS keychain.
25/// Falls back to a machine-derived key if keychain is unavailable.
26/// Key is cached per-process to ensure consistency.
27fn get_or_create_key() -> [u8; 32] {
28    *CACHED_KEY.get_or_init(|| {
29        // Try OS keychain first
30        if let Ok(key) = load_key_from_keychain() {
31            return key;
32        }
33
34        // Try to generate and store a new key
35        if let Ok(key) = generate_and_store_key() {
36            return key;
37        }
38
39        // Fallback: derive from machine-specific data
40        derive_fallback_key()
41    })
42}
43
44fn load_key_from_keychain() -> std::result::Result<[u8; 32], ()> {
45    let entry = keyring::Entry::new(KEYRING_SERVICE, KEYRING_USER).map_err(|_| ())?;
46    let secret = entry.get_password().map_err(|_| ())?;
47    use base64::Engine;
48    let bytes = base64::engine::general_purpose::STANDARD
49        .decode(&secret)
50        .map_err(|_| ())?;
51    if bytes.len() != 32 {
52        return Err(());
53    }
54    let mut key = [0u8; 32];
55    key.copy_from_slice(&bytes);
56    Ok(key)
57}
58
59fn generate_and_store_key() -> std::result::Result<[u8; 32], ()> {
60    use rand::Rng;
61    let mut key = [0u8; 32];
62    rand::rng().fill(&mut key);
63
64    use base64::Engine;
65    let encoded = base64::engine::general_purpose::STANDARD.encode(key);
66
67    let entry = keyring::Entry::new(KEYRING_SERVICE, KEYRING_USER).map_err(|_| ())?;
68    entry.set_password(&encoded).map_err(|_| ())?;
69
70    Ok(key)
71}
72
73/// Derive a deterministic key from machine-specific data when keychain is unavailable.
74fn derive_fallback_key() -> [u8; 32] {
75    use sha2::Digest;
76
77    let mut hasher = sha2::Sha256::new();
78    hasher.update(b"audex-fallback-key-v1");
79
80    // Mix in username
81    if let Ok(user) = std::env::var("USER").or_else(|_| std::env::var("LOGNAME")) {
82        hasher.update(user.as_bytes());
83    }
84
85    // Mix in home directory
86    if let Some(home) = dirs::home_dir() {
87        hasher.update(home.to_string_lossy().as_bytes());
88    }
89
90    // Mix in machine-id if available (Linux)
91    if let Ok(machine_id) = std::fs::read_to_string("/etc/machine-id") {
92        hasher.update(machine_id.trim().as_bytes());
93    }
94
95    hasher.finalize().into()
96}
97
98/// Encrypt data using AES-256-GCM with a key from the OS keychain.
99pub fn encrypt(plaintext: &[u8]) -> Result<EncryptedBlob> {
100    let key = get_or_create_key();
101    let cipher = Aes256Gcm::new_from_slice(&key)
102        .map_err(|e| AvError::InvalidPolicy(format!("Encryption init failed: {}", e)))?;
103
104    use rand::Rng;
105    let mut nonce_bytes = [0u8; NONCE_LEN];
106    rand::rng().fill(&mut nonce_bytes);
107    let nonce = Nonce::from_slice(&nonce_bytes);
108
109    let ciphertext = cipher
110        .encrypt(nonce, plaintext)
111        .map_err(|e| AvError::InvalidPolicy(format!("Encryption failed: {}", e)))?;
112
113    use base64::Engine;
114    Ok(EncryptedBlob {
115        nonce: base64::engine::general_purpose::STANDARD.encode(nonce_bytes),
116        ciphertext: base64::engine::general_purpose::STANDARD.encode(ciphertext),
117    })
118}
119
120/// Decrypt data using AES-256-GCM with a key from the OS keychain.
121pub fn decrypt(blob: &EncryptedBlob) -> Result<Vec<u8>> {
122    let key = get_or_create_key();
123    let cipher = Aes256Gcm::new_from_slice(&key)
124        .map_err(|e| AvError::InvalidPolicy(format!("Decryption init failed: {}", e)))?;
125
126    use base64::Engine;
127    let nonce_bytes = base64::engine::general_purpose::STANDARD
128        .decode(&blob.nonce)
129        .map_err(|e| AvError::InvalidPolicy(format!("Invalid nonce: {}", e)))?;
130    let ciphertext = base64::engine::general_purpose::STANDARD
131        .decode(&blob.ciphertext)
132        .map_err(|e| AvError::InvalidPolicy(format!("Invalid ciphertext: {}", e)))?;
133
134    if nonce_bytes.len() != NONCE_LEN {
135        return Err(AvError::InvalidPolicy("Invalid nonce length".to_string()));
136    }
137
138    let nonce = Nonce::from_slice(&nonce_bytes);
139    cipher
140        .decrypt(nonce, ciphertext.as_ref())
141        .map_err(|e| AvError::InvalidPolicy(format!("Decryption failed: {}", e)))
142}
143
144/// Encrypt a serializable value and write to a file.
145pub fn encrypt_to_file<T: Serialize>(path: &std::path::Path, value: &T) -> Result<()> {
146    let json = serde_json::to_vec(value)?;
147    let blob = encrypt(&json)?;
148    let blob_json = serde_json::to_string(&blob)?;
149    std::fs::write(path, blob_json)?;
150    Ok(())
151}
152
153/// Read and decrypt a value from a file.
154pub fn decrypt_from_file<T: for<'de> Deserialize<'de>>(path: &std::path::Path) -> Result<Option<T>> {
155    if !path.exists() {
156        return Ok(None);
157    }
158    let blob_json = std::fs::read_to_string(path)?;
159
160    // Try to decrypt as encrypted blob
161    if let Ok(blob) = serde_json::from_str::<EncryptedBlob>(&blob_json) {
162        if !blob.nonce.is_empty() && !blob.ciphertext.is_empty() {
163            let plaintext = decrypt(&blob)?;
164            let value: T = serde_json::from_slice(&plaintext)?;
165            return Ok(Some(value));
166        }
167    }
168
169    // Fallback: try reading as plaintext JSON (migration from unencrypted)
170    match serde_json::from_str::<T>(&blob_json) {
171        Ok(value) => Ok(Some(value)),
172        Err(e) => Err(AvError::InvalidPolicy(format!("Failed to read cached data: {}", e))),
173    }
174}
175
176#[cfg(test)]
177mod tests {
178    use super::*;
179
180    #[test]
181    fn test_encrypt_decrypt_roundtrip() {
182        let plaintext = b"secret credential data here";
183        let blob = encrypt(plaintext).unwrap();
184
185        // Verify it's not plaintext
186        assert_ne!(blob.ciphertext.as_bytes(), plaintext);
187
188        let decrypted = decrypt(&blob).unwrap();
189        assert_eq!(decrypted, plaintext);
190    }
191
192    #[test]
193    fn test_encrypt_decrypt_json_value() {
194        use serde_json::json;
195        let value = json!({
196            "access_key_id": "AKIAIOSFODNN7EXAMPLE",
197            "secret_access_key": "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY",
198            "session_token": "FwoGZXIvYXdzEBY..."
199        });
200
201        let json_bytes = serde_json::to_vec(&value).unwrap();
202        let blob = encrypt(&json_bytes).unwrap();
203        let decrypted = decrypt(&blob).unwrap();
204        let parsed: serde_json::Value = serde_json::from_slice(&decrypted).unwrap();
205
206        assert_eq!(parsed["access_key_id"], "AKIAIOSFODNN7EXAMPLE");
207    }
208
209    #[test]
210    fn test_different_encryptions_differ() {
211        let plaintext = b"same data";
212        let blob1 = encrypt(plaintext).unwrap();
213        let blob2 = encrypt(plaintext).unwrap();
214        // Different nonces should produce different ciphertexts
215        assert_ne!(blob1.ciphertext, blob2.ciphertext);
216        assert_ne!(blob1.nonce, blob2.nonce);
217    }
218
219    #[test]
220    fn test_tampered_ciphertext_fails() {
221        let blob = encrypt(b"secret").unwrap();
222        let mut tampered = blob;
223        // Flip a character in the ciphertext
224        let mut chars: Vec<char> = tampered.ciphertext.chars().collect();
225        if let Some(c) = chars.get_mut(5) {
226            *c = if *c == 'A' { 'B' } else { 'A' };
227        }
228        tampered.ciphertext = chars.into_iter().collect();
229        assert!(decrypt(&tampered).is_err());
230    }
231
232    #[test]
233    fn test_encrypted_blob_serialization() {
234        let blob = encrypt(b"test data").unwrap();
235        let json = serde_json::to_string(&blob).unwrap();
236        let parsed: EncryptedBlob = serde_json::from_str(&json).unwrap();
237        assert_eq!(parsed.nonce, blob.nonce);
238        assert_eq!(parsed.ciphertext, blob.ciphertext);
239
240        let decrypted = decrypt(&parsed).unwrap();
241        assert_eq!(decrypted, b"test data");
242    }
243
244    #[test]
245    fn test_fallback_key_deterministic() {
246        let key1 = derive_fallback_key();
247        let key2 = derive_fallback_key();
248        assert_eq!(key1, key2);
249    }
250}