Skip to main content

spider_lib/middlewares/
rate_limit.rs

1use async_trait::async_trait;
2use moka::future::Cache;
3use rand::distributions::{Distribution, Uniform};
4use std::hash::Hash;
5use std::sync::Arc;
6use std::time::Duration;
7use tokio::sync::Mutex;
8use tokio::time::{Instant, sleep};
9use tracing::{debug, info};
10
11use governor::{Quota, RateLimiter as GovernorRateLimiter};
12use governor::state::{InMemoryState, NotKeyed};
13use governor::clock::DefaultClock;
14use std::num::NonZeroU32;
15
16use crate::utils;
17use crate::{
18    error::SpiderError,
19    middleware::{Middleware, MiddlewareAction},
20    request::Request,
21    response::Response,
22};
23
24/// Determines the scope at which rate limits are applied.
25#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
26pub enum Scope {
27    /// A single global rate limit for all requests.
28    Global,
29    /// A separate rate limit for each domain.
30    Domain,
31}
32
33/// A trait for asynchronous, stateful rate limiters.
34#[async_trait]
35pub trait RateLimiter: Send + Sync {
36    /// Blocks until a request is allowed to proceed.
37    async fn acquire(&self);
38    /// Adjusts the rate limit based on the response.
39    async fn adjust(&self, response: &Response);
40    /// Returns the current delay between requests.
41    async fn current_delay(&self) -> Duration;
42}
43
44const INITIAL_DELAY: Duration = Duration::from_millis(500);
45const MIN_DELAY: Duration = Duration::from_millis(50);
46const MAX_DELAY: Duration = Duration::from_secs(60);
47
48const ERROR_PENALTY_MULTIPLIER: f64 = 1.5;
49const SUCCESS_DECAY_MULTIPLIER: f64 = 0.95;
50const FORBIDDEN_PENALTY_MULTIPLIER: f64 = 1.2;
51
52struct AdaptiveState {
53    delay: Duration,
54    next_allowed_at: Instant,
55}
56
57/// An adaptive rate limiter that adjusts the delay based on response status.
58pub struct AdaptiveLimiter {
59    state: Mutex<AdaptiveState>,
60    jitter: bool,
61}
62
63impl AdaptiveLimiter {
64    /// Creates a new `AdaptiveLimiter` with a given initial delay and jitter setting.
65    pub fn new(initial_delay: Duration, jitter: bool) -> Self {
66        Self {
67            state: Mutex::new(AdaptiveState {
68                delay: initial_delay,
69                next_allowed_at: Instant::now(),
70            }),
71            jitter,
72        }
73    }
74
75    fn apply_jitter(&self, delay: Duration) -> Duration {
76        if !self.jitter || delay.is_zero() {
77            return delay;
78        }
79
80        let max_jitter = Duration::from_millis(500);
81        let jitter_window = delay.mul_f64(0.25).min(max_jitter);
82
83        let low = delay.saturating_sub(jitter_window);
84        let high = delay + jitter_window;
85
86        let mut rng = rand::thread_rng();
87        let uniform = Uniform::new_inclusive(low, high);
88        uniform.sample(&mut rng)
89    }
90}
91
92#[async_trait]
93impl RateLimiter for AdaptiveLimiter {
94    async fn acquire(&self) {
95        let sleep_duration = {
96            let mut state = self.state.lock().await;
97            let now = Instant::now();
98
99            let delay = state.delay;
100            if now < state.next_allowed_at {
101                let wait = state.next_allowed_at - now;
102                state.next_allowed_at += delay;
103                wait
104            } else {
105                state.next_allowed_at = now + delay;
106                Duration::ZERO
107            }
108        };
109
110        let sleep_duration = self.apply_jitter(sleep_duration);
111        if !sleep_duration.is_zero() {
112            debug!("Rate limiting: sleeping for {:?}", sleep_duration);
113            sleep(sleep_duration).await;
114        }
115    }
116
117    async fn adjust(&self, response: &Response) {
118        let mut state = self.state.lock().await;
119
120        let old_delay = state.delay;
121        let status = response.status.as_u16();
122        let new_delay = match status {
123            200..=399 => state.delay.mul_f64(SUCCESS_DECAY_MULTIPLIER),
124            403 => state.delay.mul_f64(FORBIDDEN_PENALTY_MULTIPLIER),
125            429 | 500..=599 => state.delay.mul_f64(ERROR_PENALTY_MULTIPLIER),
126            _ => state.delay,
127        };
128
129        state.delay = new_delay.clamp(MIN_DELAY, MAX_DELAY);
130
131        if old_delay != state.delay {
132            debug!(
133                "Adjusting delay for status {}: {:?} -> {:?}",
134                status, old_delay, state.delay
135            );
136        }
137    }
138
139    async fn current_delay(&self) -> Duration {
140        self.state.lock().await.delay
141    }
142}
143
144/// A rate limiter that uses a token bucket algorithm for a fixed rate.
145pub struct TokenBucketLimiter {
146    limiter: Arc<GovernorRateLimiter<NotKeyed, InMemoryState, DefaultClock>>,
147}
148
149impl TokenBucketLimiter {
150    /// Creates a new `TokenBucketLimiter` with the specified rate (requests per second).
151    pub fn new(requests_per_second: u32) -> Self {
152        let quota = Quota::per_second(NonZeroU32::new(requests_per_second).expect("requests_per_second must be non-zero"));
153        TokenBucketLimiter {
154            limiter: Arc::new(GovernorRateLimiter::direct_with_clock(quota, &DefaultClock::default())),
155        }
156    }
157}
158
159#[async_trait]
160impl RateLimiter for TokenBucketLimiter {
161    async fn acquire(&self) {
162        self.limiter.until_ready().await;
163    }
164
165    // A fixed-rate limiter does not adjust based on responses.
166    async fn adjust(&self, _response: &Response) {
167        // No-op for a fixed-rate limiter
168    }
169
170    async fn current_delay(&self) -> Duration {
171        // Token bucket doesn't directly expose a "current delay", but rather
172        // manages when the next request is allowed.
173        // Returning Duration::ZERO is a simplification, as delay is handled by `acquire`.
174        Duration::ZERO
175    }
176}
177
178
179/// A middleware for rate limiting requests.
180pub struct RateLimitMiddleware {
181    scope: Scope,
182    limiters: Cache<String, Arc<dyn RateLimiter>>,
183    limiter_factory: Arc<dyn Fn() -> Arc<dyn RateLimiter> + Send + Sync>,
184}
185
186impl RateLimitMiddleware {
187    /// Creates a new `RateLimitMiddlewareBuilder`.
188    pub fn builder() -> RateLimitMiddlewareBuilder {
189        RateLimitMiddlewareBuilder::default()
190    }
191
192    fn scope_key(&self, request: &Request) -> String {
193        match self.scope {
194            Scope::Global => "global".to_string(),
195            Scope::Domain => utils::normalize_origin(request),
196        }
197    }
198}
199
200#[async_trait]
201impl<C: Send + Sync> Middleware<C> for RateLimitMiddleware {
202    fn name(&self) -> &str {
203        "RateLimitMiddleware"
204    }
205
206    async fn process_request(
207        &mut self,
208        _client: &C,
209        request: Request,
210    ) -> Result<MiddlewareAction<Request>, SpiderError> {
211        let key = self.scope_key(&request);
212
213        let limiter = self
214            .limiters
215            .get_with(key.clone(), async { (self.limiter_factory)() })
216            .await;
217
218        let current_delay = limiter.current_delay().await;
219        debug!(
220            "Acquiring lock for key '{}' (delay: {:?})",
221            key, current_delay
222        );
223
224        limiter.acquire().await;
225        Ok(MiddlewareAction::Continue(request))
226    }
227
228    async fn process_response(
229        &mut self,
230        response: Response,
231    ) -> Result<MiddlewareAction<Response>, SpiderError> {
232        let key = self.scope_key(&response.request_from_response());
233
234        if let Some(limiter) = self.limiters.get(&key).await {
235            let old_delay = limiter.current_delay().await;
236            limiter.adjust(&response).await;
237            let new_delay = limiter.current_delay().await;
238            if old_delay != new_delay {
239                debug!(
240                    "Adjusted rate limit for key '{}': {:?} -> {:?}",
241                    key, old_delay, new_delay
242                );
243            }
244        }
245
246        Ok(MiddlewareAction::Continue(response))
247    }
248}
249
250/// Builder for `RateLimitMiddleware`.
251pub struct RateLimitMiddlewareBuilder {
252    scope: Scope,
253    cache_ttl: Duration,
254    cache_capacity: u64,
255    limiter_factory: Box<dyn Fn() -> Arc<dyn RateLimiter> + Send + Sync>,
256}
257
258impl Default for RateLimitMiddlewareBuilder {
259    fn default() -> Self {
260        Self {
261            scope: Scope::Domain,
262            cache_ttl: Duration::from_secs(3600),
263            cache_capacity: 10_000,
264            limiter_factory: Box::new(|| Arc::new(AdaptiveLimiter::new(INITIAL_DELAY, true))),
265        }
266    }
267}
268
269impl RateLimitMiddlewareBuilder {
270    /// Sets the scope for the rate limiter.
271    pub fn scope(mut self, scope: Scope) -> Self {
272        self.scope = scope;
273        self
274    }
275
276    /// Configures the builder to use a `TokenBucketLimiter` with the specified requests per second.
277    pub fn use_token_bucket_limiter(mut self, requests_per_second: u32) -> Self {
278        self.limiter_factory = Box::new(move || Arc::new(TokenBucketLimiter::new(requests_per_second)));
279        self
280    }
281
282    /// Sets a specific rate limiter instance to be used.
283    pub fn limiter(mut self, limiter: impl RateLimiter + 'static) -> Self {
284        let arc = Arc::new(limiter);
285        self.limiter_factory = Box::new(move || arc.clone());
286        self
287    }
288
289    /// Sets a factory function for creating rate limiters.
290    pub fn limiter_factory(
291        mut self,
292        factory: impl Fn() -> Arc<dyn RateLimiter> + Send + Sync + 'static,
293    ) -> Self {
294        self.limiter_factory = Box::new(factory);
295        self
296    }
297
298    /// Builds the `RateLimitMiddleware`.
299    pub fn build(self) -> RateLimitMiddleware {
300        info!(
301            "Initializing RateLimitMiddleware with config: scope={:?}, cache_ttl={:?}, cache_capacity={}",
302            self.scope, self.cache_ttl, self.cache_capacity
303        );
304        RateLimitMiddleware {
305            scope: self.scope,
306            limiters: Cache::builder()
307                .time_to_idle(self.cache_ttl)
308                .max_capacity(self.cache_capacity)
309                .build(),
310            limiter_factory: self.limiter_factory.into(),
311        }
312    }
313}