1use async_trait::async_trait;
2use moka::future::Cache;
3use rand::distributions::{Distribution, Uniform};
4use std::hash::Hash;
5use std::sync::Arc;
6use std::time::Duration;
7use tokio::sync::Mutex;
8use tokio::time::{Instant, sleep};
9use tracing::{debug, info};
10
11use governor::{Quota, RateLimiter as GovernorRateLimiter};
12use governor::state::{InMemoryState, NotKeyed};
13use governor::clock::DefaultClock;
14use std::num::NonZeroU32;
15
16use crate::utils;
17use crate::{
18 error::SpiderError,
19 middleware::{Middleware, MiddlewareAction},
20 request::Request,
21 response::Response,
22};
23
24#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
26pub enum Scope {
27 Global,
29 Domain,
31}
32
33#[async_trait]
35pub trait RateLimiter: Send + Sync {
36 async fn acquire(&self);
38 async fn adjust(&self, response: &Response);
40 async fn current_delay(&self) -> Duration;
42}
43
44const INITIAL_DELAY: Duration = Duration::from_millis(500);
45const MIN_DELAY: Duration = Duration::from_millis(50);
46const MAX_DELAY: Duration = Duration::from_secs(60);
47
48const ERROR_PENALTY_MULTIPLIER: f64 = 1.5;
49const SUCCESS_DECAY_MULTIPLIER: f64 = 0.95;
50const FORBIDDEN_PENALTY_MULTIPLIER: f64 = 1.2;
51
52struct AdaptiveState {
53 delay: Duration,
54 next_allowed_at: Instant,
55}
56
57pub struct AdaptiveLimiter {
59 state: Mutex<AdaptiveState>,
60 jitter: bool,
61}
62
63impl AdaptiveLimiter {
64 pub fn new(initial_delay: Duration, jitter: bool) -> Self {
66 Self {
67 state: Mutex::new(AdaptiveState {
68 delay: initial_delay,
69 next_allowed_at: Instant::now(),
70 }),
71 jitter,
72 }
73 }
74
75 fn apply_jitter(&self, delay: Duration) -> Duration {
76 if !self.jitter || delay.is_zero() {
77 return delay;
78 }
79
80 let max_jitter = Duration::from_millis(500);
81 let jitter_window = delay.mul_f64(0.25).min(max_jitter);
82
83 let low = delay.saturating_sub(jitter_window);
84 let high = delay + jitter_window;
85
86 let mut rng = rand::thread_rng();
87 let uniform = Uniform::new_inclusive(low, high);
88 uniform.sample(&mut rng)
89 }
90}
91
92#[async_trait]
93impl RateLimiter for AdaptiveLimiter {
94 async fn acquire(&self) {
95 let sleep_duration = {
96 let mut state = self.state.lock().await;
97 let now = Instant::now();
98
99 let delay = state.delay;
100 if now < state.next_allowed_at {
101 let wait = state.next_allowed_at - now;
102 state.next_allowed_at += delay;
103 wait
104 } else {
105 state.next_allowed_at = now + delay;
106 Duration::ZERO
107 }
108 };
109
110 let sleep_duration = self.apply_jitter(sleep_duration);
111 if !sleep_duration.is_zero() {
112 debug!("Rate limiting: sleeping for {:?}", sleep_duration);
113 sleep(sleep_duration).await;
114 }
115 }
116
117 async fn adjust(&self, response: &Response) {
118 let mut state = self.state.lock().await;
119
120 let old_delay = state.delay;
121 let status = response.status.as_u16();
122 let new_delay = match status {
123 200..=399 => state.delay.mul_f64(SUCCESS_DECAY_MULTIPLIER),
124 403 => state.delay.mul_f64(FORBIDDEN_PENALTY_MULTIPLIER),
125 429 | 500..=599 => state.delay.mul_f64(ERROR_PENALTY_MULTIPLIER),
126 _ => state.delay,
127 };
128
129 state.delay = new_delay.clamp(MIN_DELAY, MAX_DELAY);
130
131 if old_delay != state.delay {
132 debug!(
133 "Adjusting delay for status {}: {:?} -> {:?}",
134 status, old_delay, state.delay
135 );
136 }
137 }
138
139 async fn current_delay(&self) -> Duration {
140 self.state.lock().await.delay
141 }
142}
143
144pub struct TokenBucketLimiter {
146 limiter: Arc<GovernorRateLimiter<NotKeyed, InMemoryState, DefaultClock>>,
147}
148
149impl TokenBucketLimiter {
150 pub fn new(requests_per_second: u32) -> Self {
152 let quota = Quota::per_second(NonZeroU32::new(requests_per_second).expect("requests_per_second must be non-zero"));
153 TokenBucketLimiter {
154 limiter: Arc::new(GovernorRateLimiter::direct_with_clock(quota, &DefaultClock::default())),
155 }
156 }
157}
158
159#[async_trait]
160impl RateLimiter for TokenBucketLimiter {
161 async fn acquire(&self) {
162 self.limiter.until_ready().await;
163 }
164
165 async fn adjust(&self, _response: &Response) {
167 }
169
170 async fn current_delay(&self) -> Duration {
171 Duration::ZERO
175 }
176}
177
178
179pub struct RateLimitMiddleware {
181 scope: Scope,
182 limiters: Cache<String, Arc<dyn RateLimiter>>,
183 limiter_factory: Arc<dyn Fn() -> Arc<dyn RateLimiter> + Send + Sync>,
184}
185
186impl RateLimitMiddleware {
187 pub fn builder() -> RateLimitMiddlewareBuilder {
189 RateLimitMiddlewareBuilder::default()
190 }
191
192 fn scope_key(&self, request: &Request) -> String {
193 match self.scope {
194 Scope::Global => "global".to_string(),
195 Scope::Domain => utils::normalize_origin(request),
196 }
197 }
198}
199
200#[async_trait]
201impl<C: Send + Sync> Middleware<C> for RateLimitMiddleware {
202 fn name(&self) -> &str {
203 "RateLimitMiddleware"
204 }
205
206 async fn process_request(
207 &mut self,
208 _client: &C,
209 request: Request,
210 ) -> Result<MiddlewareAction<Request>, SpiderError> {
211 let key = self.scope_key(&request);
212
213 let limiter = self
214 .limiters
215 .get_with(key.clone(), async { (self.limiter_factory)() })
216 .await;
217
218 let current_delay = limiter.current_delay().await;
219 debug!(
220 "Acquiring lock for key '{}' (delay: {:?})",
221 key, current_delay
222 );
223
224 limiter.acquire().await;
225 Ok(MiddlewareAction::Continue(request))
226 }
227
228 async fn process_response(
229 &mut self,
230 response: Response,
231 ) -> Result<MiddlewareAction<Response>, SpiderError> {
232 let key = self.scope_key(&response.request_from_response());
233
234 if let Some(limiter) = self.limiters.get(&key).await {
235 let old_delay = limiter.current_delay().await;
236 limiter.adjust(&response).await;
237 let new_delay = limiter.current_delay().await;
238 if old_delay != new_delay {
239 debug!(
240 "Adjusted rate limit for key '{}': {:?} -> {:?}",
241 key, old_delay, new_delay
242 );
243 }
244 }
245
246 Ok(MiddlewareAction::Continue(response))
247 }
248}
249
250pub struct RateLimitMiddlewareBuilder {
252 scope: Scope,
253 cache_ttl: Duration,
254 cache_capacity: u64,
255 limiter_factory: Box<dyn Fn() -> Arc<dyn RateLimiter> + Send + Sync>,
256}
257
258impl Default for RateLimitMiddlewareBuilder {
259 fn default() -> Self {
260 Self {
261 scope: Scope::Domain,
262 cache_ttl: Duration::from_secs(3600),
263 cache_capacity: 10_000,
264 limiter_factory: Box::new(|| Arc::new(AdaptiveLimiter::new(INITIAL_DELAY, true))),
265 }
266 }
267}
268
269impl RateLimitMiddlewareBuilder {
270 pub fn scope(mut self, scope: Scope) -> Self {
272 self.scope = scope;
273 self
274 }
275
276 pub fn use_token_bucket_limiter(mut self, requests_per_second: u32) -> Self {
278 self.limiter_factory = Box::new(move || Arc::new(TokenBucketLimiter::new(requests_per_second)));
279 self
280 }
281
282 pub fn limiter(mut self, limiter: impl RateLimiter + 'static) -> Self {
284 let arc = Arc::new(limiter);
285 self.limiter_factory = Box::new(move || arc.clone());
286 self
287 }
288
289 pub fn limiter_factory(
291 mut self,
292 factory: impl Fn() -> Arc<dyn RateLimiter> + Send + Sync + 'static,
293 ) -> Self {
294 self.limiter_factory = Box::new(factory);
295 self
296 }
297
298 pub fn build(self) -> RateLimitMiddleware {
300 info!(
301 "Initializing RateLimitMiddleware with config: scope={:?}, cache_ttl={:?}, cache_capacity={}",
302 self.scope, self.cache_ttl, self.cache_capacity
303 );
304 RateLimitMiddleware {
305 scope: self.scope,
306 limiters: Cache::builder()
307 .time_to_idle(self.cache_ttl)
308 .max_capacity(self.cache_capacity)
309 .build(),
310 limiter_factory: self.limiter_factory.into(),
311 }
312 }
313}