webull_rs/endpoints/
base.rs

1use crate::auth::AuthManager;
2use crate::error::{WebullError, WebullResult};
3use crate::models::response::ApiResponse;
4use crate::utils::cache::CacheManager;
5use crate::utils::rate_limit::RateLimiter;
6use reqwest::header::AUTHORIZATION;
7use reqwest::{Client, Method, RequestBuilder, StatusCode};
8use serde::de::DeserializeOwned;
9use serde::Serialize;
10use std::sync::Arc;
11use std::time::Duration;
12use url::Url;
13
14/// Base endpoint for API requests.
15pub struct BaseEndpoint {
16    /// HTTP client
17    client: Client,
18
19    /// Base URL for API requests
20    base_url: String,
21
22    /// Authentication manager
23    auth_manager: Arc<AuthManager>,
24
25    /// Rate limiter
26    rate_limiter: Arc<RateLimiter>,
27
28    /// Cache manager
29    cache_manager: Arc<CacheManager>,
30}
31
32impl BaseEndpoint {
33    /// Create a new base endpoint.
34    pub fn new(client: Client, base_url: String, auth_manager: Arc<AuthManager>) -> Self {
35        Self {
36            client,
37            base_url,
38            auth_manager,
39            rate_limiter: Arc::new(RateLimiter::new(60)), // Default to 60 requests per minute
40            cache_manager: Arc::new(CacheManager::new()),
41        }
42    }
43
44    /// Build a request to the API.
45    pub fn request<T>(&self, method: Method, path: &str) -> RequestBuilder
46    where
47        T: DeserializeOwned,
48    {
49        let url = self.build_url(path);
50        self.client.request(method, url)
51    }
52
53    /// Send a request to the API and parse the response.
54    pub async fn send_request<T>(&self, request: RequestBuilder) -> WebullResult<T>
55    where
56        T: DeserializeOwned + Clone,
57    {
58        // Clone the request URL to get the path
59        let req_url = request
60            .try_clone()
61            .ok_or_else(|| WebullError::InvalidRequest("Failed to clone request".to_string()))?
62            .build()
63            .map_err(WebullError::NetworkError)?
64            .url()
65            .clone();
66
67        let path = req_url.path();
68
69        // Wait for rate limit
70        self.rate_limiter.wait(path).await;
71
72        // Send the request
73        let response = request.send().await.map_err(WebullError::NetworkError)?;
74
75        let status = response.status();
76
77        // Handle rate limiting
78        if status == StatusCode::TOO_MANY_REQUESTS {
79            // Get the retry-after header if available
80            let retry_after = response
81                .headers()
82                .get("retry-after")
83                .and_then(|h| h.to_str().ok())
84                .and_then(|s| s.parse::<u64>().ok())
85                .unwrap_or(1);
86
87            // Wait for the specified time
88            tokio::time::sleep(std::time::Duration::from_secs(retry_after)).await;
89
90            return Err(WebullError::RateLimitExceeded);
91        }
92
93        // Handle unauthorized
94        if status == StatusCode::UNAUTHORIZED {
95            return Err(WebullError::Unauthorized);
96        }
97
98        // Get the response body
99        let body = response.text().await.map_err(WebullError::NetworkError)?;
100
101        // Handle other errors
102        if !status.is_success() {
103            return Err(WebullError::ApiError {
104                code: status.as_u16().to_string(),
105                message: body,
106            });
107        }
108
109        // Parse the response
110        let api_response: ApiResponse<T> =
111            serde_json::from_str(&body).map_err(|e| WebullError::SerializationError(e))?;
112
113        // Check for API errors
114        if !api_response.is_success() {
115            return Err(WebullError::ApiError {
116                code: api_response.code.unwrap_or_else(|| "unknown".to_string()),
117                message: api_response
118                    .message
119                    .unwrap_or_else(|| "Unknown error".to_string()),
120            });
121        }
122
123        // Return the data
124        api_response
125            .get_data()
126            .cloned()
127            .ok_or_else(|| WebullError::ApiError {
128                code: "no_data".to_string(),
129                message: "Response did not contain data".to_string(),
130            })
131    }
132
133    /// Build a URL for the API.
134    fn build_url(&self, path: &str) -> Url {
135        let base = self.base_url.trim_end_matches('/');
136        let path = path.trim_start_matches('/');
137        let url = format!("{}/{}", base, path);
138
139        Url::parse(&url).unwrap_or_else(|_| {
140            // This should never happen if the base URL is valid
141            panic!("Invalid URL: {}", url);
142        })
143    }
144
145    /// Add authentication headers to a request.
146    pub async fn authenticate_request(
147        &self,
148        request: RequestBuilder,
149    ) -> WebullResult<RequestBuilder> {
150        // Get the token from the auth manager
151        let token = self.auth_manager.get_token().await?;
152
153        // Add the token to the request headers
154        let request = request.header(AUTHORIZATION, format!("Bearer {}", token.token));
155
156        Ok(request)
157    }
158
159    /// Send a GET request to the API.
160    pub async fn get<T>(&self, path: &str) -> WebullResult<T>
161    where
162        T: DeserializeOwned + Clone + Send + Sync + 'static,
163    {
164        // Check if we have a cached response
165        let cache = self.cache_manager.get_cache::<T>("get");
166        if let Some(cached) = cache.get("GET", path, None, None) {
167            return Ok(cached);
168        }
169
170        // Send the request
171        let request = self.request::<T>(Method::GET, path);
172        let request = self.authenticate_request(request).await?;
173        let response: T = self.send_request(request).await?;
174
175        // Cache the response
176        cache.set(
177            "GET",
178            path,
179            None,
180            None,
181            response.clone(),
182            Some(Duration::from_secs(60)),
183        );
184
185        Ok(response)
186    }
187
188    /// Send a POST request to the API.
189    pub async fn post<T, B>(&self, path: &str, body: &B) -> WebullResult<T>
190    where
191        T: DeserializeOwned + Clone + Send + Sync + 'static,
192        B: Serialize,
193    {
194        // For POST requests, we only cache if the body is cacheable
195        // For simplicity, we'll just serialize the body and use it as part of the cache key
196        let body_str = match serde_json::to_string(body) {
197            Ok(s) => Some(s),
198            Err(_) => None,
199        };
200
201        // Check if we have a cached response
202        if let Some(body_str) = &body_str {
203            let cache = self.cache_manager.get_cache::<T>("post");
204            if let Some(cached) = cache.get("POST", path, None, Some(body_str)) {
205                return Ok(cached);
206            }
207        }
208
209        // Send the request
210        let request = self.request::<T>(Method::POST, path).json(body);
211        let request = self.authenticate_request(request).await?;
212        let response: T = self.send_request(request).await?;
213
214        // Cache the response if the body is cacheable
215        if let Some(body_str) = body_str {
216            let cache = self.cache_manager.get_cache::<T>("post");
217            cache.set(
218                "POST",
219                path,
220                None,
221                Some(&body_str),
222                response.clone(),
223                Some(Duration::from_secs(60)),
224            );
225        }
226
227        Ok(response)
228    }
229
230    /// Send a PUT request to the API.
231    pub async fn put<T, B>(&self, path: &str, body: &B) -> WebullResult<T>
232    where
233        T: DeserializeOwned + Clone + Send + Sync + 'static,
234        B: Serialize,
235    {
236        // For PUT requests, we don't cache the response, but we invalidate any cached GET responses
237        // for the same path
238
239        // Send the request
240        let request = self.request::<T>(Method::PUT, path).json(body);
241        let request = self.authenticate_request(request).await?;
242        let response: T = self.send_request(request).await?;
243
244        // Invalidate any cached GET responses for this path
245        let get_cache = self.cache_manager.get_cache::<T>("get");
246        get_cache.clear();
247
248        Ok(response)
249    }
250
251    /// Send a DELETE request to the API.
252    pub async fn delete<T>(&self, path: &str) -> WebullResult<T>
253    where
254        T: DeserializeOwned + Clone + Send + Sync + 'static,
255    {
256        // For DELETE requests, we don't cache the response, but we invalidate any cached responses
257        // for the same path
258
259        // Send the request
260        let request = self.request::<T>(Method::DELETE, path);
261        let request = self.authenticate_request(request).await?;
262        let response: T = self.send_request(request).await?;
263
264        // Invalidate any cached responses for this path
265        let get_cache = self.cache_manager.get_cache::<T>("get");
266        get_cache.clear();
267
268        let post_cache = self.cache_manager.get_cache::<T>("post");
269        post_cache.clear();
270
271        Ok(response)
272    }
273}