titanium_http/
client.rs

1//! Discord REST API HTTP client.
2//!
3//! This module provides [`HttpClient`] for making requests to Discord's REST API.
4//!
5//! # Features
6//!
7//! - **Automatic Rate Limiting**: Respects Discord's rate limits with per-route tracking
8//! - **SIMD-Accelerated JSON**: Uses `simd-json` for fast parsing on supported CPUs
9//! - **Connection Pooling**: Reuses HTTP/2 connections for efficiency
10//! - **Thread-Local Buffers**: Avoids allocations on hot paths
11//!
12//! # Performance
13//!
14//! The client is optimized for high-throughput scenarios:
15//! - 30-second request timeout, 10-second connect timeout
16//! - HTTP/2 with adaptive window sizing
17//! - TCP_NODELAY for reduced latency
18//!
19//! # Example
20//!
21//! ```no_run
22//! # use titanium_http::HttpClient;
23//! # async fn example() -> Result<(), Box<dyn std::error::Error>> {
24//! let http = HttpClient::new("your-bot-token")?;
25//!
26//! // Get gateway information
27//! let gateway = http.get_gateway_bot().await?;
28//! println!("Recommended shards: {}", gateway.shards);
29//! # Ok(())
30//! # }
31//! ```
32
33use crate::error::{DiscordError, HttpError};
34use crate::ratelimit::RateLimiter;
35use crate::routes::{CurrentApplication, CurrentUser, GatewayBotResponse};
36
37use reqwest::header::{HeaderMap, HeaderValue, AUTHORIZATION, CONTENT_TYPE, USER_AGENT};
38use reqwest::{Client, Method, Response, StatusCode};
39use serde::de::DeserializeOwned;
40use simd_json::prelude::*;
41use std::sync::Arc;
42use tracing::{debug, warn};
43
44/// Discord API base URL (v10).
45const API_BASE: &str = "https://discord.com/api/v10";
46
47/// User agent sent with all requests (required by Discord).
48const USER_AGENT_VALUE: &str = concat!(
49    "DiscordBot (https://github.com/Sh4dowNotFound/titanium-rs, ",
50    env!("CARGO_PKG_VERSION"),
51    ")"
52);
53
54/// Discord REST API client.
55///
56/// Handles all HTTP communication with Discord, including:
57/// - Authentication (Bot token)
58/// - Rate limit tracking and waiting
59/// - JSON serialization/deserialization
60/// - Error handling
61///
62/// # Thread Safety
63///
64/// `HttpClient` can be cloned and shared across tasks. The inner
65/// `reqwest::Client` uses connection pooling internally.
66pub struct HttpClient {
67    /// Inner reqwest HTTP client with connection pooling.
68    client: Client,
69    /// Bot token for authentication.
70    token: String,
71    /// Rate limiter tracking per-route and global limits.
72    rate_limiter: Arc<RateLimiter>,
73}
74
75thread_local! {
76    /// Per-thread scratch buffer for HTTP responses to avoid allocations.
77    /// 32KB is sufficient for most Discord API responses.
78    static RESPONSE_BUFFER: std::cell::RefCell<Vec<u8>> = std::cell::RefCell::new(Vec::with_capacity(32 * 1024));
79}
80
81impl HttpClient {
82    /// Create a new HTTP client with the given bot token.
83    pub fn new(token: impl Into<String>) -> Result<Self, HttpError> {
84        let token = token.into();
85
86        let mut headers = HeaderMap::new();
87        headers.insert(
88            AUTHORIZATION,
89            HeaderValue::from_str(&format!("Bot {}", token))
90                .map_err(|_| HttpError::Unauthorized)?,
91        );
92        headers.insert(USER_AGENT, HeaderValue::from_static(USER_AGENT_VALUE));
93        headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
94
95        let client = Client::builder()
96            .default_headers(headers)
97            .http2_adaptive_window(true)
98            .tcp_nodelay(true)
99            .timeout(std::time::Duration::from_secs(30))
100            .connect_timeout(std::time::Duration::from_secs(10))
101            .build()?;
102
103        Ok(Self {
104            client,
105            token,
106            rate_limiter: Arc::new(RateLimiter::new()),
107        })
108    }
109
110    /// Get the bot token.
111    pub fn token(&self) -> &str {
112        &self.token
113    }
114
115    // =========================================================================
116    // Gateway Endpoints
117    // =========================================================================
118
119    /// Get gateway bot information.
120    ///
121    /// Returns the recommended number of shards, gateway URL, and session limits.
122    /// This is essential for large bots to determine sharding configuration.
123    pub async fn get_gateway_bot(&self) -> Result<GatewayBotResponse, HttpError> {
124        self.get("/gateway/bot").await
125    }
126
127    // =========================================================================
128    // User Endpoints
129    // =========================================================================
130
131    /// Get the current bot user.
132    pub async fn get_current_user(&self) -> Result<CurrentUser, HttpError> {
133        self.get("/users/@me").await
134    }
135
136    /// Get the current application.
137    pub async fn get_current_application(&self) -> Result<CurrentApplication, HttpError> {
138        self.get("/applications/@me").await
139    }
140
141    // =========================================================================
142    // Internal Request Methods
143    // =========================================================================
144
145    /// Make a GET request with query parameters.
146    pub(crate) async fn get_with_query<T: DeserializeOwned, Q: serde::Serialize + ?Sized>(
147        &self,
148        route: &str,
149        query: &Q,
150    ) -> Result<T, HttpError> {
151        self.request_with_query(Method::GET, route, query, None::<()>, None)
152            .await
153    }
154
155    /// Make a GET request.
156    pub(crate) async fn get<T: DeserializeOwned>(&self, route: &str) -> Result<T, HttpError> {
157        self.request(Method::GET, route, None::<()>, None).await
158    }
159
160    /// Make a POST request.
161    #[allow(dead_code)]
162    pub(crate) async fn post<T: DeserializeOwned, B: serde::Serialize>(
163        &self,
164        route: &str,
165        body: B,
166    ) -> Result<T, HttpError> {
167        self.request(Method::POST, route, Some(body), None).await
168    }
169
170    /// Make a POST request that returns no body (e.g., bulk delete).
171    #[allow(dead_code)]
172    pub(crate) async fn post_no_response<B: serde::Serialize>(
173        &self,
174        route: &str,
175        body: B,
176    ) -> Result<(), HttpError> {
177        let _: serde::de::IgnoredAny = self.request(Method::POST, route, Some(body), None).await?;
178        Ok(())
179    }
180
181    /// Make a PUT request.
182    #[allow(dead_code)]
183    pub(crate) async fn put<T: DeserializeOwned, B: serde::Serialize>(
184        &self,
185        route: &str,
186        body: Option<B>,
187    ) -> Result<T, HttpError> {
188        self.request(Method::PUT, route, body, None).await
189    }
190
191    /// Make a PUT request with headers (for bans etc).
192    pub(crate) async fn put_with_headers<T: DeserializeOwned, B: serde::Serialize>(
193        &self,
194        route: &str,
195        body: Option<B>,
196        headers: Option<HeaderMap>,
197    ) -> Result<T, HttpError> {
198        self.request(Method::PUT, route, body, headers).await
199    }
200
201    /// Make a PATCH request.
202    #[allow(dead_code)]
203    pub(crate) async fn patch<T: DeserializeOwned, B: serde::Serialize>(
204        &self,
205        route: &str,
206        body: B,
207    ) -> Result<T, HttpError> {
208        self.request(Method::PATCH, route, Some(body), None).await
209    }
210
211    /// Make a DELETE request.
212    #[allow(dead_code)]
213    pub(crate) async fn delete<T: DeserializeOwned>(&self, route: &str) -> Result<T, HttpError> {
214        self.request(Method::DELETE, route, None::<()>, None).await
215    }
216
217    /// Make a DELETE request with headers.
218    pub(crate) async fn delete_with_headers<T: DeserializeOwned>(
219        &self,
220        route: &str,
221        headers: Option<HeaderMap>,
222    ) -> Result<T, HttpError> {
223        self.request(Method::DELETE, route, None::<()>, headers)
224            .await
225    }
226
227    /// Make a POST request with query parameters.
228    #[allow(dead_code)]
229    pub(crate) async fn post_with_query<
230        T: DeserializeOwned,
231        B: serde::Serialize,
232        Q: serde::Serialize + ?Sized,
233    >(
234        &self,
235        route: &str,
236        body: B,
237        query: &Q,
238    ) -> Result<T, HttpError> {
239        self.request_with_query(Method::POST, route, query, Some(body), None)
240            .await
241    }
242
243    /// Make an HTTP request with rate limit handling.
244    async fn request<T: DeserializeOwned, B: serde::Serialize>(
245        &self,
246        method: Method,
247        route: &str,
248        body: Option<B>,
249        headers: Option<HeaderMap>,
250    ) -> Result<T, HttpError> {
251        self.request_with_query(method, route, &(), body, headers)
252            .await
253    }
254
255    /// Make an HTTP request with query params, rate limit handling, and headers.
256    async fn request_with_query<
257        T: DeserializeOwned,
258        Q: serde::Serialize + ?Sized,
259        B: serde::Serialize,
260    >(
261        &self,
262        method: Method,
263        route: &str,
264        query: &Q,
265        body: Option<B>,
266        headers: Option<HeaderMap>,
267    ) -> Result<T, HttpError> {
268        let url = format!("{}{}", API_BASE, route);
269
270        // Acquire rate limit permit
271        self.rate_limiter.acquire(route).await?;
272
273        // Build request
274        let mut request = self.client.request(method.clone(), &url);
275
276        // Add query params
277        // reqwest::RequestBuilder::query handles generic Serialize.
278        // Unit () serializes to empty/null which is ignored by reqwest for query params.
279        request = request.query(query);
280
281        if let Some(headers) = headers {
282            request = request.headers(headers);
283        }
284
285        if let Some(ref body) = body {
286            let body_bytes = simd_json::to_vec(body).map_err(|e| HttpError::Discord {
287                code: 0,
288                message: format!("Serialization error: {}", e),
289            })?;
290            request = request.body(body_bytes);
291        }
292
293        debug!(method = %method, route = %route, "Making request");
294
295        // Send request
296        let response = request.send().await?;
297
298        // Handle response
299        self.handle_response(route, response).await
300    }
301
302    /// Handle an HTTP response.
303    async fn handle_response<T: DeserializeOwned>(
304        &self,
305        route: &str,
306        response: Response,
307    ) -> Result<T, HttpError> {
308        let status = response.status();
309
310        // Extract rate limit headers
311        if let Some(remaining) = response
312            .headers()
313            .get("x-ratelimit-remaining")
314            .and_then(|h| h.to_str().ok())
315            .and_then(|s| s.parse().ok())
316        {
317            let reset_after = response
318                .headers()
319                .get("x-ratelimit-reset-after")
320                .and_then(|h| h.to_str().ok())
321                .and_then(|s| s.parse::<f64>().ok())
322                .map(|f| (f * 1000.0) as u64)
323                .unwrap_or(1000);
324
325            self.rate_limiter.update(route, remaining, reset_after);
326        }
327
328        // Handle errors
329        match status {
330            StatusCode::OK | StatusCode::CREATED | StatusCode::NO_CONTENT => {
331                let bytes = response.bytes().await?;
332                if bytes.is_empty() {
333                    // For NO_CONTENT or empty responses
334                    // simd-json might choke on empty, but "null" is better?
335                    // T might be () or Option<T>.
336                    // Let's assume empty body means "null" or default.
337                    // But if T is a struct, from_slice(b"null") might fail if not Option.
338                    // Actually, existing code used b"null".to_vec();
339
340                    RESPONSE_BUFFER.with(|buf_cell| {
341                        let mut buf = buf_cell.borrow_mut();
342                        buf.clear();
343                        buf.extend_from_slice(b"null");
344                        simd_json::from_slice(&mut buf).map_err(|e| HttpError::Discord {
345                            code: 0,
346                            message: e.to_string(),
347                        })
348                    })
349                } else {
350                    RESPONSE_BUFFER.with(|buf_cell| {
351                        let mut buf = buf_cell.borrow_mut();
352
353                        // Memory management: Shrink buffer if it gets too large (> 10MB)
354                        if buf.capacity() > 10 * 1024 * 1024 {
355                            buf.shrink_to(1024 * 1024); // Shrink to 1MB
356                        }
357
358                        buf.clear();
359                        buf.extend_from_slice(&bytes);
360                        // simd-json parses in-place
361                        simd_json::from_slice(&mut buf).map_err(|e| HttpError::Discord {
362                            code: 0,
363                            message: e.to_string(),
364                        })
365                    })
366                }
367            }
368            StatusCode::TOO_MANY_REQUESTS => {
369                let bytes = response.bytes().await?;
370                let body: simd_json::OwnedValue = RESPONSE_BUFFER.with(|buf_cell| {
371                    let mut buf = buf_cell.borrow_mut();
372                    buf.clear();
373                    buf.extend_from_slice(&bytes);
374                    simd_json::from_slice(&mut buf).map_err(|e| HttpError::Discord {
375                        code: 0,
376                        message: e.to_string(),
377                    })
378                })?;
379
380                let retry_after = body
381                    .get("retry_after")
382                    .and_then(|v| v.as_f64())
383                    .map(|f| (f * 1000.0) as u64)
384                    .unwrap_or(5000);
385
386                let global = body
387                    .get("global")
388                    .and_then(|v| v.as_bool())
389                    .unwrap_or(false);
390
391                if global {
392                    warn!(retry_after_ms = retry_after, "Global rate limit hit");
393                    self.rate_limiter.set_global(retry_after);
394                }
395
396                Err(HttpError::RateLimited {
397                    retry_after_ms: retry_after,
398                    global,
399                })
400            }
401            StatusCode::UNAUTHORIZED => Err(HttpError::Unauthorized),
402            StatusCode::FORBIDDEN => Err(HttpError::Forbidden),
403            StatusCode::NOT_FOUND => Err(HttpError::NotFound),
404            _ if status.is_server_error() => Err(HttpError::ServerError(status.as_u16())),
405            _ => {
406                let bytes = response.bytes().await?;
407                let error: DiscordError = RESPONSE_BUFFER.with(|buf_cell| {
408                    let mut buf = buf_cell.borrow_mut();
409                    buf.clear();
410                    buf.extend_from_slice(&bytes);
411                    simd_json::from_slice(&mut buf).map_err(|e| HttpError::Discord {
412                        code: 0,
413                        message: e.to_string(),
414                    })
415                })?;
416
417                Err(HttpError::Discord {
418                    code: error.code,
419                    message: error.message,
420                })
421            }
422        }
423    }
424    // =========================================================================
425    // Interaction Endpoints
426    // =========================================================================
427
428    /// Create a global application command.
429    pub async fn create_global_application_command(
430        &self,
431        application_id: titanium_model::Snowflake,
432        command: &titanium_model::ApplicationCommand,
433    ) -> Result<titanium_model::ApplicationCommand, HttpError> {
434        let route = format!("/applications/{}/commands", application_id);
435        self.post(&route, command).await
436    }
437
438    // =========================================================================
439    // Channel Endpoints
440    // =========================================================================
441
442    /// Create a message in a channel.
443    pub async fn create_message(
444        &self,
445        channel_id: titanium_model::Snowflake,
446        content: &titanium_model::CreateMessage<'_>,
447    ) -> Result<titanium_model::Message<'static>, HttpError> {
448        let route = format!("/channels/{}/messages", channel_id);
449        self.post(&route, content).await
450    }
451}
452
453#[cfg(test)]
454mod tests {
455    use super::*;
456
457    #[test]
458    fn test_client_creation() {
459        let client = HttpClient::new("test_token");
460        assert!(client.is_ok());
461    }
462}