Skip to main content

spider_lib/middlewares/
rate_limit.rs

1//! Rate Limit Middleware for controlling request frequency.
2//!
3//! This module provides the `RateLimitMiddleware`, designed to manage the rate
4//! at which HTTP requests are sent to target servers. It helps prevent
5//! overloading websites and respects server-side rate limits, making crawls
6//! more robust and polite.
7//!
8//! The middleware supports:
9//! - **Different scopes:** Applying rate limits globally or per-domain.
10//! - **Pluggable limiters:** Offering an `AdaptiveLimiter` that dynamically adjusts
11//!   delays based on response status (e.g., increasing delay on errors, decreasing on success),
12//!   and a `TokenBucketLimiter` for enforcing a fixed requests-per-second rate.
13//!
14//! This flexibility allows for fine-tuned control over crawl speed and server interaction.
15
16use async_trait::async_trait;
17use moka::future::Cache;
18use rand::distributions::{Distribution, Uniform};
19use std::hash::Hash;
20use std::sync::Arc;
21use std::time::Duration;
22use tokio::sync::Mutex;
23use tokio::time::{Instant, sleep};
24use tracing::{debug, info};
25
26use governor::clock::DefaultClock;
27use governor::state::{InMemoryState, NotKeyed};
28use governor::{Quota, RateLimiter as GovernorRateLimiter};
29use std::num::NonZeroU32;
30
31use crate::utils;
32use crate::{
33    error::SpiderError,
34    middleware::{Middleware, MiddlewareAction},
35    request::Request,
36    response::Response,
37};
38
39/// Determines the scope at which rate limits are applied.
40#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
41pub enum Scope {
42    /// A single global rate limit for all requests.
43    Global,
44    /// A separate rate limit for each domain.
45    Domain,
46}
47
48/// A trait for asynchronous, stateful rate limiters.
49#[async_trait]
50pub trait RateLimiter: Send + Sync {
51    /// Blocks until a request is allowed to proceed.
52    async fn acquire(&self);
53    /// Adjusts the rate limit based on the response.
54    async fn adjust(&self, response: &Response);
55    /// Returns the current delay between requests.
56    async fn current_delay(&self) -> Duration;
57}
58
59const INITIAL_DELAY: Duration = Duration::from_millis(500);
60const MIN_DELAY: Duration = Duration::from_millis(50);
61const MAX_DELAY: Duration = Duration::from_secs(60);
62
63const ERROR_PENALTY_MULTIPLIER: f64 = 1.5;
64const SUCCESS_DECAY_MULTIPLIER: f64 = 0.95;
65const FORBIDDEN_PENALTY_MULTIPLIER: f64 = 1.2;
66
67struct AdaptiveState {
68    delay: Duration,
69    next_allowed_at: Instant,
70}
71
72/// An adaptive rate limiter that adjusts the delay based on response status.
73pub struct AdaptiveLimiter {
74    state: Mutex<AdaptiveState>,
75    jitter: bool,
76}
77
78impl AdaptiveLimiter {
79    /// Creates a new `AdaptiveLimiter` with a given initial delay and jitter setting.
80    pub fn new(initial_delay: Duration, jitter: bool) -> Self {
81        Self {
82            state: Mutex::new(AdaptiveState {
83                delay: initial_delay,
84                next_allowed_at: Instant::now(),
85            }),
86            jitter,
87        }
88    }
89
90    fn apply_jitter(&self, delay: Duration) -> Duration {
91        if !self.jitter || delay.is_zero() {
92            return delay;
93        }
94
95        let max_jitter = Duration::from_millis(500);
96        let jitter_window = delay.mul_f64(0.25).min(max_jitter);
97
98        let low = delay.saturating_sub(jitter_window);
99        let high = delay + jitter_window;
100
101        let mut rng = rand::thread_rng();
102        let uniform = Uniform::new_inclusive(low, high);
103        uniform.sample(&mut rng)
104    }
105}
106
107#[async_trait]
108impl RateLimiter for AdaptiveLimiter {
109    async fn acquire(&self) {
110        let sleep_duration = {
111            let mut state = self.state.lock().await;
112            let now = Instant::now();
113
114            let delay = state.delay;
115            if now < state.next_allowed_at {
116                let wait = state.next_allowed_at - now;
117                state.next_allowed_at += delay;
118                wait
119            } else {
120                state.next_allowed_at = now + delay;
121                Duration::ZERO
122            }
123        };
124
125        let sleep_duration = self.apply_jitter(sleep_duration);
126        if !sleep_duration.is_zero() {
127            debug!("Rate limiting: sleeping for {:?}", sleep_duration);
128            sleep(sleep_duration).await;
129        }
130    }
131
132    async fn adjust(&self, response: &Response) {
133        let mut state = self.state.lock().await;
134
135        let old_delay = state.delay;
136        let status = response.status.as_u16();
137        let new_delay = match status {
138            200..=399 => state.delay.mul_f64(SUCCESS_DECAY_MULTIPLIER),
139            403 => state.delay.mul_f64(FORBIDDEN_PENALTY_MULTIPLIER),
140            429 | 500..=599 => state.delay.mul_f64(ERROR_PENALTY_MULTIPLIER),
141            _ => state.delay,
142        };
143
144        state.delay = new_delay.clamp(MIN_DELAY, MAX_DELAY);
145
146        if old_delay != state.delay {
147            debug!(
148                "Adjusting delay for status {}: {:?} -> {:?}",
149                status, old_delay, state.delay
150            );
151        }
152    }
153
154    async fn current_delay(&self) -> Duration {
155        self.state.lock().await.delay
156    }
157}
158
159/// A rate limiter that uses a token bucket algorithm for a fixed rate.
160pub struct TokenBucketLimiter {
161    limiter: Arc<GovernorRateLimiter<NotKeyed, InMemoryState, DefaultClock>>,
162}
163
164impl TokenBucketLimiter {
165    /// Creates a new `TokenBucketLimiter` with the specified rate (requests per second).
166    pub fn new(requests_per_second: u32) -> Self {
167        let quota = Quota::per_second(
168            NonZeroU32::new(requests_per_second).expect("requests_per_second must be non-zero"),
169        );
170        TokenBucketLimiter {
171            limiter: Arc::new(GovernorRateLimiter::direct_with_clock(
172                quota,
173                &DefaultClock::default(),
174            )),
175        }
176    }
177}
178
179#[async_trait]
180impl RateLimiter for TokenBucketLimiter {
181    async fn acquire(&self) {
182        self.limiter.until_ready().await;
183    }
184
185    // A fixed-rate limiter does not adjust based on responses.
186    async fn adjust(&self, _response: &Response) {
187        // No-op for a fixed-rate limiter
188    }
189
190    async fn current_delay(&self) -> Duration {
191        // Token bucket doesn't directly expose a "current delay", but rather
192        // manages when the next request is allowed.
193        // Returning Duration::ZERO is a simplification, as delay is handled by `acquire`.
194        Duration::ZERO
195    }
196}
197
198/// A middleware for rate limiting requests.
199pub struct RateLimitMiddleware {
200    scope: Scope,
201    limiters: Cache<String, Arc<dyn RateLimiter>>,
202    limiter_factory: Arc<dyn Fn() -> Arc<dyn RateLimiter> + Send + Sync>,
203}
204
205impl RateLimitMiddleware {
206    /// Creates a new `RateLimitMiddlewareBuilder`.
207    pub fn builder() -> RateLimitMiddlewareBuilder {
208        RateLimitMiddlewareBuilder::default()
209    }
210
211    fn scope_key(&self, request: &Request) -> String {
212        match self.scope {
213            Scope::Global => "global".to_string(),
214            Scope::Domain => utils::normalize_origin(request),
215        }
216    }
217}
218
219#[async_trait]
220impl<C: Send + Sync> Middleware<C> for RateLimitMiddleware {
221    fn name(&self) -> &str {
222        "RateLimitMiddleware"
223    }
224
225    async fn process_request(
226        &mut self,
227        _client: &C,
228        request: Request,
229    ) -> Result<MiddlewareAction<Request>, SpiderError> {
230        let key = self.scope_key(&request);
231
232        let limiter = self
233            .limiters
234            .get_with(key.clone(), async { (self.limiter_factory)() })
235            .await;
236
237        let current_delay = limiter.current_delay().await;
238        debug!(
239            "Acquiring lock for key '{}' (delay: {:?})",
240            key, current_delay
241        );
242
243        limiter.acquire().await;
244        Ok(MiddlewareAction::Continue(request))
245    }
246
247    async fn process_response(
248        &mut self,
249        response: Response,
250    ) -> Result<MiddlewareAction<Response>, SpiderError> {
251        let key = self.scope_key(&response.request_from_response());
252
253        if let Some(limiter) = self.limiters.get(&key).await {
254            let old_delay = limiter.current_delay().await;
255            limiter.adjust(&response).await;
256            let new_delay = limiter.current_delay().await;
257            if old_delay != new_delay {
258                debug!(
259                    "Adjusted rate limit for key '{}': {:?} -> {:?}",
260                    key, old_delay, new_delay
261                );
262            }
263        }
264
265        Ok(MiddlewareAction::Continue(response))
266    }
267}
268
269/// Builder for `RateLimitMiddleware`.
270pub struct RateLimitMiddlewareBuilder {
271    scope: Scope,
272    cache_ttl: Duration,
273    cache_capacity: u64,
274    limiter_factory: Box<dyn Fn() -> Arc<dyn RateLimiter> + Send + Sync>,
275}
276
277impl Default for RateLimitMiddlewareBuilder {
278    fn default() -> Self {
279        Self {
280            scope: Scope::Domain,
281            cache_ttl: Duration::from_secs(3600),
282            cache_capacity: 10_000,
283            limiter_factory: Box::new(|| Arc::new(AdaptiveLimiter::new(INITIAL_DELAY, true))),
284        }
285    }
286}
287
288impl RateLimitMiddlewareBuilder {
289    /// Sets the scope for the rate limiter.
290    pub fn scope(mut self, scope: Scope) -> Self {
291        self.scope = scope;
292        self
293    }
294
295    /// Configures the builder to use a `TokenBucketLimiter` with the specified requests per second.
296    pub fn use_token_bucket_limiter(mut self, requests_per_second: u32) -> Self {
297        self.limiter_factory =
298            Box::new(move || Arc::new(TokenBucketLimiter::new(requests_per_second)));
299        self
300    }
301
302    /// Sets a specific rate limiter instance to be used.
303    pub fn limiter(mut self, limiter: impl RateLimiter + 'static) -> Self {
304        let arc = Arc::new(limiter);
305        self.limiter_factory = Box::new(move || arc.clone());
306        self
307    }
308
309    /// Sets a factory function for creating rate limiters.
310    pub fn limiter_factory(
311        mut self,
312        factory: impl Fn() -> Arc<dyn RateLimiter> + Send + Sync + 'static,
313    ) -> Self {
314        self.limiter_factory = Box::new(factory);
315        self
316    }
317
318    /// Builds the `RateLimitMiddleware`.
319    pub fn build(self) -> RateLimitMiddleware {
320        info!(
321            "Initializing RateLimitMiddleware with config: scope={:?}, cache_ttl={:?}, cache_capacity={}",
322            self.scope, self.cache_ttl, self.cache_capacity
323        );
324        RateLimitMiddleware {
325            scope: self.scope,
326            limiters: Cache::builder()
327                .time_to_idle(self.cache_ttl)
328                .max_capacity(self.cache_capacity)
329                .build(),
330            limiter_factory: self.limiter_factory.into(),
331        }
332    }
333}