titanium_http/
client.rs

1//! Discord HTTP client implementation.
2
3use crate::error::{DiscordError, HttpError};
4use crate::ratelimit::RateLimiter;
5use crate::routes::{CurrentApplication, CurrentUser, GatewayBotResponse};
6
7use reqwest::header::{HeaderMap, HeaderValue, AUTHORIZATION, CONTENT_TYPE, USER_AGENT};
8use reqwest::{Client, Method, Response, StatusCode};
9use serde::de::DeserializeOwned;
10use simd_json::prelude::*;
11use std::sync::Arc;
12use tracing::{debug, warn};
13
14/// Discord API base URL.
15const API_BASE: &str = "https://discord.com/api/v10";
16
17/// User agent for requests.
18const USER_AGENT_VALUE: &str = concat!(
19    "DiscordBot (https://github.com/Sh4dowNotFound/titanium-rs, ",
20    env!("CARGO_PKG_VERSION"),
21    ")"
22);
23
24/// Discord REST API client.
25pub struct HttpClient {
26    /// Inner HTTP client.
27    client: Client,
28    /// Bot token.
29    token: String,
30    /// Rate limiter.
31    rate_limiter: Arc<RateLimiter>,
32}
33
34thread_local! {
35    /// Per-thread scratch buffer for HTTP responses to avoid allocations.
36    /// Default 32KB is enough for almost all Discord JSON responses.
37    static RESPONSE_BUFFER: std::cell::RefCell<Vec<u8>> = std::cell::RefCell::new(Vec::with_capacity(32 * 1024));
38}
39
40impl HttpClient {
41    /// Create a new HTTP client with the given bot token.
42    pub fn new(token: impl Into<String>) -> Result<Self, HttpError> {
43        let token = token.into();
44
45        let mut headers = HeaderMap::new();
46        headers.insert(
47            AUTHORIZATION,
48            HeaderValue::from_str(&format!("Bot {}", token))
49                .map_err(|_| HttpError::Unauthorized)?,
50        );
51        headers.insert(USER_AGENT, HeaderValue::from_static(USER_AGENT_VALUE));
52        headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
53
54        let client = Client::builder().default_headers(headers).build()?;
55
56        Ok(Self {
57            client,
58            token,
59            rate_limiter: Arc::new(RateLimiter::new()),
60        })
61    }
62
63    /// Get the bot token.
64    pub fn token(&self) -> &str {
65        &self.token
66    }
67
68    // =========================================================================
69    // Gateway Endpoints
70    // =========================================================================
71
72    /// Get gateway bot information.
73    ///
74    /// Returns the recommended number of shards, gateway URL, and session limits.
75    /// This is essential for large bots to determine sharding configuration.
76    pub async fn get_gateway_bot(&self) -> Result<GatewayBotResponse, HttpError> {
77        self.get("/gateway/bot").await
78    }
79
80    // =========================================================================
81    // User Endpoints
82    // =========================================================================
83
84    /// Get the current bot user.
85    pub async fn get_current_user(&self) -> Result<CurrentUser, HttpError> {
86        self.get("/users/@me").await
87    }
88
89    /// Get the current application.
90    pub async fn get_current_application(&self) -> Result<CurrentApplication, HttpError> {
91        self.get("/applications/@me").await
92    }
93
94    // =========================================================================
95    // Internal Request Methods
96    // =========================================================================
97
98    /// Make a GET request with query parameters.
99    pub(crate) async fn get_with_query<T: DeserializeOwned, Q: serde::Serialize + ?Sized>(
100        &self,
101        route: &str,
102        query: &Q,
103    ) -> Result<T, HttpError> {
104        self.request_with_query(Method::GET, route, query, None::<()>, None)
105            .await
106    }
107
108    /// Make a GET request.
109    pub(crate) async fn get<T: DeserializeOwned>(&self, route: &str) -> Result<T, HttpError> {
110        self.request(Method::GET, route, None::<()>, None).await
111    }
112
113    /// Make a POST request.
114    #[allow(dead_code)]
115    pub(crate) async fn post<T: DeserializeOwned, B: serde::Serialize>(
116        &self,
117        route: &str,
118        body: B,
119    ) -> Result<T, HttpError> {
120        self.request(Method::POST, route, Some(body), None).await
121    }
122
123    /// Make a PUT request.
124    #[allow(dead_code)]
125    pub(crate) async fn put<T: DeserializeOwned, B: serde::Serialize>(
126        &self,
127        route: &str,
128        body: Option<B>,
129    ) -> Result<T, HttpError> {
130        self.request(Method::PUT, route, body, None).await
131    }
132
133    /// Make a PUT request with headers (for bans etc).
134    pub(crate) async fn put_with_headers<T: DeserializeOwned, B: serde::Serialize>(
135        &self,
136        route: &str,
137        body: Option<B>,
138        headers: Option<HeaderMap>,
139    ) -> Result<T, HttpError> {
140        self.request(Method::PUT, route, body, headers).await
141    }
142
143    /// Make a PATCH request.
144    #[allow(dead_code)]
145    pub(crate) async fn patch<T: DeserializeOwned, B: serde::Serialize>(
146        &self,
147        route: &str,
148        body: B,
149    ) -> Result<T, HttpError> {
150        self.request(Method::PATCH, route, Some(body), None).await
151    }
152
153    /// Make a DELETE request.
154    #[allow(dead_code)]
155    pub(crate) async fn delete<T: DeserializeOwned>(&self, route: &str) -> Result<T, HttpError> {
156        self.request(Method::DELETE, route, None::<()>, None).await
157    }
158
159    /// Make a DELETE request with headers.
160    pub(crate) async fn delete_with_headers<T: DeserializeOwned>(
161        &self,
162        route: &str,
163        headers: Option<HeaderMap>,
164    ) -> Result<T, HttpError> {
165        self.request(Method::DELETE, route, None::<()>, headers)
166            .await
167    }
168
169    /// Make a POST request with query parameters.
170    #[allow(dead_code)]
171    pub(crate) async fn post_with_query<
172        T: DeserializeOwned,
173        B: serde::Serialize,
174        Q: serde::Serialize + ?Sized,
175    >(
176        &self,
177        route: &str,
178        body: B,
179        query: &Q,
180    ) -> Result<T, HttpError> {
181        self.request_with_query(Method::POST, route, query, Some(body), None)
182            .await
183    }
184
185    /// Make an HTTP request with rate limit handling.
186    async fn request<T: DeserializeOwned, B: serde::Serialize>(
187        &self,
188        method: Method,
189        route: &str,
190        body: Option<B>,
191        headers: Option<HeaderMap>,
192    ) -> Result<T, HttpError> {
193        self.request_with_query(method, route, &(), body, headers)
194            .await
195    }
196
197    /// Make an HTTP request with query params, rate limit handling, and headers.
198    async fn request_with_query<
199        T: DeserializeOwned,
200        Q: serde::Serialize + ?Sized,
201        B: serde::Serialize,
202    >(
203        &self,
204        method: Method,
205        route: &str,
206        query: &Q,
207        body: Option<B>,
208        headers: Option<HeaderMap>,
209    ) -> Result<T, HttpError> {
210        let url = format!("{}{}", API_BASE, route);
211
212        // Acquire rate limit permit
213        self.rate_limiter.acquire(route).await;
214
215        // Build request
216        let mut request = self.client.request(method.clone(), &url);
217
218        // Add query params
219        // reqwest::RequestBuilder::query handles generic Serialize.
220        // Unit () serializes to empty/null which is ignored by reqwest for query params.
221        request = request.query(query);
222
223        if let Some(headers) = headers {
224            request = request.headers(headers);
225        }
226
227        if let Some(ref body) = body {
228            let body_bytes = simd_json::to_vec(body).map_err(|e| HttpError::Discord {
229                code: 0,
230                message: format!("Serialization error: {}", e),
231            })?;
232            request = request.body(body_bytes);
233        }
234
235        debug!(method = %method, route = %route, "Making request");
236
237        // Send request
238        let response = request.send().await?;
239
240        // Handle response
241        self.handle_response(route, response).await
242    }
243
244    /// Handle an HTTP response.
245    async fn handle_response<T: DeserializeOwned>(
246        &self,
247        route: &str,
248        response: Response,
249    ) -> Result<T, HttpError> {
250        let status = response.status();
251
252        // Extract rate limit headers
253        if let Some(remaining) = response
254            .headers()
255            .get("x-ratelimit-remaining")
256            .and_then(|h| h.to_str().ok())
257            .and_then(|s| s.parse().ok())
258        {
259            let reset_after = response
260                .headers()
261                .get("x-ratelimit-reset-after")
262                .and_then(|h| h.to_str().ok())
263                .and_then(|s| s.parse::<f64>().ok())
264                .map(|f| (f * 1000.0) as u64)
265                .unwrap_or(1000);
266
267            self.rate_limiter.update(route, remaining, reset_after);
268        }
269
270        // Handle errors
271        match status {
272            StatusCode::OK | StatusCode::CREATED | StatusCode::NO_CONTENT => {
273                let bytes = response.bytes().await?;
274                if bytes.is_empty() {
275                    // For NO_CONTENT or empty responses
276                    // simd-json might choke on empty, but "null" is better?
277                    // T might be () or Option<T>.
278                    // Let's assume empty body means "null" or default.
279                    // But if T is a struct, from_slice(b"null") might fail if not Option.
280                    // Actually, existing code used b"null".to_vec();
281
282                    RESPONSE_BUFFER.with(|buf_cell| {
283                        let mut buf = buf_cell.borrow_mut();
284                        buf.clear();
285                        buf.extend_from_slice(b"null");
286                        simd_json::from_slice(&mut buf).map_err(|e| HttpError::Discord {
287                            code: 0,
288                            message: e.to_string(),
289                        })
290                    })
291                } else {
292                    RESPONSE_BUFFER.with(|buf_cell| {
293                        let mut buf = buf_cell.borrow_mut();
294                        buf.clear();
295                        buf.extend_from_slice(&bytes);
296                        // simd-json parses in-place
297                        simd_json::from_slice(&mut buf).map_err(|e| HttpError::Discord {
298                            code: 0,
299                            message: e.to_string(),
300                        })
301                    })
302                }
303            }
304            StatusCode::TOO_MANY_REQUESTS => {
305                let bytes = response.bytes().await?;
306                let body: simd_json::OwnedValue = RESPONSE_BUFFER.with(|buf_cell| {
307                    let mut buf = buf_cell.borrow_mut();
308                    buf.clear();
309                    buf.extend_from_slice(&bytes);
310                    simd_json::from_slice(&mut buf).map_err(|e| HttpError::Discord {
311                        code: 0,
312                        message: e.to_string(),
313                    })
314                })?;
315
316                let retry_after = body
317                    .get("retry_after")
318                    .and_then(|v| v.as_f64())
319                    .map(|f| (f * 1000.0) as u64)
320                    .unwrap_or(5000);
321
322                let global = body
323                    .get("global")
324                    .and_then(|v| v.as_bool())
325                    .unwrap_or(false);
326
327                if global {
328                    warn!(retry_after_ms = retry_after, "Global rate limit hit");
329                    self.rate_limiter.set_global(retry_after);
330                }
331
332                Err(HttpError::RateLimited {
333                    retry_after_ms: retry_after,
334                    global,
335                })
336            }
337            StatusCode::UNAUTHORIZED => Err(HttpError::Unauthorized),
338            StatusCode::FORBIDDEN => Err(HttpError::Forbidden),
339            StatusCode::NOT_FOUND => Err(HttpError::NotFound),
340            _ if status.is_server_error() => Err(HttpError::ServerError(status.as_u16())),
341            _ => {
342                let bytes = response.bytes().await?;
343                let error: DiscordError = RESPONSE_BUFFER.with(|buf_cell| {
344                    let mut buf = buf_cell.borrow_mut();
345                    buf.clear();
346                    buf.extend_from_slice(&bytes);
347                    simd_json::from_slice(&mut buf).map_err(|e| HttpError::Discord {
348                        code: 0,
349                        message: e.to_string(),
350                    })
351                })?;
352
353                Err(HttpError::Discord {
354                    code: error.code,
355                    message: error.message,
356                })
357            }
358        }
359    }
360    // =========================================================================
361    // Interaction Endpoints
362    // =========================================================================
363
364    /// Create a global application command.
365    pub async fn create_global_application_command(
366        &self,
367        application_id: titanium_model::Snowflake,
368        command: &titanium_model::ApplicationCommand,
369    ) -> Result<titanium_model::ApplicationCommand, HttpError> {
370        let route = format!("/applications/{}/commands", application_id);
371        self.post(&route, command).await
372    }
373
374    // =========================================================================
375    // Channel Endpoints
376    // =========================================================================
377
378    /// Create a message in a channel.
379    pub async fn create_message(
380        &self,
381        channel_id: titanium_model::Snowflake,
382        content: &titanium_model::CreateMessage<'_>,
383    ) -> Result<titanium_model::Message<'static>, HttpError> {
384        let route = format!("/channels/{}/messages", channel_id);
385        self.post(&route, content).await
386    }
387}
388
389#[cfg(test)]
390mod tests {
391    use super::*;
392
393    #[test]
394    fn test_client_creation() {
395        let client = HttpClient::new("test_token");
396        assert!(client.is_ok());
397    }
398}