1use async_trait::async_trait;
17use moka::future::Cache;
18use rand::distributions::{Distribution, Uniform};
19use std::hash::Hash;
20use std::sync::Arc;
21use std::time::Duration;
22use tokio::sync::Mutex;
23use tokio::time::{Instant, sleep};
24use tracing::{debug, info};
25
26use governor::clock::DefaultClock;
27use governor::state::{InMemoryState, NotKeyed};
28use governor::{Quota, RateLimiter as GovernorRateLimiter};
29use std::num::NonZeroU32;
30
31use crate::utils;
32use crate::{
33 error::SpiderError,
34 middleware::{Middleware, MiddlewareAction},
35 request::Request,
36 response::Response,
37};
38
39#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
41pub enum Scope {
42 Global,
44 Domain,
46}
47
48#[async_trait]
50pub trait RateLimiter: Send + Sync {
51 async fn acquire(&self);
53 async fn adjust(&self, response: &Response);
55 async fn current_delay(&self) -> Duration;
57}
58
59const INITIAL_DELAY: Duration = Duration::from_millis(500);
60const MIN_DELAY: Duration = Duration::from_millis(50);
61const MAX_DELAY: Duration = Duration::from_secs(60);
62
63const ERROR_PENALTY_MULTIPLIER: f64 = 1.5;
64const SUCCESS_DECAY_MULTIPLIER: f64 = 0.95;
65const FORBIDDEN_PENALTY_MULTIPLIER: f64 = 1.2;
66
67struct AdaptiveState {
68 delay: Duration,
69 next_allowed_at: Instant,
70}
71
72pub struct AdaptiveLimiter {
74 state: Mutex<AdaptiveState>,
75 jitter: bool,
76}
77
78impl AdaptiveLimiter {
79 pub fn new(initial_delay: Duration, jitter: bool) -> Self {
81 Self {
82 state: Mutex::new(AdaptiveState {
83 delay: initial_delay,
84 next_allowed_at: Instant::now(),
85 }),
86 jitter,
87 }
88 }
89
90 fn apply_jitter(&self, delay: Duration) -> Duration {
91 if !self.jitter || delay.is_zero() {
92 return delay;
93 }
94
95 let max_jitter = Duration::from_millis(500);
96 let jitter_window = delay.mul_f64(0.25).min(max_jitter);
97
98 let low = delay.saturating_sub(jitter_window);
99 let high = delay + jitter_window;
100
101 let mut rng = rand::thread_rng();
102 let uniform = Uniform::new_inclusive(low, high);
103 uniform.sample(&mut rng)
104 }
105}
106
107#[async_trait]
108impl RateLimiter for AdaptiveLimiter {
109 async fn acquire(&self) {
110 let sleep_duration = {
111 let mut state = self.state.lock().await;
112 let now = Instant::now();
113
114 let delay = state.delay;
115 if now < state.next_allowed_at {
116 let wait = state.next_allowed_at - now;
117 state.next_allowed_at += delay;
118 wait
119 } else {
120 state.next_allowed_at = now + delay;
121 Duration::ZERO
122 }
123 };
124
125 let sleep_duration = self.apply_jitter(sleep_duration);
126 if !sleep_duration.is_zero() {
127 debug!("Rate limiting: sleeping for {:?}", sleep_duration);
128 sleep(sleep_duration).await;
129 }
130 }
131
132 async fn adjust(&self, response: &Response) {
133 let mut state = self.state.lock().await;
134
135 let old_delay = state.delay;
136 let status = response.status.as_u16();
137 let new_delay = match status {
138 200..=399 => state.delay.mul_f64(SUCCESS_DECAY_MULTIPLIER),
139 403 => state.delay.mul_f64(FORBIDDEN_PENALTY_MULTIPLIER),
140 429 | 500..=599 => state.delay.mul_f64(ERROR_PENALTY_MULTIPLIER),
141 _ => state.delay,
142 };
143
144 state.delay = new_delay.clamp(MIN_DELAY, MAX_DELAY);
145
146 if old_delay != state.delay {
147 debug!(
148 "Adjusting delay for status {}: {:?} -> {:?}",
149 status, old_delay, state.delay
150 );
151 }
152 }
153
154 async fn current_delay(&self) -> Duration {
155 self.state.lock().await.delay
156 }
157}
158
159pub struct TokenBucketLimiter {
161 limiter: Arc<GovernorRateLimiter<NotKeyed, InMemoryState, DefaultClock>>,
162}
163
164impl TokenBucketLimiter {
165 pub fn new(requests_per_second: u32) -> Self {
167 let quota = Quota::per_second(
168 NonZeroU32::new(requests_per_second).expect("requests_per_second must be non-zero"),
169 );
170 TokenBucketLimiter {
171 limiter: Arc::new(GovernorRateLimiter::direct_with_clock(
172 quota,
173 &DefaultClock::default(),
174 )),
175 }
176 }
177}
178
179#[async_trait]
180impl RateLimiter for TokenBucketLimiter {
181 async fn acquire(&self) {
182 self.limiter.until_ready().await;
183 }
184
185 async fn adjust(&self, _response: &Response) {
187 }
189
190 async fn current_delay(&self) -> Duration {
191 Duration::ZERO
195 }
196}
197
198pub struct RateLimitMiddleware {
200 scope: Scope,
201 limiters: Cache<String, Arc<dyn RateLimiter>>,
202 limiter_factory: Arc<dyn Fn() -> Arc<dyn RateLimiter> + Send + Sync>,
203}
204
205impl RateLimitMiddleware {
206 pub fn builder() -> RateLimitMiddlewareBuilder {
208 RateLimitMiddlewareBuilder::default()
209 }
210
211 fn scope_key(&self, request: &Request) -> String {
212 match self.scope {
213 Scope::Global => "global".to_string(),
214 Scope::Domain => utils::normalize_origin(request),
215 }
216 }
217}
218
219#[async_trait]
220impl<C: Send + Sync> Middleware<C> for RateLimitMiddleware {
221 fn name(&self) -> &str {
222 "RateLimitMiddleware"
223 }
224
225 async fn process_request(
226 &mut self,
227 _client: &C,
228 request: Request,
229 ) -> Result<MiddlewareAction<Request>, SpiderError> {
230 let key = self.scope_key(&request);
231
232 let limiter = self
233 .limiters
234 .get_with(key.clone(), async { (self.limiter_factory)() })
235 .await;
236
237 let current_delay = limiter.current_delay().await;
238 debug!(
239 "Acquiring lock for key '{}' (delay: {:?})",
240 key, current_delay
241 );
242
243 limiter.acquire().await;
244 Ok(MiddlewareAction::Continue(request))
245 }
246
247 async fn process_response(
248 &mut self,
249 response: Response,
250 ) -> Result<MiddlewareAction<Response>, SpiderError> {
251 let key = self.scope_key(&response.request_from_response());
252
253 if let Some(limiter) = self.limiters.get(&key).await {
254 let old_delay = limiter.current_delay().await;
255 limiter.adjust(&response).await;
256 let new_delay = limiter.current_delay().await;
257 if old_delay != new_delay {
258 debug!(
259 "Adjusted rate limit for key '{}': {:?} -> {:?}",
260 key, old_delay, new_delay
261 );
262 }
263 }
264
265 Ok(MiddlewareAction::Continue(response))
266 }
267}
268
269pub struct RateLimitMiddlewareBuilder {
271 scope: Scope,
272 cache_ttl: Duration,
273 cache_capacity: u64,
274 limiter_factory: Box<dyn Fn() -> Arc<dyn RateLimiter> + Send + Sync>,
275}
276
277impl Default for RateLimitMiddlewareBuilder {
278 fn default() -> Self {
279 Self {
280 scope: Scope::Domain,
281 cache_ttl: Duration::from_secs(3600),
282 cache_capacity: 10_000,
283 limiter_factory: Box::new(|| Arc::new(AdaptiveLimiter::new(INITIAL_DELAY, true))),
284 }
285 }
286}
287
288impl RateLimitMiddlewareBuilder {
289 pub fn scope(mut self, scope: Scope) -> Self {
291 self.scope = scope;
292 self
293 }
294
295 pub fn use_token_bucket_limiter(mut self, requests_per_second: u32) -> Self {
297 self.limiter_factory =
298 Box::new(move || Arc::new(TokenBucketLimiter::new(requests_per_second)));
299 self
300 }
301
302 pub fn limiter(mut self, limiter: impl RateLimiter + 'static) -> Self {
304 let arc = Arc::new(limiter);
305 self.limiter_factory = Box::new(move || arc.clone());
306 self
307 }
308
309 pub fn limiter_factory(
311 mut self,
312 factory: impl Fn() -> Arc<dyn RateLimiter> + Send + Sync + 'static,
313 ) -> Self {
314 self.limiter_factory = Box::new(factory);
315 self
316 }
317
318 pub fn build(self) -> RateLimitMiddleware {
320 info!(
321 "Initializing RateLimitMiddleware with config: scope={:?}, cache_ttl={:?}, cache_capacity={}",
322 self.scope, self.cache_ttl, self.cache_capacity
323 );
324 RateLimitMiddleware {
325 scope: self.scope,
326 limiters: Cache::builder()
327 .time_to_idle(self.cache_ttl)
328 .max_capacity(self.cache_capacity)
329 .build(),
330 limiter_factory: self.limiter_factory.into(),
331 }
332 }
333}