tryaudex_core/
keystore.rs1use aes_gcm::{
2 aead::{Aead, KeyInit},
3 Aes256Gcm, Nonce,
4};
5use serde::{Deserialize, Serialize};
6
7use crate::error::{AvError, Result};
8
9const KEYRING_SERVICE: &str = "audex";
10const KEYRING_USER: &str = "credential-encryption-key";
11const NONCE_LEN: usize = 12;
12
13static CACHED_KEY: std::sync::OnceLock<[u8; 32]> = std::sync::OnceLock::new();
14
15#[derive(Debug, Serialize, Deserialize)]
17pub struct EncryptedBlob {
18 pub nonce: String,
20 pub ciphertext: String,
22}
23
24fn get_or_create_key() -> [u8; 32] {
28 *CACHED_KEY.get_or_init(|| {
29 if let Ok(key) = load_key_from_keychain() {
31 return key;
32 }
33
34 if let Ok(key) = generate_and_store_key() {
36 return key;
37 }
38
39 derive_fallback_key()
41 })
42}
43
44fn load_key_from_keychain() -> std::result::Result<[u8; 32], ()> {
45 let entry = keyring::Entry::new(KEYRING_SERVICE, KEYRING_USER).map_err(|_| ())?;
46 let secret = entry.get_password().map_err(|_| ())?;
47 use base64::Engine;
48 let bytes = base64::engine::general_purpose::STANDARD
49 .decode(&secret)
50 .map_err(|_| ())?;
51 if bytes.len() != 32 {
52 return Err(());
53 }
54 let mut key = [0u8; 32];
55 key.copy_from_slice(&bytes);
56 Ok(key)
57}
58
59fn generate_and_store_key() -> std::result::Result<[u8; 32], ()> {
60 use rand::Rng;
61 let mut key = [0u8; 32];
62 rand::rng().fill(&mut key);
63
64 use base64::Engine;
65 let encoded = base64::engine::general_purpose::STANDARD.encode(key);
66
67 let entry = keyring::Entry::new(KEYRING_SERVICE, KEYRING_USER).map_err(|_| ())?;
68 entry.set_password(&encoded).map_err(|_| ())?;
69
70 Ok(key)
71}
72
73fn derive_fallback_key() -> [u8; 32] {
75 use sha2::Digest;
76
77 let mut hasher = sha2::Sha256::new();
78 hasher.update(b"audex-fallback-key-v1");
79
80 if let Ok(user) = std::env::var("USER").or_else(|_| std::env::var("LOGNAME")) {
82 hasher.update(user.as_bytes());
83 }
84
85 if let Some(home) = dirs::home_dir() {
87 hasher.update(home.to_string_lossy().as_bytes());
88 }
89
90 if let Ok(machine_id) = std::fs::read_to_string("/etc/machine-id") {
92 hasher.update(machine_id.trim().as_bytes());
93 }
94
95 hasher.finalize().into()
96}
97
98pub fn encrypt(plaintext: &[u8]) -> Result<EncryptedBlob> {
100 let key = get_or_create_key();
101 let cipher = Aes256Gcm::new_from_slice(&key)
102 .map_err(|e| AvError::InvalidPolicy(format!("Encryption init failed: {}", e)))?;
103
104 use rand::Rng;
105 let mut nonce_bytes = [0u8; NONCE_LEN];
106 rand::rng().fill(&mut nonce_bytes);
107 let nonce = Nonce::from_slice(&nonce_bytes);
108
109 let ciphertext = cipher
110 .encrypt(nonce, plaintext)
111 .map_err(|e| AvError::InvalidPolicy(format!("Encryption failed: {}", e)))?;
112
113 use base64::Engine;
114 Ok(EncryptedBlob {
115 nonce: base64::engine::general_purpose::STANDARD.encode(nonce_bytes),
116 ciphertext: base64::engine::general_purpose::STANDARD.encode(ciphertext),
117 })
118}
119
120pub fn decrypt(blob: &EncryptedBlob) -> Result<Vec<u8>> {
122 let key = get_or_create_key();
123 let cipher = Aes256Gcm::new_from_slice(&key)
124 .map_err(|e| AvError::InvalidPolicy(format!("Decryption init failed: {}", e)))?;
125
126 use base64::Engine;
127 let nonce_bytes = base64::engine::general_purpose::STANDARD
128 .decode(&blob.nonce)
129 .map_err(|e| AvError::InvalidPolicy(format!("Invalid nonce: {}", e)))?;
130 let ciphertext = base64::engine::general_purpose::STANDARD
131 .decode(&blob.ciphertext)
132 .map_err(|e| AvError::InvalidPolicy(format!("Invalid ciphertext: {}", e)))?;
133
134 if nonce_bytes.len() != NONCE_LEN {
135 return Err(AvError::InvalidPolicy("Invalid nonce length".to_string()));
136 }
137
138 let nonce = Nonce::from_slice(&nonce_bytes);
139 cipher
140 .decrypt(nonce, ciphertext.as_ref())
141 .map_err(|e| AvError::InvalidPolicy(format!("Decryption failed: {}", e)))
142}
143
144pub fn encrypt_to_file<T: Serialize>(path: &std::path::Path, value: &T) -> Result<()> {
146 let json = serde_json::to_vec(value)?;
147 let blob = encrypt(&json)?;
148 let blob_json = serde_json::to_string(&blob)?;
149 std::fs::write(path, blob_json)?;
150 Ok(())
151}
152
153pub fn decrypt_from_file<T: for<'de> Deserialize<'de>>(path: &std::path::Path) -> Result<Option<T>> {
155 if !path.exists() {
156 return Ok(None);
157 }
158 let blob_json = std::fs::read_to_string(path)?;
159
160 if let Ok(blob) = serde_json::from_str::<EncryptedBlob>(&blob_json) {
162 if !blob.nonce.is_empty() && !blob.ciphertext.is_empty() {
163 let plaintext = decrypt(&blob)?;
164 let value: T = serde_json::from_slice(&plaintext)?;
165 return Ok(Some(value));
166 }
167 }
168
169 match serde_json::from_str::<T>(&blob_json) {
171 Ok(value) => Ok(Some(value)),
172 Err(e) => Err(AvError::InvalidPolicy(format!("Failed to read cached data: {}", e))),
173 }
174}
175
176#[cfg(test)]
177mod tests {
178 use super::*;
179
180 #[test]
181 fn test_encrypt_decrypt_roundtrip() {
182 let plaintext = b"secret credential data here";
183 let blob = encrypt(plaintext).unwrap();
184
185 assert_ne!(blob.ciphertext.as_bytes(), plaintext);
187
188 let decrypted = decrypt(&blob).unwrap();
189 assert_eq!(decrypted, plaintext);
190 }
191
192 #[test]
193 fn test_encrypt_decrypt_json_value() {
194 use serde_json::json;
195 let value = json!({
196 "access_key_id": "AKIAIOSFODNN7EXAMPLE",
197 "secret_access_key": "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY",
198 "session_token": "FwoGZXIvYXdzEBY..."
199 });
200
201 let json_bytes = serde_json::to_vec(&value).unwrap();
202 let blob = encrypt(&json_bytes).unwrap();
203 let decrypted = decrypt(&blob).unwrap();
204 let parsed: serde_json::Value = serde_json::from_slice(&decrypted).unwrap();
205
206 assert_eq!(parsed["access_key_id"], "AKIAIOSFODNN7EXAMPLE");
207 }
208
209 #[test]
210 fn test_different_encryptions_differ() {
211 let plaintext = b"same data";
212 let blob1 = encrypt(plaintext).unwrap();
213 let blob2 = encrypt(plaintext).unwrap();
214 assert_ne!(blob1.ciphertext, blob2.ciphertext);
216 assert_ne!(blob1.nonce, blob2.nonce);
217 }
218
219 #[test]
220 fn test_tampered_ciphertext_fails() {
221 let blob = encrypt(b"secret").unwrap();
222 let mut tampered = blob;
223 let mut chars: Vec<char> = tampered.ciphertext.chars().collect();
225 if let Some(c) = chars.get_mut(5) {
226 *c = if *c == 'A' { 'B' } else { 'A' };
227 }
228 tampered.ciphertext = chars.into_iter().collect();
229 assert!(decrypt(&tampered).is_err());
230 }
231
232 #[test]
233 fn test_encrypted_blob_serialization() {
234 let blob = encrypt(b"test data").unwrap();
235 let json = serde_json::to_string(&blob).unwrap();
236 let parsed: EncryptedBlob = serde_json::from_str(&json).unwrap();
237 assert_eq!(parsed.nonce, blob.nonce);
238 assert_eq!(parsed.ciphertext, blob.ciphertext);
239
240 let decrypted = decrypt(&parsed).unwrap();
241 assert_eq!(decrypted, b"test data");
242 }
243
244 #[test]
245 fn test_fallback_key_deterministic() {
246 let key1 = derive_fallback_key();
247 let key2 = derive_fallback_key();
248 assert_eq!(key1, key2);
249 }
250}