Skip to main content

spider_middleware/
rate_limit.rs

1//! Rate Limit Middleware for controlling request frequency.
2//!
3//! This module provides the `RateLimitMiddleware`, designed to manage the rate
4//! at which HTTP requests are sent to target servers. It helps prevent
5//! overloading websites and respects server-side rate limits, making crawls
6//! more robust and polite.
7//!
8//! The middleware supports:
9//! - **Different scopes:** Applying rate limits globally or per-domain.
10//! - **Pluggable limiters:** Offering an `AdaptiveLimiter` that dynamically adjusts
11//!   delays based on response status (e.g., increasing delay on errors, decreasing on success),
12//!   and a `TokenBucketLimiter` for enforcing a fixed requests-per-second rate.
13//!
14//! This flexibility allows for fine-tuned control over crawl speed and server interaction.
15
16use async_trait::async_trait;
17use moka::future::Cache;
18use rand::distributions::{Distribution, Uniform};
19use std::hash::Hash;
20use std::sync::Arc;
21use std::time::Duration;
22use tokio::sync::Mutex;
23use tokio::time::{Instant, sleep};
24use tracing::{debug, info};
25
26use governor::clock::DefaultClock;
27use governor::state::{InMemoryState, NotKeyed};
28use governor::{Quota, RateLimiter as GovernorRateLimiter};
29use std::num::NonZeroU32;
30
31use crate::middleware::{Middleware, MiddlewareAction};
32use spider_util::error::SpiderError;
33use spider_util::request::Request;
34use spider_util::response::Response;
35
36/// Determines the scope at which rate limits are applied.
37#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
38pub enum Scope {
39    /// A single global rate limit for all requests.
40    Global,
41    /// A separate rate limit for each domain.
42    Domain,
43}
44
45/// A trait for asynchronous, stateful rate limiters.
46#[async_trait]
47pub trait RateLimiter: Send + Sync {
48    /// Blocks until a request is allowed to proceed.
49    async fn acquire(&self);
50    /// Adjusts the rate limit based on the response.
51    async fn adjust(&self, response: &Response);
52    /// Returns the current delay between requests.
53    async fn current_delay(&self) -> Duration;
54}
55
56const INITIAL_DELAY: Duration = Duration::from_millis(500);
57const MIN_DELAY: Duration = Duration::from_millis(50);
58const MAX_DELAY: Duration = Duration::from_secs(60);
59
60const ERROR_PENALTY_MULTIPLIER: f64 = 1.5;
61const SUCCESS_DECAY_MULTIPLIER: f64 = 0.95;
62const FORBIDDEN_PENALTY_MULTIPLIER: f64 = 1.2;
63
64struct AdaptiveState {
65    delay: Duration,
66    next_allowed_at: Instant,
67}
68
69/// An adaptive rate limiter that adjusts the delay based on response status.
70pub struct AdaptiveLimiter {
71    state: Mutex<AdaptiveState>,
72    jitter: bool,
73}
74
75impl AdaptiveLimiter {
76    /// Creates a new `AdaptiveLimiter` with a given initial delay and jitter setting.
77    pub fn new(initial_delay: Duration, jitter: bool) -> Self {
78        Self {
79            state: Mutex::new(AdaptiveState {
80                delay: initial_delay,
81                next_allowed_at: Instant::now(),
82            }),
83            jitter,
84        }
85    }
86
87    fn apply_jitter(&self, delay: Duration) -> Duration {
88        if !self.jitter || delay.is_zero() {
89            return delay;
90        }
91
92        let max_jitter = Duration::from_millis(500);
93        let jitter_window = delay.mul_f64(0.25).min(max_jitter);
94
95        let low = delay.saturating_sub(jitter_window);
96        let high = delay + jitter_window;
97
98        let mut rng = rand::thread_rng();
99        let uniform = Uniform::new_inclusive(low, high);
100        uniform.sample(&mut rng)
101    }
102}
103
104#[async_trait]
105impl RateLimiter for AdaptiveLimiter {
106    async fn acquire(&self) {
107        let sleep_duration = {
108            let mut state = self.state.lock().await;
109            let now = Instant::now();
110
111            let delay = state.delay;
112            if now < state.next_allowed_at {
113                let wait = state.next_allowed_at - now;
114                state.next_allowed_at += delay;
115                wait
116            } else {
117                state.next_allowed_at = now + delay;
118                Duration::ZERO
119            }
120        };
121
122        let sleep_duration = self.apply_jitter(sleep_duration);
123        if !sleep_duration.is_zero() {
124            debug!("Rate limiting: sleeping for {:?}", sleep_duration);
125            sleep(sleep_duration).await;
126        }
127    }
128
129    async fn adjust(&self, response: &Response) {
130        let mut state = self.state.lock().await;
131
132        let old_delay = state.delay;
133        let status = response.status.as_u16();
134        let new_delay = match status {
135            200..=399 => state.delay.mul_f64(SUCCESS_DECAY_MULTIPLIER),
136            403 => state.delay.mul_f64(FORBIDDEN_PENALTY_MULTIPLIER),
137            429 | 500..=599 => state.delay.mul_f64(ERROR_PENALTY_MULTIPLIER),
138            _ => state.delay,
139        };
140
141        state.delay = new_delay.clamp(MIN_DELAY, MAX_DELAY);
142
143        if old_delay != state.delay {
144            debug!(
145                "Adjusting delay for status {}: {:?} -> {:?}",
146                status, old_delay, state.delay
147            );
148        }
149    }
150
151    async fn current_delay(&self) -> Duration {
152        self.state.lock().await.delay
153    }
154}
155
156/// A rate limiter that uses a token bucket algorithm for a fixed rate.
157pub struct TokenBucketLimiter {
158    limiter: Arc<GovernorRateLimiter<NotKeyed, InMemoryState, DefaultClock>>,
159}
160
161impl TokenBucketLimiter {
162    /// Creates a new `TokenBucketLimiter` with the specified rate (requests per second).
163    pub fn new(requests_per_second: u32) -> Self {
164        let quota = Quota::per_second(
165            NonZeroU32::new(requests_per_second).expect("requests_per_second must be non-zero"),
166        );
167        TokenBucketLimiter {
168            limiter: Arc::new(GovernorRateLimiter::direct_with_clock(
169                quota,
170                &DefaultClock::default(),
171            )),
172        }
173    }
174}
175
176#[async_trait]
177impl RateLimiter for TokenBucketLimiter {
178    async fn acquire(&self) {
179        self.limiter.until_ready().await;
180    }
181
182    /// A fixed-rate limiter does not adjust based on responses.
183    async fn adjust(&self, _response: &Response) {
184        // No-op for a fixed-rate limiter
185    }
186
187    async fn current_delay(&self) -> Duration {
188        // Token bucket doesn't directly expose a "current delay", but rather
189        // manages when the next request is allowed.
190        // Returning Duration::ZERO is a simplification, as delay is handled by `acquire`.
191        Duration::ZERO
192    }
193}
194
195/// A middleware for rate limiting requests.
196pub struct RateLimitMiddleware {
197    scope: Scope,
198    limiters: Cache<String, Arc<dyn RateLimiter>>,
199    limiter_factory: Arc<dyn Fn() -> Arc<dyn RateLimiter> + Send + Sync>,
200}
201
202impl RateLimitMiddleware {
203    /// Creates a new `RateLimitMiddlewareBuilder`.
204    pub fn builder() -> RateLimitMiddlewareBuilder {
205        RateLimitMiddlewareBuilder::default()
206    }
207
208    fn scope_key(&self, request: &Request) -> String {
209        match self.scope {
210            Scope::Global => "global".to_string(),
211            Scope::Domain => spider_util::utils::normalize_origin(request),
212        }
213    }
214}
215
216#[async_trait]
217impl<C: Send + Sync> Middleware<C> for RateLimitMiddleware {
218    fn name(&self) -> &str {
219        "RateLimitMiddleware"
220    }
221
222    async fn process_request(
223        &mut self,
224        _client: &C,
225        request: Request,
226    ) -> Result<MiddlewareAction<Request>, SpiderError> {
227        let key = self.scope_key(&request);
228
229        let limiter = self
230            .limiters
231            .get_with(key.clone(), async { (self.limiter_factory)() })
232            .await;
233
234        let current_delay = limiter.current_delay().await;
235        debug!(
236            "Acquiring lock for key '{}' (delay: {:?})",
237            key, current_delay
238        );
239
240        limiter.acquire().await;
241        Ok(MiddlewareAction::Continue(request))
242    }
243
244    async fn process_response(
245        &mut self,
246        response: Response,
247    ) -> Result<MiddlewareAction<Response>, SpiderError> {
248        let key = self.scope_key(&response.request_from_response());
249
250        if let Some(limiter) = self.limiters.get(&key).await {
251            let old_delay = limiter.current_delay().await;
252            limiter.adjust(&response).await;
253            let new_delay = limiter.current_delay().await;
254            if old_delay != new_delay {
255                debug!(
256                    "Adjusted rate limit for key '{}': {:?} -> {:?}",
257                    key, old_delay, new_delay
258                );
259            }
260        }
261
262        Ok(MiddlewareAction::Continue(response))
263    }
264}
265
266/// Builder for `RateLimitMiddleware`.
267pub struct RateLimitMiddlewareBuilder {
268    scope: Scope,
269    cache_ttl: Duration,
270    cache_capacity: u64,
271    limiter_factory: Box<dyn Fn() -> Arc<dyn RateLimiter> + Send + Sync>,
272}
273
274impl Default for RateLimitMiddlewareBuilder {
275    fn default() -> Self {
276        Self {
277            scope: Scope::Domain,
278            cache_ttl: Duration::from_secs(3600),
279            cache_capacity: 10_000,
280            limiter_factory: Box::new(|| Arc::new(AdaptiveLimiter::new(INITIAL_DELAY, true))),
281        }
282    }
283}
284
285impl RateLimitMiddlewareBuilder {
286    /// Sets the scope for the rate limiter.
287    pub fn scope(mut self, scope: Scope) -> Self {
288        self.scope = scope;
289        self
290    }
291
292    /// Configures the builder to use a `TokenBucketLimiter` with the specified requests per second.
293    pub fn use_token_bucket_limiter(mut self, requests_per_second: u32) -> Self {
294        self.limiter_factory =
295            Box::new(move || Arc::new(TokenBucketLimiter::new(requests_per_second)));
296        self
297    }
298
299    /// Sets a specific rate limiter instance to be used.
300    pub fn limiter(mut self, limiter: impl RateLimiter + 'static) -> Self {
301        let arc = Arc::new(limiter);
302        self.limiter_factory = Box::new(move || arc.clone());
303        self
304    }
305
306    /// Sets a factory function for creating rate limiters.
307    pub fn limiter_factory(
308        mut self,
309        factory: impl Fn() -> Arc<dyn RateLimiter> + Send + Sync + 'static,
310    ) -> Self {
311        self.limiter_factory = Box::new(factory);
312        self
313    }
314
315    /// Builds the `RateLimitMiddleware`.
316    pub fn build(self) -> RateLimitMiddleware {
317        info!(
318            "Initializing RateLimitMiddleware with config: scope={:?}, cache_ttl={:?}, cache_capacity={}",
319            self.scope, self.cache_ttl, self.cache_capacity
320        );
321        RateLimitMiddleware {
322            scope: self.scope,
323            limiters: Cache::builder()
324                .time_to_idle(self.cache_ttl)
325                .max_capacity(self.cache_capacity)
326                .build(),
327            limiter_factory: self.limiter_factory.into(),
328        }
329    }
330}