Skip to main content

thunkmetrc_wrapper/
ratelimiter.rs

1use std::collections::HashMap;
2use std::sync::Arc;
3use tokio::sync::{Mutex, Semaphore};
4use std::time::{Duration, Instant};
5
6#[derive(Clone)]
7pub struct RateLimiterConfig {
8    pub enabled: bool,
9    pub max_get_per_second_per_facility: f64,
10    pub max_get_per_second_integrator: f64,
11    pub max_concurrent_get_per_facility: usize,
12    pub max_concurrent_get_integrator: usize,
13    pub max_retries: usize,
14}
15
16impl Default for RateLimiterConfig {
17    fn default() -> Self {
18        Self {
19            enabled: false,
20            max_get_per_second_per_facility: 50.0,
21            max_get_per_second_integrator: 150.0,
22            max_concurrent_get_per_facility: 10,
23            max_concurrent_get_integrator: 30,
24            max_retries: 5,
25        }
26    }
27}
28
29pub struct TokenBucket {
30    rate: f64,
31    capacity: f64,
32    tokens: f64,
33    last_refill: Instant,
34}
35
36impl TokenBucket {
37    pub fn new(rate: f64, capacity: f64) -> Self {
38        Self {
39            rate,
40            capacity,
41            tokens: capacity,
42            last_refill: Instant::now(),
43        }
44    }
45
46    fn refill(&mut self) {
47        let now = Instant::now();
48        let elapsed = now.duration_since(self.last_refill).as_secs_f64();
49        self.tokens = (self.tokens + elapsed * self.rate).min(self.capacity);
50        self.last_refill = now;
51    }
52
53    pub fn try_consume(&mut self, amount: f64) -> Option<Duration> {
54        self.refill();
55        if self.tokens >= amount {
56            self.tokens -= amount;
57            None
58        } else {
59            let missing = amount - self.tokens;
60            Some(Duration::from_secs_f64(missing / self.rate))
61        }
62    }
63}
64
65pub struct MetrcRateLimiter {
66    config: RateLimiterConfig,
67    integrator_rate: Mutex<TokenBucket>,
68    integrator_sem: Arc<Semaphore>,
69    facility_rates: Mutex<HashMap<String, Arc<Mutex<TokenBucket>>>>,
70    facility_sems: Mutex<HashMap<String, Arc<Semaphore>>>,
71}
72
73impl MetrcRateLimiter {
74    pub fn new(config: Option<RateLimiterConfig>) -> Self {
75        let config = config.unwrap_or_default();
76        Self {
77            integrator_rate: Mutex::new(TokenBucket::new(config.max_get_per_second_integrator, config.max_get_per_second_integrator)),
78            integrator_sem: Arc::new(Semaphore::new(config.max_concurrent_get_integrator)),
79            facility_rates: Mutex::new(HashMap::new()),
80            facility_sems: Mutex::new(HashMap::new()),
81            config,
82        }
83    }
84
85    pub async fn execute<F, Fut, T>(&self, facility: Option<&str>, is_get: bool, op: F) -> Result<T, Box<dyn std::error::Error + Send + Sync>>
86    where
87        F: Fn() -> Fut,
88        Fut: std::future::Future<Output = Result<T, Box<dyn std::error::Error + Send + Sync>>>,
89    {
90        if !self.config.enabled || !is_get {
91            return op().await;
92        }
93
94        // 1. Integrator Semaphore
95        let _sem_permit = self.integrator_sem.acquire().await.map_err(|e| format!("Semaphore closed: {:?}", e)).unwrap();
96
97        // 2. Facility Semaphore
98        let _fac_permit = if let Some(f) = facility {
99             let sem = {
100                 let mut sems = self.facility_sems.lock().await;
101                 sems.entry(f.to_string())
102                     .or_insert_with(|| Arc::new(Semaphore::new(self.config.max_concurrent_get_per_facility)))
103                     .clone()
104             };
105             Some(sem.acquire_owned().await.map_err(|e| format!("Semaphore closed: {:?}", e)).unwrap())
106        } else {
107             None
108        };
109
110        // 3. Global Rate
111        loop {
112            let mut bucket = self.integrator_rate.lock().await;
113            if let Some(wait) = bucket.try_consume(1.0) {
114                 drop(bucket);
115                 tokio::time::sleep(wait).await;
116            } else {
117                 break;
118            }
119        }
120
121        // 4. Facility Rate
122        if let Some(f) = facility {
123            let bucket_arc = {
124                let mut rates = self.facility_rates.lock().await;
125                rates.entry(f.to_string())
126                    .or_insert_with(|| Arc::new(Mutex::new(TokenBucket::new(self.config.max_get_per_second_per_facility, self.config.max_get_per_second_per_facility))))
127                    .clone()
128            };
129            loop {
130                 let mut bucket = bucket_arc.lock().await;
131                 if let Some(wait) = bucket.try_consume(1.0) {
132                      drop(bucket);
133                      tokio::time::sleep(wait).await;
134                 } else {
135                      break;
136                 }
137            }
138        }
139
140        // 5. Retry Loop with max retries
141        let mut retry_count = 0;
142        loop {
143             let res = op().await;
144             match res {
145                 Ok(v) => return Ok(v),
146                 Err(e) => {
147                      if retry_count >= self.config.max_retries {
148                          return Err(e);
149                      }
150                      retry_count += 1;
151
152                      // Check for ApiError to respect Retry-After
153                      // e is Box<dyn Error + Send + Sync>
154                      if let Some(api_err) = e.downcast_ref::<thunkmetrc_client::ApiError>() {
155                          if api_err.status == reqwest::StatusCode::TOO_MANY_REQUESTS {
156                               if let Some(retry_after) = api_err.headers.get("Retry-After") {
157                                   if let Ok(val_str) = retry_after.to_str() {
158                                       if let Ok(secs) = val_str.parse::<u64>() {
159                                           tokio::time::sleep(Duration::from_secs(secs)).await;
160                                           continue;
161                                       }
162                                   }
163                               }
164                               // Default 429 wait
165                               tokio::time::sleep(Duration::from_secs(1)).await;
166                               continue;
167                          }
168                          // Handle 5xx with exponential backoff
169                          if api_err.status.is_server_error() {
170                              tokio::time::sleep(Duration::from_millis(500)).await;
171                              continue; 
172                          }
173                      }
174                      
175                      let msg = format!("{:?}", e);
176                      if msg.contains("429") {
177                           tokio::time::sleep(Duration::from_secs(1)).await;
178                           continue;
179                      }
180                      return Err(e);
181                 }
182             }
183        }
184    }
185}