tryaudex_core/
keystore.rs1use aes_gcm::{
2 aead::{Aead, KeyInit},
3 Aes256Gcm, Nonce,
4};
5use serde::{Deserialize, Serialize};
6
7use crate::error::{AvError, Result};
8
9const KEYRING_SERVICE: &str = "audex";
10const KEYRING_USER: &str = "credential-encryption-key";
11const NONCE_LEN: usize = 12;
12
13static CACHED_KEY: std::sync::OnceLock<[u8; 32]> = std::sync::OnceLock::new();
14
15#[derive(Debug, Serialize, Deserialize)]
17pub struct EncryptedBlob {
18 pub nonce: String,
20 pub ciphertext: String,
22}
23
24fn get_or_create_key() -> [u8; 32] {
28 *CACHED_KEY.get_or_init(|| {
29 if let Ok(key) = load_key_from_keychain() {
31 return key;
32 }
33
34 if let Ok(key) = generate_and_store_key() {
36 return key;
37 }
38
39 derive_fallback_key()
41 })
42}
43
44fn load_key_from_keychain() -> std::result::Result<[u8; 32], ()> {
45 let entry = keyring::Entry::new(KEYRING_SERVICE, KEYRING_USER).map_err(|_| ())?;
46 let secret = entry.get_password().map_err(|_| ())?;
47 use base64::Engine;
48 let bytes = base64::engine::general_purpose::STANDARD
49 .decode(&secret)
50 .map_err(|_| ())?;
51 if bytes.len() != 32 {
52 return Err(());
53 }
54 let mut key = [0u8; 32];
55 key.copy_from_slice(&bytes);
56 Ok(key)
57}
58
59fn generate_and_store_key() -> std::result::Result<[u8; 32], ()> {
60 use rand::Rng;
61 let mut key = [0u8; 32];
62 rand::rng().fill(&mut key);
63
64 use base64::Engine;
65 let encoded = base64::engine::general_purpose::STANDARD.encode(key);
66
67 let entry = keyring::Entry::new(KEYRING_SERVICE, KEYRING_USER).map_err(|_| ())?;
68 entry.set_password(&encoded).map_err(|_| ())?;
69
70 Ok(key)
71}
72
73fn derive_fallback_key() -> [u8; 32] {
75 use sha2::Digest;
76
77 let mut hasher = sha2::Sha256::new();
78 hasher.update(b"audex-fallback-key-v1");
79
80 if let Ok(user) = std::env::var("USER").or_else(|_| std::env::var("LOGNAME")) {
82 hasher.update(user.as_bytes());
83 }
84
85 if let Some(home) = dirs::home_dir() {
87 hasher.update(home.to_string_lossy().as_bytes());
88 }
89
90 if let Ok(machine_id) = std::fs::read_to_string("/etc/machine-id") {
92 hasher.update(machine_id.trim().as_bytes());
93 }
94
95 hasher.finalize().into()
96}
97
98pub fn encrypt(plaintext: &[u8]) -> Result<EncryptedBlob> {
100 let key = get_or_create_key();
101 let cipher = Aes256Gcm::new_from_slice(&key)
102 .map_err(|e| AvError::InvalidPolicy(format!("Encryption init failed: {}", e)))?;
103
104 use rand::Rng;
105 let mut nonce_bytes = [0u8; NONCE_LEN];
106 rand::rng().fill(&mut nonce_bytes);
107 let nonce = Nonce::from_slice(&nonce_bytes);
108
109 let ciphertext = cipher
110 .encrypt(nonce, plaintext)
111 .map_err(|e| AvError::InvalidPolicy(format!("Encryption failed: {}", e)))?;
112
113 use base64::Engine;
114 Ok(EncryptedBlob {
115 nonce: base64::engine::general_purpose::STANDARD.encode(nonce_bytes),
116 ciphertext: base64::engine::general_purpose::STANDARD.encode(ciphertext),
117 })
118}
119
120pub fn decrypt(blob: &EncryptedBlob) -> Result<Vec<u8>> {
122 let key = get_or_create_key();
123 let cipher = Aes256Gcm::new_from_slice(&key)
124 .map_err(|e| AvError::InvalidPolicy(format!("Decryption init failed: {}", e)))?;
125
126 use base64::Engine;
127 let nonce_bytes = base64::engine::general_purpose::STANDARD
128 .decode(&blob.nonce)
129 .map_err(|e| AvError::InvalidPolicy(format!("Invalid nonce: {}", e)))?;
130 let ciphertext = base64::engine::general_purpose::STANDARD
131 .decode(&blob.ciphertext)
132 .map_err(|e| AvError::InvalidPolicy(format!("Invalid ciphertext: {}", e)))?;
133
134 if nonce_bytes.len() != NONCE_LEN {
135 return Err(AvError::InvalidPolicy("Invalid nonce length".to_string()));
136 }
137
138 let nonce = Nonce::from_slice(&nonce_bytes);
139 cipher
140 .decrypt(nonce, ciphertext.as_ref())
141 .map_err(|e| AvError::InvalidPolicy(format!("Decryption failed: {}", e)))
142}
143
144pub fn encrypt_to_file<T: Serialize>(path: &std::path::Path, value: &T) -> Result<()> {
146 let json = serde_json::to_vec(value)?;
147 let blob = encrypt(&json)?;
148 let blob_json = serde_json::to_string(&blob)?;
149 std::fs::write(path, blob_json)?;
150 Ok(())
151}
152
153pub fn decrypt_from_file<T: for<'de> Deserialize<'de>>(
155 path: &std::path::Path,
156) -> Result<Option<T>> {
157 if !path.exists() {
158 return Ok(None);
159 }
160 let blob_json = std::fs::read_to_string(path)?;
161
162 if let Ok(blob) = serde_json::from_str::<EncryptedBlob>(&blob_json) {
164 if !blob.nonce.is_empty() && !blob.ciphertext.is_empty() {
165 let plaintext = decrypt(&blob)?;
166 let value: T = serde_json::from_slice(&plaintext)?;
167 return Ok(Some(value));
168 }
169 }
170
171 match serde_json::from_str::<T>(&blob_json) {
173 Ok(value) => Ok(Some(value)),
174 Err(e) => Err(AvError::InvalidPolicy(format!(
175 "Failed to read cached data: {}",
176 e
177 ))),
178 }
179}
180
181#[cfg(test)]
182mod tests {
183 use super::*;
184
185 #[test]
186 fn test_encrypt_decrypt_roundtrip() {
187 let plaintext = b"secret credential data here";
188 let blob = encrypt(plaintext).unwrap();
189
190 assert_ne!(blob.ciphertext.as_bytes(), plaintext);
192
193 let decrypted = decrypt(&blob).unwrap();
194 assert_eq!(decrypted, plaintext);
195 }
196
197 #[test]
198 fn test_encrypt_decrypt_json_value() {
199 use serde_json::json;
200 let value = json!({
201 "access_key_id": "AKIAIOSFODNN7EXAMPLE",
202 "secret_access_key": "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY",
203 "session_token": "FwoGZXIvYXdzEBY..."
204 });
205
206 let json_bytes = serde_json::to_vec(&value).unwrap();
207 let blob = encrypt(&json_bytes).unwrap();
208 let decrypted = decrypt(&blob).unwrap();
209 let parsed: serde_json::Value = serde_json::from_slice(&decrypted).unwrap();
210
211 assert_eq!(parsed["access_key_id"], "AKIAIOSFODNN7EXAMPLE");
212 }
213
214 #[test]
215 fn test_different_encryptions_differ() {
216 let plaintext = b"same data";
217 let blob1 = encrypt(plaintext).unwrap();
218 let blob2 = encrypt(plaintext).unwrap();
219 assert_ne!(blob1.ciphertext, blob2.ciphertext);
221 assert_ne!(blob1.nonce, blob2.nonce);
222 }
223
224 #[test]
225 fn test_tampered_ciphertext_fails() {
226 let blob = encrypt(b"secret").unwrap();
227 let mut tampered = blob;
228 let mut chars: Vec<char> = tampered.ciphertext.chars().collect();
230 if let Some(c) = chars.get_mut(5) {
231 *c = if *c == 'A' { 'B' } else { 'A' };
232 }
233 tampered.ciphertext = chars.into_iter().collect();
234 assert!(decrypt(&tampered).is_err());
235 }
236
237 #[test]
238 fn test_encrypted_blob_serialization() {
239 let blob = encrypt(b"test data").unwrap();
240 let json = serde_json::to_string(&blob).unwrap();
241 let parsed: EncryptedBlob = serde_json::from_str(&json).unwrap();
242 assert_eq!(parsed.nonce, blob.nonce);
243 assert_eq!(parsed.ciphertext, blob.ciphertext);
244
245 let decrypted = decrypt(&parsed).unwrap();
246 assert_eq!(decrypted, b"test data");
247 }
248
249 #[test]
250 fn test_fallback_key_deterministic() {
251 let key1 = derive_fallback_key();
252 let key2 = derive_fallback_key();
253 assert_eq!(key1, key2);
254 }
255}