Skip to main content

spider_middleware/
rate_limit.rs

1//! Rate Limit Middleware for controlling request frequency.
2//!
3//! This module provides the `RateLimitMiddleware`, designed to manage the rate
4//! at which HTTP requests are sent to target servers. It helps prevent
5//! overloading websites and respects server-side rate limits, making crawls
6//! more robust and polite.
7//!
8//! The middleware supports:
9//! - **Different scopes:** Applying rate limits globally or per-domain.
10//! - **Pluggable limiters:** Offering an `AdaptiveLimiter` that dynamically adjusts
11//!   delays based on response status (e.g., increasing delay on errors, decreasing on success),
12//!   and a `TokenBucketLimiter` for enforcing a fixed requests-per-second rate.
13//!
14//! This flexibility allows for fine-tuned control over crawl speed and server interaction.
15
16use async_trait::async_trait;
17use log::{debug, info};
18use moka::future::Cache;
19use rand::distributions::{Distribution, Uniform};
20use std::hash::Hash;
21use std::sync::Arc;
22use std::time::Duration;
23use tokio::sync::Mutex;
24use tokio::time::{Instant, sleep};
25
26use governor::clock::DefaultClock;
27use governor::state::{InMemoryState, NotKeyed};
28use governor::{Quota, RateLimiter as GovernorRateLimiter};
29use std::num::NonZeroU32;
30
31use crate::middleware::{Middleware, MiddlewareAction};
32use spider_util::constants::{
33    MIDDLEWARE_CACHE_CAPACITY, MIDDLEWARE_CACHE_TTL_SECS, RATE_LIMIT_INITIAL_DELAY_MS,
34    RATE_LIMIT_MAX_DELAY_MS, RATE_LIMIT_MAX_JITTER_MS, RATE_LIMIT_MIN_DELAY_MS,
35};
36use spider_util::error::SpiderError;
37use spider_util::request::Request;
38use spider_util::response::Response;
39
40/// Determines the scope at which rate limits are applied.
41#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
42pub enum Scope {
43    /// A single global rate limit for all requests.
44    Global,
45    /// A separate rate limit for each domain.
46    Domain,
47}
48
49/// A trait for asynchronous, stateful rate limiters.
50#[async_trait]
51pub trait RateLimiter: Send + Sync {
52    /// Blocks until a request is allowed to proceed.
53    async fn acquire(&self);
54    /// Adjusts the rate limit based on the response.
55    async fn adjust(&self, response: &Response);
56    /// Returns the current delay between requests.
57    async fn current_delay(&self) -> Duration;
58}
59
60const INITIAL_DELAY: Duration = Duration::from_millis(RATE_LIMIT_INITIAL_DELAY_MS);
61const MIN_DELAY: Duration = Duration::from_millis(RATE_LIMIT_MIN_DELAY_MS);
62const MAX_DELAY: Duration = Duration::from_millis(RATE_LIMIT_MAX_DELAY_MS);
63
64const ERROR_PENALTY_MULTIPLIER: f64 = 1.5;
65const SUCCESS_DECAY_MULTIPLIER: f64 = 0.95;
66const FORBIDDEN_PENALTY_MULTIPLIER: f64 = 1.2;
67
68struct AdaptiveState {
69    delay: Duration,
70    next_allowed_at: Instant,
71}
72
73/// An adaptive rate limiter that adjusts the delay based on response status.
74pub struct AdaptiveLimiter {
75    state: Mutex<AdaptiveState>,
76    jitter: bool,
77}
78
79impl AdaptiveLimiter {
80    /// Creates a new `AdaptiveLimiter` with a given initial delay and jitter setting.
81    pub fn new(initial_delay: Duration, jitter: bool) -> Self {
82        Self {
83            state: Mutex::new(AdaptiveState {
84                delay: initial_delay,
85                next_allowed_at: Instant::now(),
86            }),
87            jitter,
88        }
89    }
90
91    fn apply_jitter(&self, delay: Duration) -> Duration {
92        if !self.jitter || delay.is_zero() {
93            return delay;
94        }
95
96        let max_jitter = Duration::from_millis(RATE_LIMIT_MAX_JITTER_MS);
97        let jitter_window = delay.mul_f64(0.25).min(max_jitter);
98
99        let low = delay.saturating_sub(jitter_window);
100        let high = delay + jitter_window;
101
102        let mut rng = rand::thread_rng();
103        let uniform = Uniform::new_inclusive(low, high);
104        uniform.sample(&mut rng)
105    }
106}
107
108#[async_trait]
109impl RateLimiter for AdaptiveLimiter {
110    async fn acquire(&self) {
111        let sleep_duration = {
112            let mut state = self.state.lock().await;
113            let now = Instant::now();
114
115            let delay = state.delay;
116            if now < state.next_allowed_at {
117                let wait = state.next_allowed_at - now;
118                state.next_allowed_at += delay;
119                wait
120            } else {
121                state.next_allowed_at = now + delay;
122                Duration::ZERO
123            }
124        };
125
126        let sleep_duration = self.apply_jitter(sleep_duration);
127        if !sleep_duration.is_zero() {
128            debug!("Rate limiting: sleeping for {:?}", sleep_duration);
129            sleep(sleep_duration).await;
130        }
131    }
132
133    async fn adjust(&self, response: &Response) {
134        let mut state = self.state.lock().await;
135
136        let old_delay = state.delay;
137        let status = response.status.as_u16();
138        let new_delay = match status {
139            200..=399 => state.delay.mul_f64(SUCCESS_DECAY_MULTIPLIER),
140            403 => state.delay.mul_f64(FORBIDDEN_PENALTY_MULTIPLIER),
141            429 | 500..=599 => state.delay.mul_f64(ERROR_PENALTY_MULTIPLIER),
142            _ => state.delay,
143        };
144
145        state.delay = new_delay.clamp(MIN_DELAY, MAX_DELAY);
146
147        if old_delay != state.delay {
148            debug!(
149                "Adjusting delay for status {}: {:?} -> {:?}",
150                status, old_delay, state.delay
151            );
152        }
153    }
154
155    async fn current_delay(&self) -> Duration {
156        self.state.lock().await.delay
157    }
158}
159
160/// A rate limiter that uses a token bucket algorithm for a fixed rate.
161pub struct TokenBucketLimiter {
162    limiter: Arc<GovernorRateLimiter<NotKeyed, InMemoryState, DefaultClock>>,
163}
164
165impl TokenBucketLimiter {
166    /// Creates a new `TokenBucketLimiter` with the specified rate (requests per second).
167    ///
168    /// # Panics
169    ///
170    /// Panics if `requests_per_second` is `0`.
171    pub fn new(requests_per_second: u32) -> Self {
172        let requests_per_second = match NonZeroU32::new(requests_per_second) {
173            Some(rps) => rps,
174            None => panic!("requests_per_second must be non-zero"),
175        };
176        let quota = Quota::per_second(requests_per_second);
177        TokenBucketLimiter {
178            limiter: Arc::new(GovernorRateLimiter::direct_with_clock(
179                quota,
180                &DefaultClock::default(),
181            )),
182        }
183    }
184}
185
186#[async_trait]
187impl RateLimiter for TokenBucketLimiter {
188    async fn acquire(&self) {
189        self.limiter.until_ready().await;
190    }
191
192    /// A fixed-rate limiter does not adjust based on responses.
193    async fn adjust(&self, _response: &Response) {}
194
195    async fn current_delay(&self) -> Duration {
196        Duration::ZERO
197    }
198}
199
200/// A middleware for rate limiting requests.
201pub struct RateLimitMiddleware {
202    scope: Scope,
203    limiters: Cache<String, Arc<dyn RateLimiter>>,
204    limiter_factory: Arc<dyn Fn() -> Arc<dyn RateLimiter> + Send + Sync>,
205}
206
207impl RateLimitMiddleware {
208    /// Creates a new `RateLimitMiddlewareBuilder`.
209    pub fn builder() -> RateLimitMiddlewareBuilder {
210        RateLimitMiddlewareBuilder::default()
211    }
212
213    fn scope_key(&self, request: &Request) -> String {
214        match self.scope {
215            Scope::Global => "global".to_string(),
216            Scope::Domain => spider_util::util::normalize_origin(request),
217        }
218    }
219}
220
221#[async_trait]
222impl<C: Send + Sync> Middleware<C> for RateLimitMiddleware {
223    fn name(&self) -> &str {
224        "RateLimitMiddleware"
225    }
226
227    async fn process_request(
228        &mut self,
229        _client: &C,
230        request: Request,
231    ) -> Result<MiddlewareAction<Request>, SpiderError> {
232        let key = self.scope_key(&request);
233
234        let limiter = self
235            .limiters
236            .get_with(key.clone(), async { (self.limiter_factory)() })
237            .await;
238
239        let current_delay = limiter.current_delay().await;
240        debug!(
241            "Acquiring lock for key '{}' (delay: {:?})",
242            key, current_delay
243        );
244
245        limiter.acquire().await;
246        Ok(MiddlewareAction::Continue(request))
247    }
248
249    async fn process_response(
250        &mut self,
251        response: Response,
252    ) -> Result<MiddlewareAction<Response>, SpiderError> {
253        let key = self.scope_key(&response.request_from_response());
254
255        if let Some(limiter) = self.limiters.get(&key).await {
256            let old_delay = limiter.current_delay().await;
257            limiter.adjust(&response).await;
258            let new_delay = limiter.current_delay().await;
259            if old_delay != new_delay {
260                debug!(
261                    "Adjusted rate limit for key '{}': {:?} -> {:?}",
262                    key, old_delay, new_delay
263                );
264            }
265        }
266
267        Ok(MiddlewareAction::Continue(response))
268    }
269}
270
271/// Builder for `RateLimitMiddleware`.
272pub struct RateLimitMiddlewareBuilder {
273    scope: Scope,
274    cache_ttl: Duration,
275    cache_capacity: u64,
276    limiter_factory: Box<dyn Fn() -> Arc<dyn RateLimiter> + Send + Sync>,
277}
278
279impl Default for RateLimitMiddlewareBuilder {
280    fn default() -> Self {
281        Self {
282            scope: Scope::Domain,
283            cache_ttl: Duration::from_secs(MIDDLEWARE_CACHE_TTL_SECS),
284            cache_capacity: MIDDLEWARE_CACHE_CAPACITY,
285            limiter_factory: Box::new(|| Arc::new(AdaptiveLimiter::new(INITIAL_DELAY, true))),
286        }
287    }
288}
289
290impl RateLimitMiddlewareBuilder {
291    /// Sets the scope for the rate limiter.
292    pub fn scope(mut self, scope: Scope) -> Self {
293        self.scope = scope;
294        self
295    }
296
297    /// Configures the builder to use a `TokenBucketLimiter` with the specified requests per second.
298    pub fn use_token_bucket_limiter(mut self, requests_per_second: u32) -> Self {
299        self.limiter_factory =
300            Box::new(move || Arc::new(TokenBucketLimiter::new(requests_per_second)));
301        self
302    }
303
304    /// Sets a specific rate limiter instance to be used.
305    pub fn limiter(mut self, limiter: impl RateLimiter + 'static) -> Self {
306        let arc = Arc::new(limiter);
307        self.limiter_factory = Box::new(move || arc.clone());
308        self
309    }
310
311    /// Sets a factory function for creating rate limiters.
312    pub fn limiter_factory(
313        mut self,
314        factory: impl Fn() -> Arc<dyn RateLimiter> + Send + Sync + 'static,
315    ) -> Self {
316        self.limiter_factory = Box::new(factory);
317        self
318    }
319
320    /// Builds the `RateLimitMiddleware`.
321    pub fn build(self) -> RateLimitMiddleware {
322        info!(
323            "Initializing RateLimitMiddleware with config: scope={:?}, cache_ttl={:?}, cache_capacity={}",
324            self.scope, self.cache_ttl, self.cache_capacity
325        );
326        RateLimitMiddleware {
327            scope: self.scope,
328            limiters: Cache::builder()
329                .time_to_idle(self.cache_ttl)
330                .max_capacity(self.cache_capacity)
331                .build(),
332            limiter_factory: self.limiter_factory.into(),
333        }
334    }
335}