webull_rs/endpoints/
base.rs1use crate::auth::AuthManager;
2use crate::error::{WebullError, WebullResult};
3use crate::models::response::ApiResponse;
4use crate::utils::cache::CacheManager;
5use crate::utils::rate_limit::RateLimiter;
6use reqwest::{Client, Method, RequestBuilder, StatusCode};
7use reqwest::header::AUTHORIZATION;
8use serde::de::DeserializeOwned;
9use serde::Serialize;
10use std::sync::Arc;
11use std::time::Duration;
12use url::Url;
13
14pub struct BaseEndpoint {
16 client: Client,
18
19 base_url: String,
21
22 auth_manager: Arc<AuthManager>,
24
25 rate_limiter: Arc<RateLimiter>,
27
28 cache_manager: Arc<CacheManager>,
30}
31
32impl BaseEndpoint {
33 pub fn new(client: Client, base_url: String, auth_manager: Arc<AuthManager>) -> Self {
35 Self {
36 client,
37 base_url,
38 auth_manager,
39 rate_limiter: Arc::new(RateLimiter::new(60)), cache_manager: Arc::new(CacheManager::new()),
41 }
42 }
43
44 pub fn request<T>(&self, method: Method, path: &str) -> RequestBuilder
46 where
47 T: DeserializeOwned,
48 {
49 let url = self.build_url(path);
50 self.client.request(method, url)
51 }
52
53 pub async fn send_request<T>(&self, request: RequestBuilder) -> WebullResult<T>
55 where
56 T: DeserializeOwned + Clone,
57 {
58 let req_url = request.try_clone()
60 .ok_or_else(|| WebullError::InvalidRequest("Failed to clone request".to_string()))?
61 .build()
62 .map_err(WebullError::NetworkError)?
63 .url()
64 .clone();
65
66 let path = req_url.path();
67
68 self.rate_limiter.wait(path).await;
70
71 let response = request.send().await.map_err(WebullError::NetworkError)?;
73
74 let status = response.status();
75
76 if status == StatusCode::TOO_MANY_REQUESTS {
78 let retry_after = response.headers()
80 .get("retry-after")
81 .and_then(|h| h.to_str().ok())
82 .and_then(|s| s.parse::<u64>().ok())
83 .unwrap_or(1);
84
85 tokio::time::sleep(std::time::Duration::from_secs(retry_after)).await;
87
88 return Err(WebullError::RateLimitExceeded);
89 }
90
91 if status == StatusCode::UNAUTHORIZED {
93 return Err(WebullError::Unauthorized);
94 }
95
96 let body = response.text().await.map_err(WebullError::NetworkError)?;
98
99 if !status.is_success() {
101 return Err(WebullError::ApiError {
102 code: status.as_u16().to_string(),
103 message: body,
104 });
105 }
106
107 let api_response: ApiResponse<T> = serde_json::from_str(&body)
109 .map_err(|e| WebullError::SerializationError(e))?;
110
111 if !api_response.is_success() {
113 return Err(WebullError::ApiError {
114 code: api_response.code.unwrap_or_else(|| "unknown".to_string()),
115 message: api_response.message.unwrap_or_else(|| "Unknown error".to_string()),
116 });
117 }
118
119 api_response.get_data().cloned().ok_or_else(|| WebullError::ApiError {
121 code: "no_data".to_string(),
122 message: "Response did not contain data".to_string(),
123 })
124 }
125
126 fn build_url(&self, path: &str) -> Url {
128 let base = self.base_url.trim_end_matches('/');
129 let path = path.trim_start_matches('/');
130 let url = format!("{}/{}", base, path);
131
132 Url::parse(&url).unwrap_or_else(|_| {
133 panic!("Invalid URL: {}", url);
135 })
136 }
137
138 pub async fn authenticate_request(&self, request: RequestBuilder) -> WebullResult<RequestBuilder> {
140 let token = self.auth_manager.get_token().await?;
142
143 let request = request.header(AUTHORIZATION, format!("Bearer {}", token.token));
145
146 Ok(request)
147 }
148
149 pub async fn get<T>(&self, path: &str) -> WebullResult<T>
151 where
152 T: DeserializeOwned + Clone + Send + Sync + 'static,
153 {
154 let cache = self.cache_manager.get_cache::<T>("get");
156 if let Some(cached) = cache.get("GET", path, None, None) {
157 return Ok(cached);
158 }
159
160 let request = self.request::<T>(Method::GET, path);
162 let request = self.authenticate_request(request).await?;
163 let response: T = self.send_request(request).await?;
164
165 cache.set("GET", path, None, None, response.clone(), Some(Duration::from_secs(60)));
167
168 Ok(response)
169 }
170
171 pub async fn post<T, B>(&self, path: &str, body: &B) -> WebullResult<T>
173 where
174 T: DeserializeOwned + Clone + Send + Sync + 'static,
175 B: Serialize,
176 {
177 let body_str = match serde_json::to_string(body) {
180 Ok(s) => Some(s),
181 Err(_) => None,
182 };
183
184 if let Some(body_str) = &body_str {
186 let cache = self.cache_manager.get_cache::<T>("post");
187 if let Some(cached) = cache.get("POST", path, None, Some(body_str)) {
188 return Ok(cached);
189 }
190 }
191
192 let request = self.request::<T>(Method::POST, path).json(body);
194 let request = self.authenticate_request(request).await?;
195 let response: T = self.send_request(request).await?;
196
197 if let Some(body_str) = body_str {
199 let cache = self.cache_manager.get_cache::<T>("post");
200 cache.set("POST", path, None, Some(&body_str), response.clone(), Some(Duration::from_secs(60)));
201 }
202
203 Ok(response)
204 }
205
206 pub async fn put<T, B>(&self, path: &str, body: &B) -> WebullResult<T>
208 where
209 T: DeserializeOwned + Clone + Send + Sync + 'static,
210 B: Serialize,
211 {
212 let request = self.request::<T>(Method::PUT, path).json(body);
217 let request = self.authenticate_request(request).await?;
218 let response: T = self.send_request(request).await?;
219
220 let get_cache = self.cache_manager.get_cache::<T>("get");
222 get_cache.clear();
223
224 Ok(response)
225 }
226
227 pub async fn delete<T>(&self, path: &str) -> WebullResult<T>
229 where
230 T: DeserializeOwned + Clone + Send + Sync + 'static,
231 {
232 let request = self.request::<T>(Method::DELETE, path);
237 let request = self.authenticate_request(request).await?;
238 let response: T = self.send_request(request).await?;
239
240 let get_cache = self.cache_manager.get_cache::<T>("get");
242 get_cache.clear();
243
244 let post_cache = self.cache_manager.get_cache::<T>("post");
245 post_cache.clear();
246
247 Ok(response)
248 }
249}