titanium_gateway/
ratelimit.rs

1//! Gateway rate limiting.
2//!
3//! Discord limits how quickly bots can identify on the Gateway.
4//! Large bots (150k+ guilds) get higher `max_concurrency` values.
5
6use std::sync::Arc;
7use std::time::Duration;
8use tokio::sync::Semaphore;
9use tokio::time::sleep;
10
11/// Rate limiter for Gateway identify operations.
12///
13/// Discord allows `max_concurrency` identify operations every 5 seconds.
14/// This rate limiter ensures we don't exceed this limit.
15pub struct IdentifyRateLimiter {
16    /// Semaphore with max_concurrency permits.
17    semaphore: Arc<Semaphore>,
18
19    /// Duration to hold the permit (5 seconds per Discord docs).
20    hold_duration: Duration,
21}
22
23impl IdentifyRateLimiter {
24    /// Create a new identify rate limiter.
25    ///
26    /// # Arguments
27    /// * `max_concurrency` - Maximum concurrent identifies (from /gateway/bot).
28    pub fn new(max_concurrency: usize) -> Self {
29        Self {
30            semaphore: Arc::new(Semaphore::new(max_concurrency)),
31            hold_duration: Duration::from_secs(5),
32        }
33    }
34
35    /// Acquire permission to send an Identify payload.
36    ///
37    /// This will block until a slot is available. The slot is automatically
38    /// released after 5 seconds.
39    pub async fn acquire(&self) {
40        // Acquire a permit
41        let permit = self
42            .semaphore
43            .clone()
44            .acquire_owned()
45            .await
46            .expect("semaphore should not be closed");
47
48        // Spawn a task to release the permit after hold_duration
49        let hold_duration = self.hold_duration;
50        tokio::spawn(async move {
51            sleep(hold_duration).await;
52            drop(permit);
53        });
54    }
55
56    /// Get the number of available permits.
57    pub fn available_permits(&self) -> usize {
58        self.semaphore.available_permits()
59    }
60}
61
62impl Default for IdentifyRateLimiter {
63    fn default() -> Self {
64        // Default max_concurrency is 1 for most bots
65        Self::new(1)
66    }
67}
68
69/// Calculate backoff duration with exponential increase.
70///
71/// # Arguments
72/// * `attempt` - Current attempt number (0-indexed).
73/// * `base_ms` - Base delay in milliseconds.
74/// * `max_ms` - Maximum delay in milliseconds.
75///
76/// # Returns
77/// Duration to wait before next retry.
78pub fn exponential_backoff(attempt: u32, base_ms: u64, max_ms: u64) -> Duration {
79    let delay_ms = base_ms.saturating_mul(2u64.saturating_pow(attempt));
80    Duration::from_millis(delay_ms.min(max_ms))
81}
82
83/// Add jitter to a duration.
84///
85/// # Arguments
86/// * `duration` - Base duration.
87/// * `jitter_factor` - Factor of jitter (0.0 = no jitter, 1.0 = up to 100% jitter).
88///
89/// # Returns
90/// Duration with random jitter added.
91pub fn with_jitter(duration: Duration, jitter_factor: f64) -> Duration {
92    use rand::Rng;
93
94    let jitter_range = (duration.as_millis() as f64 * jitter_factor) as u64;
95    let jitter = rand::rng().random_range(0..=jitter_range);
96    duration + Duration::from_millis(jitter)
97}
98
99#[cfg(test)]
100mod tests {
101    use super::*;
102
103    #[test]
104    fn test_exponential_backoff() {
105        assert_eq!(
106            exponential_backoff(0, 1000, 60000),
107            Duration::from_millis(1000)
108        );
109        assert_eq!(
110            exponential_backoff(1, 1000, 60000),
111            Duration::from_millis(2000)
112        );
113        assert_eq!(
114            exponential_backoff(2, 1000, 60000),
115            Duration::from_millis(4000)
116        );
117        assert_eq!(
118            exponential_backoff(3, 1000, 60000),
119            Duration::from_millis(8000)
120        );
121
122        // Test capping at max
123        assert_eq!(
124            exponential_backoff(10, 1000, 60000),
125            Duration::from_millis(60000)
126        );
127    }
128
129    #[tokio::test]
130    async fn test_rate_limiter_permits() {
131        let limiter = IdentifyRateLimiter::new(3);
132        assert_eq!(limiter.available_permits(), 3);
133
134        limiter.acquire().await;
135        // One permit should be used (then released after 5s in background)
136        // But we check immediately so might still show 3 or 2 depending on timing
137        assert!(limiter.available_permits() <= 3);
138    }
139}