spider_middleware/
rate_limit.rs1use async_trait::async_trait;
17use moka::future::Cache;
18use rand::distributions::{Distribution, Uniform};
19use std::hash::Hash;
20use std::sync::Arc;
21use std::time::Duration;
22use tokio::sync::Mutex;
23use tokio::time::{Instant, sleep};
24use tracing::{debug, info};
25
26use governor::clock::DefaultClock;
27use governor::state::{InMemoryState, NotKeyed};
28use governor::{Quota, RateLimiter as GovernorRateLimiter};
29use std::num::NonZeroU32;
30
31use crate::middleware::{Middleware, MiddlewareAction};
32use spider_util::error::SpiderError;
33use spider_util::request::Request;
34use spider_util::response::Response;
35
36#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
38pub enum Scope {
39 Global,
41 Domain,
43}
44
45#[async_trait]
47pub trait RateLimiter: Send + Sync {
48 async fn acquire(&self);
50 async fn adjust(&self, response: &Response);
52 async fn current_delay(&self) -> Duration;
54}
55
56const INITIAL_DELAY: Duration = Duration::from_millis(500);
57const MIN_DELAY: Duration = Duration::from_millis(50);
58const MAX_DELAY: Duration = Duration::from_secs(60);
59
60const ERROR_PENALTY_MULTIPLIER: f64 = 1.5;
61const SUCCESS_DECAY_MULTIPLIER: f64 = 0.95;
62const FORBIDDEN_PENALTY_MULTIPLIER: f64 = 1.2;
63
64struct AdaptiveState {
65 delay: Duration,
66 next_allowed_at: Instant,
67}
68
69pub struct AdaptiveLimiter {
71 state: Mutex<AdaptiveState>,
72 jitter: bool,
73}
74
75impl AdaptiveLimiter {
76 pub fn new(initial_delay: Duration, jitter: bool) -> Self {
78 Self {
79 state: Mutex::new(AdaptiveState {
80 delay: initial_delay,
81 next_allowed_at: Instant::now(),
82 }),
83 jitter,
84 }
85 }
86
87 fn apply_jitter(&self, delay: Duration) -> Duration {
88 if !self.jitter || delay.is_zero() {
89 return delay;
90 }
91
92 let max_jitter = Duration::from_millis(500);
93 let jitter_window = delay.mul_f64(0.25).min(max_jitter);
94
95 let low = delay.saturating_sub(jitter_window);
96 let high = delay + jitter_window;
97
98 let mut rng = rand::thread_rng();
99 let uniform = Uniform::new_inclusive(low, high);
100 uniform.sample(&mut rng)
101 }
102}
103
104#[async_trait]
105impl RateLimiter for AdaptiveLimiter {
106 async fn acquire(&self) {
107 let sleep_duration = {
108 let mut state = self.state.lock().await;
109 let now = Instant::now();
110
111 let delay = state.delay;
112 if now < state.next_allowed_at {
113 let wait = state.next_allowed_at - now;
114 state.next_allowed_at += delay;
115 wait
116 } else {
117 state.next_allowed_at = now + delay;
118 Duration::ZERO
119 }
120 };
121
122 let sleep_duration = self.apply_jitter(sleep_duration);
123 if !sleep_duration.is_zero() {
124 debug!("Rate limiting: sleeping for {:?}", sleep_duration);
125 sleep(sleep_duration).await;
126 }
127 }
128
129 async fn adjust(&self, response: &Response) {
130 let mut state = self.state.lock().await;
131
132 let old_delay = state.delay;
133 let status = response.status.as_u16();
134 let new_delay = match status {
135 200..=399 => state.delay.mul_f64(SUCCESS_DECAY_MULTIPLIER),
136 403 => state.delay.mul_f64(FORBIDDEN_PENALTY_MULTIPLIER),
137 429 | 500..=599 => state.delay.mul_f64(ERROR_PENALTY_MULTIPLIER),
138 _ => state.delay,
139 };
140
141 state.delay = new_delay.clamp(MIN_DELAY, MAX_DELAY);
142
143 if old_delay != state.delay {
144 debug!(
145 "Adjusting delay for status {}: {:?} -> {:?}",
146 status, old_delay, state.delay
147 );
148 }
149 }
150
151 async fn current_delay(&self) -> Duration {
152 self.state.lock().await.delay
153 }
154}
155
156pub struct TokenBucketLimiter {
158 limiter: Arc<GovernorRateLimiter<NotKeyed, InMemoryState, DefaultClock>>,
159}
160
161impl TokenBucketLimiter {
162 pub fn new(requests_per_second: u32) -> Self {
164 let quota = Quota::per_second(
165 NonZeroU32::new(requests_per_second).expect("requests_per_second must be non-zero"),
166 );
167 TokenBucketLimiter {
168 limiter: Arc::new(GovernorRateLimiter::direct_with_clock(
169 quota,
170 &DefaultClock::default(),
171 )),
172 }
173 }
174}
175
176#[async_trait]
177impl RateLimiter for TokenBucketLimiter {
178 async fn acquire(&self) {
179 self.limiter.until_ready().await;
180 }
181
182 async fn adjust(&self, _response: &Response) {
184 }
186
187 async fn current_delay(&self) -> Duration {
188 Duration::ZERO
192 }
193}
194
195pub struct RateLimitMiddleware {
197 scope: Scope,
198 limiters: Cache<String, Arc<dyn RateLimiter>>,
199 limiter_factory: Arc<dyn Fn() -> Arc<dyn RateLimiter> + Send + Sync>,
200}
201
202impl RateLimitMiddleware {
203 pub fn builder() -> RateLimitMiddlewareBuilder {
205 RateLimitMiddlewareBuilder::default()
206 }
207
208 fn scope_key(&self, request: &Request) -> String {
209 match self.scope {
210 Scope::Global => "global".to_string(),
211 Scope::Domain => spider_util::utils::normalize_origin(request),
212 }
213 }
214}
215
216#[async_trait]
217impl<C: Send + Sync> Middleware<C> for RateLimitMiddleware {
218 fn name(&self) -> &str {
219 "RateLimitMiddleware"
220 }
221
222 async fn process_request(
223 &mut self,
224 _client: &C,
225 request: Request,
226 ) -> Result<MiddlewareAction<Request>, SpiderError> {
227 let key = self.scope_key(&request);
228
229 let limiter = self
230 .limiters
231 .get_with(key.clone(), async { (self.limiter_factory)() })
232 .await;
233
234 let current_delay = limiter.current_delay().await;
235 debug!(
236 "Acquiring lock for key '{}' (delay: {:?})",
237 key, current_delay
238 );
239
240 limiter.acquire().await;
241 Ok(MiddlewareAction::Continue(request))
242 }
243
244 async fn process_response(
245 &mut self,
246 response: Response,
247 ) -> Result<MiddlewareAction<Response>, SpiderError> {
248 let key = self.scope_key(&response.request_from_response());
249
250 if let Some(limiter) = self.limiters.get(&key).await {
251 let old_delay = limiter.current_delay().await;
252 limiter.adjust(&response).await;
253 let new_delay = limiter.current_delay().await;
254 if old_delay != new_delay {
255 debug!(
256 "Adjusted rate limit for key '{}': {:?} -> {:?}",
257 key, old_delay, new_delay
258 );
259 }
260 }
261
262 Ok(MiddlewareAction::Continue(response))
263 }
264}
265
266pub struct RateLimitMiddlewareBuilder {
268 scope: Scope,
269 cache_ttl: Duration,
270 cache_capacity: u64,
271 limiter_factory: Box<dyn Fn() -> Arc<dyn RateLimiter> + Send + Sync>,
272}
273
274impl Default for RateLimitMiddlewareBuilder {
275 fn default() -> Self {
276 Self {
277 scope: Scope::Domain,
278 cache_ttl: Duration::from_secs(3600),
279 cache_capacity: 10_000,
280 limiter_factory: Box::new(|| Arc::new(AdaptiveLimiter::new(INITIAL_DELAY, true))),
281 }
282 }
283}
284
285impl RateLimitMiddlewareBuilder {
286 pub fn scope(mut self, scope: Scope) -> Self {
288 self.scope = scope;
289 self
290 }
291
292 pub fn use_token_bucket_limiter(mut self, requests_per_second: u32) -> Self {
294 self.limiter_factory =
295 Box::new(move || Arc::new(TokenBucketLimiter::new(requests_per_second)));
296 self
297 }
298
299 pub fn limiter(mut self, limiter: impl RateLimiter + 'static) -> Self {
301 let arc = Arc::new(limiter);
302 self.limiter_factory = Box::new(move || arc.clone());
303 self
304 }
305
306 pub fn limiter_factory(
308 mut self,
309 factory: impl Fn() -> Arc<dyn RateLimiter> + Send + Sync + 'static,
310 ) -> Self {
311 self.limiter_factory = Box::new(factory);
312 self
313 }
314
315 pub fn build(self) -> RateLimitMiddleware {
317 info!(
318 "Initializing RateLimitMiddleware with config: scope={:?}, cache_ttl={:?}, cache_capacity={}",
319 self.scope, self.cache_ttl, self.cache_capacity
320 );
321 RateLimitMiddleware {
322 scope: self.scope,
323 limiters: Cache::builder()
324 .time_to_idle(self.cache_ttl)
325 .max_capacity(self.cache_capacity)
326 .build(),
327 limiter_factory: self.limiter_factory.into(),
328 }
329 }
330}