webull_rs/endpoints/
base.rs

1use crate::auth::AuthManager;
2use crate::error::{WebullError, WebullResult};
3use crate::models::response::ApiResponse;
4use crate::utils::cache::CacheManager;
5use crate::utils::rate_limit::RateLimiter;
6use reqwest::{Client, Method, RequestBuilder, StatusCode};
7use reqwest::header::AUTHORIZATION;
8use serde::de::DeserializeOwned;
9use serde::Serialize;
10use std::sync::Arc;
11use std::time::Duration;
12use url::Url;
13
14/// Base endpoint for API requests.
15pub struct BaseEndpoint {
16    /// HTTP client
17    client: Client,
18
19    /// Base URL for API requests
20    base_url: String,
21
22    /// Authentication manager
23    auth_manager: Arc<AuthManager>,
24
25    /// Rate limiter
26    rate_limiter: Arc<RateLimiter>,
27
28    /// Cache manager
29    cache_manager: Arc<CacheManager>,
30}
31
32impl BaseEndpoint {
33    /// Create a new base endpoint.
34    pub fn new(client: Client, base_url: String, auth_manager: Arc<AuthManager>) -> Self {
35        Self {
36            client,
37            base_url,
38            auth_manager,
39            rate_limiter: Arc::new(RateLimiter::new(60)), // Default to 60 requests per minute
40            cache_manager: Arc::new(CacheManager::new()),
41        }
42    }
43
44    /// Build a request to the API.
45    pub fn request<T>(&self, method: Method, path: &str) -> RequestBuilder
46    where
47        T: DeserializeOwned,
48    {
49        let url = self.build_url(path);
50        self.client.request(method, url)
51    }
52
53    /// Send a request to the API and parse the response.
54    pub async fn send_request<T>(&self, request: RequestBuilder) -> WebullResult<T>
55    where
56        T: DeserializeOwned + Clone,
57    {
58        // Clone the request URL to get the path
59        let req_url = request.try_clone()
60            .ok_or_else(|| WebullError::InvalidRequest("Failed to clone request".to_string()))?
61            .build()
62            .map_err(WebullError::NetworkError)?
63            .url()
64            .clone();
65
66        let path = req_url.path();
67
68        // Wait for rate limit
69        self.rate_limiter.wait(path).await;
70
71        // Send the request
72        let response = request.send().await.map_err(WebullError::NetworkError)?;
73
74        let status = response.status();
75
76        // Handle rate limiting
77        if status == StatusCode::TOO_MANY_REQUESTS {
78            // Get the retry-after header if available
79            let retry_after = response.headers()
80                .get("retry-after")
81                .and_then(|h| h.to_str().ok())
82                .and_then(|s| s.parse::<u64>().ok())
83                .unwrap_or(1);
84
85            // Wait for the specified time
86            tokio::time::sleep(std::time::Duration::from_secs(retry_after)).await;
87
88            return Err(WebullError::RateLimitExceeded);
89        }
90
91        // Handle unauthorized
92        if status == StatusCode::UNAUTHORIZED {
93            return Err(WebullError::Unauthorized);
94        }
95
96        // Get the response body
97        let body = response.text().await.map_err(WebullError::NetworkError)?;
98
99        // Handle other errors
100        if !status.is_success() {
101            return Err(WebullError::ApiError {
102                code: status.as_u16().to_string(),
103                message: body,
104            });
105        }
106
107        // Parse the response
108        let api_response: ApiResponse<T> = serde_json::from_str(&body)
109            .map_err(|e| WebullError::SerializationError(e))?;
110
111        // Check for API errors
112        if !api_response.is_success() {
113            return Err(WebullError::ApiError {
114                code: api_response.code.unwrap_or_else(|| "unknown".to_string()),
115                message: api_response.message.unwrap_or_else(|| "Unknown error".to_string()),
116            });
117        }
118
119        // Return the data
120        api_response.get_data().cloned().ok_or_else(|| WebullError::ApiError {
121            code: "no_data".to_string(),
122            message: "Response did not contain data".to_string(),
123        })
124    }
125
126    /// Build a URL for the API.
127    fn build_url(&self, path: &str) -> Url {
128        let base = self.base_url.trim_end_matches('/');
129        let path = path.trim_start_matches('/');
130        let url = format!("{}/{}", base, path);
131
132        Url::parse(&url).unwrap_or_else(|_| {
133            // This should never happen if the base URL is valid
134            panic!("Invalid URL: {}", url);
135        })
136    }
137
138    /// Add authentication headers to a request.
139    pub async fn authenticate_request(&self, request: RequestBuilder) -> WebullResult<RequestBuilder> {
140        // Get the token from the auth manager
141        let token = self.auth_manager.get_token().await?;
142
143        // Add the token to the request headers
144        let request = request.header(AUTHORIZATION, format!("Bearer {}", token.token));
145
146        Ok(request)
147    }
148
149    /// Send a GET request to the API.
150    pub async fn get<T>(&self, path: &str) -> WebullResult<T>
151    where
152        T: DeserializeOwned + Clone + Send + Sync + 'static,
153    {
154        // Check if we have a cached response
155        let cache = self.cache_manager.get_cache::<T>("get");
156        if let Some(cached) = cache.get("GET", path, None, None) {
157            return Ok(cached);
158        }
159
160        // Send the request
161        let request = self.request::<T>(Method::GET, path);
162        let request = self.authenticate_request(request).await?;
163        let response: T = self.send_request(request).await?;
164
165        // Cache the response
166        cache.set("GET", path, None, None, response.clone(), Some(Duration::from_secs(60)));
167
168        Ok(response)
169    }
170
171    /// Send a POST request to the API.
172    pub async fn post<T, B>(&self, path: &str, body: &B) -> WebullResult<T>
173    where
174        T: DeserializeOwned + Clone + Send + Sync + 'static,
175        B: Serialize,
176    {
177        // For POST requests, we only cache if the body is cacheable
178        // For simplicity, we'll just serialize the body and use it as part of the cache key
179        let body_str = match serde_json::to_string(body) {
180            Ok(s) => Some(s),
181            Err(_) => None,
182        };
183
184        // Check if we have a cached response
185        if let Some(body_str) = &body_str {
186            let cache = self.cache_manager.get_cache::<T>("post");
187            if let Some(cached) = cache.get("POST", path, None, Some(body_str)) {
188                return Ok(cached);
189            }
190        }
191
192        // Send the request
193        let request = self.request::<T>(Method::POST, path).json(body);
194        let request = self.authenticate_request(request).await?;
195        let response: T = self.send_request(request).await?;
196
197        // Cache the response if the body is cacheable
198        if let Some(body_str) = body_str {
199            let cache = self.cache_manager.get_cache::<T>("post");
200            cache.set("POST", path, None, Some(&body_str), response.clone(), Some(Duration::from_secs(60)));
201        }
202
203        Ok(response)
204    }
205
206    /// Send a PUT request to the API.
207    pub async fn put<T, B>(&self, path: &str, body: &B) -> WebullResult<T>
208    where
209        T: DeserializeOwned + Clone + Send + Sync + 'static,
210        B: Serialize,
211    {
212        // For PUT requests, we don't cache the response, but we invalidate any cached GET responses
213        // for the same path
214
215        // Send the request
216        let request = self.request::<T>(Method::PUT, path).json(body);
217        let request = self.authenticate_request(request).await?;
218        let response: T = self.send_request(request).await?;
219
220        // Invalidate any cached GET responses for this path
221        let get_cache = self.cache_manager.get_cache::<T>("get");
222        get_cache.clear();
223
224        Ok(response)
225    }
226
227    /// Send a DELETE request to the API.
228    pub async fn delete<T>(&self, path: &str) -> WebullResult<T>
229    where
230        T: DeserializeOwned + Clone + Send + Sync + 'static,
231    {
232        // For DELETE requests, we don't cache the response, but we invalidate any cached responses
233        // for the same path
234
235        // Send the request
236        let request = self.request::<T>(Method::DELETE, path);
237        let request = self.authenticate_request(request).await?;
238        let response: T = self.send_request(request).await?;
239
240        // Invalidate any cached responses for this path
241        let get_cache = self.cache_manager.get_cache::<T>("get");
242        get_cache.clear();
243
244        let post_cache = self.cache_manager.get_cache::<T>("post");
245        post_cache.clear();
246
247        Ok(response)
248    }
249}