webull_rs/
auth.rs

1use crate::config::WebullConfig;
2use crate::error::{WebullError, WebullResult};
3use crate::utils::crypto::{encrypt_password, generate_signature, generate_timestamp};
4use crate::utils::serialization::{from_json, to_json};
5use chrono::{DateTime, Utc};
6use reqwest::header::{HeaderMap, HeaderValue, AUTHORIZATION, CONTENT_TYPE};
7use serde::{Deserialize, Serialize};
8use serde_json::json;
9use std::sync::Mutex;
10
11/// Credentials for authentication.
12#[derive(Debug, Clone, Serialize, Deserialize)]
13pub struct Credentials {
14    /// Username for authentication
15    pub username: String,
16
17    /// Password for authentication
18    pub password: String,
19}
20
21/// Access token for API requests.
22#[derive(Debug, Clone, Serialize, Deserialize)]
23pub struct AccessToken {
24    /// The access token
25    pub token: String,
26
27    /// When the token expires
28    pub expires_at: DateTime<Utc>,
29
30    /// The refresh token
31    pub refresh_token: Option<String>,
32}
33
34/// Interface for storing and retrieving tokens.
35pub trait TokenStore: Send + Sync {
36    /// Get the current access token.
37    fn get_token(&self) -> WebullResult<Option<AccessToken>>;
38
39    /// Store an access token.
40    fn store_token(&self, token: AccessToken) -> WebullResult<()>;
41
42    /// Clear the stored token.
43    fn clear_token(&self) -> WebullResult<()>;
44}
45
46/// In-memory token store.
47#[derive(Debug, Default)]
48pub struct MemoryTokenStore {
49    token: Mutex<Option<AccessToken>>,
50}
51
52impl TokenStore for MemoryTokenStore {
53    fn get_token(&self) -> WebullResult<Option<AccessToken>> {
54        Ok(self.token.lock().unwrap().clone())
55    }
56
57    fn store_token(&self, token: AccessToken) -> WebullResult<()> {
58        *self.token.lock().unwrap() = Some(token);
59        Ok(())
60    }
61
62    fn clear_token(&self) -> WebullResult<()> {
63        *self.token.lock().unwrap() = None;
64        Ok(())
65    }
66}
67
68/// Manager for authentication.
69pub struct AuthManager {
70    /// Credentials for authentication
71    credentials: Option<Credentials>,
72
73    /// Token store
74    pub token_store: Box<dyn TokenStore>,
75
76    /// Configuration
77    config: WebullConfig,
78
79    /// HTTP client
80    client: reqwest::Client,
81}
82
83impl AuthManager {
84    /// Create a new authentication manager.
85    pub fn new(
86        config: WebullConfig,
87        token_store: Box<dyn TokenStore>,
88        client: reqwest::Client,
89    ) -> Self {
90        Self {
91            credentials: None,
92            token_store,
93            config,
94            client,
95        }
96    }
97
98    /// Authenticate with username and password.
99    pub async fn authenticate(
100        &mut self,
101        username: &str,
102        password: &str,
103    ) -> WebullResult<AccessToken> {
104        // Store credentials for potential token refresh
105        self.credentials = Some(Credentials {
106            username: username.to_string(),
107            password: password.to_string(),
108        });
109
110        // Encrypt the password
111        let encrypted_password = encrypt_password(
112            password,
113            &self.config.api_secret.clone().unwrap_or_default(),
114        )?;
115
116        // Create the request body
117        let body = json!({
118            "username": username,
119            "password": encrypted_password,
120            "deviceId": self.config.device_id.clone().unwrap_or_default(),
121            "deviceName": "Rust API Client",
122            "deviceType": "Web",
123        });
124
125        // Create headers
126        let mut headers = HeaderMap::new();
127        headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
128
129        // Add API key if available
130        if let Some(api_key) = &self.config.api_key {
131            headers.insert("api-key", HeaderValue::from_str(api_key).unwrap());
132        }
133
134        // Generate timestamp and signature
135        let timestamp = generate_timestamp();
136        let signature = if let Some(api_secret) = &self.config.api_secret {
137            let message = format!("{}{}", timestamp, to_json(&body)?);
138            generate_signature(api_secret, &message)?
139        } else {
140            String::new()
141        };
142
143        // Add timestamp and signature to headers
144        headers.insert("timestamp", HeaderValue::from_str(&timestamp).unwrap());
145        headers.insert("signature", HeaderValue::from_str(&signature).unwrap());
146
147        // Send the request
148        let response = self
149            .client
150            .post(format!(
151                "{}/api/passport/login/v5/account",
152                self.config.base_url
153            ))
154            .headers(headers)
155            .json(&body)
156            .send()
157            .await
158            .map_err(|e| WebullError::NetworkError(e))?;
159
160        // Check for errors
161        if !response.status().is_success() {
162            let status = response.status();
163            let text = response
164                .text()
165                .await
166                .unwrap_or_else(|_| "Unknown error".to_string());
167
168            if status.as_u16() == 401 {
169                return Err(WebullError::Unauthorized);
170            } else if status.as_u16() == 429 {
171                return Err(WebullError::RateLimitExceeded);
172            } else {
173                return Err(WebullError::ApiError {
174                    code: status.as_u16().to_string(),
175                    message: text,
176                });
177            }
178        }
179
180        // Parse the response
181        let response_text = response
182            .text()
183            .await
184            .map_err(|e| WebullError::NetworkError(e))?;
185
186        #[derive(Debug, Deserialize)]
187        struct LoginResponse {
188            access_token: String,
189            refresh_token: String,
190            token_type: String,
191            expires_in: i64,
192        }
193
194        let login_response: LoginResponse = from_json(&response_text)?;
195
196        // Create the token
197        let token = AccessToken {
198            token: login_response.access_token,
199            expires_at: Utc::now() + chrono::Duration::seconds(login_response.expires_in),
200            refresh_token: Some(login_response.refresh_token),
201        };
202
203        // Store the token
204        self.token_store.store_token(token.clone())?;
205
206        Ok(token)
207    }
208
209    /// Handle multi-factor authentication.
210    pub async fn multi_factor_auth(&mut self, mfa_code: &str) -> WebullResult<AccessToken> {
211        // Check if we have credentials
212        let credentials = self.credentials.as_ref().ok_or_else(|| {
213            WebullError::InvalidRequest("No credentials available for MFA".to_string())
214        })?;
215
216        // Create the request body
217        let body = json!({
218            "username": credentials.username,
219            "verificationCode": mfa_code,
220            "deviceId": self.config.device_id.clone().unwrap_or_default(),
221        });
222
223        // Create headers
224        let mut headers = HeaderMap::new();
225        headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
226
227        // Add API key if available
228        if let Some(api_key) = &self.config.api_key {
229            headers.insert("api-key", HeaderValue::from_str(api_key).unwrap());
230        }
231
232        // Generate timestamp and signature
233        let timestamp = generate_timestamp();
234        let signature = if let Some(api_secret) = &self.config.api_secret {
235            let message = format!("{}{}", timestamp, to_json(&body)?);
236            generate_signature(api_secret, &message)?
237        } else {
238            String::new()
239        };
240
241        // Add timestamp and signature to headers
242        headers.insert("timestamp", HeaderValue::from_str(&timestamp).unwrap());
243        headers.insert("signature", HeaderValue::from_str(&signature).unwrap());
244
245        // Send the request
246        let response = self
247            .client
248            .post(format!(
249                "{}/api/passport/verificationCode/verify",
250                self.config.base_url
251            ))
252            .headers(headers)
253            .json(&body)
254            .send()
255            .await
256            .map_err(|e| WebullError::NetworkError(e))?;
257
258        // Check for errors
259        if !response.status().is_success() {
260            let status = response.status();
261            let text = response
262                .text()
263                .await
264                .unwrap_or_else(|_| "Unknown error".to_string());
265
266            if status.as_u16() == 401 {
267                return Err(WebullError::Unauthorized);
268            } else if status.as_u16() == 429 {
269                return Err(WebullError::RateLimitExceeded);
270            } else {
271                return Err(WebullError::ApiError {
272                    code: status.as_u16().to_string(),
273                    message: text,
274                });
275            }
276        }
277
278        // Parse the response
279        let response_text = response
280            .text()
281            .await
282            .map_err(|e| WebullError::NetworkError(e))?;
283
284        #[derive(Debug, Deserialize)]
285        struct MfaResponse {
286            access_token: String,
287            refresh_token: String,
288            token_type: String,
289            expires_in: i64,
290        }
291
292        let mfa_response: MfaResponse = from_json(&response_text)?;
293
294        // Create the token
295        let token = AccessToken {
296            token: mfa_response.access_token,
297            expires_at: Utc::now() + chrono::Duration::seconds(mfa_response.expires_in),
298            refresh_token: Some(mfa_response.refresh_token),
299        };
300
301        // Store the token
302        self.token_store.store_token(token.clone())?;
303
304        Ok(token)
305    }
306
307    /// Refresh the access token.
308    pub async fn refresh_token(&mut self) -> WebullResult<AccessToken> {
309        // Get the current token
310        let current_token = self.token_store.get_token()?.ok_or_else(|| {
311            WebullError::InvalidRequest("No token available for refresh".to_string())
312        })?;
313
314        // Check if we have a refresh token
315        let refresh_token = current_token
316            .refresh_token
317            .ok_or_else(|| WebullError::InvalidRequest("No refresh token available".to_string()))?;
318
319        // Create the request body
320        let body = json!({
321            "refreshToken": refresh_token,
322            "deviceId": self.config.device_id.clone().unwrap_or_default(),
323        });
324
325        // Create headers
326        let mut headers = HeaderMap::new();
327        headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
328
329        // Add API key if available
330        if let Some(api_key) = &self.config.api_key {
331            headers.insert("api-key", HeaderValue::from_str(api_key).unwrap());
332        }
333
334        // Generate timestamp and signature
335        let timestamp = generate_timestamp();
336        let signature = if let Some(api_secret) = &self.config.api_secret {
337            let message = format!("{}{}", timestamp, to_json(&body)?);
338            generate_signature(api_secret, &message)?
339        } else {
340            String::new()
341        };
342
343        // Add timestamp and signature to headers
344        headers.insert("timestamp", HeaderValue::from_str(&timestamp).unwrap());
345        headers.insert("signature", HeaderValue::from_str(&signature).unwrap());
346
347        // Send the request
348        let response = self
349            .client
350            .post(format!(
351                "{}/api/passport/refreshToken",
352                self.config.base_url
353            ))
354            .headers(headers)
355            .json(&body)
356            .send()
357            .await
358            .map_err(|e| WebullError::NetworkError(e))?;
359
360        // Check for errors
361        if !response.status().is_success() {
362            let status = response.status();
363            let text = response
364                .text()
365                .await
366                .unwrap_or_else(|_| "Unknown error".to_string());
367
368            if status.as_u16() == 401 {
369                return Err(WebullError::Unauthorized);
370            } else if status.as_u16() == 429 {
371                return Err(WebullError::RateLimitExceeded);
372            } else {
373                return Err(WebullError::ApiError {
374                    code: status.as_u16().to_string(),
375                    message: text,
376                });
377            }
378        }
379
380        // Parse the response
381        let response_text = response
382            .text()
383            .await
384            .map_err(|e| WebullError::NetworkError(e))?;
385
386        #[derive(Debug, Deserialize)]
387        struct RefreshResponse {
388            access_token: String,
389            refresh_token: String,
390            token_type: String,
391            expires_in: i64,
392        }
393
394        let refresh_response: RefreshResponse = from_json(&response_text)?;
395
396        // Create the token
397        let token = AccessToken {
398            token: refresh_response.access_token,
399            expires_at: Utc::now() + chrono::Duration::seconds(refresh_response.expires_in),
400            refresh_token: Some(refresh_response.refresh_token),
401        };
402
403        // Store the token
404        self.token_store.store_token(token.clone())?;
405
406        Ok(token)
407    }
408
409    /// Get the current access token.
410    pub async fn get_token(&self) -> WebullResult<AccessToken> {
411        match self.token_store.get_token()? {
412            Some(token) => {
413                // Check if token is expired
414                if token.expires_at <= Utc::now() {
415                    return Err(WebullError::Unauthorized);
416                }
417                Ok(token)
418            }
419            None => Err(WebullError::Unauthorized),
420        }
421    }
422
423    /// Revoke the current token.
424    pub async fn revoke_token(&mut self) -> WebullResult<()> {
425        // Get the current token
426        let current_token = match self.token_store.get_token()? {
427            Some(token) => token,
428            None => {
429                // No token to revoke
430                self.credentials = None;
431                return Ok(());
432            }
433        };
434
435        // Create the request body
436        let body = json!({
437            "accessToken": current_token.token,
438            "deviceId": self.config.device_id.clone().unwrap_or_default(),
439        });
440
441        // Create headers
442        let mut headers = HeaderMap::new();
443        headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
444        headers.insert(
445            AUTHORIZATION,
446            HeaderValue::from_str(&format!("Bearer {}", current_token.token)).unwrap(),
447        );
448
449        // Add API key if available
450        if let Some(api_key) = &self.config.api_key {
451            headers.insert("api-key", HeaderValue::from_str(api_key).unwrap());
452        }
453
454        // Generate timestamp and signature
455        let timestamp = generate_timestamp();
456        let signature = if let Some(api_secret) = &self.config.api_secret {
457            let message = format!("{}{}", timestamp, to_json(&body)?);
458            generate_signature(api_secret, &message)?
459        } else {
460            String::new()
461        };
462
463        // Add timestamp and signature to headers
464        headers.insert("timestamp", HeaderValue::from_str(&timestamp).unwrap());
465        headers.insert("signature", HeaderValue::from_str(&signature).unwrap());
466
467        // Send the request
468        let response = self
469            .client
470            .post(format!("{}/api/passport/logout", self.config.base_url))
471            .headers(headers)
472            .json(&body)
473            .send()
474            .await
475            .map_err(|e| WebullError::NetworkError(e))?;
476
477        // Check for errors
478        if !response.status().is_success() {
479            let status = response.status();
480            let text = response
481                .text()
482                .await
483                .unwrap_or_else(|_| "Unknown error".to_string());
484
485            if status.as_u16() == 401 {
486                // Token is already invalid, so we can just clear it
487            } else if status.as_u16() == 429 {
488                return Err(WebullError::RateLimitExceeded);
489            } else {
490                return Err(WebullError::ApiError {
491                    code: status.as_u16().to_string(),
492                    message: text,
493                });
494            }
495        }
496
497        // Clear the token and credentials
498        self.token_store.clear_token()?;
499        self.credentials = None;
500
501        Ok(())
502    }
503}