1use crate::error::{DiscordError, HttpError};
4use crate::ratelimit::RateLimiter;
5use crate::routes::{CurrentApplication, CurrentUser, GatewayBotResponse};
6
7use reqwest::header::{HeaderMap, HeaderValue, AUTHORIZATION, CONTENT_TYPE, USER_AGENT};
8use reqwest::{Client, Method, Response, StatusCode};
9use serde::de::DeserializeOwned;
10use simd_json::prelude::*;
11use std::sync::Arc;
12use tracing::{debug, warn};
13
14const API_BASE: &str = "https://discord.com/api/v10";
16
17const USER_AGENT_VALUE: &str = concat!(
19 "DiscordBot (https://github.com/Sh4dowNotFound/titanium-rs, ",
20 env!("CARGO_PKG_VERSION"),
21 ")"
22);
23
24pub struct HttpClient {
26 client: Client,
28 token: String,
30 rate_limiter: Arc<RateLimiter>,
32}
33
34thread_local! {
35 static RESPONSE_BUFFER: std::cell::RefCell<Vec<u8>> = std::cell::RefCell::new(Vec::with_capacity(32 * 1024));
38}
39
40impl HttpClient {
41 pub fn new(token: impl Into<String>) -> Result<Self, HttpError> {
43 let token = token.into();
44
45 let mut headers = HeaderMap::new();
46 headers.insert(
47 AUTHORIZATION,
48 HeaderValue::from_str(&format!("Bot {}", token))
49 .map_err(|_| HttpError::Unauthorized)?,
50 );
51 headers.insert(USER_AGENT, HeaderValue::from_static(USER_AGENT_VALUE));
52 headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
53
54 let client = Client::builder().default_headers(headers).build()?;
55
56 Ok(Self {
57 client,
58 token,
59 rate_limiter: Arc::new(RateLimiter::new()),
60 })
61 }
62
63 pub fn token(&self) -> &str {
65 &self.token
66 }
67
68 pub async fn get_gateway_bot(&self) -> Result<GatewayBotResponse, HttpError> {
77 self.get("/gateway/bot").await
78 }
79
80 pub async fn get_current_user(&self) -> Result<CurrentUser, HttpError> {
86 self.get("/users/@me").await
87 }
88
89 pub async fn get_current_application(&self) -> Result<CurrentApplication, HttpError> {
91 self.get("/applications/@me").await
92 }
93
94 pub(crate) async fn get_with_query<T: DeserializeOwned, Q: serde::Serialize + ?Sized>(
100 &self,
101 route: &str,
102 query: &Q,
103 ) -> Result<T, HttpError> {
104 self.request_with_query(Method::GET, route, query, None::<()>, None)
105 .await
106 }
107
108 pub(crate) async fn get<T: DeserializeOwned>(&self, route: &str) -> Result<T, HttpError> {
110 self.request(Method::GET, route, None::<()>, None).await
111 }
112
113 #[allow(dead_code)]
115 pub(crate) async fn post<T: DeserializeOwned, B: serde::Serialize>(
116 &self,
117 route: &str,
118 body: B,
119 ) -> Result<T, HttpError> {
120 self.request(Method::POST, route, Some(body), None).await
121 }
122
123 #[allow(dead_code)]
125 pub(crate) async fn put<T: DeserializeOwned, B: serde::Serialize>(
126 &self,
127 route: &str,
128 body: Option<B>,
129 ) -> Result<T, HttpError> {
130 self.request(Method::PUT, route, body, None).await
131 }
132
133 pub(crate) async fn put_with_headers<T: DeserializeOwned, B: serde::Serialize>(
135 &self,
136 route: &str,
137 body: Option<B>,
138 headers: Option<HeaderMap>,
139 ) -> Result<T, HttpError> {
140 self.request(Method::PUT, route, body, headers).await
141 }
142
143 #[allow(dead_code)]
145 pub(crate) async fn patch<T: DeserializeOwned, B: serde::Serialize>(
146 &self,
147 route: &str,
148 body: B,
149 ) -> Result<T, HttpError> {
150 self.request(Method::PATCH, route, Some(body), None).await
151 }
152
153 #[allow(dead_code)]
155 pub(crate) async fn delete<T: DeserializeOwned>(&self, route: &str) -> Result<T, HttpError> {
156 self.request(Method::DELETE, route, None::<()>, None).await
157 }
158
159 pub(crate) async fn delete_with_headers<T: DeserializeOwned>(
161 &self,
162 route: &str,
163 headers: Option<HeaderMap>,
164 ) -> Result<T, HttpError> {
165 self.request(Method::DELETE, route, None::<()>, headers)
166 .await
167 }
168
169 #[allow(dead_code)]
171 pub(crate) async fn post_with_query<
172 T: DeserializeOwned,
173 B: serde::Serialize,
174 Q: serde::Serialize + ?Sized,
175 >(
176 &self,
177 route: &str,
178 body: B,
179 query: &Q,
180 ) -> Result<T, HttpError> {
181 self.request_with_query(Method::POST, route, query, Some(body), None)
182 .await
183 }
184
185 async fn request<T: DeserializeOwned, B: serde::Serialize>(
187 &self,
188 method: Method,
189 route: &str,
190 body: Option<B>,
191 headers: Option<HeaderMap>,
192 ) -> Result<T, HttpError> {
193 self.request_with_query(method, route, &(), body, headers)
194 .await
195 }
196
197 async fn request_with_query<
199 T: DeserializeOwned,
200 Q: serde::Serialize + ?Sized,
201 B: serde::Serialize,
202 >(
203 &self,
204 method: Method,
205 route: &str,
206 query: &Q,
207 body: Option<B>,
208 headers: Option<HeaderMap>,
209 ) -> Result<T, HttpError> {
210 let url = format!("{}{}", API_BASE, route);
211
212 self.rate_limiter.acquire(route).await;
214
215 let mut request = self.client.request(method.clone(), &url);
217
218 request = request.query(query);
222
223 if let Some(headers) = headers {
224 request = request.headers(headers);
225 }
226
227 if let Some(ref body) = body {
228 let body_bytes = simd_json::to_vec(body).map_err(|e| HttpError::Discord {
229 code: 0,
230 message: format!("Serialization error: {}", e),
231 })?;
232 request = request.body(body_bytes);
233 }
234
235 debug!(method = %method, route = %route, "Making request");
236
237 let response = request.send().await?;
239
240 self.handle_response(route, response).await
242 }
243
244 async fn handle_response<T: DeserializeOwned>(
246 &self,
247 route: &str,
248 response: Response,
249 ) -> Result<T, HttpError> {
250 let status = response.status();
251
252 if let Some(remaining) = response
254 .headers()
255 .get("x-ratelimit-remaining")
256 .and_then(|h| h.to_str().ok())
257 .and_then(|s| s.parse().ok())
258 {
259 let reset_after = response
260 .headers()
261 .get("x-ratelimit-reset-after")
262 .and_then(|h| h.to_str().ok())
263 .and_then(|s| s.parse::<f64>().ok())
264 .map(|f| (f * 1000.0) as u64)
265 .unwrap_or(1000);
266
267 self.rate_limiter.update(route, remaining, reset_after);
268 }
269
270 match status {
272 StatusCode::OK | StatusCode::CREATED | StatusCode::NO_CONTENT => {
273 let bytes = response.bytes().await?;
274 if bytes.is_empty() {
275 RESPONSE_BUFFER.with(|buf_cell| {
283 let mut buf = buf_cell.borrow_mut();
284 buf.clear();
285 buf.extend_from_slice(b"null");
286 simd_json::from_slice(&mut buf).map_err(|e| HttpError::Discord {
287 code: 0,
288 message: e.to_string(),
289 })
290 })
291 } else {
292 RESPONSE_BUFFER.with(|buf_cell| {
293 let mut buf = buf_cell.borrow_mut();
294 buf.clear();
295 buf.extend_from_slice(&bytes);
296 simd_json::from_slice(&mut buf).map_err(|e| HttpError::Discord {
298 code: 0,
299 message: e.to_string(),
300 })
301 })
302 }
303 }
304 StatusCode::TOO_MANY_REQUESTS => {
305 let bytes = response.bytes().await?;
306 let body: simd_json::OwnedValue = RESPONSE_BUFFER.with(|buf_cell| {
307 let mut buf = buf_cell.borrow_mut();
308 buf.clear();
309 buf.extend_from_slice(&bytes);
310 simd_json::from_slice(&mut buf).map_err(|e| HttpError::Discord {
311 code: 0,
312 message: e.to_string(),
313 })
314 })?;
315
316 let retry_after = body
317 .get("retry_after")
318 .and_then(|v| v.as_f64())
319 .map(|f| (f * 1000.0) as u64)
320 .unwrap_or(5000);
321
322 let global = body
323 .get("global")
324 .and_then(|v| v.as_bool())
325 .unwrap_or(false);
326
327 if global {
328 warn!(retry_after_ms = retry_after, "Global rate limit hit");
329 self.rate_limiter.set_global(retry_after);
330 }
331
332 Err(HttpError::RateLimited {
333 retry_after_ms: retry_after,
334 global,
335 })
336 }
337 StatusCode::UNAUTHORIZED => Err(HttpError::Unauthorized),
338 StatusCode::FORBIDDEN => Err(HttpError::Forbidden),
339 StatusCode::NOT_FOUND => Err(HttpError::NotFound),
340 _ if status.is_server_error() => Err(HttpError::ServerError(status.as_u16())),
341 _ => {
342 let bytes = response.bytes().await?;
343 let error: DiscordError = RESPONSE_BUFFER.with(|buf_cell| {
344 let mut buf = buf_cell.borrow_mut();
345 buf.clear();
346 buf.extend_from_slice(&bytes);
347 simd_json::from_slice(&mut buf).map_err(|e| HttpError::Discord {
348 code: 0,
349 message: e.to_string(),
350 })
351 })?;
352
353 Err(HttpError::Discord {
354 code: error.code,
355 message: error.message,
356 })
357 }
358 }
359 }
360 pub async fn create_global_application_command(
366 &self,
367 application_id: titanium_model::Snowflake,
368 command: &titanium_model::ApplicationCommand,
369 ) -> Result<titanium_model::ApplicationCommand, HttpError> {
370 let route = format!("/applications/{}/commands", application_id);
371 self.post(&route, command).await
372 }
373
374 pub async fn create_message(
380 &self,
381 channel_id: titanium_model::Snowflake,
382 content: &titanium_model::CreateMessage<'_>,
383 ) -> Result<titanium_model::Message<'static>, HttpError> {
384 let route = format!("/channels/{}/messages", channel_id);
385 self.post(&route, content).await
386 }
387}
388
389#[cfg(test)]
390mod tests {
391 use super::*;
392
393 #[test]
394 fn test_client_creation() {
395 let client = HttpClient::new("test_token");
396 assert!(client.is_ok());
397 }
398}