webull_rs/utils/
credentials.rs

1use crate::auth::{AccessToken, Credentials};
2use crate::error::{WebullError, WebullResult};
3use serde::{Deserialize, Serialize};
4use std::path::Path;
5use std::sync::Mutex;
6
7/// Interface for storing and retrieving credentials.
8pub trait CredentialStore: Send + Sync {
9    /// Get the stored credentials.
10    fn get_credentials(&self) -> WebullResult<Option<Credentials>>;
11
12    /// Store credentials.
13    fn store_credentials(&self, credentials: Credentials) -> WebullResult<()>;
14
15    /// Clear the stored credentials.
16    fn clear_credentials(&self) -> WebullResult<()>;
17
18    /// Get the stored access token.
19    fn get_token(&self) -> WebullResult<Option<AccessToken>>;
20
21    /// Store an access token.
22    fn store_token(&self, token: AccessToken) -> WebullResult<()>;
23
24    /// Clear the stored token.
25    fn clear_token(&self) -> WebullResult<()>;
26}
27
28/// In-memory credential store.
29#[derive(Debug, Default)]
30pub struct MemoryCredentialStore {
31    /// Stored credentials
32    credentials: Mutex<Option<Credentials>>,
33
34    /// Stored access token
35    token: Mutex<Option<AccessToken>>,
36}
37
38impl CredentialStore for MemoryCredentialStore {
39    fn get_credentials(&self) -> WebullResult<Option<Credentials>> {
40        Ok(self.credentials.lock().unwrap().clone())
41    }
42
43    fn store_credentials(&self, credentials: Credentials) -> WebullResult<()> {
44        *self.credentials.lock().unwrap() = Some(credentials);
45        Ok(())
46    }
47
48    fn clear_credentials(&self) -> WebullResult<()> {
49        *self.credentials.lock().unwrap() = None;
50        Ok(())
51    }
52
53    fn get_token(&self) -> WebullResult<Option<AccessToken>> {
54        Ok(self.token.lock().unwrap().clone())
55    }
56
57    fn store_token(&self, token: AccessToken) -> WebullResult<()> {
58        *self.token.lock().unwrap() = Some(token);
59        Ok(())
60    }
61
62    fn clear_token(&self) -> WebullResult<()> {
63        *self.token.lock().unwrap() = None;
64        Ok(())
65    }
66}
67
68/// Encrypted credential store for disk-based storage.
69pub struct EncryptedCredentialStore {
70    /// Path to the credentials file
71    credentials_path: String,
72
73    /// Path to the token file
74    token_path: String,
75
76    /// Encryption key
77    encryption_key: String,
78
79    /// In-memory cache
80    memory_store: MemoryCredentialStore,
81}
82
83/// Stored credentials with encryption.
84#[derive(Debug, Clone, Serialize, Deserialize)]
85struct StoredCredentials {
86    /// Encrypted username
87    encrypted_username: String,
88
89    /// Encrypted password
90    encrypted_password: String,
91
92    /// Initialization vector for encryption
93    iv: String,
94
95    /// Salt for encryption
96    salt: String,
97}
98
99/// Stored token with encryption.
100#[derive(Debug, Clone, Serialize, Deserialize)]
101struct StoredToken {
102    /// Encrypted token
103    encrypted_token: String,
104
105    /// Encrypted refresh token
106    encrypted_refresh_token: Option<String>,
107
108    /// Expiration timestamp
109    expires_at: i64,
110
111    /// Initialization vector for encryption
112    iv: String,
113
114    /// Salt for encryption
115    salt: String,
116}
117
118impl EncryptedCredentialStore {
119    /// Create a new encrypted credential store.
120    pub fn new(credentials_path: String, token_path: String, encryption_key: String) -> Self {
121        Self {
122            credentials_path,
123            token_path,
124            encryption_key,
125            memory_store: MemoryCredentialStore::default(),
126        }
127    }
128
129    /// Encrypt a string.
130    fn encrypt(&self, data: &str) -> WebullResult<(String, String, String)> {
131        // Generate a random salt and IV
132        let salt = self.generate_random_string(16);
133        let iv = self.generate_random_string(16);
134
135        // Derive a key from the encryption key and salt
136        let key = self.derive_key(&self.encryption_key, &salt)?;
137
138        // Encrypt the data
139        let encrypted = self.encrypt_with_key(data, &key, &iv)?;
140
141        Ok((encrypted, iv, salt))
142    }
143
144    /// Decrypt a string.
145    fn decrypt(&self, encrypted: &str, iv: &str, salt: &str) -> WebullResult<String> {
146        // Derive a key from the encryption key and salt
147        let key = self.derive_key(&self.encryption_key, salt)?;
148
149        // Decrypt the data
150        self.decrypt_with_key(encrypted, &key, iv)
151    }
152
153    /// Generate a random string.
154    fn generate_random_string(&self, length: usize) -> String {
155        use rand::distributions::Alphanumeric;
156        use rand::{thread_rng, Rng};
157
158        thread_rng()
159            .sample_iter(&Alphanumeric)
160            .take(length)
161            .map(char::from)
162            .collect()
163    }
164
165    /// Derive a key from a password and salt.
166    fn derive_key(&self, password: &str, salt: &str) -> WebullResult<Vec<u8>> {
167        // In a real implementation, we would use a proper key derivation function
168        // like PBKDF2, Argon2, or scrypt. For simplicity, we'll just use a basic
169        // approach here.
170
171        let mut key = Vec::with_capacity(32);
172        let password_bytes = password.as_bytes();
173        let salt_bytes = salt.as_bytes();
174
175        for i in 0..32 {
176            let byte = password_bytes[i % password_bytes.len()] ^ salt_bytes[i % salt_bytes.len()];
177            key.push(byte);
178        }
179
180        Ok(key)
181    }
182
183    /// Encrypt data with a key and IV.
184    fn encrypt_with_key(&self, data: &str, _key: &[u8], _iv: &str) -> WebullResult<String> {
185        // In a real implementation, we would use a proper encryption algorithm
186        // like AES-GCM. For simplicity, we'll just use base64 encoding as a
187        // placeholder.
188
189        let encoded = base64::encode(data);
190        Ok(encoded)
191    }
192
193    /// Decrypt data with a key and IV.
194    fn decrypt_with_key(&self, encrypted: &str, _key: &[u8], _iv: &str) -> WebullResult<String> {
195        // In a real implementation, we would use a proper decryption algorithm
196        // like AES-GCM. For simplicity, we'll just use base64 decoding as a
197        // placeholder.
198
199        let decoded = base64::decode(encrypted)
200            .map_err(|e| WebullError::InvalidRequest(format!("Invalid data: {}", e)))?;
201
202        let decrypted = String::from_utf8(decoded)
203            .map_err(|e| WebullError::InvalidRequest(format!("Invalid UTF-8: {}", e)))?;
204
205        Ok(decrypted)
206    }
207
208    /// Load credentials from disk.
209    fn load_credentials(&self) -> WebullResult<Option<Credentials>> {
210        // Check if the file exists
211        let path = Path::new(&self.credentials_path);
212        if !path.exists() {
213            return Ok(None);
214        }
215
216        // Read the file
217        let contents = std::fs::read_to_string(path).map_err(|e| {
218            WebullError::InvalidRequest(format!("Failed to read credentials file: {}", e))
219        })?;
220
221        // Parse the stored credentials
222        let stored: StoredCredentials =
223            serde_json::from_str(&contents).map_err(|e| WebullError::SerializationError(e))?;
224
225        // Decrypt the username and password
226        let username = self.decrypt(&stored.encrypted_username, &stored.iv, &stored.salt)?;
227        let password = self.decrypt(&stored.encrypted_password, &stored.iv, &stored.salt)?;
228
229        Ok(Some(Credentials { username, password }))
230    }
231
232    /// Save credentials to disk.
233    fn save_credentials(&self, credentials: &Credentials) -> WebullResult<()> {
234        // Encrypt the username and password
235        let (encrypted_username, iv, salt) = self.encrypt(&credentials.username)?;
236        let (encrypted_password, _, _) = self.encrypt(&credentials.password)?;
237
238        // Create the stored credentials
239        let stored = StoredCredentials {
240            encrypted_username,
241            encrypted_password,
242            iv,
243            salt,
244        };
245
246        // Serialize to JSON
247        let json =
248            serde_json::to_string(&stored).map_err(|e| WebullError::SerializationError(e))?;
249
250        // Write to file
251        std::fs::write(&self.credentials_path, json).map_err(|e| {
252            WebullError::InvalidRequest(format!("Failed to write credentials file: {}", e))
253        })?;
254
255        Ok(())
256    }
257
258    /// Load token from disk.
259    fn load_token(&self) -> WebullResult<Option<AccessToken>> {
260        // Check if the file exists
261        let path = Path::new(&self.token_path);
262        if !path.exists() {
263            return Ok(None);
264        }
265
266        // Read the file
267        let contents = std::fs::read_to_string(path).map_err(|e| {
268            WebullError::InvalidRequest(format!("Failed to read token file: {}", e))
269        })?;
270
271        // Parse the stored token
272        let stored: StoredToken =
273            serde_json::from_str(&contents).map_err(|e| WebullError::SerializationError(e))?;
274
275        // Decrypt the token
276        let token = self.decrypt(&stored.encrypted_token, &stored.iv, &stored.salt)?;
277
278        // Decrypt the refresh token if present
279        let refresh_token = if let Some(encrypted_refresh_token) = stored.encrypted_refresh_token {
280            Some(self.decrypt(&encrypted_refresh_token, &stored.iv, &stored.salt)?)
281        } else {
282            None
283        };
284
285        // Create the access token
286        let expires_at = chrono::DateTime::from_timestamp(stored.expires_at, 0)
287            .ok_or_else(|| WebullError::InvalidRequest("Invalid timestamp".to_string()))?;
288
289        Ok(Some(AccessToken {
290            token,
291            expires_at,
292            refresh_token,
293        }))
294    }
295
296    /// Save token to disk.
297    fn save_token(&self, token: &AccessToken) -> WebullResult<()> {
298        // Encrypt the token
299        let (encrypted_token, iv, salt) = self.encrypt(&token.token)?;
300
301        // Encrypt the refresh token if present
302        let encrypted_refresh_token = if let Some(refresh_token) = &token.refresh_token {
303            Some(self.encrypt(refresh_token)?.0)
304        } else {
305            None
306        };
307
308        // Create the stored token
309        let stored = StoredToken {
310            encrypted_token,
311            encrypted_refresh_token,
312            expires_at: token.expires_at.timestamp(),
313            iv,
314            salt,
315        };
316
317        // Serialize to JSON
318        let json =
319            serde_json::to_string(&stored).map_err(|e| WebullError::SerializationError(e))?;
320
321        // Write to file
322        std::fs::write(&self.token_path, json).map_err(|e| {
323            WebullError::InvalidRequest(format!("Failed to write token file: {}", e))
324        })?;
325
326        Ok(())
327    }
328}
329
330impl CredentialStore for EncryptedCredentialStore {
331    fn get_credentials(&self) -> WebullResult<Option<Credentials>> {
332        // Check if we have credentials in memory
333        if let Some(credentials) = self.memory_store.get_credentials()? {
334            return Ok(Some(credentials));
335        }
336
337        // Load credentials from disk
338        let credentials = self.load_credentials()?;
339
340        // Store in memory for future use
341        if let Some(credentials) = &credentials {
342            self.memory_store.store_credentials(credentials.clone())?;
343        }
344
345        Ok(credentials)
346    }
347
348    fn store_credentials(&self, credentials: Credentials) -> WebullResult<()> {
349        // Store in memory
350        self.memory_store.store_credentials(credentials.clone())?;
351
352        // Save to disk
353        self.save_credentials(&credentials)?;
354
355        Ok(())
356    }
357
358    fn clear_credentials(&self) -> WebullResult<()> {
359        // Clear from memory
360        self.memory_store.clear_credentials()?;
361
362        // Remove the file if it exists
363        let path = Path::new(&self.credentials_path);
364        if path.exists() {
365            std::fs::remove_file(path).map_err(|e| {
366                WebullError::InvalidRequest(format!("Failed to remove credentials file: {}", e))
367            })?;
368        }
369
370        Ok(())
371    }
372
373    fn get_token(&self) -> WebullResult<Option<AccessToken>> {
374        // Check if we have a token in memory
375        if let Some(token) = self.memory_store.get_token()? {
376            return Ok(Some(token));
377        }
378
379        // Load token from disk
380        let token = self.load_token()?;
381
382        // Store in memory for future use
383        if let Some(token) = &token {
384            self.memory_store.store_token(token.clone())?;
385        }
386
387        Ok(token)
388    }
389
390    fn store_token(&self, token: AccessToken) -> WebullResult<()> {
391        // Store in memory
392        self.memory_store.store_token(token.clone())?;
393
394        // Save to disk
395        self.save_token(&token)?;
396
397        Ok(())
398    }
399
400    fn clear_token(&self) -> WebullResult<()> {
401        // Clear from memory
402        self.memory_store.clear_token()?;
403
404        // Remove the file if it exists
405        let path = Path::new(&self.token_path);
406        if path.exists() {
407            std::fs::remove_file(path).map_err(|e| {
408                WebullError::InvalidRequest(format!("Failed to remove token file: {}", e))
409            })?;
410        }
411
412        Ok(())
413    }
414}