Skip to main content

spider_middleware/
rate_limit.rs

1//! Rate-limiting middleware.
2//!
3//! This module provides both adaptive and fixed-rate limiters and wraps them in
4//! [`RateLimitMiddleware`], which can be applied globally or per domain.
5
6use async_trait::async_trait;
7use log::{debug, info};
8use moka::future::Cache;
9use rand::distributions::{Distribution, Uniform};
10use std::hash::Hash;
11use std::sync::Arc;
12use std::time::Duration;
13use tokio::sync::Mutex;
14use tokio::time::{Instant, sleep};
15
16use governor::clock::DefaultClock;
17use governor::state::{InMemoryState, NotKeyed};
18use governor::{Quota, RateLimiter as GovernorRateLimiter};
19use std::num::NonZeroU32;
20
21use crate::middleware::{Middleware, MiddlewareAction};
22use spider_util::constants::{
23    MIDDLEWARE_CACHE_CAPACITY, MIDDLEWARE_CACHE_TTL_SECS, RATE_LIMIT_INITIAL_DELAY_MS,
24    RATE_LIMIT_MAX_DELAY_MS, RATE_LIMIT_MAX_JITTER_MS, RATE_LIMIT_MIN_DELAY_MS,
25};
26use spider_util::error::SpiderError;
27use spider_util::request::Request;
28use spider_util::response::Response;
29
30/// Scope at which the middleware applies rate limits.
31#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
32pub enum Scope {
33    /// A single global rate limit for all requests.
34    Global,
35    /// A separate rate limit for each domain.
36    Domain,
37}
38
39/// Trait for asynchronous, stateful rate limiters.
40#[async_trait]
41pub trait RateLimiter: Send + Sync {
42    /// Blocks until a request is allowed to proceed.
43    async fn acquire(&self);
44    /// Adjusts the rate limit based on the response.
45    async fn adjust(&self, response: &Response);
46    /// Returns the current delay between requests.
47    async fn current_delay(&self) -> Duration;
48}
49
50const INITIAL_DELAY: Duration = Duration::from_millis(RATE_LIMIT_INITIAL_DELAY_MS);
51const MIN_DELAY: Duration = Duration::from_millis(RATE_LIMIT_MIN_DELAY_MS);
52const MAX_DELAY: Duration = Duration::from_millis(RATE_LIMIT_MAX_DELAY_MS);
53
54const ERROR_PENALTY_MULTIPLIER: f64 = 1.5;
55const SUCCESS_DECAY_MULTIPLIER: f64 = 0.95;
56const FORBIDDEN_PENALTY_MULTIPLIER: f64 = 1.2;
57
58struct AdaptiveState {
59    delay: Duration,
60    next_allowed_at: Instant,
61}
62
63/// Adaptive limiter that reacts to response outcomes.
64pub struct AdaptiveLimiter {
65    state: Mutex<AdaptiveState>,
66    jitter: bool,
67}
68
69impl AdaptiveLimiter {
70    /// Creates a new `AdaptiveLimiter` with a given initial delay and jitter setting.
71    pub fn new(initial_delay: Duration, jitter: bool) -> Self {
72        Self {
73            state: Mutex::new(AdaptiveState {
74                delay: initial_delay,
75                next_allowed_at: Instant::now(),
76            }),
77            jitter,
78        }
79    }
80
81    fn apply_jitter(&self, delay: Duration) -> Duration {
82        if !self.jitter || delay.is_zero() {
83            return delay;
84        }
85
86        let max_jitter = Duration::from_millis(RATE_LIMIT_MAX_JITTER_MS);
87        let jitter_window = delay.mul_f64(0.25).min(max_jitter);
88
89        let low = delay.saturating_sub(jitter_window);
90        let high = delay + jitter_window;
91
92        let mut rng = rand::thread_rng();
93        let uniform = Uniform::new_inclusive(low, high);
94        uniform.sample(&mut rng)
95    }
96}
97
98#[async_trait]
99impl RateLimiter for AdaptiveLimiter {
100    async fn acquire(&self) {
101        let sleep_duration = {
102            let mut state = self.state.lock().await;
103            let now = Instant::now();
104
105            let delay = state.delay;
106            if now < state.next_allowed_at {
107                let wait = state.next_allowed_at - now;
108                state.next_allowed_at += delay;
109                wait
110            } else {
111                state.next_allowed_at = now + delay;
112                Duration::ZERO
113            }
114        };
115
116        let sleep_duration = self.apply_jitter(sleep_duration);
117        if !sleep_duration.is_zero() {
118            debug!("Rate limiting: sleeping for {:?}", sleep_duration);
119            sleep(sleep_duration).await;
120        }
121    }
122
123    async fn adjust(&self, response: &Response) {
124        let mut state = self.state.lock().await;
125
126        let old_delay = state.delay;
127        let status = response.status.as_u16();
128        let new_delay = match status {
129            200..=399 => state.delay.mul_f64(SUCCESS_DECAY_MULTIPLIER),
130            403 => state.delay.mul_f64(FORBIDDEN_PENALTY_MULTIPLIER),
131            429 | 500..=599 => state.delay.mul_f64(ERROR_PENALTY_MULTIPLIER),
132            _ => state.delay,
133        };
134
135        state.delay = new_delay.clamp(MIN_DELAY, MAX_DELAY);
136
137        if old_delay != state.delay {
138            debug!(
139                "Adjusting delay for status {}: {:?} -> {:?}",
140                status, old_delay, state.delay
141            );
142        }
143    }
144
145    async fn current_delay(&self) -> Duration {
146        self.state.lock().await.delay
147    }
148}
149
150/// Fixed-rate limiter backed by a token bucket.
151pub struct TokenBucketLimiter {
152    limiter: Arc<GovernorRateLimiter<NotKeyed, InMemoryState, DefaultClock>>,
153}
154
155impl TokenBucketLimiter {
156    /// Creates a new `TokenBucketLimiter` with the specified rate (requests per second).
157    ///
158    /// # Panics
159    ///
160    /// Panics if `requests_per_second` is `0`.
161    pub fn new(requests_per_second: u32) -> Self {
162        let requests_per_second = match NonZeroU32::new(requests_per_second) {
163            Some(rps) => rps,
164            None => panic!("requests_per_second must be non-zero"),
165        };
166        let quota = Quota::per_second(requests_per_second);
167        TokenBucketLimiter {
168            limiter: Arc::new(GovernorRateLimiter::direct_with_clock(
169                quota,
170                &DefaultClock::default(),
171            )),
172        }
173    }
174}
175
176#[async_trait]
177impl RateLimiter for TokenBucketLimiter {
178    async fn acquire(&self) {
179        self.limiter.until_ready().await;
180    }
181
182    /// A fixed-rate limiter does not adjust based on responses.
183    async fn adjust(&self, _response: &Response) {}
184
185    async fn current_delay(&self) -> Duration {
186        Duration::ZERO
187    }
188}
189
190/// A middleware for rate limiting requests.
191pub struct RateLimitMiddleware {
192    scope: Scope,
193    limiters: Cache<String, Arc<dyn RateLimiter>>,
194    limiter_factory: Arc<dyn Fn() -> Arc<dyn RateLimiter> + Send + Sync>,
195}
196
197impl RateLimitMiddleware {
198    /// Creates a new `RateLimitMiddlewareBuilder`.
199    pub fn builder() -> RateLimitMiddlewareBuilder {
200        RateLimitMiddlewareBuilder::default()
201    }
202
203    fn scope_key(&self, request: &Request) -> String {
204        match self.scope {
205            Scope::Global => "global".to_string(),
206            Scope::Domain => spider_util::util::normalize_origin(request),
207        }
208    }
209}
210
211#[async_trait]
212impl<C: Send + Sync> Middleware<C> for RateLimitMiddleware {
213    fn name(&self) -> &str {
214        "RateLimitMiddleware"
215    }
216
217    async fn process_request(
218        &self,
219        _client: &C,
220        request: Request,
221    ) -> Result<MiddlewareAction<Request>, SpiderError> {
222        let key = self.scope_key(&request);
223
224        let limiter = self
225            .limiters
226            .get_with(key.clone(), async { (self.limiter_factory)() })
227            .await;
228
229        let current_delay = limiter.current_delay().await;
230        debug!(
231            "Acquiring lock for key '{}' (delay: {:?})",
232            key, current_delay
233        );
234
235        limiter.acquire().await;
236        Ok(MiddlewareAction::Continue(request))
237    }
238
239    async fn process_response(
240        &self,
241        response: Response,
242    ) -> Result<MiddlewareAction<Response>, SpiderError> {
243        let key = self.scope_key(&response.request_from_response());
244
245        if let Some(limiter) = self.limiters.get(&key).await {
246            let old_delay = limiter.current_delay().await;
247            limiter.adjust(&response).await;
248            let new_delay = limiter.current_delay().await;
249            if old_delay != new_delay {
250                debug!(
251                    "Adjusted rate limit for key '{}': {:?} -> {:?}",
252                    key, old_delay, new_delay
253                );
254            }
255        }
256
257        Ok(MiddlewareAction::Continue(response))
258    }
259}
260
261/// Builder for `RateLimitMiddleware`.
262pub struct RateLimitMiddlewareBuilder {
263    scope: Scope,
264    cache_ttl: Duration,
265    cache_capacity: u64,
266    limiter_factory: Box<dyn Fn() -> Arc<dyn RateLimiter> + Send + Sync>,
267}
268
269impl Default for RateLimitMiddlewareBuilder {
270    fn default() -> Self {
271        Self {
272            scope: Scope::Domain,
273            cache_ttl: Duration::from_secs(MIDDLEWARE_CACHE_TTL_SECS),
274            cache_capacity: MIDDLEWARE_CACHE_CAPACITY,
275            limiter_factory: Box::new(|| Arc::new(AdaptiveLimiter::new(INITIAL_DELAY, true))),
276        }
277    }
278}
279
280impl RateLimitMiddlewareBuilder {
281    /// Sets the scope for the rate limiter.
282    pub fn scope(mut self, scope: Scope) -> Self {
283        self.scope = scope;
284        self
285    }
286
287    /// Configures the builder to use a `TokenBucketLimiter` with the specified requests per second.
288    pub fn use_token_bucket_limiter(mut self, requests_per_second: u32) -> Self {
289        self.limiter_factory =
290            Box::new(move || Arc::new(TokenBucketLimiter::new(requests_per_second)));
291        self
292    }
293
294    /// Sets a specific rate limiter instance to be used.
295    pub fn limiter(mut self, limiter: impl RateLimiter + 'static) -> Self {
296        let arc = Arc::new(limiter);
297        self.limiter_factory = Box::new(move || arc.clone());
298        self
299    }
300
301    /// Sets a factory function for creating rate limiters.
302    pub fn limiter_factory(
303        mut self,
304        factory: impl Fn() -> Arc<dyn RateLimiter> + Send + Sync + 'static,
305    ) -> Self {
306        self.limiter_factory = Box::new(factory);
307        self
308    }
309
310    /// Builds the `RateLimitMiddleware`.
311    pub fn build(self) -> RateLimitMiddleware {
312        info!(
313            "Initializing RateLimitMiddleware with config: scope={:?}, cache_ttl={:?}, cache_capacity={}",
314            self.scope, self.cache_ttl, self.cache_capacity
315        );
316        RateLimitMiddleware {
317            scope: self.scope,
318            limiters: Cache::builder()
319                .time_to_idle(self.cache_ttl)
320                .max_capacity(self.cache_capacity)
321                .build(),
322            limiter_factory: self.limiter_factory.into(),
323        }
324    }
325}