titanium_gateway/
ratelimit.rs1use std::sync::Arc;
7use std::time::Duration;
8use tokio::sync::Semaphore;
9use tokio::time::sleep;
10
11pub struct IdentifyRateLimiter {
16 semaphore: Arc<Semaphore>,
18
19 hold_duration: Duration,
21}
22
23impl IdentifyRateLimiter {
24 pub fn new(max_concurrency: usize) -> Self {
29 Self {
30 semaphore: Arc::new(Semaphore::new(max_concurrency)),
31 hold_duration: Duration::from_secs(5),
32 }
33 }
34
35 pub async fn acquire(&self) {
40 let permit = self
42 .semaphore
43 .clone()
44 .acquire_owned()
45 .await
46 .expect("semaphore should not be closed");
47
48 let hold_duration = self.hold_duration;
50 tokio::spawn(async move {
51 sleep(hold_duration).await;
52 drop(permit);
53 });
54 }
55
56 pub fn available_permits(&self) -> usize {
58 self.semaphore.available_permits()
59 }
60}
61
62impl Default for IdentifyRateLimiter {
63 fn default() -> Self {
64 Self::new(1)
66 }
67}
68
69pub fn exponential_backoff(attempt: u32, base_ms: u64, max_ms: u64) -> Duration {
79 let delay_ms = base_ms.saturating_mul(2u64.saturating_pow(attempt));
80 Duration::from_millis(delay_ms.min(max_ms))
81}
82
83pub fn with_jitter(duration: Duration, jitter_factor: f64) -> Duration {
92 use rand::Rng;
93
94 let jitter_range = (duration.as_millis() as f64 * jitter_factor) as u64;
95 let jitter = rand::rng().random_range(0..=jitter_range);
96 duration + Duration::from_millis(jitter)
97}
98
99#[cfg(test)]
100mod tests {
101 use super::*;
102
103 #[test]
104 fn test_exponential_backoff() {
105 assert_eq!(
106 exponential_backoff(0, 1000, 60000),
107 Duration::from_millis(1000)
108 );
109 assert_eq!(
110 exponential_backoff(1, 1000, 60000),
111 Duration::from_millis(2000)
112 );
113 assert_eq!(
114 exponential_backoff(2, 1000, 60000),
115 Duration::from_millis(4000)
116 );
117 assert_eq!(
118 exponential_backoff(3, 1000, 60000),
119 Duration::from_millis(8000)
120 );
121
122 assert_eq!(
124 exponential_backoff(10, 1000, 60000),
125 Duration::from_millis(60000)
126 );
127 }
128
129 #[tokio::test]
130 async fn test_rate_limiter_permits() {
131 let limiter = IdentifyRateLimiter::new(3);
132 assert_eq!(limiter.available_permits(), 3);
133
134 limiter.acquire().await;
135 assert!(limiter.available_permits() <= 3);
138 }
139}