1use crate::auth::{AccessToken, Credentials};
2use crate::error::{WebullError, WebullResult};
3use serde::{Deserialize, Serialize};
4use std::path::Path;
5use std::sync::Mutex;
6
7pub trait CredentialStore: Send + Sync {
9 fn get_credentials(&self) -> WebullResult<Option<Credentials>>;
11
12 fn store_credentials(&self, credentials: Credentials) -> WebullResult<()>;
14
15 fn clear_credentials(&self) -> WebullResult<()>;
17
18 fn get_token(&self) -> WebullResult<Option<AccessToken>>;
20
21 fn store_token(&self, token: AccessToken) -> WebullResult<()>;
23
24 fn clear_token(&self) -> WebullResult<()>;
26}
27
28#[derive(Debug, Default)]
30pub struct MemoryCredentialStore {
31 credentials: Mutex<Option<Credentials>>,
33
34 token: Mutex<Option<AccessToken>>,
36}
37
38impl CredentialStore for MemoryCredentialStore {
39 fn get_credentials(&self) -> WebullResult<Option<Credentials>> {
40 Ok(self.credentials.lock().unwrap().clone())
41 }
42
43 fn store_credentials(&self, credentials: Credentials) -> WebullResult<()> {
44 *self.credentials.lock().unwrap() = Some(credentials);
45 Ok(())
46 }
47
48 fn clear_credentials(&self) -> WebullResult<()> {
49 *self.credentials.lock().unwrap() = None;
50 Ok(())
51 }
52
53 fn get_token(&self) -> WebullResult<Option<AccessToken>> {
54 Ok(self.token.lock().unwrap().clone())
55 }
56
57 fn store_token(&self, token: AccessToken) -> WebullResult<()> {
58 *self.token.lock().unwrap() = Some(token);
59 Ok(())
60 }
61
62 fn clear_token(&self) -> WebullResult<()> {
63 *self.token.lock().unwrap() = None;
64 Ok(())
65 }
66}
67
68pub struct EncryptedCredentialStore {
70 credentials_path: String,
72
73 token_path: String,
75
76 encryption_key: String,
78
79 memory_store: MemoryCredentialStore,
81}
82
83#[derive(Debug, Clone, Serialize, Deserialize)]
85struct StoredCredentials {
86 encrypted_username: String,
88
89 encrypted_password: String,
91
92 iv: String,
94
95 salt: String,
97}
98
99#[derive(Debug, Clone, Serialize, Deserialize)]
101struct StoredToken {
102 encrypted_token: String,
104
105 encrypted_refresh_token: Option<String>,
107
108 expires_at: i64,
110
111 iv: String,
113
114 salt: String,
116}
117
118impl EncryptedCredentialStore {
119 pub fn new(credentials_path: String, token_path: String, encryption_key: String) -> Self {
121 Self {
122 credentials_path,
123 token_path,
124 encryption_key,
125 memory_store: MemoryCredentialStore::default(),
126 }
127 }
128
129 fn encrypt(&self, data: &str) -> WebullResult<(String, String, String)> {
131 let salt = self.generate_random_string(16);
133 let iv = self.generate_random_string(16);
134
135 let key = self.derive_key(&self.encryption_key, &salt)?;
137
138 let encrypted = self.encrypt_with_key(data, &key, &iv)?;
140
141 Ok((encrypted, iv, salt))
142 }
143
144 fn decrypt(&self, encrypted: &str, iv: &str, salt: &str) -> WebullResult<String> {
146 let key = self.derive_key(&self.encryption_key, salt)?;
148
149 self.decrypt_with_key(encrypted, &key, iv)
151 }
152
153 fn generate_random_string(&self, length: usize) -> String {
155 use rand::{thread_rng, Rng};
156 use rand::distributions::Alphanumeric;
157
158 thread_rng()
159 .sample_iter(&Alphanumeric)
160 .take(length)
161 .map(char::from)
162 .collect()
163 }
164
165 fn derive_key(&self, password: &str, salt: &str) -> WebullResult<Vec<u8>> {
167 let mut key = Vec::with_capacity(32);
172 let password_bytes = password.as_bytes();
173 let salt_bytes = salt.as_bytes();
174
175 for i in 0..32 {
176 let byte = password_bytes[i % password_bytes.len()] ^ salt_bytes[i % salt_bytes.len()];
177 key.push(byte);
178 }
179
180 Ok(key)
181 }
182
183 fn encrypt_with_key(&self, data: &str, _key: &[u8], _iv: &str) -> WebullResult<String> {
185 let encoded = base64::encode(data);
190 Ok(encoded)
191 }
192
193 fn decrypt_with_key(&self, encrypted: &str, _key: &[u8], _iv: &str) -> WebullResult<String> {
195 let decoded = base64::decode(encrypted)
200 .map_err(|e| WebullError::InvalidRequest(format!("Invalid data: {}", e)))?;
201
202 let decrypted = String::from_utf8(decoded)
203 .map_err(|e| WebullError::InvalidRequest(format!("Invalid UTF-8: {}", e)))?;
204
205 Ok(decrypted)
206 }
207
208 fn load_credentials(&self) -> WebullResult<Option<Credentials>> {
210 let path = Path::new(&self.credentials_path);
212 if !path.exists() {
213 return Ok(None);
214 }
215
216 let contents = std::fs::read_to_string(path)
218 .map_err(|e| WebullError::InvalidRequest(format!("Failed to read credentials file: {}", e)))?;
219
220 let stored: StoredCredentials = serde_json::from_str(&contents)
222 .map_err(|e| WebullError::SerializationError(e))?;
223
224 let username = self.decrypt(&stored.encrypted_username, &stored.iv, &stored.salt)?;
226 let password = self.decrypt(&stored.encrypted_password, &stored.iv, &stored.salt)?;
227
228 Ok(Some(Credentials {
229 username,
230 password,
231 }))
232 }
233
234 fn save_credentials(&self, credentials: &Credentials) -> WebullResult<()> {
236 let (encrypted_username, iv, salt) = self.encrypt(&credentials.username)?;
238 let (encrypted_password, _, _) = self.encrypt(&credentials.password)?;
239
240 let stored = StoredCredentials {
242 encrypted_username,
243 encrypted_password,
244 iv,
245 salt,
246 };
247
248 let json = serde_json::to_string(&stored)
250 .map_err(|e| WebullError::SerializationError(e))?;
251
252 std::fs::write(&self.credentials_path, json)
254 .map_err(|e| WebullError::InvalidRequest(format!("Failed to write credentials file: {}", e)))?;
255
256 Ok(())
257 }
258
259 fn load_token(&self) -> WebullResult<Option<AccessToken>> {
261 let path = Path::new(&self.token_path);
263 if !path.exists() {
264 return Ok(None);
265 }
266
267 let contents = std::fs::read_to_string(path)
269 .map_err(|e| WebullError::InvalidRequest(format!("Failed to read token file: {}", e)))?;
270
271 let stored: StoredToken = serde_json::from_str(&contents)
273 .map_err(|e| WebullError::SerializationError(e))?;
274
275 let token = self.decrypt(&stored.encrypted_token, &stored.iv, &stored.salt)?;
277
278 let refresh_token = if let Some(encrypted_refresh_token) = stored.encrypted_refresh_token {
280 Some(self.decrypt(&encrypted_refresh_token, &stored.iv, &stored.salt)?)
281 } else {
282 None
283 };
284
285 let expires_at = chrono::DateTime::from_timestamp(stored.expires_at, 0)
287 .ok_or_else(|| WebullError::InvalidRequest("Invalid timestamp".to_string()))?;
288
289 Ok(Some(AccessToken {
290 token,
291 expires_at,
292 refresh_token,
293 }))
294 }
295
296 fn save_token(&self, token: &AccessToken) -> WebullResult<()> {
298 let (encrypted_token, iv, salt) = self.encrypt(&token.token)?;
300
301 let encrypted_refresh_token = if let Some(refresh_token) = &token.refresh_token {
303 Some(self.encrypt(refresh_token)?.0)
304 } else {
305 None
306 };
307
308 let stored = StoredToken {
310 encrypted_token,
311 encrypted_refresh_token,
312 expires_at: token.expires_at.timestamp(),
313 iv,
314 salt,
315 };
316
317 let json = serde_json::to_string(&stored)
319 .map_err(|e| WebullError::SerializationError(e))?;
320
321 std::fs::write(&self.token_path, json)
323 .map_err(|e| WebullError::InvalidRequest(format!("Failed to write token file: {}", e)))?;
324
325 Ok(())
326 }
327}
328
329impl CredentialStore for EncryptedCredentialStore {
330 fn get_credentials(&self) -> WebullResult<Option<Credentials>> {
331 if let Some(credentials) = self.memory_store.get_credentials()? {
333 return Ok(Some(credentials));
334 }
335
336 let credentials = self.load_credentials()?;
338
339 if let Some(credentials) = &credentials {
341 self.memory_store.store_credentials(credentials.clone())?;
342 }
343
344 Ok(credentials)
345 }
346
347 fn store_credentials(&self, credentials: Credentials) -> WebullResult<()> {
348 self.memory_store.store_credentials(credentials.clone())?;
350
351 self.save_credentials(&credentials)?;
353
354 Ok(())
355 }
356
357 fn clear_credentials(&self) -> WebullResult<()> {
358 self.memory_store.clear_credentials()?;
360
361 let path = Path::new(&self.credentials_path);
363 if path.exists() {
364 std::fs::remove_file(path)
365 .map_err(|e| WebullError::InvalidRequest(format!("Failed to remove credentials file: {}", e)))?;
366 }
367
368 Ok(())
369 }
370
371 fn get_token(&self) -> WebullResult<Option<AccessToken>> {
372 if let Some(token) = self.memory_store.get_token()? {
374 return Ok(Some(token));
375 }
376
377 let token = self.load_token()?;
379
380 if let Some(token) = &token {
382 self.memory_store.store_token(token.clone())?;
383 }
384
385 Ok(token)
386 }
387
388 fn store_token(&self, token: AccessToken) -> WebullResult<()> {
389 self.memory_store.store_token(token.clone())?;
391
392 self.save_token(&token)?;
394
395 Ok(())
396 }
397
398 fn clear_token(&self) -> WebullResult<()> {
399 self.memory_store.clear_token()?;
401
402 let path = Path::new(&self.token_path);
404 if path.exists() {
405 std::fs::remove_file(path)
406 .map_err(|e| WebullError::InvalidRequest(format!("Failed to remove token file: {}", e)))?;
407 }
408
409 Ok(())
410 }
411}