webull_rs/endpoints/
base.rs1use crate::auth::AuthManager;
2use crate::error::{WebullError, WebullResult};
3use crate::models::response::ApiResponse;
4use crate::utils::cache::CacheManager;
5use crate::utils::rate_limit::RateLimiter;
6use reqwest::header::AUTHORIZATION;
7use reqwest::{Client, Method, RequestBuilder, StatusCode};
8use serde::de::DeserializeOwned;
9use serde::Serialize;
10use std::sync::Arc;
11use std::time::Duration;
12use url::Url;
13
14pub struct BaseEndpoint {
16 client: Client,
18
19 base_url: String,
21
22 auth_manager: Arc<AuthManager>,
24
25 rate_limiter: Arc<RateLimiter>,
27
28 cache_manager: Arc<CacheManager>,
30}
31
32impl BaseEndpoint {
33 pub fn new(client: Client, base_url: String, auth_manager: Arc<AuthManager>) -> Self {
35 Self {
36 client,
37 base_url,
38 auth_manager,
39 rate_limiter: Arc::new(RateLimiter::new(60)), cache_manager: Arc::new(CacheManager::new()),
41 }
42 }
43
44 pub fn request<T>(&self, method: Method, path: &str) -> RequestBuilder
46 where
47 T: DeserializeOwned,
48 {
49 let url = self.build_url(path);
50 self.client.request(method, url)
51 }
52
53 pub async fn send_request<T>(&self, request: RequestBuilder) -> WebullResult<T>
55 where
56 T: DeserializeOwned + Clone,
57 {
58 let req_url = request
60 .try_clone()
61 .ok_or_else(|| WebullError::InvalidRequest("Failed to clone request".to_string()))?
62 .build()
63 .map_err(WebullError::NetworkError)?
64 .url()
65 .clone();
66
67 let path = req_url.path();
68
69 self.rate_limiter.wait(path).await;
71
72 let response = request.send().await.map_err(WebullError::NetworkError)?;
74
75 let status = response.status();
76
77 if status == StatusCode::TOO_MANY_REQUESTS {
79 let retry_after = response
81 .headers()
82 .get("retry-after")
83 .and_then(|h| h.to_str().ok())
84 .and_then(|s| s.parse::<u64>().ok())
85 .unwrap_or(1);
86
87 tokio::time::sleep(std::time::Duration::from_secs(retry_after)).await;
89
90 return Err(WebullError::RateLimitExceeded);
91 }
92
93 if status == StatusCode::UNAUTHORIZED {
95 return Err(WebullError::Unauthorized);
96 }
97
98 let body = response.text().await.map_err(WebullError::NetworkError)?;
100
101 if !status.is_success() {
103 return Err(WebullError::ApiError {
104 code: status.as_u16().to_string(),
105 message: body,
106 });
107 }
108
109 let api_response: ApiResponse<T> =
111 serde_json::from_str(&body).map_err(|e| WebullError::SerializationError(e))?;
112
113 if !api_response.is_success() {
115 return Err(WebullError::ApiError {
116 code: api_response.code.unwrap_or_else(|| "unknown".to_string()),
117 message: api_response
118 .message
119 .unwrap_or_else(|| "Unknown error".to_string()),
120 });
121 }
122
123 api_response
125 .get_data()
126 .cloned()
127 .ok_or_else(|| WebullError::ApiError {
128 code: "no_data".to_string(),
129 message: "Response did not contain data".to_string(),
130 })
131 }
132
133 fn build_url(&self, path: &str) -> Url {
135 let base = self.base_url.trim_end_matches('/');
136 let path = path.trim_start_matches('/');
137 let url = format!("{}/{}", base, path);
138
139 Url::parse(&url).unwrap_or_else(|_| {
140 panic!("Invalid URL: {}", url);
142 })
143 }
144
145 pub async fn authenticate_request(
147 &self,
148 request: RequestBuilder,
149 ) -> WebullResult<RequestBuilder> {
150 let token = self.auth_manager.get_token().await?;
152
153 let request = request.header(AUTHORIZATION, format!("Bearer {}", token.token));
155
156 Ok(request)
157 }
158
159 pub async fn get<T>(&self, path: &str) -> WebullResult<T>
161 where
162 T: DeserializeOwned + Clone + Send + Sync + 'static,
163 {
164 let cache = self.cache_manager.get_cache::<T>("get");
166 if let Some(cached) = cache.get("GET", path, None, None) {
167 return Ok(cached);
168 }
169
170 let request = self.request::<T>(Method::GET, path);
172 let request = self.authenticate_request(request).await?;
173 let response: T = self.send_request(request).await?;
174
175 cache.set(
177 "GET",
178 path,
179 None,
180 None,
181 response.clone(),
182 Some(Duration::from_secs(60)),
183 );
184
185 Ok(response)
186 }
187
188 pub async fn post<T, B>(&self, path: &str, body: &B) -> WebullResult<T>
190 where
191 T: DeserializeOwned + Clone + Send + Sync + 'static,
192 B: Serialize,
193 {
194 let body_str = match serde_json::to_string(body) {
197 Ok(s) => Some(s),
198 Err(_) => None,
199 };
200
201 if let Some(body_str) = &body_str {
203 let cache = self.cache_manager.get_cache::<T>("post");
204 if let Some(cached) = cache.get("POST", path, None, Some(body_str)) {
205 return Ok(cached);
206 }
207 }
208
209 let request = self.request::<T>(Method::POST, path).json(body);
211 let request = self.authenticate_request(request).await?;
212 let response: T = self.send_request(request).await?;
213
214 if let Some(body_str) = body_str {
216 let cache = self.cache_manager.get_cache::<T>("post");
217 cache.set(
218 "POST",
219 path,
220 None,
221 Some(&body_str),
222 response.clone(),
223 Some(Duration::from_secs(60)),
224 );
225 }
226
227 Ok(response)
228 }
229
230 pub async fn put<T, B>(&self, path: &str, body: &B) -> WebullResult<T>
232 where
233 T: DeserializeOwned + Clone + Send + Sync + 'static,
234 B: Serialize,
235 {
236 let request = self.request::<T>(Method::PUT, path).json(body);
241 let request = self.authenticate_request(request).await?;
242 let response: T = self.send_request(request).await?;
243
244 let get_cache = self.cache_manager.get_cache::<T>("get");
246 get_cache.clear();
247
248 Ok(response)
249 }
250
251 pub async fn delete<T>(&self, path: &str) -> WebullResult<T>
253 where
254 T: DeserializeOwned + Clone + Send + Sync + 'static,
255 {
256 let request = self.request::<T>(Method::DELETE, path);
261 let request = self.authenticate_request(request).await?;
262 let response: T = self.send_request(request).await?;
263
264 let get_cache = self.cache_manager.get_cache::<T>("get");
266 get_cache.clear();
267
268 let post_cache = self.cache_manager.get_cache::<T>("post");
269 post_cache.clear();
270
271 Ok(response)
272 }
273}