webull_rs/utils/
credentials.rs

1use crate::auth::{AccessToken, Credentials};
2use crate::error::{WebullError, WebullResult};
3use serde::{Deserialize, Serialize};
4use std::path::Path;
5use std::sync::Mutex;
6
7/// Interface for storing and retrieving credentials.
8pub trait CredentialStore: Send + Sync {
9    /// Get the stored credentials.
10    fn get_credentials(&self) -> WebullResult<Option<Credentials>>;
11
12    /// Store credentials.
13    fn store_credentials(&self, credentials: Credentials) -> WebullResult<()>;
14
15    /// Clear the stored credentials.
16    fn clear_credentials(&self) -> WebullResult<()>;
17
18    /// Get the stored access token.
19    fn get_token(&self) -> WebullResult<Option<AccessToken>>;
20
21    /// Store an access token.
22    fn store_token(&self, token: AccessToken) -> WebullResult<()>;
23
24    /// Clear the stored token.
25    fn clear_token(&self) -> WebullResult<()>;
26}
27
28/// In-memory credential store.
29#[derive(Debug, Default)]
30pub struct MemoryCredentialStore {
31    /// Stored credentials
32    credentials: Mutex<Option<Credentials>>,
33
34    /// Stored access token
35    token: Mutex<Option<AccessToken>>,
36}
37
38impl CredentialStore for MemoryCredentialStore {
39    fn get_credentials(&self) -> WebullResult<Option<Credentials>> {
40        Ok(self.credentials.lock().unwrap().clone())
41    }
42
43    fn store_credentials(&self, credentials: Credentials) -> WebullResult<()> {
44        *self.credentials.lock().unwrap() = Some(credentials);
45        Ok(())
46    }
47
48    fn clear_credentials(&self) -> WebullResult<()> {
49        *self.credentials.lock().unwrap() = None;
50        Ok(())
51    }
52
53    fn get_token(&self) -> WebullResult<Option<AccessToken>> {
54        Ok(self.token.lock().unwrap().clone())
55    }
56
57    fn store_token(&self, token: AccessToken) -> WebullResult<()> {
58        *self.token.lock().unwrap() = Some(token);
59        Ok(())
60    }
61
62    fn clear_token(&self) -> WebullResult<()> {
63        *self.token.lock().unwrap() = None;
64        Ok(())
65    }
66}
67
68/// Encrypted credential store for disk-based storage.
69pub struct EncryptedCredentialStore {
70    /// Path to the credentials file
71    credentials_path: String,
72
73    /// Path to the token file
74    token_path: String,
75
76    /// Encryption key
77    encryption_key: String,
78
79    /// In-memory cache
80    memory_store: MemoryCredentialStore,
81}
82
83/// Stored credentials with encryption.
84#[derive(Debug, Clone, Serialize, Deserialize)]
85struct StoredCredentials {
86    /// Encrypted username
87    encrypted_username: String,
88
89    /// Encrypted password
90    encrypted_password: String,
91
92    /// Initialization vector for encryption
93    iv: String,
94
95    /// Salt for encryption
96    salt: String,
97}
98
99/// Stored token with encryption.
100#[derive(Debug, Clone, Serialize, Deserialize)]
101struct StoredToken {
102    /// Encrypted token
103    encrypted_token: String,
104
105    /// Encrypted refresh token
106    encrypted_refresh_token: Option<String>,
107
108    /// Expiration timestamp
109    expires_at: i64,
110
111    /// Initialization vector for encryption
112    iv: String,
113
114    /// Salt for encryption
115    salt: String,
116}
117
118impl EncryptedCredentialStore {
119    /// Create a new encrypted credential store.
120    pub fn new(credentials_path: String, token_path: String, encryption_key: String) -> Self {
121        Self {
122            credentials_path,
123            token_path,
124            encryption_key,
125            memory_store: MemoryCredentialStore::default(),
126        }
127    }
128
129    /// Encrypt a string.
130    fn encrypt(&self, data: &str) -> WebullResult<(String, String, String)> {
131        // Generate a random salt and IV
132        let salt = self.generate_random_string(16);
133        let iv = self.generate_random_string(16);
134
135        // Derive a key from the encryption key and salt
136        let key = self.derive_key(&self.encryption_key, &salt)?;
137
138        // Encrypt the data
139        let encrypted = self.encrypt_with_key(data, &key, &iv)?;
140
141        Ok((encrypted, iv, salt))
142    }
143
144    /// Decrypt a string.
145    fn decrypt(&self, encrypted: &str, iv: &str, salt: &str) -> WebullResult<String> {
146        // Derive a key from the encryption key and salt
147        let key = self.derive_key(&self.encryption_key, salt)?;
148
149        // Decrypt the data
150        self.decrypt_with_key(encrypted, &key, iv)
151    }
152
153    /// Generate a random string.
154    fn generate_random_string(&self, length: usize) -> String {
155        use rand::{thread_rng, Rng};
156        use rand::distributions::Alphanumeric;
157
158        thread_rng()
159            .sample_iter(&Alphanumeric)
160            .take(length)
161            .map(char::from)
162            .collect()
163    }
164
165    /// Derive a key from a password and salt.
166    fn derive_key(&self, password: &str, salt: &str) -> WebullResult<Vec<u8>> {
167        // In a real implementation, we would use a proper key derivation function
168        // like PBKDF2, Argon2, or scrypt. For simplicity, we'll just use a basic
169        // approach here.
170
171        let mut key = Vec::with_capacity(32);
172        let password_bytes = password.as_bytes();
173        let salt_bytes = salt.as_bytes();
174
175        for i in 0..32 {
176            let byte = password_bytes[i % password_bytes.len()] ^ salt_bytes[i % salt_bytes.len()];
177            key.push(byte);
178        }
179
180        Ok(key)
181    }
182
183    /// Encrypt data with a key and IV.
184    fn encrypt_with_key(&self, data: &str, _key: &[u8], _iv: &str) -> WebullResult<String> {
185        // In a real implementation, we would use a proper encryption algorithm
186        // like AES-GCM. For simplicity, we'll just use base64 encoding as a
187        // placeholder.
188
189        let encoded = base64::encode(data);
190        Ok(encoded)
191    }
192
193    /// Decrypt data with a key and IV.
194    fn decrypt_with_key(&self, encrypted: &str, _key: &[u8], _iv: &str) -> WebullResult<String> {
195        // In a real implementation, we would use a proper decryption algorithm
196        // like AES-GCM. For simplicity, we'll just use base64 decoding as a
197        // placeholder.
198
199        let decoded = base64::decode(encrypted)
200            .map_err(|e| WebullError::InvalidRequest(format!("Invalid data: {}", e)))?;
201
202        let decrypted = String::from_utf8(decoded)
203            .map_err(|e| WebullError::InvalidRequest(format!("Invalid UTF-8: {}", e)))?;
204
205        Ok(decrypted)
206    }
207
208    /// Load credentials from disk.
209    fn load_credentials(&self) -> WebullResult<Option<Credentials>> {
210        // Check if the file exists
211        let path = Path::new(&self.credentials_path);
212        if !path.exists() {
213            return Ok(None);
214        }
215
216        // Read the file
217        let contents = std::fs::read_to_string(path)
218            .map_err(|e| WebullError::InvalidRequest(format!("Failed to read credentials file: {}", e)))?;
219
220        // Parse the stored credentials
221        let stored: StoredCredentials = serde_json::from_str(&contents)
222            .map_err(|e| WebullError::SerializationError(e))?;
223
224        // Decrypt the username and password
225        let username = self.decrypt(&stored.encrypted_username, &stored.iv, &stored.salt)?;
226        let password = self.decrypt(&stored.encrypted_password, &stored.iv, &stored.salt)?;
227
228        Ok(Some(Credentials {
229            username,
230            password,
231        }))
232    }
233
234    /// Save credentials to disk.
235    fn save_credentials(&self, credentials: &Credentials) -> WebullResult<()> {
236        // Encrypt the username and password
237        let (encrypted_username, iv, salt) = self.encrypt(&credentials.username)?;
238        let (encrypted_password, _, _) = self.encrypt(&credentials.password)?;
239
240        // Create the stored credentials
241        let stored = StoredCredentials {
242            encrypted_username,
243            encrypted_password,
244            iv,
245            salt,
246        };
247
248        // Serialize to JSON
249        let json = serde_json::to_string(&stored)
250            .map_err(|e| WebullError::SerializationError(e))?;
251
252        // Write to file
253        std::fs::write(&self.credentials_path, json)
254            .map_err(|e| WebullError::InvalidRequest(format!("Failed to write credentials file: {}", e)))?;
255
256        Ok(())
257    }
258
259    /// Load token from disk.
260    fn load_token(&self) -> WebullResult<Option<AccessToken>> {
261        // Check if the file exists
262        let path = Path::new(&self.token_path);
263        if !path.exists() {
264            return Ok(None);
265        }
266
267        // Read the file
268        let contents = std::fs::read_to_string(path)
269            .map_err(|e| WebullError::InvalidRequest(format!("Failed to read token file: {}", e)))?;
270
271        // Parse the stored token
272        let stored: StoredToken = serde_json::from_str(&contents)
273            .map_err(|e| WebullError::SerializationError(e))?;
274
275        // Decrypt the token
276        let token = self.decrypt(&stored.encrypted_token, &stored.iv, &stored.salt)?;
277
278        // Decrypt the refresh token if present
279        let refresh_token = if let Some(encrypted_refresh_token) = stored.encrypted_refresh_token {
280            Some(self.decrypt(&encrypted_refresh_token, &stored.iv, &stored.salt)?)
281        } else {
282            None
283        };
284
285        // Create the access token
286        let expires_at = chrono::DateTime::from_timestamp(stored.expires_at, 0)
287            .ok_or_else(|| WebullError::InvalidRequest("Invalid timestamp".to_string()))?;
288
289        Ok(Some(AccessToken {
290            token,
291            expires_at,
292            refresh_token,
293        }))
294    }
295
296    /// Save token to disk.
297    fn save_token(&self, token: &AccessToken) -> WebullResult<()> {
298        // Encrypt the token
299        let (encrypted_token, iv, salt) = self.encrypt(&token.token)?;
300
301        // Encrypt the refresh token if present
302        let encrypted_refresh_token = if let Some(refresh_token) = &token.refresh_token {
303            Some(self.encrypt(refresh_token)?.0)
304        } else {
305            None
306        };
307
308        // Create the stored token
309        let stored = StoredToken {
310            encrypted_token,
311            encrypted_refresh_token,
312            expires_at: token.expires_at.timestamp(),
313            iv,
314            salt,
315        };
316
317        // Serialize to JSON
318        let json = serde_json::to_string(&stored)
319            .map_err(|e| WebullError::SerializationError(e))?;
320
321        // Write to file
322        std::fs::write(&self.token_path, json)
323            .map_err(|e| WebullError::InvalidRequest(format!("Failed to write token file: {}", e)))?;
324
325        Ok(())
326    }
327}
328
329impl CredentialStore for EncryptedCredentialStore {
330    fn get_credentials(&self) -> WebullResult<Option<Credentials>> {
331        // Check if we have credentials in memory
332        if let Some(credentials) = self.memory_store.get_credentials()? {
333            return Ok(Some(credentials));
334        }
335
336        // Load credentials from disk
337        let credentials = self.load_credentials()?;
338
339        // Store in memory for future use
340        if let Some(credentials) = &credentials {
341            self.memory_store.store_credentials(credentials.clone())?;
342        }
343
344        Ok(credentials)
345    }
346
347    fn store_credentials(&self, credentials: Credentials) -> WebullResult<()> {
348        // Store in memory
349        self.memory_store.store_credentials(credentials.clone())?;
350
351        // Save to disk
352        self.save_credentials(&credentials)?;
353
354        Ok(())
355    }
356
357    fn clear_credentials(&self) -> WebullResult<()> {
358        // Clear from memory
359        self.memory_store.clear_credentials()?;
360
361        // Remove the file if it exists
362        let path = Path::new(&self.credentials_path);
363        if path.exists() {
364            std::fs::remove_file(path)
365                .map_err(|e| WebullError::InvalidRequest(format!("Failed to remove credentials file: {}", e)))?;
366        }
367
368        Ok(())
369    }
370
371    fn get_token(&self) -> WebullResult<Option<AccessToken>> {
372        // Check if we have a token in memory
373        if let Some(token) = self.memory_store.get_token()? {
374            return Ok(Some(token));
375        }
376
377        // Load token from disk
378        let token = self.load_token()?;
379
380        // Store in memory for future use
381        if let Some(token) = &token {
382            self.memory_store.store_token(token.clone())?;
383        }
384
385        Ok(token)
386    }
387
388    fn store_token(&self, token: AccessToken) -> WebullResult<()> {
389        // Store in memory
390        self.memory_store.store_token(token.clone())?;
391
392        // Save to disk
393        self.save_token(&token)?;
394
395        Ok(())
396    }
397
398    fn clear_token(&self) -> WebullResult<()> {
399        // Clear from memory
400        self.memory_store.clear_token()?;
401
402        // Remove the file if it exists
403        let path = Path::new(&self.token_path);
404        if path.exists() {
405            std::fs::remove_file(path)
406                .map_err(|e| WebullError::InvalidRequest(format!("Failed to remove token file: {}", e)))?;
407        }
408
409        Ok(())
410    }
411}