thunkmetrc_wrapper/
ratelimiter.rs1use std::collections::HashMap;
2use std::sync::Arc;
3use tokio::sync::{Mutex, Semaphore};
4use std::time::{Duration, Instant};
5
6#[derive(Clone)]
7pub struct RateLimiterConfig {
8 pub enabled: bool,
9 pub max_get_per_second_per_facility: f64,
10 pub max_get_per_second_integrator: f64,
11 pub max_concurrent_get_per_facility: usize,
12 pub max_concurrent_get_integrator: usize,
13 pub max_retries: usize,
14}
15
16impl Default for RateLimiterConfig {
17 fn default() -> Self {
18 Self {
19 enabled: false,
20 max_get_per_second_per_facility: 50.0,
21 max_get_per_second_integrator: 150.0,
22 max_concurrent_get_per_facility: 10,
23 max_concurrent_get_integrator: 30,
24 max_retries: 5,
25 }
26 }
27}
28
29pub struct TokenBucket {
30 rate: f64,
31 capacity: f64,
32 tokens: f64,
33 last_refill: Instant,
34}
35
36impl TokenBucket {
37 pub fn new(rate: f64, capacity: f64) -> Self {
38 Self {
39 rate,
40 capacity,
41 tokens: capacity,
42 last_refill: Instant::now(),
43 }
44 }
45
46 fn refill(&mut self) {
47 let now = Instant::now();
48 let elapsed = now.duration_since(self.last_refill).as_secs_f64();
49 self.tokens = (self.tokens + elapsed * self.rate).min(self.capacity);
50 self.last_refill = now;
51 }
52
53 pub fn try_consume(&mut self, amount: f64) -> Option<Duration> {
54 self.refill();
55 if self.tokens >= amount {
56 self.tokens -= amount;
57 None
58 } else {
59 let missing = amount - self.tokens;
60 Some(Duration::from_secs_f64(missing / self.rate))
61 }
62 }
63}
64
65pub struct MetrcRateLimiter {
66 config: RateLimiterConfig,
67 integrator_rate: Mutex<TokenBucket>,
68 integrator_sem: Arc<Semaphore>,
69 facility_rates: Mutex<HashMap<String, Arc<Mutex<TokenBucket>>>>,
70 facility_sems: Mutex<HashMap<String, Arc<Semaphore>>>,
71}
72
73impl MetrcRateLimiter {
74 pub fn new(config: Option<RateLimiterConfig>) -> Self {
75 let config = config.unwrap_or_default();
76 Self {
77 integrator_rate: Mutex::new(TokenBucket::new(config.max_get_per_second_integrator, config.max_get_per_second_integrator)),
78 integrator_sem: Arc::new(Semaphore::new(config.max_concurrent_get_integrator)),
79 facility_rates: Mutex::new(HashMap::new()),
80 facility_sems: Mutex::new(HashMap::new()),
81 config,
82 }
83 }
84
85 pub async fn execute<F, Fut, T>(&self, facility: Option<&str>, is_get: bool, op: F) -> Result<T, Box<dyn std::error::Error + Send + Sync>>
86 where
87 F: Fn() -> Fut,
88 Fut: std::future::Future<Output = Result<T, Box<dyn std::error::Error + Send + Sync>>>,
89 {
90 if !self.config.enabled || !is_get {
91 return op().await;
92 }
93
94 let _sem_permit = self.integrator_sem.acquire().await.map_err(|e| format!("Semaphore closed: {:?}", e)).unwrap();
96
97 let _fac_permit = if let Some(f) = facility {
99 let sem = {
100 let mut sems = self.facility_sems.lock().await;
101 sems.entry(f.to_string())
102 .or_insert_with(|| Arc::new(Semaphore::new(self.config.max_concurrent_get_per_facility)))
103 .clone()
104 };
105 Some(sem.acquire_owned().await.map_err(|e| format!("Semaphore closed: {:?}", e)).unwrap())
106 } else {
107 None
108 };
109
110 loop {
112 let mut bucket = self.integrator_rate.lock().await;
113 if let Some(wait) = bucket.try_consume(1.0) {
114 drop(bucket);
115 tokio::time::sleep(wait).await;
116 } else {
117 break;
118 }
119 }
120
121 if let Some(f) = facility {
123 let bucket_arc = {
124 let mut rates = self.facility_rates.lock().await;
125 rates.entry(f.to_string())
126 .or_insert_with(|| Arc::new(Mutex::new(TokenBucket::new(self.config.max_get_per_second_per_facility, self.config.max_get_per_second_per_facility))))
127 .clone()
128 };
129 loop {
130 let mut bucket = bucket_arc.lock().await;
131 if let Some(wait) = bucket.try_consume(1.0) {
132 drop(bucket);
133 tokio::time::sleep(wait).await;
134 } else {
135 break;
136 }
137 }
138 }
139
140 let mut retry_count = 0;
142 loop {
143 let res = op().await;
144 match res {
145 Ok(v) => return Ok(v),
146 Err(e) => {
147 if retry_count >= self.config.max_retries {
148 return Err(e);
149 }
150 retry_count += 1;
151
152 if let Some(api_err) = e.downcast_ref::<thunkmetrc_client::ApiError>() {
155 if api_err.status == reqwest::StatusCode::TOO_MANY_REQUESTS {
156 if let Some(retry_after) = api_err.headers.get("Retry-After") {
157 if let Ok(val_str) = retry_after.to_str() {
158 if let Ok(secs) = val_str.parse::<u64>() {
159 tokio::time::sleep(Duration::from_secs(secs)).await;
160 continue;
161 }
162 }
163 }
164 tokio::time::sleep(Duration::from_secs(1)).await;
166 continue;
167 }
168 if api_err.status.is_server_error() {
170 tokio::time::sleep(Duration::from_millis(500)).await;
171 continue;
172 }
173 }
174
175 let msg = format!("{:?}", e);
176 if msg.contains("429") {
177 tokio::time::sleep(Duration::from_secs(1)).await;
178 continue;
179 }
180 return Err(e);
181 }
182 }
183 }
184 }
185}