1use async_trait::async_trait;
7use log::{debug, info};
8use moka::future::Cache;
9use rand::distributions::{Distribution, Uniform};
10use std::hash::Hash;
11use std::sync::Arc;
12use std::time::Duration;
13use tokio::sync::Mutex;
14use tokio::time::{Instant, sleep};
15
16use governor::clock::DefaultClock;
17use governor::state::{InMemoryState, NotKeyed};
18use governor::{Quota, RateLimiter as GovernorRateLimiter};
19use std::num::NonZeroU32;
20
21use crate::middleware::{Middleware, MiddlewareAction};
22use spider_util::constants::{
23 MIDDLEWARE_CACHE_CAPACITY, MIDDLEWARE_CACHE_TTL_SECS, RATE_LIMIT_INITIAL_DELAY_MS,
24 RATE_LIMIT_MAX_DELAY_MS, RATE_LIMIT_MAX_JITTER_MS, RATE_LIMIT_MIN_DELAY_MS,
25};
26use spider_util::error::SpiderError;
27use spider_util::request::Request;
28use spider_util::response::Response;
29
30#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
32pub enum Scope {
33 Global,
35 Domain,
37}
38
39#[async_trait]
41pub trait RateLimiter: Send + Sync {
42 async fn acquire(&self);
44 async fn adjust(&self, response: &Response);
46 async fn current_delay(&self) -> Duration;
48}
49
50const INITIAL_DELAY: Duration = Duration::from_millis(RATE_LIMIT_INITIAL_DELAY_MS);
51const MIN_DELAY: Duration = Duration::from_millis(RATE_LIMIT_MIN_DELAY_MS);
52const MAX_DELAY: Duration = Duration::from_millis(RATE_LIMIT_MAX_DELAY_MS);
53
54const ERROR_PENALTY_MULTIPLIER: f64 = 1.5;
55const SUCCESS_DECAY_MULTIPLIER: f64 = 0.95;
56const FORBIDDEN_PENALTY_MULTIPLIER: f64 = 1.2;
57
58struct AdaptiveState {
59 delay: Duration,
60 next_allowed_at: Instant,
61}
62
63pub struct AdaptiveLimiter {
65 state: Mutex<AdaptiveState>,
66 jitter: bool,
67}
68
69impl AdaptiveLimiter {
70 pub fn new(initial_delay: Duration, jitter: bool) -> Self {
72 Self {
73 state: Mutex::new(AdaptiveState {
74 delay: initial_delay,
75 next_allowed_at: Instant::now(),
76 }),
77 jitter,
78 }
79 }
80
81 fn apply_jitter(&self, delay: Duration) -> Duration {
82 if !self.jitter || delay.is_zero() {
83 return delay;
84 }
85
86 let max_jitter = Duration::from_millis(RATE_LIMIT_MAX_JITTER_MS);
87 let jitter_window = delay.mul_f64(0.25).min(max_jitter);
88
89 let low = delay.saturating_sub(jitter_window);
90 let high = delay + jitter_window;
91
92 let mut rng = rand::thread_rng();
93 let uniform = Uniform::new_inclusive(low, high);
94 uniform.sample(&mut rng)
95 }
96}
97
98#[async_trait]
99impl RateLimiter for AdaptiveLimiter {
100 async fn acquire(&self) {
101 let sleep_duration = {
102 let mut state = self.state.lock().await;
103 let now = Instant::now();
104
105 let delay = state.delay;
106 if now < state.next_allowed_at {
107 let wait = state.next_allowed_at - now;
108 state.next_allowed_at += delay;
109 wait
110 } else {
111 state.next_allowed_at = now + delay;
112 Duration::ZERO
113 }
114 };
115
116 let sleep_duration = self.apply_jitter(sleep_duration);
117 if !sleep_duration.is_zero() {
118 debug!("Rate limiting: sleeping for {:?}", sleep_duration);
119 sleep(sleep_duration).await;
120 }
121 }
122
123 async fn adjust(&self, response: &Response) {
124 let mut state = self.state.lock().await;
125
126 let old_delay = state.delay;
127 let status = response.status.as_u16();
128 let new_delay = match status {
129 200..=399 => state.delay.mul_f64(SUCCESS_DECAY_MULTIPLIER),
130 403 => state.delay.mul_f64(FORBIDDEN_PENALTY_MULTIPLIER),
131 429 | 500..=599 => state.delay.mul_f64(ERROR_PENALTY_MULTIPLIER),
132 _ => state.delay,
133 };
134
135 state.delay = new_delay.clamp(MIN_DELAY, MAX_DELAY);
136
137 if old_delay != state.delay {
138 debug!(
139 "Adjusting delay for status {}: {:?} -> {:?}",
140 status, old_delay, state.delay
141 );
142 }
143 }
144
145 async fn current_delay(&self) -> Duration {
146 self.state.lock().await.delay
147 }
148}
149
150pub struct TokenBucketLimiter {
152 limiter: Arc<GovernorRateLimiter<NotKeyed, InMemoryState, DefaultClock>>,
153}
154
155impl TokenBucketLimiter {
156 pub fn new(requests_per_second: u32) -> Self {
162 let requests_per_second = match NonZeroU32::new(requests_per_second) {
163 Some(rps) => rps,
164 None => panic!("requests_per_second must be non-zero"),
165 };
166 let quota = Quota::per_second(requests_per_second);
167 TokenBucketLimiter {
168 limiter: Arc::new(GovernorRateLimiter::direct_with_clock(
169 quota,
170 &DefaultClock::default(),
171 )),
172 }
173 }
174}
175
176#[async_trait]
177impl RateLimiter for TokenBucketLimiter {
178 async fn acquire(&self) {
179 self.limiter.until_ready().await;
180 }
181
182 async fn adjust(&self, _response: &Response) {}
184
185 async fn current_delay(&self) -> Duration {
186 Duration::ZERO
187 }
188}
189
190pub struct RateLimitMiddleware {
192 scope: Scope,
193 limiters: Cache<String, Arc<dyn RateLimiter>>,
194 limiter_factory: Arc<dyn Fn() -> Arc<dyn RateLimiter> + Send + Sync>,
195}
196
197impl RateLimitMiddleware {
198 pub fn builder() -> RateLimitMiddlewareBuilder {
200 RateLimitMiddlewareBuilder::default()
201 }
202
203 fn scope_key(&self, request: &Request) -> String {
204 match self.scope {
205 Scope::Global => "global".to_string(),
206 Scope::Domain => spider_util::util::normalize_origin(request),
207 }
208 }
209}
210
211#[async_trait]
212impl<C: Send + Sync> Middleware<C> for RateLimitMiddleware {
213 fn name(&self) -> &str {
214 "RateLimitMiddleware"
215 }
216
217 async fn process_request(
218 &self,
219 _client: &C,
220 request: Request,
221 ) -> Result<MiddlewareAction<Request>, SpiderError> {
222 let key = self.scope_key(&request);
223
224 let limiter = self
225 .limiters
226 .get_with(key.clone(), async { (self.limiter_factory)() })
227 .await;
228
229 let current_delay = limiter.current_delay().await;
230 debug!(
231 "Acquiring lock for key '{}' (delay: {:?})",
232 key, current_delay
233 );
234
235 limiter.acquire().await;
236 Ok(MiddlewareAction::Continue(request))
237 }
238
239 async fn process_response(
240 &self,
241 response: Response,
242 ) -> Result<MiddlewareAction<Response>, SpiderError> {
243 let key = self.scope_key(&response.request_from_response());
244
245 if let Some(limiter) = self.limiters.get(&key).await {
246 let old_delay = limiter.current_delay().await;
247 limiter.adjust(&response).await;
248 let new_delay = limiter.current_delay().await;
249 if old_delay != new_delay {
250 debug!(
251 "Adjusted rate limit for key '{}': {:?} -> {:?}",
252 key, old_delay, new_delay
253 );
254 }
255 }
256
257 Ok(MiddlewareAction::Continue(response))
258 }
259}
260
261pub struct RateLimitMiddlewareBuilder {
263 scope: Scope,
264 cache_ttl: Duration,
265 cache_capacity: u64,
266 limiter_factory: Box<dyn Fn() -> Arc<dyn RateLimiter> + Send + Sync>,
267}
268
269impl Default for RateLimitMiddlewareBuilder {
270 fn default() -> Self {
271 Self {
272 scope: Scope::Domain,
273 cache_ttl: Duration::from_secs(MIDDLEWARE_CACHE_TTL_SECS),
274 cache_capacity: MIDDLEWARE_CACHE_CAPACITY,
275 limiter_factory: Box::new(|| Arc::new(AdaptiveLimiter::new(INITIAL_DELAY, true))),
276 }
277 }
278}
279
280impl RateLimitMiddlewareBuilder {
281 pub fn scope(mut self, scope: Scope) -> Self {
283 self.scope = scope;
284 self
285 }
286
287 pub fn use_token_bucket_limiter(mut self, requests_per_second: u32) -> Self {
289 self.limiter_factory =
290 Box::new(move || Arc::new(TokenBucketLimiter::new(requests_per_second)));
291 self
292 }
293
294 pub fn limiter(mut self, limiter: impl RateLimiter + 'static) -> Self {
296 let arc = Arc::new(limiter);
297 self.limiter_factory = Box::new(move || arc.clone());
298 self
299 }
300
301 pub fn limiter_factory(
303 mut self,
304 factory: impl Fn() -> Arc<dyn RateLimiter> + Send + Sync + 'static,
305 ) -> Self {
306 self.limiter_factory = Box::new(factory);
307 self
308 }
309
310 pub fn build(self) -> RateLimitMiddleware {
312 info!(
313 "Initializing RateLimitMiddleware with config: scope={:?}, cache_ttl={:?}, cache_capacity={}",
314 self.scope, self.cache_ttl, self.cache_capacity
315 );
316 RateLimitMiddleware {
317 scope: self.scope,
318 limiters: Cache::builder()
319 .time_to_idle(self.cache_ttl)
320 .max_capacity(self.cache_capacity)
321 .build(),
322 limiter_factory: self.limiter_factory.into(),
323 }
324 }
325}