1use crate::error::{DiscordError, HttpError};
34use crate::ratelimit::RateLimiter;
35use crate::routes::{CurrentApplication, CurrentUser, GatewayBotResponse};
36
37use reqwest::header::{HeaderMap, HeaderValue, AUTHORIZATION, CONTENT_TYPE, USER_AGENT};
38use reqwest::{Client, Method, Response, StatusCode};
39use serde::de::DeserializeOwned;
40use simd_json::prelude::*;
41use std::sync::Arc;
42use tracing::{debug, warn};
43
44const API_BASE: &str = "https://discord.com/api/v10";
46
47const USER_AGENT_VALUE: &str = concat!(
49 "DiscordBot (https://github.com/Sh4dowNotFound/titanium-rs, ",
50 env!("CARGO_PKG_VERSION"),
51 ")"
52);
53
54pub struct HttpClient {
67 client: Client,
69 token: String,
71 rate_limiter: Arc<RateLimiter>,
73}
74
75thread_local! {
76 static RESPONSE_BUFFER: std::cell::RefCell<Vec<u8>> = std::cell::RefCell::new(Vec::with_capacity(32 * 1024));
79}
80
81impl HttpClient {
82 pub fn new(token: impl Into<String>) -> Result<Self, HttpError> {
84 let token = token.into();
85
86 let mut headers = HeaderMap::new();
87 headers.insert(
88 AUTHORIZATION,
89 HeaderValue::from_str(&format!("Bot {}", token))
90 .map_err(|_| HttpError::Unauthorized)?,
91 );
92 headers.insert(USER_AGENT, HeaderValue::from_static(USER_AGENT_VALUE));
93 headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
94
95 let client = Client::builder()
96 .default_headers(headers)
97 .http2_adaptive_window(true)
98 .tcp_nodelay(true)
99 .timeout(std::time::Duration::from_secs(30))
100 .connect_timeout(std::time::Duration::from_secs(10))
101 .build()?;
102
103 Ok(Self {
104 client,
105 token,
106 rate_limiter: Arc::new(RateLimiter::new()),
107 })
108 }
109
110 pub fn token(&self) -> &str {
112 &self.token
113 }
114
115 pub async fn get_gateway_bot(&self) -> Result<GatewayBotResponse, HttpError> {
124 self.get("/gateway/bot").await
125 }
126
127 pub async fn get_current_user(&self) -> Result<CurrentUser, HttpError> {
133 self.get("/users/@me").await
134 }
135
136 pub async fn get_current_application(&self) -> Result<CurrentApplication, HttpError> {
138 self.get("/applications/@me").await
139 }
140
141 pub(crate) async fn get_with_query<T: DeserializeOwned, Q: serde::Serialize + ?Sized>(
147 &self,
148 route: &str,
149 query: &Q,
150 ) -> Result<T, HttpError> {
151 self.request_with_query(Method::GET, route, query, None::<()>, None)
152 .await
153 }
154
155 pub(crate) async fn get<T: DeserializeOwned>(&self, route: &str) -> Result<T, HttpError> {
157 self.request(Method::GET, route, None::<()>, None).await
158 }
159
160 #[allow(dead_code)]
162 pub(crate) async fn post<T: DeserializeOwned, B: serde::Serialize>(
163 &self,
164 route: &str,
165 body: B,
166 ) -> Result<T, HttpError> {
167 self.request(Method::POST, route, Some(body), None).await
168 }
169
170 #[allow(dead_code)]
172 pub(crate) async fn post_no_response<B: serde::Serialize>(
173 &self,
174 route: &str,
175 body: B,
176 ) -> Result<(), HttpError> {
177 let _: serde::de::IgnoredAny = self.request(Method::POST, route, Some(body), None).await?;
178 Ok(())
179 }
180
181 #[allow(dead_code)]
183 pub(crate) async fn put<T: DeserializeOwned, B: serde::Serialize>(
184 &self,
185 route: &str,
186 body: Option<B>,
187 ) -> Result<T, HttpError> {
188 self.request(Method::PUT, route, body, None).await
189 }
190
191 pub(crate) async fn put_with_headers<T: DeserializeOwned, B: serde::Serialize>(
193 &self,
194 route: &str,
195 body: Option<B>,
196 headers: Option<HeaderMap>,
197 ) -> Result<T, HttpError> {
198 self.request(Method::PUT, route, body, headers).await
199 }
200
201 #[allow(dead_code)]
203 pub(crate) async fn patch<T: DeserializeOwned, B: serde::Serialize>(
204 &self,
205 route: &str,
206 body: B,
207 ) -> Result<T, HttpError> {
208 self.request(Method::PATCH, route, Some(body), None).await
209 }
210
211 #[allow(dead_code)]
213 pub(crate) async fn delete<T: DeserializeOwned>(&self, route: &str) -> Result<T, HttpError> {
214 self.request(Method::DELETE, route, None::<()>, None).await
215 }
216
217 pub(crate) async fn delete_with_headers<T: DeserializeOwned>(
219 &self,
220 route: &str,
221 headers: Option<HeaderMap>,
222 ) -> Result<T, HttpError> {
223 self.request(Method::DELETE, route, None::<()>, headers)
224 .await
225 }
226
227 #[allow(dead_code)]
229 pub(crate) async fn post_with_query<
230 T: DeserializeOwned,
231 B: serde::Serialize,
232 Q: serde::Serialize + ?Sized,
233 >(
234 &self,
235 route: &str,
236 body: B,
237 query: &Q,
238 ) -> Result<T, HttpError> {
239 self.request_with_query(Method::POST, route, query, Some(body), None)
240 .await
241 }
242
243 async fn request<T: DeserializeOwned, B: serde::Serialize>(
245 &self,
246 method: Method,
247 route: &str,
248 body: Option<B>,
249 headers: Option<HeaderMap>,
250 ) -> Result<T, HttpError> {
251 self.request_with_query(method, route, &(), body, headers)
252 .await
253 }
254
255 async fn request_with_query<
257 T: DeserializeOwned,
258 Q: serde::Serialize + ?Sized,
259 B: serde::Serialize,
260 >(
261 &self,
262 method: Method,
263 route: &str,
264 query: &Q,
265 body: Option<B>,
266 headers: Option<HeaderMap>,
267 ) -> Result<T, HttpError> {
268 let url = format!("{}{}", API_BASE, route);
269
270 self.rate_limiter.acquire(route).await?;
272
273 let mut request = self.client.request(method.clone(), &url);
275
276 request = request.query(query);
280
281 if let Some(headers) = headers {
282 request = request.headers(headers);
283 }
284
285 if let Some(ref body) = body {
286 let body_bytes = simd_json::to_vec(body).map_err(|e| HttpError::Discord {
287 code: 0,
288 message: format!("Serialization error: {}", e),
289 })?;
290 request = request.body(body_bytes);
291 }
292
293 debug!(method = %method, route = %route, "Making request");
294
295 let response = request.send().await?;
297
298 self.handle_response(route, response).await
300 }
301
302 async fn handle_response<T: DeserializeOwned>(
304 &self,
305 route: &str,
306 response: Response,
307 ) -> Result<T, HttpError> {
308 let status = response.status();
309
310 if let Some(remaining) = response
312 .headers()
313 .get("x-ratelimit-remaining")
314 .and_then(|h| h.to_str().ok())
315 .and_then(|s| s.parse().ok())
316 {
317 let reset_after = response
318 .headers()
319 .get("x-ratelimit-reset-after")
320 .and_then(|h| h.to_str().ok())
321 .and_then(|s| s.parse::<f64>().ok())
322 .map(|f| (f * 1000.0) as u64)
323 .unwrap_or(1000);
324
325 self.rate_limiter.update(route, remaining, reset_after);
326 }
327
328 match status {
330 StatusCode::OK | StatusCode::CREATED | StatusCode::NO_CONTENT => {
331 let bytes = response.bytes().await?;
332 if bytes.is_empty() {
333 RESPONSE_BUFFER.with(|buf_cell| {
341 let mut buf = buf_cell.borrow_mut();
342 buf.clear();
343 buf.extend_from_slice(b"null");
344 simd_json::from_slice(&mut buf).map_err(|e| HttpError::Discord {
345 code: 0,
346 message: e.to_string(),
347 })
348 })
349 } else {
350 RESPONSE_BUFFER.with(|buf_cell| {
351 let mut buf = buf_cell.borrow_mut();
352
353 if buf.capacity() > 10 * 1024 * 1024 {
355 buf.shrink_to(1024 * 1024); }
357
358 buf.clear();
359 buf.extend_from_slice(&bytes);
360 simd_json::from_slice(&mut buf).map_err(|e| HttpError::Discord {
362 code: 0,
363 message: e.to_string(),
364 })
365 })
366 }
367 }
368 StatusCode::TOO_MANY_REQUESTS => {
369 let bytes = response.bytes().await?;
370 let body: simd_json::OwnedValue = RESPONSE_BUFFER.with(|buf_cell| {
371 let mut buf = buf_cell.borrow_mut();
372 buf.clear();
373 buf.extend_from_slice(&bytes);
374 simd_json::from_slice(&mut buf).map_err(|e| HttpError::Discord {
375 code: 0,
376 message: e.to_string(),
377 })
378 })?;
379
380 let retry_after = body
381 .get("retry_after")
382 .and_then(|v| v.as_f64())
383 .map(|f| (f * 1000.0) as u64)
384 .unwrap_or(5000);
385
386 let global = body
387 .get("global")
388 .and_then(|v| v.as_bool())
389 .unwrap_or(false);
390
391 if global {
392 warn!(retry_after_ms = retry_after, "Global rate limit hit");
393 self.rate_limiter.set_global(retry_after);
394 }
395
396 Err(HttpError::RateLimited {
397 retry_after_ms: retry_after,
398 global,
399 })
400 }
401 StatusCode::UNAUTHORIZED => Err(HttpError::Unauthorized),
402 StatusCode::FORBIDDEN => Err(HttpError::Forbidden),
403 StatusCode::NOT_FOUND => Err(HttpError::NotFound),
404 _ if status.is_server_error() => Err(HttpError::ServerError(status.as_u16())),
405 _ => {
406 let bytes = response.bytes().await?;
407 let error: DiscordError = RESPONSE_BUFFER.with(|buf_cell| {
408 let mut buf = buf_cell.borrow_mut();
409 buf.clear();
410 buf.extend_from_slice(&bytes);
411 simd_json::from_slice(&mut buf).map_err(|e| HttpError::Discord {
412 code: 0,
413 message: e.to_string(),
414 })
415 })?;
416
417 Err(HttpError::Discord {
418 code: error.code,
419 message: error.message,
420 })
421 }
422 }
423 }
424 pub async fn create_global_application_command(
430 &self,
431 application_id: titanium_model::Snowflake,
432 command: &titanium_model::ApplicationCommand,
433 ) -> Result<titanium_model::ApplicationCommand, HttpError> {
434 let route = format!("/applications/{}/commands", application_id);
435 self.post(&route, command).await
436 }
437
438 pub async fn create_message(
444 &self,
445 channel_id: titanium_model::Snowflake,
446 content: &titanium_model::CreateMessage<'_>,
447 ) -> Result<titanium_model::Message<'static>, HttpError> {
448 let route = format!("/channels/{}/messages", channel_id);
449 self.post(&route, content).await
450 }
451}
452
453#[cfg(test)]
454mod tests {
455 use super::*;
456
457 #[test]
458 fn test_client_creation() {
459 let client = HttpClient::new("test_token");
460 assert!(client.is_ok());
461 }
462}